1 //===- MLIRGen.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
10 #include "mlir/AsmParser/AsmParser.h"
11 #include "mlir/Dialect/PDL/IR/PDL.h"
12 #include "mlir/Dialect/PDL/IR/PDLOps.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Verifier.h"
17 #include "mlir/Tools/PDLL/AST/Context.h"
18 #include "mlir/Tools/PDLL/AST/Nodes.h"
19 #include "mlir/Tools/PDLL/AST/Types.h"
20 #include "mlir/Tools/PDLL/ODS/Context.h"
21 #include "mlir/Tools/PDLL/ODS/Operation.h"
22 #include "llvm/ADT/ScopedHashTable.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25
26 using namespace mlir;
27 using namespace mlir::pdll;
28
29 //===----------------------------------------------------------------------===//
30 // CodeGen
31 //===----------------------------------------------------------------------===//
32
33 namespace {
34 class CodeGen {
35 public:
CodeGen(MLIRContext * mlirContext,const ast::Context & context,const llvm::SourceMgr & sourceMgr)36 CodeGen(MLIRContext *mlirContext, const ast::Context &context,
37 const llvm::SourceMgr &sourceMgr)
38 : builder(mlirContext), odsContext(context.getODSContext()),
39 sourceMgr(sourceMgr) {
40 // Make sure that the PDL dialect is loaded.
41 mlirContext->loadDialect<pdl::PDLDialect>();
42 }
43
44 OwningOpRef<ModuleOp> generate(const ast::Module &module);
45
46 private:
47 /// Generate an MLIR location from the given source location.
48 Location genLoc(llvm::SMLoc loc);
genLoc(llvm::SMRange loc)49 Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); }
50
51 /// Generate an MLIR type from the given source type.
52 Type genType(ast::Type type);
53
54 /// Generate MLIR for the given AST node.
55 void gen(const ast::Node *node);
56
57 //===--------------------------------------------------------------------===//
58 // Statements
59 //===--------------------------------------------------------------------===//
60
61 void genImpl(const ast::CompoundStmt *stmt);
62 void genImpl(const ast::EraseStmt *stmt);
63 void genImpl(const ast::LetStmt *stmt);
64 void genImpl(const ast::ReplaceStmt *stmt);
65 void genImpl(const ast::RewriteStmt *stmt);
66 void genImpl(const ast::ReturnStmt *stmt);
67
68 //===--------------------------------------------------------------------===//
69 // Decls
70 //===--------------------------------------------------------------------===//
71
72 void genImpl(const ast::UserConstraintDecl *decl);
73 void genImpl(const ast::UserRewriteDecl *decl);
74 void genImpl(const ast::PatternDecl *decl);
75
76 /// Generate the set of MLIR values defined for the given variable decl, and
77 /// apply any attached constraints.
78 SmallVector<Value> genVar(const ast::VariableDecl *varDecl);
79
80 /// Generate the value for a variable that does not have an initializer
81 /// expression, i.e. create the PDL value based on the type/constraints of the
82 /// variable.
83 Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc);
84
85 /// Apply the constraints of the given variable to `values`, which correspond
86 /// to the MLIR values of the variable.
87 void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values);
88
89 //===--------------------------------------------------------------------===//
90 // Expressions
91 //===--------------------------------------------------------------------===//
92
93 Value genSingleExpr(const ast::Expr *expr);
94 SmallVector<Value> genExpr(const ast::Expr *expr);
95 Value genExprImpl(const ast::AttributeExpr *expr);
96 SmallVector<Value> genExprImpl(const ast::CallExpr *expr);
97 SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr);
98 Value genExprImpl(const ast::MemberAccessExpr *expr);
99 Value genExprImpl(const ast::OperationExpr *expr);
100 SmallVector<Value> genExprImpl(const ast::TupleExpr *expr);
101 Value genExprImpl(const ast::TypeExpr *expr);
102
103 SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
104 Location loc, ValueRange inputs);
105 SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
106 Location loc, ValueRange inputs);
107 template <typename PDLOpT, typename T>
108 SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
109 ValueRange inputs);
110
111 //===--------------------------------------------------------------------===//
112 // Fields
113 //===--------------------------------------------------------------------===//
114
115 /// The MLIR builder used for building the resultant IR.
116 OpBuilder builder;
117
118 /// A map from variable declarations to the MLIR equivalent.
119 using VariableMapTy =
120 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
121 VariableMapTy variables;
122
123 /// A reference to the ODS context.
124 const ods::Context &odsContext;
125
126 /// The source manager of the PDLL ast.
127 const llvm::SourceMgr &sourceMgr;
128 };
129 } // namespace
130
generate(const ast::Module & module)131 OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) {
132 OwningOpRef<ModuleOp> mlirModule =
133 builder.create<ModuleOp>(genLoc(module.getLoc()));
134 builder.setInsertionPointToStart(mlirModule->getBody());
135
136 // Generate code for each of the decls within the module.
137 for (const ast::Decl *decl : module.getChildren())
138 gen(decl);
139
140 return mlirModule;
141 }
142
genLoc(llvm::SMLoc loc)143 Location CodeGen::genLoc(llvm::SMLoc loc) {
144 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
145
146 // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can
147 // use it here.
148 auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
149 unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
150 unsigned column =
151 (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
152 auto *buffer = sourceMgr.getMemoryBuffer(fileID);
153
154 return FileLineColLoc::get(builder.getContext(),
155 buffer->getBufferIdentifier(), lineNo, column);
156 }
157
genType(ast::Type type)158 Type CodeGen::genType(ast::Type type) {
159 return TypeSwitch<ast::Type, Type>(type)
160 .Case([&](ast::AttributeType astType) -> Type {
161 return builder.getType<pdl::AttributeType>();
162 })
163 .Case([&](ast::OperationType astType) -> Type {
164 return builder.getType<pdl::OperationType>();
165 })
166 .Case([&](ast::TypeType astType) -> Type {
167 return builder.getType<pdl::TypeType>();
168 })
169 .Case([&](ast::ValueType astType) -> Type {
170 return builder.getType<pdl::ValueType>();
171 })
172 .Case([&](ast::RangeType astType) -> Type {
173 return pdl::RangeType::get(genType(astType.getElementType()));
174 });
175 }
176
gen(const ast::Node * node)177 void CodeGen::gen(const ast::Node *node) {
178 TypeSwitch<const ast::Node *>(node)
179 .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt,
180 const ast::ReplaceStmt, const ast::RewriteStmt,
181 const ast::ReturnStmt, const ast::UserConstraintDecl,
182 const ast::UserRewriteDecl, const ast::PatternDecl>(
183 [&](auto derivedNode) { this->genImpl(derivedNode); })
184 .Case([&](const ast::Expr *expr) { genExpr(expr); });
185 }
186
187 //===----------------------------------------------------------------------===//
188 // CodeGen: Statements
189 //===----------------------------------------------------------------------===//
190
genImpl(const ast::CompoundStmt * stmt)191 void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
192 VariableMapTy::ScopeTy varScope(variables);
193 for (const ast::Stmt *childStmt : stmt->getChildren())
194 gen(childStmt);
195 }
196
197 /// If the given builder is nested under a PDL PatternOp, build a rewrite
198 /// operation and update the builder to nest under it. This is necessary for
199 /// PDLL operation rewrite statements that are directly nested within a Pattern.
checkAndNestUnderRewriteOp(OpBuilder & builder,Value rootExpr,Location loc)200 static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
201 Location loc) {
202 if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
203 pdl::RewriteOp rewrite =
204 builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
205 /*externalArgs=*/ValueRange());
206 builder.createBlock(&rewrite.body());
207 }
208 }
209
genImpl(const ast::EraseStmt * stmt)210 void CodeGen::genImpl(const ast::EraseStmt *stmt) {
211 OpBuilder::InsertionGuard insertGuard(builder);
212 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
213 Location loc = genLoc(stmt->getLoc());
214
215 // Make sure we are nested in a RewriteOp.
216 OpBuilder::InsertionGuard guard(builder);
217 checkAndNestUnderRewriteOp(builder, rootExpr, loc);
218 builder.create<pdl::EraseOp>(loc, rootExpr);
219 }
220
genImpl(const ast::LetStmt * stmt)221 void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); }
222
genImpl(const ast::ReplaceStmt * stmt)223 void CodeGen::genImpl(const ast::ReplaceStmt *stmt) {
224 OpBuilder::InsertionGuard insertGuard(builder);
225 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
226 Location loc = genLoc(stmt->getLoc());
227
228 // Make sure we are nested in a RewriteOp.
229 OpBuilder::InsertionGuard guard(builder);
230 checkAndNestUnderRewriteOp(builder, rootExpr, loc);
231
232 SmallVector<Value> replValues;
233 for (ast::Expr *replExpr : stmt->getReplExprs())
234 replValues.push_back(genSingleExpr(replExpr));
235
236 // Check to see if the statement has a replacement operation, or a range of
237 // replacement values.
238 bool usesReplOperation =
239 replValues.size() == 1 &&
240 replValues.front().getType().isa<pdl::OperationType>();
241 builder.create<pdl::ReplaceOp>(
242 loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
243 usesReplOperation ? ValueRange() : ValueRange(replValues));
244 }
245
genImpl(const ast::RewriteStmt * stmt)246 void CodeGen::genImpl(const ast::RewriteStmt *stmt) {
247 OpBuilder::InsertionGuard insertGuard(builder);
248 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
249
250 // Make sure we are nested in a RewriteOp.
251 OpBuilder::InsertionGuard guard(builder);
252 checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc()));
253 gen(stmt->getRewriteBody());
254 }
255
genImpl(const ast::ReturnStmt * stmt)256 void CodeGen::genImpl(const ast::ReturnStmt *stmt) {
257 // ReturnStmt generation is handled by the respective constraint or rewrite
258 // parent node.
259 }
260
261 //===----------------------------------------------------------------------===//
262 // CodeGen: Decls
263 //===----------------------------------------------------------------------===//
264
genImpl(const ast::UserConstraintDecl * decl)265 void CodeGen::genImpl(const ast::UserConstraintDecl *decl) {
266 // All PDLL constraints get inlined when called, and the main native
267 // constraint declarations doesn't require any MLIR to be generated, only uses
268 // of it do.
269 }
270
genImpl(const ast::UserRewriteDecl * decl)271 void CodeGen::genImpl(const ast::UserRewriteDecl *decl) {
272 // All PDLL rewrites get inlined when called, and the main native
273 // rewrite declarations doesn't require any MLIR to be generated, only uses
274 // of it do.
275 }
276
genImpl(const ast::PatternDecl * decl)277 void CodeGen::genImpl(const ast::PatternDecl *decl) {
278 const ast::Name *name = decl->getName();
279
280 // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it
281 // here.
282 pdl::PatternOp pattern = builder.create<pdl::PatternOp>(
283 genLoc(decl->getLoc()), decl->getBenefit(),
284 name ? Optional<StringRef>(name->getName()) : Optional<StringRef>());
285
286 OpBuilder::InsertionGuard savedInsertPoint(builder);
287 builder.setInsertionPointToStart(pattern.getBody());
288 gen(decl->getBody());
289 }
290
genVar(const ast::VariableDecl * varDecl)291 SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) {
292 auto it = variables.begin(varDecl);
293 if (it != variables.end())
294 return *it;
295
296 // If the variable has an initial value, use that as the base value.
297 // Otherwise, generate a value using the constraint list.
298 SmallVector<Value> values;
299 if (const ast::Expr *initExpr = varDecl->getInitExpr())
300 values = genExpr(initExpr);
301 else
302 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
303
304 // Apply the constraints of the values of the variable.
305 applyVarConstraints(varDecl, values);
306
307 variables.insert(varDecl, values);
308 return values;
309 }
310
genNonInitializerVar(const ast::VariableDecl * varDecl,Location loc)311 Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
312 Location loc) {
313 // A functor used to generate expressions nested
314 auto getTypeConstraint = [&]() -> Value {
315 for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) {
316 Value typeValue =
317 TypeSwitch<const ast::Node *, Value>(constraint.constraint)
318 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
319 ast::ValueRangeConstraintDecl>(
320 [&, this](auto *cst) -> Value {
321 if (auto *typeConstraintExpr = cst->getTypeExpr())
322 return this->genSingleExpr(typeConstraintExpr);
323 return Value();
324 })
325 .Default(Value());
326 if (typeValue)
327 return typeValue;
328 }
329 return Value();
330 };
331
332 // Generate a value based on the type of the variable.
333 ast::Type type = varDecl->getType();
334 Type mlirType = genType(type);
335 if (type.isa<ast::ValueType>())
336 return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
337 if (type.isa<ast::TypeType>())
338 return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
339 if (type.isa<ast::AttributeType>())
340 return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
341 if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) {
342 Value operands = builder.create<pdl::OperandsOp>(
343 loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
344 /*type=*/Value());
345 Value results = builder.create<pdl::TypesOp>(
346 loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
347 /*types=*/ArrayAttr());
348 return builder.create<pdl::OperationOp>(loc, opType.getName(), operands,
349 llvm::None, ValueRange(), results);
350 }
351
352 if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) {
353 ast::Type eleTy = rangeTy.getElementType();
354 if (eleTy.isa<ast::ValueType>())
355 return builder.create<pdl::OperandsOp>(loc, mlirType,
356 getTypeConstraint());
357 if (eleTy.isa<ast::TypeType>())
358 return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
359 }
360
361 llvm_unreachable("invalid non-initialized variable type");
362 }
363
applyVarConstraints(const ast::VariableDecl * varDecl,ValueRange values)364 void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,
365 ValueRange values) {
366 // Generate calls to any user constraints that were attached via the
367 // constraint list.
368 for (const ast::ConstraintRef &ref : varDecl->getConstraints())
369 if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
370 genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
371 }
372
373 //===----------------------------------------------------------------------===//
374 // CodeGen: Expressions
375 //===----------------------------------------------------------------------===//
376
genSingleExpr(const ast::Expr * expr)377 Value CodeGen::genSingleExpr(const ast::Expr *expr) {
378 return TypeSwitch<const ast::Expr *, Value>(expr)
379 .Case<const ast::AttributeExpr, const ast::MemberAccessExpr,
380 const ast::OperationExpr, const ast::TypeExpr>(
381 [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
382 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
383 [&](auto derivedNode) {
384 SmallVector<Value> results = this->genExprImpl(derivedNode);
385 assert(results.size() == 1 && "expected single expression result");
386 return results[0];
387 });
388 }
389
genExpr(const ast::Expr * expr)390 SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) {
391 return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr)
392 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
393 [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
394 .Default([&](const ast::Expr *expr) -> SmallVector<Value> {
395 return {genSingleExpr(expr)};
396 });
397 }
398
genExprImpl(const ast::AttributeExpr * expr)399 Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) {
400 Attribute attr = parseAttribute(expr->getValue(), builder.getContext());
401 assert(attr && "invalid MLIR attribute data");
402 return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
403 }
404
genExprImpl(const ast::CallExpr * expr)405 SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
406 Location loc = genLoc(expr->getLoc());
407 SmallVector<Value> arguments;
408 for (const ast::Expr *arg : expr->getArguments())
409 arguments.push_back(genSingleExpr(arg));
410
411 // Resolve the callable expression of this call.
412 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
413 assert(callableExpr && "unhandled CallExpr callable");
414
415 // Generate the PDL based on the type of callable.
416 const ast::Decl *callable = callableExpr->getDecl();
417 if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
418 return genConstraintCall(decl, loc, arguments);
419 if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
420 return genRewriteCall(decl, loc, arguments);
421 llvm_unreachable("unhandled CallExpr callable");
422 }
423
genExprImpl(const ast::DeclRefExpr * expr)424 SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) {
425 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl()))
426 return genVar(varDecl);
427 llvm_unreachable("unknown decl reference expression");
428 }
429
genExprImpl(const ast::MemberAccessExpr * expr)430 Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
431 Location loc = genLoc(expr->getLoc());
432 StringRef name = expr->getMemberName();
433 SmallVector<Value> parentExprs = genExpr(expr->getParentExpr());
434 ast::Type parentType = expr->getParentExpr()->getType();
435
436 // Handle operation based member access.
437 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
438 if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
439 Type mlirType = genType(expr->getType());
440 if (mlirType.isa<pdl::ValueType>())
441 return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
442 builder.getI32IntegerAttr(0));
443 return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
444 }
445
446 const ods::Operation *odsOp = opType.getODSOperation();
447 if (!odsOp) {
448 assert(llvm::isDigit(name[0]) &&
449 "unregistered op only allows numeric indexing");
450 unsigned resultIndex;
451 name.getAsInteger(/*Radix=*/10, resultIndex);
452 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
453 return builder.create<pdl::ResultOp>(loc, genType(expr->getType()),
454 parentExprs[0], index);
455 }
456
457 // Find the result with the member name or by index.
458 ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
459 unsigned resultIndex = results.size();
460 if (llvm::isDigit(name[0])) {
461 name.getAsInteger(/*Radix=*/10, resultIndex);
462 } else {
463 auto findFn = [&](const ods::OperandOrResult &result) {
464 return result.getName() == name;
465 };
466 resultIndex = llvm::find_if(results, findFn) - results.begin();
467 }
468 assert(resultIndex < results.size() && "invalid result index");
469
470 // Generate the result access.
471 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
472 return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
473 parentExprs[0], index);
474 }
475
476 // Handle tuple based member access.
477 if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
478 auto elementNames = tupleType.getElementNames();
479
480 // The index is either a numeric index, or a name.
481 unsigned index = 0;
482 if (llvm::isDigit(name[0]))
483 name.getAsInteger(/*Radix=*/10, index);
484 else
485 index = llvm::find(elementNames, name) - elementNames.begin();
486
487 assert(index < parentExprs.size() && "invalid result index");
488 return parentExprs[index];
489 }
490
491 llvm_unreachable("unhandled member access expression");
492 }
493
genExprImpl(const ast::OperationExpr * expr)494 Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
495 Location loc = genLoc(expr->getLoc());
496 Optional<StringRef> opName = expr->getName();
497
498 // Operands.
499 SmallVector<Value> operands;
500 for (const ast::Expr *operand : expr->getOperands())
501 operands.push_back(genSingleExpr(operand));
502
503 // Attributes.
504 SmallVector<StringRef> attrNames;
505 SmallVector<Value> attrValues;
506 for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) {
507 attrNames.push_back(attr->getName().getName());
508 attrValues.push_back(genSingleExpr(attr->getValue()));
509 }
510
511 // Results.
512 SmallVector<Value> results;
513 for (const ast::Expr *result : expr->getResultTypes())
514 results.push_back(genSingleExpr(result));
515
516 return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
517 attrValues, results);
518 }
519
genExprImpl(const ast::TupleExpr * expr)520 SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) {
521 SmallVector<Value> elements;
522 for (const ast::Expr *element : expr->getElements())
523 elements.push_back(genSingleExpr(element));
524 return elements;
525 }
526
genExprImpl(const ast::TypeExpr * expr)527 Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
528 Type type = parseType(expr->getValue(), builder.getContext());
529 assert(type && "invalid MLIR type data");
530 return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
531 builder.getType<pdl::TypeType>(),
532 TypeAttr::get(type));
533 }
534
535 SmallVector<Value>
genConstraintCall(const ast::UserConstraintDecl * decl,Location loc,ValueRange inputs)536 CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
537 ValueRange inputs) {
538 // Apply any constraints defined on the arguments to the input values.
539 for (auto it : llvm::zip(decl->getInputs(), inputs))
540 applyVarConstraints(std::get<0>(it), std::get<1>(it));
541
542 // Generate the constraint call.
543 SmallVector<Value> results =
544 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc,
545 inputs);
546
547 // Apply any constraints defined on the results of the constraint.
548 for (auto it : llvm::zip(decl->getResults(), results))
549 applyVarConstraints(std::get<0>(it), std::get<1>(it));
550 return results;
551 }
552
genRewriteCall(const ast::UserRewriteDecl * decl,Location loc,ValueRange inputs)553 SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
554 Location loc, ValueRange inputs) {
555 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
556 inputs);
557 }
558
559 template <typename PDLOpT, typename T>
genConstraintOrRewriteCall(const T * decl,Location loc,ValueRange inputs)560 SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
561 Location loc,
562 ValueRange inputs) {
563 const ast::CompoundStmt *cstBody = decl->getBody();
564
565 // If the decl doesn't have a statement body, it is a native decl.
566 if (!cstBody) {
567 ast::Type declResultType = decl->getResultType();
568 SmallVector<Type> resultTypes;
569 if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) {
570 for (ast::Type type : tupleType.getElementTypes())
571 resultTypes.push_back(genType(type));
572 } else {
573 resultTypes.push_back(genType(declResultType));
574 }
575 Operation *pdlOp = builder.create<PDLOpT>(
576 loc, resultTypes, decl->getName().getName(), inputs);
577 return pdlOp->getResults();
578 }
579
580 // Otherwise, this is a PDLL decl.
581 VariableMapTy::ScopeTy varScope(variables);
582
583 // Map the inputs of the call to the decl arguments.
584 // Note: This is only valid because we do not support recursion, meaning
585 // we don't need to worry about conflicting mappings here.
586 for (auto it : llvm::zip(inputs, decl->getInputs()))
587 variables.insert(std::get<1>(it), {std::get<0>(it)});
588
589 // Visit the body of the call as normal.
590 gen(cstBody);
591
592 // If the decl has no results, there is nothing to do.
593 if (cstBody->getChildren().empty())
594 return SmallVector<Value>();
595 auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back());
596 if (!returnStmt)
597 return SmallVector<Value>();
598
599 // Otherwise, grab the results from the return statement.
600 return genExpr(returnStmt->getResultExpr());
601 }
602
603 //===----------------------------------------------------------------------===//
604 // MLIRGen
605 //===----------------------------------------------------------------------===//
606
codegenPDLLToMLIR(MLIRContext * mlirContext,const ast::Context & context,const llvm::SourceMgr & sourceMgr,const ast::Module & module)607 OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR(
608 MLIRContext *mlirContext, const ast::Context &context,
609 const llvm::SourceMgr &sourceMgr, const ast::Module &module) {
610 CodeGen codegen(mlirContext, context, sourceMgr);
611 OwningOpRef<ModuleOp> mlirModule = codegen.generate(module);
612 if (failed(verify(*mlirModule)))
613 return nullptr;
614 return mlirModule;
615 }
616