1 //===- MLIRGen.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" 10 #include "mlir/Dialect/PDL/IR/PDL.h" 11 #include "mlir/Dialect/PDL/IR/PDLOps.h" 12 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Verifier.h" 16 #include "mlir/Parser.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Nodes.h" 19 #include "mlir/Tools/PDLL/AST/Types.h" 20 #include "llvm/ADT/ScopedHashTable.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 24 using namespace mlir; 25 using namespace mlir::pdll; 26 27 //===----------------------------------------------------------------------===// 28 // CodeGen 29 //===----------------------------------------------------------------------===// 30 31 namespace { 32 class CodeGen { 33 public: 34 CodeGen(MLIRContext *mlirContext, const ast::Context &context, 35 const llvm::SourceMgr &sourceMgr) 36 : builder(mlirContext), sourceMgr(sourceMgr) { 37 // Make sure that the PDL dialect is loaded. 38 mlirContext->loadDialect<pdl::PDLDialect>(); 39 } 40 41 OwningOpRef<ModuleOp> generate(const ast::Module &module); 42 43 private: 44 /// Generate an MLIR location from the given source location. 45 Location genLoc(llvm::SMLoc loc); 46 Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); } 47 48 /// Generate an MLIR type from the given source type. 49 Type genType(ast::Type type); 50 51 /// Generate MLIR for the given AST node. 52 void gen(const ast::Node *node); 53 54 //===--------------------------------------------------------------------===// 55 // Statements 56 //===--------------------------------------------------------------------===// 57 58 void genImpl(const ast::CompoundStmt *stmt); 59 void genImpl(const ast::EraseStmt *stmt); 60 void genImpl(const ast::LetStmt *stmt); 61 void genImpl(const ast::ReplaceStmt *stmt); 62 void genImpl(const ast::RewriteStmt *stmt); 63 void genImpl(const ast::ReturnStmt *stmt); 64 65 //===--------------------------------------------------------------------===// 66 // Decls 67 //===--------------------------------------------------------------------===// 68 69 void genImpl(const ast::UserConstraintDecl *decl); 70 void genImpl(const ast::UserRewriteDecl *decl); 71 void genImpl(const ast::PatternDecl *decl); 72 73 /// Generate the set of MLIR values defined for the given variable decl, and 74 /// apply any attached constraints. 75 SmallVector<Value> genVar(const ast::VariableDecl *varDecl); 76 77 /// Generate the value for a variable that does not have an initializer 78 /// expression, i.e. create the PDL value based on the type/constraints of the 79 /// variable. 80 Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc); 81 82 /// Apply the constraints of the given variable to `values`, which correspond 83 /// to the MLIR values of the variable. 84 void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values); 85 86 //===--------------------------------------------------------------------===// 87 // Expressions 88 //===--------------------------------------------------------------------===// 89 90 Value genSingleExpr(const ast::Expr *expr); 91 SmallVector<Value> genExpr(const ast::Expr *expr); 92 Value genExprImpl(const ast::AttributeExpr *expr); 93 SmallVector<Value> genExprImpl(const ast::CallExpr *expr); 94 SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr); 95 Value genExprImpl(const ast::MemberAccessExpr *expr); 96 Value genExprImpl(const ast::OperationExpr *expr); 97 SmallVector<Value> genExprImpl(const ast::TupleExpr *expr); 98 Value genExprImpl(const ast::TypeExpr *expr); 99 100 SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl, 101 Location loc, ValueRange inputs); 102 SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl, 103 Location loc, ValueRange inputs); 104 template <typename PDLOpT, typename T> 105 SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc, 106 ValueRange inputs); 107 108 //===--------------------------------------------------------------------===// 109 // Fields 110 //===--------------------------------------------------------------------===// 111 112 /// The MLIR builder used for building the resultant IR. 113 OpBuilder builder; 114 115 /// A map from variable declarations to the MLIR equivalent. 116 using VariableMapTy = 117 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>; 118 VariableMapTy variables; 119 120 /// The source manager of the PDLL ast. 121 const llvm::SourceMgr &sourceMgr; 122 }; 123 } // namespace 124 125 OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) { 126 OwningOpRef<ModuleOp> mlirModule = 127 builder.create<ModuleOp>(genLoc(module.getLoc())); 128 builder.setInsertionPointToStart(mlirModule->getBody()); 129 130 // Generate code for each of the decls within the module. 131 for (const ast::Decl *decl : module.getChildren()) 132 gen(decl); 133 134 return mlirModule; 135 } 136 137 Location CodeGen::genLoc(llvm::SMLoc loc) { 138 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc); 139 140 // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can 141 // use it here. 142 auto &bufferInfo = sourceMgr.getBufferInfo(fileID); 143 unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer()); 144 unsigned column = 145 (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1; 146 auto *buffer = sourceMgr.getMemoryBuffer(fileID); 147 148 return FileLineColLoc::get(builder.getContext(), 149 buffer->getBufferIdentifier(), lineNo, column); 150 } 151 152 Type CodeGen::genType(ast::Type type) { 153 return TypeSwitch<ast::Type, Type>(type) 154 .Case([&](ast::AttributeType astType) -> Type { 155 return builder.getType<pdl::AttributeType>(); 156 }) 157 .Case([&](ast::OperationType astType) -> Type { 158 return builder.getType<pdl::OperationType>(); 159 }) 160 .Case([&](ast::TypeType astType) -> Type { 161 return builder.getType<pdl::TypeType>(); 162 }) 163 .Case([&](ast::ValueType astType) -> Type { 164 return builder.getType<pdl::ValueType>(); 165 }) 166 .Case([&](ast::RangeType astType) -> Type { 167 return pdl::RangeType::get(genType(astType.getElementType())); 168 }); 169 } 170 171 void CodeGen::gen(const ast::Node *node) { 172 TypeSwitch<const ast::Node *>(node) 173 .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt, 174 const ast::ReplaceStmt, const ast::RewriteStmt, 175 const ast::ReturnStmt, const ast::UserConstraintDecl, 176 const ast::UserRewriteDecl, const ast::PatternDecl>( 177 [&](auto derivedNode) { this->genImpl(derivedNode); }) 178 .Case([&](const ast::Expr *expr) { genExpr(expr); }); 179 } 180 181 //===----------------------------------------------------------------------===// 182 // CodeGen: Statements 183 //===----------------------------------------------------------------------===// 184 185 void CodeGen::genImpl(const ast::CompoundStmt *stmt) { 186 VariableMapTy::ScopeTy varScope(variables); 187 for (const ast::Stmt *childStmt : stmt->getChildren()) 188 gen(childStmt); 189 } 190 191 /// If the given builder is nested under a PDL PatternOp, build a rewrite 192 /// operation and update the builder to nest under it. This is necessary for 193 /// PDLL operation rewrite statements that are directly nested within a Pattern. 194 static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, 195 Location loc) { 196 if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) { 197 pdl::RewriteOp rewrite = builder.create<pdl::RewriteOp>( 198 loc, rootExpr, /*name=*/StringAttr(), 199 /*externalArgs=*/ValueRange(), /*externalConstParams=*/ArrayAttr()); 200 builder.createBlock(&rewrite.body()); 201 } 202 } 203 204 void CodeGen::genImpl(const ast::EraseStmt *stmt) { 205 OpBuilder::InsertionGuard insertGuard(builder); 206 Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); 207 Location loc = genLoc(stmt->getLoc()); 208 209 // Make sure we are nested in a RewriteOp. 210 OpBuilder::InsertionGuard guard(builder); 211 checkAndNestUnderRewriteOp(builder, rootExpr, loc); 212 builder.create<pdl::EraseOp>(loc, rootExpr); 213 } 214 215 void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); } 216 217 void CodeGen::genImpl(const ast::ReplaceStmt *stmt) { 218 OpBuilder::InsertionGuard insertGuard(builder); 219 Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); 220 Location loc = genLoc(stmt->getLoc()); 221 222 // Make sure we are nested in a RewriteOp. 223 OpBuilder::InsertionGuard guard(builder); 224 checkAndNestUnderRewriteOp(builder, rootExpr, loc); 225 226 SmallVector<Value> replValues; 227 for (ast::Expr *replExpr : stmt->getReplExprs()) 228 replValues.push_back(genSingleExpr(replExpr)); 229 230 // Check to see if the statement has a replacement operation, or a range of 231 // replacement values. 232 bool usesReplOperation = 233 replValues.size() == 1 && 234 replValues.front().getType().isa<pdl::OperationType>(); 235 builder.create<pdl::ReplaceOp>( 236 loc, rootExpr, usesReplOperation ? replValues[0] : Value(), 237 usesReplOperation ? ValueRange() : ValueRange(replValues)); 238 } 239 240 void CodeGen::genImpl(const ast::RewriteStmt *stmt) { 241 OpBuilder::InsertionGuard insertGuard(builder); 242 Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); 243 244 // Make sure we are nested in a RewriteOp. 245 OpBuilder::InsertionGuard guard(builder); 246 checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc())); 247 gen(stmt->getRewriteBody()); 248 } 249 250 void CodeGen::genImpl(const ast::ReturnStmt *stmt) { 251 // ReturnStmt generation is handled by the respective constraint or rewrite 252 // parent node. 253 } 254 255 //===----------------------------------------------------------------------===// 256 // CodeGen: Decls 257 //===----------------------------------------------------------------------===// 258 259 void CodeGen::genImpl(const ast::UserConstraintDecl *decl) { 260 // All PDLL constraints get inlined when called, and the main native 261 // constraint declarations doesn't require any MLIR to be generated, only uses 262 // of it do. 263 } 264 265 void CodeGen::genImpl(const ast::UserRewriteDecl *decl) { 266 // All PDLL rewrites get inlined when called, and the main native 267 // rewrite declarations doesn't require any MLIR to be generated, only uses 268 // of it do. 269 } 270 271 void CodeGen::genImpl(const ast::PatternDecl *decl) { 272 const ast::Name *name = decl->getName(); 273 274 // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it 275 // here. 276 pdl::PatternOp pattern = builder.create<pdl::PatternOp>( 277 genLoc(decl->getLoc()), decl->getBenefit(), 278 name ? Optional<StringRef>(name->getName()) : Optional<StringRef>()); 279 280 OpBuilder::InsertionGuard savedInsertPoint(builder); 281 builder.setInsertionPointToStart(pattern.getBody()); 282 gen(decl->getBody()); 283 } 284 285 SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) { 286 auto it = variables.begin(varDecl); 287 if (it != variables.end()) 288 return *it; 289 290 // If the variable has an initial value, use that as the base value. 291 // Otherwise, generate a value using the constraint list. 292 SmallVector<Value> values; 293 if (const ast::Expr *initExpr = varDecl->getInitExpr()) 294 values = genExpr(initExpr); 295 else 296 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc()))); 297 298 // Apply the constraints of the values of the variable. 299 applyVarConstraints(varDecl, values); 300 301 variables.insert(varDecl, values); 302 return values; 303 } 304 305 Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl, 306 Location loc) { 307 // A functor used to generate expressions nested 308 auto getTypeConstraint = [&]() -> Value { 309 for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) { 310 Value typeValue = 311 TypeSwitch<const ast::Node *, Value>(constraint.constraint) 312 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 313 ast::ValueRangeConstraintDecl>([&, this](auto *cst) -> Value { 314 if (auto *typeConstraintExpr = cst->getTypeExpr()) 315 return genSingleExpr(typeConstraintExpr); 316 return Value(); 317 }) 318 .Default(Value()); 319 if (typeValue) 320 return typeValue; 321 } 322 return Value(); 323 }; 324 325 // Generate a value based on the type of the variable. 326 ast::Type type = varDecl->getType(); 327 Type mlirType = genType(type); 328 if (type.isa<ast::ValueType>()) 329 return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint()); 330 if (type.isa<ast::TypeType>()) 331 return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr()); 332 if (type.isa<ast::AttributeType>()) 333 return builder.create<pdl::AttributeOp>(loc, getTypeConstraint()); 334 if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) { 335 Value operands = builder.create<pdl::OperandsOp>( 336 loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()), 337 /*type=*/Value()); 338 Value results = builder.create<pdl::TypesOp>( 339 loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()), 340 /*types=*/ArrayAttr()); 341 return builder.create<pdl::OperationOp>(loc, opType.getName(), operands, 342 llvm::None, ValueRange(), results); 343 } 344 345 if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) { 346 ast::Type eleTy = rangeTy.getElementType(); 347 if (eleTy.isa<ast::ValueType>()) 348 return builder.create<pdl::OperandsOp>(loc, mlirType, 349 getTypeConstraint()); 350 if (eleTy.isa<ast::TypeType>()) 351 return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr()); 352 } 353 354 llvm_unreachable("invalid non-initialized variable type"); 355 } 356 357 void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl, 358 ValueRange values) { 359 // Generate calls to any user constraints that were attached via the 360 // constraint list. 361 for (const ast::ConstraintRef &ref : varDecl->getConstraints()) 362 if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint)) 363 genConstraintCall(userCst, genLoc(ref.referenceLoc), values); 364 } 365 366 //===----------------------------------------------------------------------===// 367 // CodeGen: Expressions 368 //===----------------------------------------------------------------------===// 369 370 Value CodeGen::genSingleExpr(const ast::Expr *expr) { 371 return TypeSwitch<const ast::Expr *, Value>(expr) 372 .Case<const ast::AttributeExpr, const ast::MemberAccessExpr, 373 const ast::OperationExpr, const ast::TypeExpr>( 374 [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) 375 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>( 376 [&](auto derivedNode) { 377 SmallVector<Value> results = this->genExprImpl(derivedNode); 378 assert(results.size() == 1 && "expected single expression result"); 379 return results[0]; 380 }); 381 } 382 383 SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) { 384 return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr) 385 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>( 386 [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) 387 .Default([&](const ast::Expr *expr) -> SmallVector<Value> { 388 return {genSingleExpr(expr)}; 389 }); 390 } 391 392 Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) { 393 Attribute attr = parseAttribute(expr->getValue(), builder.getContext()); 394 assert(attr && "invalid MLIR attribute data"); 395 return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr); 396 } 397 398 SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) { 399 Location loc = genLoc(expr->getLoc()); 400 SmallVector<Value> arguments; 401 for (const ast::Expr *arg : expr->getArguments()) 402 arguments.push_back(genSingleExpr(arg)); 403 404 // Resolve the callable expression of this call. 405 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr()); 406 assert(callableExpr && "unhandled CallExpr callable"); 407 408 // Generate the PDL based on the type of callable. 409 const ast::Decl *callable = callableExpr->getDecl(); 410 if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable)) 411 return genConstraintCall(decl, loc, arguments); 412 if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable)) 413 return genRewriteCall(decl, loc, arguments); 414 llvm_unreachable("unhandled CallExpr callable"); 415 } 416 417 SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) { 418 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl())) 419 return genVar(varDecl); 420 llvm_unreachable("unknown decl reference expression"); 421 } 422 423 Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { 424 Location loc = genLoc(expr->getLoc()); 425 StringRef name = expr->getMemberName(); 426 SmallVector<Value> parentExprs = genExpr(expr->getParentExpr()); 427 ast::Type parentType = expr->getParentExpr()->getType(); 428 429 // Handle operation based member access. 430 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 431 if (isa<ast::AllResultsMemberAccessExpr>(expr)) { 432 Type mlirType = genType(expr->getType()); 433 if (mlirType.isa<pdl::ValueType>()) 434 return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0], 435 builder.getI32IntegerAttr(0)); 436 return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]); 437 } 438 llvm_unreachable("unhandled operation member access expression"); 439 } 440 441 // Handle tuple based member access. 442 if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 443 auto elementNames = tupleType.getElementNames(); 444 445 // The index is either a numeric index, or a name. 446 unsigned index = 0; 447 if (llvm::isDigit(name[0])) 448 name.getAsInteger(/*Radix=*/10, index); 449 else 450 index = llvm::find(elementNames, name) - elementNames.begin(); 451 452 assert(index < parentExprs.size() && "invalid result index"); 453 return parentExprs[index]; 454 } 455 456 llvm_unreachable("unhandled member access expression"); 457 } 458 459 Value CodeGen::genExprImpl(const ast::OperationExpr *expr) { 460 Location loc = genLoc(expr->getLoc()); 461 Optional<StringRef> opName = expr->getName(); 462 463 // Operands. 464 SmallVector<Value> operands; 465 for (const ast::Expr *operand : expr->getOperands()) 466 operands.push_back(genSingleExpr(operand)); 467 468 // Attributes. 469 SmallVector<StringRef> attrNames; 470 SmallVector<Value> attrValues; 471 for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) { 472 attrNames.push_back(attr->getName().getName()); 473 attrValues.push_back(genSingleExpr(attr->getValue())); 474 } 475 476 // Results. 477 SmallVector<Value> results; 478 for (const ast::Expr *result : expr->getResultTypes()) 479 results.push_back(genSingleExpr(result)); 480 481 return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames, 482 attrValues, results); 483 } 484 485 SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) { 486 SmallVector<Value> elements; 487 for (const ast::Expr *element : expr->getElements()) 488 elements.push_back(genSingleExpr(element)); 489 return elements; 490 } 491 492 Value CodeGen::genExprImpl(const ast::TypeExpr *expr) { 493 Type type = parseType(expr->getValue(), builder.getContext()); 494 assert(type && "invalid MLIR type data"); 495 return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()), 496 builder.getType<pdl::TypeType>(), 497 TypeAttr::get(type)); 498 } 499 500 SmallVector<Value> 501 CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc, 502 ValueRange inputs) { 503 // Apply any constraints defined on the arguments to the input values. 504 for (auto it : llvm::zip(decl->getInputs(), inputs)) 505 applyVarConstraints(std::get<0>(it), std::get<1>(it)); 506 507 // Generate the constraint call. 508 SmallVector<Value> results = 509 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc, 510 inputs); 511 512 // Apply any constraints defined on the results of the constraint. 513 for (auto it : llvm::zip(decl->getResults(), results)) 514 applyVarConstraints(std::get<0>(it), std::get<1>(it)); 515 return results; 516 } 517 518 SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl, 519 Location loc, ValueRange inputs) { 520 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc, 521 inputs); 522 } 523 524 template <typename PDLOpT, typename T> 525 SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl, 526 Location loc, 527 ValueRange inputs) { 528 const ast::CompoundStmt *cstBody = decl->getBody(); 529 530 // If the decl doesn't have a statement body, it is a native decl. 531 if (!cstBody) { 532 ast::Type declResultType = decl->getResultType(); 533 SmallVector<Type> resultTypes; 534 if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) { 535 for (ast::Type type : tupleType.getElementTypes()) 536 resultTypes.push_back(genType(type)); 537 } else { 538 resultTypes.push_back(genType(declResultType)); 539 } 540 541 // FIXME: We currently do not have a modeling for the "constant params" 542 // support PDL provides. We should either figure out a modeling for this, or 543 // refactor the support within PDL to be something a bit more reasonable for 544 // what we need as a frontend. 545 Operation *pdlOp = builder.create<PDLOpT>(loc, resultTypes, 546 decl->getName().getName(), inputs, 547 /*params=*/ArrayAttr()); 548 return pdlOp->getResults(); 549 } 550 551 // Otherwise, this is a PDLL decl. 552 VariableMapTy::ScopeTy varScope(variables); 553 554 // Map the inputs of the call to the decl arguments. 555 // Note: This is only valid because we do not support recursion, meaning 556 // we don't need to worry about conflicting mappings here. 557 for (auto it : llvm::zip(inputs, decl->getInputs())) 558 variables.insert(std::get<1>(it), {std::get<0>(it)}); 559 560 // Visit the body of the call as normal. 561 gen(cstBody); 562 563 // If the decl has no results, there is nothing to do. 564 if (cstBody->getChildren().empty()) 565 return SmallVector<Value>(); 566 auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back()); 567 if (!returnStmt) 568 return SmallVector<Value>(); 569 570 // Otherwise, grab the results from the return statement. 571 return genExpr(returnStmt->getResultExpr()); 572 } 573 574 //===----------------------------------------------------------------------===// 575 // MLIRGen 576 //===----------------------------------------------------------------------===// 577 578 OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR( 579 MLIRContext *mlirContext, const ast::Context &context, 580 const llvm::SourceMgr &sourceMgr, const ast::Module &module) { 581 CodeGen codegen(mlirContext, context, sourceMgr); 582 OwningOpRef<ModuleOp> mlirModule = codegen.generate(module); 583 if (failed(verify(*mlirModule))) 584 return nullptr; 585 return mlirModule; 586 } 587