1 //===- MLIRGen.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
10 #include "mlir/Dialect/PDL/IR/PDL.h"
11 #include "mlir/Dialect/PDL/IR/PDLOps.h"
12 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Verifier.h"
16 #include "mlir/Parser.h"
17 #include "mlir/Tools/PDLL/AST/Context.h"
18 #include "mlir/Tools/PDLL/AST/Nodes.h"
19 #include "mlir/Tools/PDLL/AST/Types.h"
20 #include "llvm/ADT/ScopedHashTable.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 using namespace mlir;
25 using namespace mlir::pdll;
26 
27 //===----------------------------------------------------------------------===//
28 // CodeGen
29 //===----------------------------------------------------------------------===//
30 
31 namespace {
32 class CodeGen {
33 public:
34   CodeGen(MLIRContext *mlirContext, const ast::Context &context,
35           const llvm::SourceMgr &sourceMgr)
36       : builder(mlirContext), sourceMgr(sourceMgr) {
37     // Make sure that the PDL dialect is loaded.
38     mlirContext->loadDialect<pdl::PDLDialect>();
39   }
40 
41   OwningOpRef<ModuleOp> generate(const ast::Module &module);
42 
43 private:
44   /// Generate an MLIR location from the given source location.
45   Location genLoc(llvm::SMLoc loc);
46   Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); }
47 
48   /// Generate an MLIR type from the given source type.
49   Type genType(ast::Type type);
50 
51   /// Generate MLIR for the given AST node.
52   void gen(const ast::Node *node);
53 
54   //===--------------------------------------------------------------------===//
55   // Statements
56   //===--------------------------------------------------------------------===//
57 
58   void genImpl(const ast::CompoundStmt *stmt);
59   void genImpl(const ast::EraseStmt *stmt);
60   void genImpl(const ast::LetStmt *stmt);
61   void genImpl(const ast::ReplaceStmt *stmt);
62   void genImpl(const ast::RewriteStmt *stmt);
63   void genImpl(const ast::ReturnStmt *stmt);
64 
65   //===--------------------------------------------------------------------===//
66   // Decls
67   //===--------------------------------------------------------------------===//
68 
69   void genImpl(const ast::UserConstraintDecl *decl);
70   void genImpl(const ast::UserRewriteDecl *decl);
71   void genImpl(const ast::PatternDecl *decl);
72 
73   /// Generate the set of MLIR values defined for the given variable decl, and
74   /// apply any attached constraints.
75   SmallVector<Value> genVar(const ast::VariableDecl *varDecl);
76 
77   /// Generate the value for a variable that does not have an initializer
78   /// expression, i.e. create the PDL value based on the type/constraints of the
79   /// variable.
80   Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc);
81 
82   /// Apply the constraints of the given variable to `values`, which correspond
83   /// to the MLIR values of the variable.
84   void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values);
85 
86   //===--------------------------------------------------------------------===//
87   // Expressions
88   //===--------------------------------------------------------------------===//
89 
90   Value genSingleExpr(const ast::Expr *expr);
91   SmallVector<Value> genExpr(const ast::Expr *expr);
92   Value genExprImpl(const ast::AttributeExpr *expr);
93   SmallVector<Value> genExprImpl(const ast::CallExpr *expr);
94   SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr);
95   Value genExprImpl(const ast::MemberAccessExpr *expr);
96   Value genExprImpl(const ast::OperationExpr *expr);
97   SmallVector<Value> genExprImpl(const ast::TupleExpr *expr);
98   Value genExprImpl(const ast::TypeExpr *expr);
99 
100   SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
101                                        Location loc, ValueRange inputs);
102   SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
103                                     Location loc, ValueRange inputs);
104   template <typename PDLOpT, typename T>
105   SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
106                                                 ValueRange inputs);
107 
108   //===--------------------------------------------------------------------===//
109   // Fields
110   //===--------------------------------------------------------------------===//
111 
112   /// The MLIR builder used for building the resultant IR.
113   OpBuilder builder;
114 
115   /// A map from variable declarations to the MLIR equivalent.
116   using VariableMapTy =
117       llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
118   VariableMapTy variables;
119 
120   /// The source manager of the PDLL ast.
121   const llvm::SourceMgr &sourceMgr;
122 };
123 } // namespace
124 
125 OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) {
126   OwningOpRef<ModuleOp> mlirModule =
127       builder.create<ModuleOp>(genLoc(module.getLoc()));
128   builder.setInsertionPointToStart(mlirModule->getBody());
129 
130   // Generate code for each of the decls within the module.
131   for (const ast::Decl *decl : module.getChildren())
132     gen(decl);
133 
134   return mlirModule;
135 }
136 
137 Location CodeGen::genLoc(llvm::SMLoc loc) {
138   unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
139 
140   // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can
141   //       use it here.
142   auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
143   unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
144   unsigned column =
145       (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
146   auto *buffer = sourceMgr.getMemoryBuffer(fileID);
147 
148   return FileLineColLoc::get(builder.getContext(),
149                              buffer->getBufferIdentifier(), lineNo, column);
150 }
151 
152 Type CodeGen::genType(ast::Type type) {
153   return TypeSwitch<ast::Type, Type>(type)
154       .Case([&](ast::AttributeType astType) -> Type {
155         return builder.getType<pdl::AttributeType>();
156       })
157       .Case([&](ast::OperationType astType) -> Type {
158         return builder.getType<pdl::OperationType>();
159       })
160       .Case([&](ast::TypeType astType) -> Type {
161         return builder.getType<pdl::TypeType>();
162       })
163       .Case([&](ast::ValueType astType) -> Type {
164         return builder.getType<pdl::ValueType>();
165       })
166       .Case([&](ast::RangeType astType) -> Type {
167         return pdl::RangeType::get(genType(astType.getElementType()));
168       });
169 }
170 
171 void CodeGen::gen(const ast::Node *node) {
172   TypeSwitch<const ast::Node *>(node)
173       .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt,
174             const ast::ReplaceStmt, const ast::RewriteStmt,
175             const ast::ReturnStmt, const ast::UserConstraintDecl,
176             const ast::UserRewriteDecl, const ast::PatternDecl>(
177           [&](auto derivedNode) { this->genImpl(derivedNode); })
178       .Case([&](const ast::Expr *expr) { genExpr(expr); });
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // CodeGen: Statements
183 //===----------------------------------------------------------------------===//
184 
185 void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
186   VariableMapTy::ScopeTy varScope(variables);
187   for (const ast::Stmt *childStmt : stmt->getChildren())
188     gen(childStmt);
189 }
190 
191 /// If the given builder is nested under a PDL PatternOp, build a rewrite
192 /// operation and update the builder to nest under it. This is necessary for
193 /// PDLL operation rewrite statements that are directly nested within a Pattern.
194 static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
195                                        Location loc) {
196   if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
197     pdl::RewriteOp rewrite = builder.create<pdl::RewriteOp>(
198         loc, rootExpr, /*name=*/StringAttr(),
199         /*externalArgs=*/ValueRange(), /*externalConstParams=*/ArrayAttr());
200     builder.createBlock(&rewrite.body());
201   }
202 }
203 
204 void CodeGen::genImpl(const ast::EraseStmt *stmt) {
205   OpBuilder::InsertionGuard insertGuard(builder);
206   Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
207   Location loc = genLoc(stmt->getLoc());
208 
209   // Make sure we are nested in a RewriteOp.
210   OpBuilder::InsertionGuard guard(builder);
211   checkAndNestUnderRewriteOp(builder, rootExpr, loc);
212   builder.create<pdl::EraseOp>(loc, rootExpr);
213 }
214 
215 void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); }
216 
217 void CodeGen::genImpl(const ast::ReplaceStmt *stmt) {
218   OpBuilder::InsertionGuard insertGuard(builder);
219   Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
220   Location loc = genLoc(stmt->getLoc());
221 
222   // Make sure we are nested in a RewriteOp.
223   OpBuilder::InsertionGuard guard(builder);
224   checkAndNestUnderRewriteOp(builder, rootExpr, loc);
225 
226   SmallVector<Value> replValues;
227   for (ast::Expr *replExpr : stmt->getReplExprs())
228     replValues.push_back(genSingleExpr(replExpr));
229 
230   // Check to see if the statement has a replacement operation, or a range of
231   // replacement values.
232   bool usesReplOperation =
233       replValues.size() == 1 &&
234       replValues.front().getType().isa<pdl::OperationType>();
235   builder.create<pdl::ReplaceOp>(
236       loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
237       usesReplOperation ? ValueRange() : ValueRange(replValues));
238 }
239 
240 void CodeGen::genImpl(const ast::RewriteStmt *stmt) {
241   OpBuilder::InsertionGuard insertGuard(builder);
242   Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
243 
244   // Make sure we are nested in a RewriteOp.
245   OpBuilder::InsertionGuard guard(builder);
246   checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc()));
247   gen(stmt->getRewriteBody());
248 }
249 
250 void CodeGen::genImpl(const ast::ReturnStmt *stmt) {
251   // ReturnStmt generation is handled by the respective constraint or rewrite
252   // parent node.
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // CodeGen: Decls
257 //===----------------------------------------------------------------------===//
258 
259 void CodeGen::genImpl(const ast::UserConstraintDecl *decl) {
260   // All PDLL constraints get inlined when called, and the main native
261   // constraint declarations doesn't require any MLIR to be generated, only uses
262   // of it do.
263 }
264 
265 void CodeGen::genImpl(const ast::UserRewriteDecl *decl) {
266   // All PDLL rewrites get inlined when called, and the main native
267   // rewrite declarations doesn't require any MLIR to be generated, only uses
268   // of it do.
269 }
270 
271 void CodeGen::genImpl(const ast::PatternDecl *decl) {
272   const ast::Name *name = decl->getName();
273 
274   // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it
275   // here.
276   pdl::PatternOp pattern = builder.create<pdl::PatternOp>(
277       genLoc(decl->getLoc()), decl->getBenefit(),
278       name ? Optional<StringRef>(name->getName()) : Optional<StringRef>());
279 
280   OpBuilder::InsertionGuard savedInsertPoint(builder);
281   builder.setInsertionPointToStart(pattern.getBody());
282   gen(decl->getBody());
283 }
284 
285 SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) {
286   auto it = variables.begin(varDecl);
287   if (it != variables.end())
288     return *it;
289 
290   // If the variable has an initial value, use that as the base value.
291   // Otherwise, generate a value using the constraint list.
292   SmallVector<Value> values;
293   if (const ast::Expr *initExpr = varDecl->getInitExpr())
294     values = genExpr(initExpr);
295   else
296     values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
297 
298   // Apply the constraints of the values of the variable.
299   applyVarConstraints(varDecl, values);
300 
301   variables.insert(varDecl, values);
302   return values;
303 }
304 
305 Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
306                                     Location loc) {
307   // A functor used to generate expressions nested
308   auto getTypeConstraint = [&]() -> Value {
309     for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) {
310       Value typeValue =
311           TypeSwitch<const ast::Node *, Value>(constraint.constraint)
312               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
313                     ast::ValueRangeConstraintDecl>([&, this](auto *cst) -> Value {
314                 if (auto *typeConstraintExpr = cst->getTypeExpr())
315                   return genSingleExpr(typeConstraintExpr);
316                 return Value();
317               })
318               .Default(Value());
319       if (typeValue)
320         return typeValue;
321     }
322     return Value();
323   };
324 
325   // Generate a value based on the type of the variable.
326   ast::Type type = varDecl->getType();
327   Type mlirType = genType(type);
328   if (type.isa<ast::ValueType>())
329     return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
330   if (type.isa<ast::TypeType>())
331     return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
332   if (type.isa<ast::AttributeType>())
333     return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
334   if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) {
335     Value operands = builder.create<pdl::OperandsOp>(
336         loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
337         /*type=*/Value());
338     Value results = builder.create<pdl::TypesOp>(
339         loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
340         /*types=*/ArrayAttr());
341     return builder.create<pdl::OperationOp>(loc, opType.getName(), operands,
342                                             llvm::None, ValueRange(), results);
343   }
344 
345   if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) {
346     ast::Type eleTy = rangeTy.getElementType();
347     if (eleTy.isa<ast::ValueType>())
348       return builder.create<pdl::OperandsOp>(loc, mlirType,
349                                              getTypeConstraint());
350     if (eleTy.isa<ast::TypeType>())
351       return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
352   }
353 
354   llvm_unreachable("invalid non-initialized variable type");
355 }
356 
357 void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,
358                                   ValueRange values) {
359   // Generate calls to any user constraints that were attached via the
360   // constraint list.
361   for (const ast::ConstraintRef &ref : varDecl->getConstraints())
362     if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
363       genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // CodeGen: Expressions
368 //===----------------------------------------------------------------------===//
369 
370 Value CodeGen::genSingleExpr(const ast::Expr *expr) {
371   return TypeSwitch<const ast::Expr *, Value>(expr)
372       .Case<const ast::AttributeExpr, const ast::MemberAccessExpr,
373             const ast::OperationExpr, const ast::TypeExpr>(
374           [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
375       .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
376           [&](auto derivedNode) {
377             SmallVector<Value> results = this->genExprImpl(derivedNode);
378             assert(results.size() == 1 && "expected single expression result");
379             return results[0];
380           });
381 }
382 
383 SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) {
384   return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr)
385       .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
386           [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
387       .Default([&](const ast::Expr *expr) -> SmallVector<Value> {
388         return {genSingleExpr(expr)};
389       });
390 }
391 
392 Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) {
393   Attribute attr = parseAttribute(expr->getValue(), builder.getContext());
394   assert(attr && "invalid MLIR attribute data");
395   return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
396 }
397 
398 SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
399   Location loc = genLoc(expr->getLoc());
400   SmallVector<Value> arguments;
401   for (const ast::Expr *arg : expr->getArguments())
402     arguments.push_back(genSingleExpr(arg));
403 
404   // Resolve the callable expression of this call.
405   auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
406   assert(callableExpr && "unhandled CallExpr callable");
407 
408   // Generate the PDL based on the type of callable.
409   const ast::Decl *callable = callableExpr->getDecl();
410   if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
411     return genConstraintCall(decl, loc, arguments);
412   if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
413     return genRewriteCall(decl, loc, arguments);
414   llvm_unreachable("unhandled CallExpr callable");
415 }
416 
417 SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) {
418   if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl()))
419     return genVar(varDecl);
420   llvm_unreachable("unknown decl reference expression");
421 }
422 
423 Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
424   Location loc = genLoc(expr->getLoc());
425   StringRef name = expr->getMemberName();
426   SmallVector<Value> parentExprs = genExpr(expr->getParentExpr());
427   ast::Type parentType = expr->getParentExpr()->getType();
428 
429   // Handle operation based member access.
430   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
431     if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
432       Type mlirType = genType(expr->getType());
433       if (mlirType.isa<pdl::ValueType>())
434         return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
435                                              builder.getI32IntegerAttr(0));
436       return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
437     }
438     llvm_unreachable("unhandled operation member access expression");
439   }
440 
441   // Handle tuple based member access.
442   if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
443     auto elementNames = tupleType.getElementNames();
444 
445     // The index is either a numeric index, or a name.
446     unsigned index = 0;
447     if (llvm::isDigit(name[0]))
448       name.getAsInteger(/*Radix=*/10, index);
449     else
450       index = llvm::find(elementNames, name) - elementNames.begin();
451 
452     assert(index < parentExprs.size() && "invalid result index");
453     return parentExprs[index];
454   }
455 
456   llvm_unreachable("unhandled member access expression");
457 }
458 
459 Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
460   Location loc = genLoc(expr->getLoc());
461   Optional<StringRef> opName = expr->getName();
462 
463   // Operands.
464   SmallVector<Value> operands;
465   for (const ast::Expr *operand : expr->getOperands())
466     operands.push_back(genSingleExpr(operand));
467 
468   // Attributes.
469   SmallVector<StringRef> attrNames;
470   SmallVector<Value> attrValues;
471   for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) {
472     attrNames.push_back(attr->getName().getName());
473     attrValues.push_back(genSingleExpr(attr->getValue()));
474   }
475 
476   // Results.
477   SmallVector<Value> results;
478   for (const ast::Expr *result : expr->getResultTypes())
479     results.push_back(genSingleExpr(result));
480 
481   return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
482                                           attrValues, results);
483 }
484 
485 SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) {
486   SmallVector<Value> elements;
487   for (const ast::Expr *element : expr->getElements())
488     elements.push_back(genSingleExpr(element));
489   return elements;
490 }
491 
492 Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
493   Type type = parseType(expr->getValue(), builder.getContext());
494   assert(type && "invalid MLIR type data");
495   return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
496                                      builder.getType<pdl::TypeType>(),
497                                      TypeAttr::get(type));
498 }
499 
500 SmallVector<Value>
501 CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
502                            ValueRange inputs) {
503   // Apply any constraints defined on the arguments to the input values.
504   for (auto it : llvm::zip(decl->getInputs(), inputs))
505     applyVarConstraints(std::get<0>(it), std::get<1>(it));
506 
507   // Generate the constraint call.
508   SmallVector<Value> results =
509       genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc,
510                                                                inputs);
511 
512   // Apply any constraints defined on the results of the constraint.
513   for (auto it : llvm::zip(decl->getResults(), results))
514     applyVarConstraints(std::get<0>(it), std::get<1>(it));
515   return results;
516 }
517 
518 SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
519                                            Location loc, ValueRange inputs) {
520   return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
521                                                                inputs);
522 }
523 
524 template <typename PDLOpT, typename T>
525 SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
526                                                        Location loc,
527                                                        ValueRange inputs) {
528   const ast::CompoundStmt *cstBody = decl->getBody();
529 
530   // If the decl doesn't have a statement body, it is a native decl.
531   if (!cstBody) {
532     ast::Type declResultType = decl->getResultType();
533     SmallVector<Type> resultTypes;
534     if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) {
535       for (ast::Type type : tupleType.getElementTypes())
536         resultTypes.push_back(genType(type));
537     } else {
538       resultTypes.push_back(genType(declResultType));
539     }
540 
541     // FIXME: We currently do not have a modeling for the "constant params"
542     // support PDL provides. We should either figure out a modeling for this, or
543     // refactor the support within PDL to be something a bit more reasonable for
544     // what we need as a frontend.
545     Operation *pdlOp = builder.create<PDLOpT>(loc, resultTypes,
546                                               decl->getName().getName(), inputs,
547                                               /*params=*/ArrayAttr());
548     return pdlOp->getResults();
549   }
550 
551   // Otherwise, this is a PDLL decl.
552   VariableMapTy::ScopeTy varScope(variables);
553 
554   // Map the inputs of the call to the decl arguments.
555   // Note: This is only valid because we do not support recursion, meaning
556   // we don't need to worry about conflicting mappings here.
557   for (auto it : llvm::zip(inputs, decl->getInputs()))
558     variables.insert(std::get<1>(it), {std::get<0>(it)});
559 
560   // Visit the body of the call as normal.
561   gen(cstBody);
562 
563   // If the decl has no results, there is nothing to do.
564   if (cstBody->getChildren().empty())
565     return SmallVector<Value>();
566   auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back());
567   if (!returnStmt)
568     return SmallVector<Value>();
569 
570   // Otherwise, grab the results from the return statement.
571   return genExpr(returnStmt->getResultExpr());
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // MLIRGen
576 //===----------------------------------------------------------------------===//
577 
578 OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR(
579     MLIRContext *mlirContext, const ast::Context &context,
580     const llvm::SourceMgr &sourceMgr, const ast::Module &module) {
581   CodeGen codegen(mlirContext, context, sourceMgr);
582   OwningOpRef<ModuleOp> mlirModule = codegen.generate(module);
583   if (failed(verify(*mlirModule)))
584     return nullptr;
585   return mlirModule;
586 }
587