1 //===- MLIRGen.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
10 #include "mlir/Dialect/PDL/IR/PDL.h"
11 #include "mlir/Dialect/PDL/IR/PDLOps.h"
12 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Verifier.h"
16 #include "mlir/Parser/Parser.h"
17 #include "mlir/Tools/PDLL/AST/Context.h"
18 #include "mlir/Tools/PDLL/AST/Nodes.h"
19 #include "mlir/Tools/PDLL/AST/Types.h"
20 #include "mlir/Tools/PDLL/ODS/Context.h"
21 #include "mlir/Tools/PDLL/ODS/Operation.h"
22 #include "llvm/ADT/ScopedHashTable.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::pdll;
28 
29 //===----------------------------------------------------------------------===//
30 // CodeGen
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 class CodeGen {
35 public:
36   CodeGen(MLIRContext *mlirContext, const ast::Context &context,
37           const llvm::SourceMgr &sourceMgr)
38       : builder(mlirContext), odsContext(context.getODSContext()),
39         sourceMgr(sourceMgr) {
40     // Make sure that the PDL dialect is loaded.
41     mlirContext->loadDialect<pdl::PDLDialect>();
42   }
43 
44   OwningOpRef<ModuleOp> generate(const ast::Module &module);
45 
46 private:
47   /// Generate an MLIR location from the given source location.
48   Location genLoc(llvm::SMLoc loc);
49   Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); }
50 
51   /// Generate an MLIR type from the given source type.
52   Type genType(ast::Type type);
53 
54   /// Generate MLIR for the given AST node.
55   void gen(const ast::Node *node);
56 
57   //===--------------------------------------------------------------------===//
58   // Statements
59   //===--------------------------------------------------------------------===//
60 
61   void genImpl(const ast::CompoundStmt *stmt);
62   void genImpl(const ast::EraseStmt *stmt);
63   void genImpl(const ast::LetStmt *stmt);
64   void genImpl(const ast::ReplaceStmt *stmt);
65   void genImpl(const ast::RewriteStmt *stmt);
66   void genImpl(const ast::ReturnStmt *stmt);
67 
68   //===--------------------------------------------------------------------===//
69   // Decls
70   //===--------------------------------------------------------------------===//
71 
72   void genImpl(const ast::UserConstraintDecl *decl);
73   void genImpl(const ast::UserRewriteDecl *decl);
74   void genImpl(const ast::PatternDecl *decl);
75 
76   /// Generate the set of MLIR values defined for the given variable decl, and
77   /// apply any attached constraints.
78   SmallVector<Value> genVar(const ast::VariableDecl *varDecl);
79 
80   /// Generate the value for a variable that does not have an initializer
81   /// expression, i.e. create the PDL value based on the type/constraints of the
82   /// variable.
83   Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc);
84 
85   /// Apply the constraints of the given variable to `values`, which correspond
86   /// to the MLIR values of the variable.
87   void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values);
88 
89   //===--------------------------------------------------------------------===//
90   // Expressions
91   //===--------------------------------------------------------------------===//
92 
93   Value genSingleExpr(const ast::Expr *expr);
94   SmallVector<Value> genExpr(const ast::Expr *expr);
95   Value genExprImpl(const ast::AttributeExpr *expr);
96   SmallVector<Value> genExprImpl(const ast::CallExpr *expr);
97   SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr);
98   Value genExprImpl(const ast::MemberAccessExpr *expr);
99   Value genExprImpl(const ast::OperationExpr *expr);
100   SmallVector<Value> genExprImpl(const ast::TupleExpr *expr);
101   Value genExprImpl(const ast::TypeExpr *expr);
102 
103   SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
104                                        Location loc, ValueRange inputs);
105   SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
106                                     Location loc, ValueRange inputs);
107   template <typename PDLOpT, typename T>
108   SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
109                                                 ValueRange inputs);
110 
111   //===--------------------------------------------------------------------===//
112   // Fields
113   //===--------------------------------------------------------------------===//
114 
115   /// The MLIR builder used for building the resultant IR.
116   OpBuilder builder;
117 
118   /// A map from variable declarations to the MLIR equivalent.
119   using VariableMapTy =
120       llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
121   VariableMapTy variables;
122 
123   /// A reference to the ODS context.
124   const ods::Context &odsContext;
125 
126   /// The source manager of the PDLL ast.
127   const llvm::SourceMgr &sourceMgr;
128 };
129 } // namespace
130 
131 OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) {
132   OwningOpRef<ModuleOp> mlirModule =
133       builder.create<ModuleOp>(genLoc(module.getLoc()));
134   builder.setInsertionPointToStart(mlirModule->getBody());
135 
136   // Generate code for each of the decls within the module.
137   for (const ast::Decl *decl : module.getChildren())
138     gen(decl);
139 
140   return mlirModule;
141 }
142 
143 Location CodeGen::genLoc(llvm::SMLoc loc) {
144   unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
145 
146   // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can
147   //       use it here.
148   auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
149   unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
150   unsigned column =
151       (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
152   auto *buffer = sourceMgr.getMemoryBuffer(fileID);
153 
154   return FileLineColLoc::get(builder.getContext(),
155                              buffer->getBufferIdentifier(), lineNo, column);
156 }
157 
158 Type CodeGen::genType(ast::Type type) {
159   return TypeSwitch<ast::Type, Type>(type)
160       .Case([&](ast::AttributeType astType) -> Type {
161         return builder.getType<pdl::AttributeType>();
162       })
163       .Case([&](ast::OperationType astType) -> Type {
164         return builder.getType<pdl::OperationType>();
165       })
166       .Case([&](ast::TypeType astType) -> Type {
167         return builder.getType<pdl::TypeType>();
168       })
169       .Case([&](ast::ValueType astType) -> Type {
170         return builder.getType<pdl::ValueType>();
171       })
172       .Case([&](ast::RangeType astType) -> Type {
173         return pdl::RangeType::get(genType(astType.getElementType()));
174       });
175 }
176 
177 void CodeGen::gen(const ast::Node *node) {
178   TypeSwitch<const ast::Node *>(node)
179       .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt,
180             const ast::ReplaceStmt, const ast::RewriteStmt,
181             const ast::ReturnStmt, const ast::UserConstraintDecl,
182             const ast::UserRewriteDecl, const ast::PatternDecl>(
183           [&](auto derivedNode) { this->genImpl(derivedNode); })
184       .Case([&](const ast::Expr *expr) { genExpr(expr); });
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // CodeGen: Statements
189 //===----------------------------------------------------------------------===//
190 
191 void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
192   VariableMapTy::ScopeTy varScope(variables);
193   for (const ast::Stmt *childStmt : stmt->getChildren())
194     gen(childStmt);
195 }
196 
197 /// If the given builder is nested under a PDL PatternOp, build a rewrite
198 /// operation and update the builder to nest under it. This is necessary for
199 /// PDLL operation rewrite statements that are directly nested within a Pattern.
200 static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
201                                        Location loc) {
202   if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
203     pdl::RewriteOp rewrite = builder.create<pdl::RewriteOp>(
204         loc, rootExpr, /*name=*/StringAttr(),
205         /*externalArgs=*/ValueRange(), /*externalConstParams=*/ArrayAttr());
206     builder.createBlock(&rewrite.body());
207   }
208 }
209 
210 void CodeGen::genImpl(const ast::EraseStmt *stmt) {
211   OpBuilder::InsertionGuard insertGuard(builder);
212   Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
213   Location loc = genLoc(stmt->getLoc());
214 
215   // Make sure we are nested in a RewriteOp.
216   OpBuilder::InsertionGuard guard(builder);
217   checkAndNestUnderRewriteOp(builder, rootExpr, loc);
218   builder.create<pdl::EraseOp>(loc, rootExpr);
219 }
220 
221 void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); }
222 
223 void CodeGen::genImpl(const ast::ReplaceStmt *stmt) {
224   OpBuilder::InsertionGuard insertGuard(builder);
225   Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
226   Location loc = genLoc(stmt->getLoc());
227 
228   // Make sure we are nested in a RewriteOp.
229   OpBuilder::InsertionGuard guard(builder);
230   checkAndNestUnderRewriteOp(builder, rootExpr, loc);
231 
232   SmallVector<Value> replValues;
233   for (ast::Expr *replExpr : stmt->getReplExprs())
234     replValues.push_back(genSingleExpr(replExpr));
235 
236   // Check to see if the statement has a replacement operation, or a range of
237   // replacement values.
238   bool usesReplOperation =
239       replValues.size() == 1 &&
240       replValues.front().getType().isa<pdl::OperationType>();
241   builder.create<pdl::ReplaceOp>(
242       loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
243       usesReplOperation ? ValueRange() : ValueRange(replValues));
244 }
245 
246 void CodeGen::genImpl(const ast::RewriteStmt *stmt) {
247   OpBuilder::InsertionGuard insertGuard(builder);
248   Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
249 
250   // Make sure we are nested in a RewriteOp.
251   OpBuilder::InsertionGuard guard(builder);
252   checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc()));
253   gen(stmt->getRewriteBody());
254 }
255 
256 void CodeGen::genImpl(const ast::ReturnStmt *stmt) {
257   // ReturnStmt generation is handled by the respective constraint or rewrite
258   // parent node.
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // CodeGen: Decls
263 //===----------------------------------------------------------------------===//
264 
265 void CodeGen::genImpl(const ast::UserConstraintDecl *decl) {
266   // All PDLL constraints get inlined when called, and the main native
267   // constraint declarations doesn't require any MLIR to be generated, only uses
268   // of it do.
269 }
270 
271 void CodeGen::genImpl(const ast::UserRewriteDecl *decl) {
272   // All PDLL rewrites get inlined when called, and the main native
273   // rewrite declarations doesn't require any MLIR to be generated, only uses
274   // of it do.
275 }
276 
277 void CodeGen::genImpl(const ast::PatternDecl *decl) {
278   const ast::Name *name = decl->getName();
279 
280   // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it
281   // here.
282   pdl::PatternOp pattern = builder.create<pdl::PatternOp>(
283       genLoc(decl->getLoc()), decl->getBenefit(),
284       name ? Optional<StringRef>(name->getName()) : Optional<StringRef>());
285 
286   OpBuilder::InsertionGuard savedInsertPoint(builder);
287   builder.setInsertionPointToStart(pattern.getBody());
288   gen(decl->getBody());
289 }
290 
291 SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) {
292   auto it = variables.begin(varDecl);
293   if (it != variables.end())
294     return *it;
295 
296   // If the variable has an initial value, use that as the base value.
297   // Otherwise, generate a value using the constraint list.
298   SmallVector<Value> values;
299   if (const ast::Expr *initExpr = varDecl->getInitExpr())
300     values = genExpr(initExpr);
301   else
302     values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
303 
304   // Apply the constraints of the values of the variable.
305   applyVarConstraints(varDecl, values);
306 
307   variables.insert(varDecl, values);
308   return values;
309 }
310 
311 Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
312                                     Location loc) {
313   // A functor used to generate expressions nested
314   auto getTypeConstraint = [&]() -> Value {
315     for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) {
316       Value typeValue =
317           TypeSwitch<const ast::Node *, Value>(constraint.constraint)
318               .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
319                     ast::ValueRangeConstraintDecl>([&, this](auto *cst) -> Value {
320                 if (auto *typeConstraintExpr = cst->getTypeExpr())
321                   return this->genSingleExpr(typeConstraintExpr);
322                 return Value();
323               })
324               .Default(Value());
325       if (typeValue)
326         return typeValue;
327     }
328     return Value();
329   };
330 
331   // Generate a value based on the type of the variable.
332   ast::Type type = varDecl->getType();
333   Type mlirType = genType(type);
334   if (type.isa<ast::ValueType>())
335     return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
336   if (type.isa<ast::TypeType>())
337     return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
338   if (type.isa<ast::AttributeType>())
339     return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
340   if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) {
341     Value operands = builder.create<pdl::OperandsOp>(
342         loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
343         /*type=*/Value());
344     Value results = builder.create<pdl::TypesOp>(
345         loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
346         /*types=*/ArrayAttr());
347     return builder.create<pdl::OperationOp>(loc, opType.getName(), operands,
348                                             llvm::None, ValueRange(), results);
349   }
350 
351   if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) {
352     ast::Type eleTy = rangeTy.getElementType();
353     if (eleTy.isa<ast::ValueType>())
354       return builder.create<pdl::OperandsOp>(loc, mlirType,
355                                              getTypeConstraint());
356     if (eleTy.isa<ast::TypeType>())
357       return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
358   }
359 
360   llvm_unreachable("invalid non-initialized variable type");
361 }
362 
363 void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,
364                                   ValueRange values) {
365   // Generate calls to any user constraints that were attached via the
366   // constraint list.
367   for (const ast::ConstraintRef &ref : varDecl->getConstraints())
368     if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
369       genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
370 }
371 
372 //===----------------------------------------------------------------------===//
373 // CodeGen: Expressions
374 //===----------------------------------------------------------------------===//
375 
376 Value CodeGen::genSingleExpr(const ast::Expr *expr) {
377   return TypeSwitch<const ast::Expr *, Value>(expr)
378       .Case<const ast::AttributeExpr, const ast::MemberAccessExpr,
379             const ast::OperationExpr, const ast::TypeExpr>(
380           [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
381       .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
382           [&](auto derivedNode) {
383             SmallVector<Value> results = this->genExprImpl(derivedNode);
384             assert(results.size() == 1 && "expected single expression result");
385             return results[0];
386           });
387 }
388 
389 SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) {
390   return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr)
391       .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
392           [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
393       .Default([&](const ast::Expr *expr) -> SmallVector<Value> {
394         return {genSingleExpr(expr)};
395       });
396 }
397 
398 Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) {
399   Attribute attr = parseAttribute(expr->getValue(), builder.getContext());
400   assert(attr && "invalid MLIR attribute data");
401   return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
402 }
403 
404 SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
405   Location loc = genLoc(expr->getLoc());
406   SmallVector<Value> arguments;
407   for (const ast::Expr *arg : expr->getArguments())
408     arguments.push_back(genSingleExpr(arg));
409 
410   // Resolve the callable expression of this call.
411   auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
412   assert(callableExpr && "unhandled CallExpr callable");
413 
414   // Generate the PDL based on the type of callable.
415   const ast::Decl *callable = callableExpr->getDecl();
416   if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
417     return genConstraintCall(decl, loc, arguments);
418   if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
419     return genRewriteCall(decl, loc, arguments);
420   llvm_unreachable("unhandled CallExpr callable");
421 }
422 
423 SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) {
424   if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl()))
425     return genVar(varDecl);
426   llvm_unreachable("unknown decl reference expression");
427 }
428 
429 Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
430   Location loc = genLoc(expr->getLoc());
431   StringRef name = expr->getMemberName();
432   SmallVector<Value> parentExprs = genExpr(expr->getParentExpr());
433   ast::Type parentType = expr->getParentExpr()->getType();
434 
435   // Handle operation based member access.
436   if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
437     if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
438       Type mlirType = genType(expr->getType());
439       if (mlirType.isa<pdl::ValueType>())
440         return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
441                                              builder.getI32IntegerAttr(0));
442       return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
443     }
444 
445     assert(opType.getName() && "expected valid operation name");
446     const ods::Operation *odsOp = odsContext.lookupOperation(*opType.getName());
447     assert(odsOp && "expected valid ODS operation information");
448 
449     // Find the result with the member name or by index.
450     ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
451     unsigned resultIndex = results.size();
452     if (llvm::isDigit(name[0])) {
453       name.getAsInteger(/*Radix=*/10, resultIndex);
454     } else {
455       auto findFn = [&](const ods::OperandOrResult &result) {
456         return result.getName() == name;
457       };
458       resultIndex = llvm::find_if(results, findFn) - results.begin();
459     }
460     assert(resultIndex < results.size() && "invalid result index");
461 
462     // Generate the result access.
463     IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
464     return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
465                                           parentExprs[0], index);
466   }
467 
468   // Handle tuple based member access.
469   if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
470     auto elementNames = tupleType.getElementNames();
471 
472     // The index is either a numeric index, or a name.
473     unsigned index = 0;
474     if (llvm::isDigit(name[0]))
475       name.getAsInteger(/*Radix=*/10, index);
476     else
477       index = llvm::find(elementNames, name) - elementNames.begin();
478 
479     assert(index < parentExprs.size() && "invalid result index");
480     return parentExprs[index];
481   }
482 
483   llvm_unreachable("unhandled member access expression");
484 }
485 
486 Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
487   Location loc = genLoc(expr->getLoc());
488   Optional<StringRef> opName = expr->getName();
489 
490   // Operands.
491   SmallVector<Value> operands;
492   for (const ast::Expr *operand : expr->getOperands())
493     operands.push_back(genSingleExpr(operand));
494 
495   // Attributes.
496   SmallVector<StringRef> attrNames;
497   SmallVector<Value> attrValues;
498   for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) {
499     attrNames.push_back(attr->getName().getName());
500     attrValues.push_back(genSingleExpr(attr->getValue()));
501   }
502 
503   // Results.
504   SmallVector<Value> results;
505   for (const ast::Expr *result : expr->getResultTypes())
506     results.push_back(genSingleExpr(result));
507 
508   return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
509                                           attrValues, results);
510 }
511 
512 SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) {
513   SmallVector<Value> elements;
514   for (const ast::Expr *element : expr->getElements())
515     elements.push_back(genSingleExpr(element));
516   return elements;
517 }
518 
519 Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
520   Type type = parseType(expr->getValue(), builder.getContext());
521   assert(type && "invalid MLIR type data");
522   return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
523                                      builder.getType<pdl::TypeType>(),
524                                      TypeAttr::get(type));
525 }
526 
527 SmallVector<Value>
528 CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
529                            ValueRange inputs) {
530   // Apply any constraints defined on the arguments to the input values.
531   for (auto it : llvm::zip(decl->getInputs(), inputs))
532     applyVarConstraints(std::get<0>(it), std::get<1>(it));
533 
534   // Generate the constraint call.
535   SmallVector<Value> results =
536       genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc,
537                                                                inputs);
538 
539   // Apply any constraints defined on the results of the constraint.
540   for (auto it : llvm::zip(decl->getResults(), results))
541     applyVarConstraints(std::get<0>(it), std::get<1>(it));
542   return results;
543 }
544 
545 SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
546                                            Location loc, ValueRange inputs) {
547   return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
548                                                                inputs);
549 }
550 
551 template <typename PDLOpT, typename T>
552 SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
553                                                        Location loc,
554                                                        ValueRange inputs) {
555   const ast::CompoundStmt *cstBody = decl->getBody();
556 
557   // If the decl doesn't have a statement body, it is a native decl.
558   if (!cstBody) {
559     ast::Type declResultType = decl->getResultType();
560     SmallVector<Type> resultTypes;
561     if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) {
562       for (ast::Type type : tupleType.getElementTypes())
563         resultTypes.push_back(genType(type));
564     } else {
565       resultTypes.push_back(genType(declResultType));
566     }
567 
568     // FIXME: We currently do not have a modeling for the "constant params"
569     // support PDL provides. We should either figure out a modeling for this, or
570     // refactor the support within PDL to be something a bit more reasonable for
571     // what we need as a frontend.
572     Operation *pdlOp = builder.create<PDLOpT>(loc, resultTypes,
573                                               decl->getName().getName(), inputs,
574                                               /*params=*/ArrayAttr());
575     return pdlOp->getResults();
576   }
577 
578   // Otherwise, this is a PDLL decl.
579   VariableMapTy::ScopeTy varScope(variables);
580 
581   // Map the inputs of the call to the decl arguments.
582   // Note: This is only valid because we do not support recursion, meaning
583   // we don't need to worry about conflicting mappings here.
584   for (auto it : llvm::zip(inputs, decl->getInputs()))
585     variables.insert(std::get<1>(it), {std::get<0>(it)});
586 
587   // Visit the body of the call as normal.
588   gen(cstBody);
589 
590   // If the decl has no results, there is nothing to do.
591   if (cstBody->getChildren().empty())
592     return SmallVector<Value>();
593   auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back());
594   if (!returnStmt)
595     return SmallVector<Value>();
596 
597   // Otherwise, grab the results from the return statement.
598   return genExpr(returnStmt->getResultExpr());
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // MLIRGen
603 //===----------------------------------------------------------------------===//
604 
605 OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR(
606     MLIRContext *mlirContext, const ast::Context &context,
607     const llvm::SourceMgr &sourceMgr, const ast::Module &module) {
608   CodeGen codegen(mlirContext, context, sourceMgr);
609   OwningOpRef<ModuleOp> mlirModule = codegen.generate(module);
610   if (failed(verify(*mlirModule)))
611     return nullptr;
612   return mlirModule;
613 }
614