1 //===- MLIRGen.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" 10 #include "mlir/AsmParser/AsmParser.h" 11 #include "mlir/Dialect/PDL/IR/PDL.h" 12 #include "mlir/Dialect/PDL/IR/PDLOps.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Verifier.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Nodes.h" 19 #include "mlir/Tools/PDLL/AST/Types.h" 20 #include "mlir/Tools/PDLL/ODS/Context.h" 21 #include "mlir/Tools/PDLL/ODS/Operation.h" 22 #include "llvm/ADT/ScopedHashTable.h" 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::pdll; 28 29 //===----------------------------------------------------------------------===// 30 // CodeGen 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 class CodeGen { 35 public: 36 CodeGen(MLIRContext *mlirContext, const ast::Context &context, 37 const llvm::SourceMgr &sourceMgr) 38 : builder(mlirContext), odsContext(context.getODSContext()), 39 sourceMgr(sourceMgr) { 40 // Make sure that the PDL dialect is loaded. 41 mlirContext->loadDialect<pdl::PDLDialect>(); 42 } 43 44 OwningOpRef<ModuleOp> generate(const ast::Module &module); 45 46 private: 47 /// Generate an MLIR location from the given source location. 48 Location genLoc(llvm::SMLoc loc); 49 Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); } 50 51 /// Generate an MLIR type from the given source type. 52 Type genType(ast::Type type); 53 54 /// Generate MLIR for the given AST node. 55 void gen(const ast::Node *node); 56 57 //===--------------------------------------------------------------------===// 58 // Statements 59 //===--------------------------------------------------------------------===// 60 61 void genImpl(const ast::CompoundStmt *stmt); 62 void genImpl(const ast::EraseStmt *stmt); 63 void genImpl(const ast::LetStmt *stmt); 64 void genImpl(const ast::ReplaceStmt *stmt); 65 void genImpl(const ast::RewriteStmt *stmt); 66 void genImpl(const ast::ReturnStmt *stmt); 67 68 //===--------------------------------------------------------------------===// 69 // Decls 70 //===--------------------------------------------------------------------===// 71 72 void genImpl(const ast::UserConstraintDecl *decl); 73 void genImpl(const ast::UserRewriteDecl *decl); 74 void genImpl(const ast::PatternDecl *decl); 75 76 /// Generate the set of MLIR values defined for the given variable decl, and 77 /// apply any attached constraints. 78 SmallVector<Value> genVar(const ast::VariableDecl *varDecl); 79 80 /// Generate the value for a variable that does not have an initializer 81 /// expression, i.e. create the PDL value based on the type/constraints of the 82 /// variable. 83 Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc); 84 85 /// Apply the constraints of the given variable to `values`, which correspond 86 /// to the MLIR values of the variable. 87 void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values); 88 89 //===--------------------------------------------------------------------===// 90 // Expressions 91 //===--------------------------------------------------------------------===// 92 93 Value genSingleExpr(const ast::Expr *expr); 94 SmallVector<Value> genExpr(const ast::Expr *expr); 95 Value genExprImpl(const ast::AttributeExpr *expr); 96 SmallVector<Value> genExprImpl(const ast::CallExpr *expr); 97 SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr); 98 Value genExprImpl(const ast::MemberAccessExpr *expr); 99 Value genExprImpl(const ast::OperationExpr *expr); 100 SmallVector<Value> genExprImpl(const ast::TupleExpr *expr); 101 Value genExprImpl(const ast::TypeExpr *expr); 102 103 SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl, 104 Location loc, ValueRange inputs); 105 SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl, 106 Location loc, ValueRange inputs); 107 template <typename PDLOpT, typename T> 108 SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc, 109 ValueRange inputs); 110 111 //===--------------------------------------------------------------------===// 112 // Fields 113 //===--------------------------------------------------------------------===// 114 115 /// The MLIR builder used for building the resultant IR. 116 OpBuilder builder; 117 118 /// A map from variable declarations to the MLIR equivalent. 119 using VariableMapTy = 120 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>; 121 VariableMapTy variables; 122 123 /// A reference to the ODS context. 124 const ods::Context &odsContext; 125 126 /// The source manager of the PDLL ast. 127 const llvm::SourceMgr &sourceMgr; 128 }; 129 } // namespace 130 131 OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) { 132 OwningOpRef<ModuleOp> mlirModule = 133 builder.create<ModuleOp>(genLoc(module.getLoc())); 134 builder.setInsertionPointToStart(mlirModule->getBody()); 135 136 // Generate code for each of the decls within the module. 137 for (const ast::Decl *decl : module.getChildren()) 138 gen(decl); 139 140 return mlirModule; 141 } 142 143 Location CodeGen::genLoc(llvm::SMLoc loc) { 144 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc); 145 146 // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can 147 // use it here. 148 auto &bufferInfo = sourceMgr.getBufferInfo(fileID); 149 unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer()); 150 unsigned column = 151 (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1; 152 auto *buffer = sourceMgr.getMemoryBuffer(fileID); 153 154 return FileLineColLoc::get(builder.getContext(), 155 buffer->getBufferIdentifier(), lineNo, column); 156 } 157 158 Type CodeGen::genType(ast::Type type) { 159 return TypeSwitch<ast::Type, Type>(type) 160 .Case([&](ast::AttributeType astType) -> Type { 161 return builder.getType<pdl::AttributeType>(); 162 }) 163 .Case([&](ast::OperationType astType) -> Type { 164 return builder.getType<pdl::OperationType>(); 165 }) 166 .Case([&](ast::TypeType astType) -> Type { 167 return builder.getType<pdl::TypeType>(); 168 }) 169 .Case([&](ast::ValueType astType) -> Type { 170 return builder.getType<pdl::ValueType>(); 171 }) 172 .Case([&](ast::RangeType astType) -> Type { 173 return pdl::RangeType::get(genType(astType.getElementType())); 174 }); 175 } 176 177 void CodeGen::gen(const ast::Node *node) { 178 TypeSwitch<const ast::Node *>(node) 179 .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt, 180 const ast::ReplaceStmt, const ast::RewriteStmt, 181 const ast::ReturnStmt, const ast::UserConstraintDecl, 182 const ast::UserRewriteDecl, const ast::PatternDecl>( 183 [&](auto derivedNode) { this->genImpl(derivedNode); }) 184 .Case([&](const ast::Expr *expr) { genExpr(expr); }); 185 } 186 187 //===----------------------------------------------------------------------===// 188 // CodeGen: Statements 189 //===----------------------------------------------------------------------===// 190 191 void CodeGen::genImpl(const ast::CompoundStmt *stmt) { 192 VariableMapTy::ScopeTy varScope(variables); 193 for (const ast::Stmt *childStmt : stmt->getChildren()) 194 gen(childStmt); 195 } 196 197 /// If the given builder is nested under a PDL PatternOp, build a rewrite 198 /// operation and update the builder to nest under it. This is necessary for 199 /// PDLL operation rewrite statements that are directly nested within a Pattern. 200 static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, 201 Location loc) { 202 if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) { 203 pdl::RewriteOp rewrite = 204 builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(), 205 /*externalArgs=*/ValueRange()); 206 builder.createBlock(&rewrite.body()); 207 } 208 } 209 210 void CodeGen::genImpl(const ast::EraseStmt *stmt) { 211 OpBuilder::InsertionGuard insertGuard(builder); 212 Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); 213 Location loc = genLoc(stmt->getLoc()); 214 215 // Make sure we are nested in a RewriteOp. 216 OpBuilder::InsertionGuard guard(builder); 217 checkAndNestUnderRewriteOp(builder, rootExpr, loc); 218 builder.create<pdl::EraseOp>(loc, rootExpr); 219 } 220 221 void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); } 222 223 void CodeGen::genImpl(const ast::ReplaceStmt *stmt) { 224 OpBuilder::InsertionGuard insertGuard(builder); 225 Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); 226 Location loc = genLoc(stmt->getLoc()); 227 228 // Make sure we are nested in a RewriteOp. 229 OpBuilder::InsertionGuard guard(builder); 230 checkAndNestUnderRewriteOp(builder, rootExpr, loc); 231 232 SmallVector<Value> replValues; 233 for (ast::Expr *replExpr : stmt->getReplExprs()) 234 replValues.push_back(genSingleExpr(replExpr)); 235 236 // Check to see if the statement has a replacement operation, or a range of 237 // replacement values. 238 bool usesReplOperation = 239 replValues.size() == 1 && 240 replValues.front().getType().isa<pdl::OperationType>(); 241 builder.create<pdl::ReplaceOp>( 242 loc, rootExpr, usesReplOperation ? replValues[0] : Value(), 243 usesReplOperation ? ValueRange() : ValueRange(replValues)); 244 } 245 246 void CodeGen::genImpl(const ast::RewriteStmt *stmt) { 247 OpBuilder::InsertionGuard insertGuard(builder); 248 Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); 249 250 // Make sure we are nested in a RewriteOp. 251 OpBuilder::InsertionGuard guard(builder); 252 checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc())); 253 gen(stmt->getRewriteBody()); 254 } 255 256 void CodeGen::genImpl(const ast::ReturnStmt *stmt) { 257 // ReturnStmt generation is handled by the respective constraint or rewrite 258 // parent node. 259 } 260 261 //===----------------------------------------------------------------------===// 262 // CodeGen: Decls 263 //===----------------------------------------------------------------------===// 264 265 void CodeGen::genImpl(const ast::UserConstraintDecl *decl) { 266 // All PDLL constraints get inlined when called, and the main native 267 // constraint declarations doesn't require any MLIR to be generated, only uses 268 // of it do. 269 } 270 271 void CodeGen::genImpl(const ast::UserRewriteDecl *decl) { 272 // All PDLL rewrites get inlined when called, and the main native 273 // rewrite declarations doesn't require any MLIR to be generated, only uses 274 // of it do. 275 } 276 277 void CodeGen::genImpl(const ast::PatternDecl *decl) { 278 const ast::Name *name = decl->getName(); 279 280 // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it 281 // here. 282 pdl::PatternOp pattern = builder.create<pdl::PatternOp>( 283 genLoc(decl->getLoc()), decl->getBenefit(), 284 name ? Optional<StringRef>(name->getName()) : Optional<StringRef>()); 285 286 OpBuilder::InsertionGuard savedInsertPoint(builder); 287 builder.setInsertionPointToStart(pattern.getBody()); 288 gen(decl->getBody()); 289 } 290 291 SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) { 292 auto it = variables.begin(varDecl); 293 if (it != variables.end()) 294 return *it; 295 296 // If the variable has an initial value, use that as the base value. 297 // Otherwise, generate a value using the constraint list. 298 SmallVector<Value> values; 299 if (const ast::Expr *initExpr = varDecl->getInitExpr()) 300 values = genExpr(initExpr); 301 else 302 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc()))); 303 304 // Apply the constraints of the values of the variable. 305 applyVarConstraints(varDecl, values); 306 307 variables.insert(varDecl, values); 308 return values; 309 } 310 311 Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl, 312 Location loc) { 313 // A functor used to generate expressions nested 314 auto getTypeConstraint = [&]() -> Value { 315 for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) { 316 Value typeValue = 317 TypeSwitch<const ast::Node *, Value>(constraint.constraint) 318 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 319 ast::ValueRangeConstraintDecl>( 320 [&, this](auto *cst) -> Value { 321 if (auto *typeConstraintExpr = cst->getTypeExpr()) 322 return this->genSingleExpr(typeConstraintExpr); 323 return Value(); 324 }) 325 .Default(Value()); 326 if (typeValue) 327 return typeValue; 328 } 329 return Value(); 330 }; 331 332 // Generate a value based on the type of the variable. 333 ast::Type type = varDecl->getType(); 334 Type mlirType = genType(type); 335 if (type.isa<ast::ValueType>()) 336 return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint()); 337 if (type.isa<ast::TypeType>()) 338 return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr()); 339 if (type.isa<ast::AttributeType>()) 340 return builder.create<pdl::AttributeOp>(loc, getTypeConstraint()); 341 if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) { 342 Value operands = builder.create<pdl::OperandsOp>( 343 loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()), 344 /*type=*/Value()); 345 Value results = builder.create<pdl::TypesOp>( 346 loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()), 347 /*types=*/ArrayAttr()); 348 return builder.create<pdl::OperationOp>(loc, opType.getName(), operands, 349 llvm::None, ValueRange(), results); 350 } 351 352 if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) { 353 ast::Type eleTy = rangeTy.getElementType(); 354 if (eleTy.isa<ast::ValueType>()) 355 return builder.create<pdl::OperandsOp>(loc, mlirType, 356 getTypeConstraint()); 357 if (eleTy.isa<ast::TypeType>()) 358 return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr()); 359 } 360 361 llvm_unreachable("invalid non-initialized variable type"); 362 } 363 364 void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl, 365 ValueRange values) { 366 // Generate calls to any user constraints that were attached via the 367 // constraint list. 368 for (const ast::ConstraintRef &ref : varDecl->getConstraints()) 369 if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint)) 370 genConstraintCall(userCst, genLoc(ref.referenceLoc), values); 371 } 372 373 //===----------------------------------------------------------------------===// 374 // CodeGen: Expressions 375 //===----------------------------------------------------------------------===// 376 377 Value CodeGen::genSingleExpr(const ast::Expr *expr) { 378 return TypeSwitch<const ast::Expr *, Value>(expr) 379 .Case<const ast::AttributeExpr, const ast::MemberAccessExpr, 380 const ast::OperationExpr, const ast::TypeExpr>( 381 [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) 382 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>( 383 [&](auto derivedNode) { 384 SmallVector<Value> results = this->genExprImpl(derivedNode); 385 assert(results.size() == 1 && "expected single expression result"); 386 return results[0]; 387 }); 388 } 389 390 SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) { 391 return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr) 392 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>( 393 [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) 394 .Default([&](const ast::Expr *expr) -> SmallVector<Value> { 395 return {genSingleExpr(expr)}; 396 }); 397 } 398 399 Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) { 400 Attribute attr = parseAttribute(expr->getValue(), builder.getContext()); 401 assert(attr && "invalid MLIR attribute data"); 402 return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr); 403 } 404 405 SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) { 406 Location loc = genLoc(expr->getLoc()); 407 SmallVector<Value> arguments; 408 for (const ast::Expr *arg : expr->getArguments()) 409 arguments.push_back(genSingleExpr(arg)); 410 411 // Resolve the callable expression of this call. 412 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr()); 413 assert(callableExpr && "unhandled CallExpr callable"); 414 415 // Generate the PDL based on the type of callable. 416 const ast::Decl *callable = callableExpr->getDecl(); 417 if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable)) 418 return genConstraintCall(decl, loc, arguments); 419 if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable)) 420 return genRewriteCall(decl, loc, arguments); 421 llvm_unreachable("unhandled CallExpr callable"); 422 } 423 424 SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) { 425 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl())) 426 return genVar(varDecl); 427 llvm_unreachable("unknown decl reference expression"); 428 } 429 430 Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { 431 Location loc = genLoc(expr->getLoc()); 432 StringRef name = expr->getMemberName(); 433 SmallVector<Value> parentExprs = genExpr(expr->getParentExpr()); 434 ast::Type parentType = expr->getParentExpr()->getType(); 435 436 // Handle operation based member access. 437 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 438 if (isa<ast::AllResultsMemberAccessExpr>(expr)) { 439 Type mlirType = genType(expr->getType()); 440 if (mlirType.isa<pdl::ValueType>()) 441 return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0], 442 builder.getI32IntegerAttr(0)); 443 return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]); 444 } 445 446 const ods::Operation *odsOp = opType.getODSOperation(); 447 if (!odsOp) { 448 assert(llvm::isDigit(name[0]) && 449 "unregistered op only allows numeric indexing"); 450 unsigned resultIndex; 451 name.getAsInteger(/*Radix=*/10, resultIndex); 452 IntegerAttr index = builder.getI32IntegerAttr(resultIndex); 453 return builder.create<pdl::ResultOp>(loc, genType(expr->getType()), 454 parentExprs[0], index); 455 } 456 457 // Find the result with the member name or by index. 458 ArrayRef<ods::OperandOrResult> results = odsOp->getResults(); 459 unsigned resultIndex = results.size(); 460 if (llvm::isDigit(name[0])) { 461 name.getAsInteger(/*Radix=*/10, resultIndex); 462 } else { 463 auto findFn = [&](const ods::OperandOrResult &result) { 464 return result.getName() == name; 465 }; 466 resultIndex = llvm::find_if(results, findFn) - results.begin(); 467 } 468 assert(resultIndex < results.size() && "invalid result index"); 469 470 // Generate the result access. 471 IntegerAttr index = builder.getI32IntegerAttr(resultIndex); 472 return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()), 473 parentExprs[0], index); 474 } 475 476 // Handle tuple based member access. 477 if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 478 auto elementNames = tupleType.getElementNames(); 479 480 // The index is either a numeric index, or a name. 481 unsigned index = 0; 482 if (llvm::isDigit(name[0])) 483 name.getAsInteger(/*Radix=*/10, index); 484 else 485 index = llvm::find(elementNames, name) - elementNames.begin(); 486 487 assert(index < parentExprs.size() && "invalid result index"); 488 return parentExprs[index]; 489 } 490 491 llvm_unreachable("unhandled member access expression"); 492 } 493 494 Value CodeGen::genExprImpl(const ast::OperationExpr *expr) { 495 Location loc = genLoc(expr->getLoc()); 496 Optional<StringRef> opName = expr->getName(); 497 498 // Operands. 499 SmallVector<Value> operands; 500 for (const ast::Expr *operand : expr->getOperands()) 501 operands.push_back(genSingleExpr(operand)); 502 503 // Attributes. 504 SmallVector<StringRef> attrNames; 505 SmallVector<Value> attrValues; 506 for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) { 507 attrNames.push_back(attr->getName().getName()); 508 attrValues.push_back(genSingleExpr(attr->getValue())); 509 } 510 511 // Results. 512 SmallVector<Value> results; 513 for (const ast::Expr *result : expr->getResultTypes()) 514 results.push_back(genSingleExpr(result)); 515 516 return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames, 517 attrValues, results); 518 } 519 520 SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) { 521 SmallVector<Value> elements; 522 for (const ast::Expr *element : expr->getElements()) 523 elements.push_back(genSingleExpr(element)); 524 return elements; 525 } 526 527 Value CodeGen::genExprImpl(const ast::TypeExpr *expr) { 528 Type type = parseType(expr->getValue(), builder.getContext()); 529 assert(type && "invalid MLIR type data"); 530 return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()), 531 builder.getType<pdl::TypeType>(), 532 TypeAttr::get(type)); 533 } 534 535 SmallVector<Value> 536 CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc, 537 ValueRange inputs) { 538 // Apply any constraints defined on the arguments to the input values. 539 for (auto it : llvm::zip(decl->getInputs(), inputs)) 540 applyVarConstraints(std::get<0>(it), std::get<1>(it)); 541 542 // Generate the constraint call. 543 SmallVector<Value> results = 544 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc, 545 inputs); 546 547 // Apply any constraints defined on the results of the constraint. 548 for (auto it : llvm::zip(decl->getResults(), results)) 549 applyVarConstraints(std::get<0>(it), std::get<1>(it)); 550 return results; 551 } 552 553 SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl, 554 Location loc, ValueRange inputs) { 555 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc, 556 inputs); 557 } 558 559 template <typename PDLOpT, typename T> 560 SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl, 561 Location loc, 562 ValueRange inputs) { 563 const ast::CompoundStmt *cstBody = decl->getBody(); 564 565 // If the decl doesn't have a statement body, it is a native decl. 566 if (!cstBody) { 567 ast::Type declResultType = decl->getResultType(); 568 SmallVector<Type> resultTypes; 569 if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) { 570 for (ast::Type type : tupleType.getElementTypes()) 571 resultTypes.push_back(genType(type)); 572 } else { 573 resultTypes.push_back(genType(declResultType)); 574 } 575 Operation *pdlOp = builder.create<PDLOpT>( 576 loc, resultTypes, decl->getName().getName(), inputs); 577 return pdlOp->getResults(); 578 } 579 580 // Otherwise, this is a PDLL decl. 581 VariableMapTy::ScopeTy varScope(variables); 582 583 // Map the inputs of the call to the decl arguments. 584 // Note: This is only valid because we do not support recursion, meaning 585 // we don't need to worry about conflicting mappings here. 586 for (auto it : llvm::zip(inputs, decl->getInputs())) 587 variables.insert(std::get<1>(it), {std::get<0>(it)}); 588 589 // Visit the body of the call as normal. 590 gen(cstBody); 591 592 // If the decl has no results, there is nothing to do. 593 if (cstBody->getChildren().empty()) 594 return SmallVector<Value>(); 595 auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back()); 596 if (!returnStmt) 597 return SmallVector<Value>(); 598 599 // Otherwise, grab the results from the return statement. 600 return genExpr(returnStmt->getResultExpr()); 601 } 602 603 //===----------------------------------------------------------------------===// 604 // MLIRGen 605 //===----------------------------------------------------------------------===// 606 607 OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR( 608 MLIRContext *mlirContext, const ast::Context &context, 609 const llvm::SourceMgr &sourceMgr, const ast::Module &module) { 610 CodeGen codegen(mlirContext, context, sourceMgr); 611 OwningOpRef<ModuleOp> mlirModule = codegen.generate(module); 612 if (failed(verify(*mlirModule))) 613 return nullptr; 614 return mlirModule; 615 } 616