1 //===- MLIRGen.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" 10 #include "mlir/Dialect/PDL/IR/PDL.h" 11 #include "mlir/Dialect/PDL/IR/PDLOps.h" 12 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Verifier.h" 16 #include "mlir/Parser/Parser.h" 17 #include "mlir/Tools/PDLL/AST/Context.h" 18 #include "mlir/Tools/PDLL/AST/Nodes.h" 19 #include "mlir/Tools/PDLL/AST/Types.h" 20 #include "mlir/Tools/PDLL/ODS/Context.h" 21 #include "mlir/Tools/PDLL/ODS/Operation.h" 22 #include "llvm/ADT/ScopedHashTable.h" 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::pdll; 28 29 //===----------------------------------------------------------------------===// 30 // CodeGen 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 class CodeGen { 35 public: 36 CodeGen(MLIRContext *mlirContext, const ast::Context &context, 37 const llvm::SourceMgr &sourceMgr) 38 : builder(mlirContext), odsContext(context.getODSContext()), 39 sourceMgr(sourceMgr) { 40 // Make sure that the PDL dialect is loaded. 41 mlirContext->loadDialect<pdl::PDLDialect>(); 42 } 43 44 OwningOpRef<ModuleOp> generate(const ast::Module &module); 45 46 private: 47 /// Generate an MLIR location from the given source location. 48 Location genLoc(llvm::SMLoc loc); 49 Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); } 50 51 /// Generate an MLIR type from the given source type. 52 Type genType(ast::Type type); 53 54 /// Generate MLIR for the given AST node. 55 void gen(const ast::Node *node); 56 57 //===--------------------------------------------------------------------===// 58 // Statements 59 //===--------------------------------------------------------------------===// 60 61 void genImpl(const ast::CompoundStmt *stmt); 62 void genImpl(const ast::EraseStmt *stmt); 63 void genImpl(const ast::LetStmt *stmt); 64 void genImpl(const ast::ReplaceStmt *stmt); 65 void genImpl(const ast::RewriteStmt *stmt); 66 void genImpl(const ast::ReturnStmt *stmt); 67 68 //===--------------------------------------------------------------------===// 69 // Decls 70 //===--------------------------------------------------------------------===// 71 72 void genImpl(const ast::UserConstraintDecl *decl); 73 void genImpl(const ast::UserRewriteDecl *decl); 74 void genImpl(const ast::PatternDecl *decl); 75 76 /// Generate the set of MLIR values defined for the given variable decl, and 77 /// apply any attached constraints. 78 SmallVector<Value> genVar(const ast::VariableDecl *varDecl); 79 80 /// Generate the value for a variable that does not have an initializer 81 /// expression, i.e. create the PDL value based on the type/constraints of the 82 /// variable. 83 Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc); 84 85 /// Apply the constraints of the given variable to `values`, which correspond 86 /// to the MLIR values of the variable. 87 void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values); 88 89 //===--------------------------------------------------------------------===// 90 // Expressions 91 //===--------------------------------------------------------------------===// 92 93 Value genSingleExpr(const ast::Expr *expr); 94 SmallVector<Value> genExpr(const ast::Expr *expr); 95 Value genExprImpl(const ast::AttributeExpr *expr); 96 SmallVector<Value> genExprImpl(const ast::CallExpr *expr); 97 SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr); 98 Value genExprImpl(const ast::MemberAccessExpr *expr); 99 Value genExprImpl(const ast::OperationExpr *expr); 100 SmallVector<Value> genExprImpl(const ast::TupleExpr *expr); 101 Value genExprImpl(const ast::TypeExpr *expr); 102 103 SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl, 104 Location loc, ValueRange inputs); 105 SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl, 106 Location loc, ValueRange inputs); 107 template <typename PDLOpT, typename T> 108 SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc, 109 ValueRange inputs); 110 111 //===--------------------------------------------------------------------===// 112 // Fields 113 //===--------------------------------------------------------------------===// 114 115 /// The MLIR builder used for building the resultant IR. 116 OpBuilder builder; 117 118 /// A map from variable declarations to the MLIR equivalent. 119 using VariableMapTy = 120 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>; 121 VariableMapTy variables; 122 123 /// A reference to the ODS context. 124 const ods::Context &odsContext; 125 126 /// The source manager of the PDLL ast. 127 const llvm::SourceMgr &sourceMgr; 128 }; 129 } // namespace 130 131 OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) { 132 OwningOpRef<ModuleOp> mlirModule = 133 builder.create<ModuleOp>(genLoc(module.getLoc())); 134 builder.setInsertionPointToStart(mlirModule->getBody()); 135 136 // Generate code for each of the decls within the module. 137 for (const ast::Decl *decl : module.getChildren()) 138 gen(decl); 139 140 return mlirModule; 141 } 142 143 Location CodeGen::genLoc(llvm::SMLoc loc) { 144 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc); 145 146 // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can 147 // use it here. 148 auto &bufferInfo = sourceMgr.getBufferInfo(fileID); 149 unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer()); 150 unsigned column = 151 (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1; 152 auto *buffer = sourceMgr.getMemoryBuffer(fileID); 153 154 return FileLineColLoc::get(builder.getContext(), 155 buffer->getBufferIdentifier(), lineNo, column); 156 } 157 158 Type CodeGen::genType(ast::Type type) { 159 return TypeSwitch<ast::Type, Type>(type) 160 .Case([&](ast::AttributeType astType) -> Type { 161 return builder.getType<pdl::AttributeType>(); 162 }) 163 .Case([&](ast::OperationType astType) -> Type { 164 return builder.getType<pdl::OperationType>(); 165 }) 166 .Case([&](ast::TypeType astType) -> Type { 167 return builder.getType<pdl::TypeType>(); 168 }) 169 .Case([&](ast::ValueType astType) -> Type { 170 return builder.getType<pdl::ValueType>(); 171 }) 172 .Case([&](ast::RangeType astType) -> Type { 173 return pdl::RangeType::get(genType(astType.getElementType())); 174 }); 175 } 176 177 void CodeGen::gen(const ast::Node *node) { 178 TypeSwitch<const ast::Node *>(node) 179 .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt, 180 const ast::ReplaceStmt, const ast::RewriteStmt, 181 const ast::ReturnStmt, const ast::UserConstraintDecl, 182 const ast::UserRewriteDecl, const ast::PatternDecl>( 183 [&](auto derivedNode) { this->genImpl(derivedNode); }) 184 .Case([&](const ast::Expr *expr) { genExpr(expr); }); 185 } 186 187 //===----------------------------------------------------------------------===// 188 // CodeGen: Statements 189 //===----------------------------------------------------------------------===// 190 191 void CodeGen::genImpl(const ast::CompoundStmt *stmt) { 192 VariableMapTy::ScopeTy varScope(variables); 193 for (const ast::Stmt *childStmt : stmt->getChildren()) 194 gen(childStmt); 195 } 196 197 /// If the given builder is nested under a PDL PatternOp, build a rewrite 198 /// operation and update the builder to nest under it. This is necessary for 199 /// PDLL operation rewrite statements that are directly nested within a Pattern. 200 static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, 201 Location loc) { 202 if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) { 203 pdl::RewriteOp rewrite = builder.create<pdl::RewriteOp>( 204 loc, rootExpr, /*name=*/StringAttr(), 205 /*externalArgs=*/ValueRange(), /*externalConstParams=*/ArrayAttr()); 206 builder.createBlock(&rewrite.body()); 207 } 208 } 209 210 void CodeGen::genImpl(const ast::EraseStmt *stmt) { 211 OpBuilder::InsertionGuard insertGuard(builder); 212 Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); 213 Location loc = genLoc(stmt->getLoc()); 214 215 // Make sure we are nested in a RewriteOp. 216 OpBuilder::InsertionGuard guard(builder); 217 checkAndNestUnderRewriteOp(builder, rootExpr, loc); 218 builder.create<pdl::EraseOp>(loc, rootExpr); 219 } 220 221 void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); } 222 223 void CodeGen::genImpl(const ast::ReplaceStmt *stmt) { 224 OpBuilder::InsertionGuard insertGuard(builder); 225 Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); 226 Location loc = genLoc(stmt->getLoc()); 227 228 // Make sure we are nested in a RewriteOp. 229 OpBuilder::InsertionGuard guard(builder); 230 checkAndNestUnderRewriteOp(builder, rootExpr, loc); 231 232 SmallVector<Value> replValues; 233 for (ast::Expr *replExpr : stmt->getReplExprs()) 234 replValues.push_back(genSingleExpr(replExpr)); 235 236 // Check to see if the statement has a replacement operation, or a range of 237 // replacement values. 238 bool usesReplOperation = 239 replValues.size() == 1 && 240 replValues.front().getType().isa<pdl::OperationType>(); 241 builder.create<pdl::ReplaceOp>( 242 loc, rootExpr, usesReplOperation ? replValues[0] : Value(), 243 usesReplOperation ? ValueRange() : ValueRange(replValues)); 244 } 245 246 void CodeGen::genImpl(const ast::RewriteStmt *stmt) { 247 OpBuilder::InsertionGuard insertGuard(builder); 248 Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); 249 250 // Make sure we are nested in a RewriteOp. 251 OpBuilder::InsertionGuard guard(builder); 252 checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc())); 253 gen(stmt->getRewriteBody()); 254 } 255 256 void CodeGen::genImpl(const ast::ReturnStmt *stmt) { 257 // ReturnStmt generation is handled by the respective constraint or rewrite 258 // parent node. 259 } 260 261 //===----------------------------------------------------------------------===// 262 // CodeGen: Decls 263 //===----------------------------------------------------------------------===// 264 265 void CodeGen::genImpl(const ast::UserConstraintDecl *decl) { 266 // All PDLL constraints get inlined when called, and the main native 267 // constraint declarations doesn't require any MLIR to be generated, only uses 268 // of it do. 269 } 270 271 void CodeGen::genImpl(const ast::UserRewriteDecl *decl) { 272 // All PDLL rewrites get inlined when called, and the main native 273 // rewrite declarations doesn't require any MLIR to be generated, only uses 274 // of it do. 275 } 276 277 void CodeGen::genImpl(const ast::PatternDecl *decl) { 278 const ast::Name *name = decl->getName(); 279 280 // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it 281 // here. 282 pdl::PatternOp pattern = builder.create<pdl::PatternOp>( 283 genLoc(decl->getLoc()), decl->getBenefit(), 284 name ? Optional<StringRef>(name->getName()) : Optional<StringRef>()); 285 286 OpBuilder::InsertionGuard savedInsertPoint(builder); 287 builder.setInsertionPointToStart(pattern.getBody()); 288 gen(decl->getBody()); 289 } 290 291 SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) { 292 auto it = variables.begin(varDecl); 293 if (it != variables.end()) 294 return *it; 295 296 // If the variable has an initial value, use that as the base value. 297 // Otherwise, generate a value using the constraint list. 298 SmallVector<Value> values; 299 if (const ast::Expr *initExpr = varDecl->getInitExpr()) 300 values = genExpr(initExpr); 301 else 302 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc()))); 303 304 // Apply the constraints of the values of the variable. 305 applyVarConstraints(varDecl, values); 306 307 variables.insert(varDecl, values); 308 return values; 309 } 310 311 Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl, 312 Location loc) { 313 // A functor used to generate expressions nested 314 auto getTypeConstraint = [&]() -> Value { 315 for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) { 316 Value typeValue = 317 TypeSwitch<const ast::Node *, Value>(constraint.constraint) 318 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, 319 ast::ValueRangeConstraintDecl>([&, this](auto *cst) -> Value { 320 if (auto *typeConstraintExpr = cst->getTypeExpr()) 321 return this->genSingleExpr(typeConstraintExpr); 322 return Value(); 323 }) 324 .Default(Value()); 325 if (typeValue) 326 return typeValue; 327 } 328 return Value(); 329 }; 330 331 // Generate a value based on the type of the variable. 332 ast::Type type = varDecl->getType(); 333 Type mlirType = genType(type); 334 if (type.isa<ast::ValueType>()) 335 return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint()); 336 if (type.isa<ast::TypeType>()) 337 return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr()); 338 if (type.isa<ast::AttributeType>()) 339 return builder.create<pdl::AttributeOp>(loc, getTypeConstraint()); 340 if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) { 341 Value operands = builder.create<pdl::OperandsOp>( 342 loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()), 343 /*type=*/Value()); 344 Value results = builder.create<pdl::TypesOp>( 345 loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()), 346 /*types=*/ArrayAttr()); 347 return builder.create<pdl::OperationOp>(loc, opType.getName(), operands, 348 llvm::None, ValueRange(), results); 349 } 350 351 if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) { 352 ast::Type eleTy = rangeTy.getElementType(); 353 if (eleTy.isa<ast::ValueType>()) 354 return builder.create<pdl::OperandsOp>(loc, mlirType, 355 getTypeConstraint()); 356 if (eleTy.isa<ast::TypeType>()) 357 return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr()); 358 } 359 360 llvm_unreachable("invalid non-initialized variable type"); 361 } 362 363 void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl, 364 ValueRange values) { 365 // Generate calls to any user constraints that were attached via the 366 // constraint list. 367 for (const ast::ConstraintRef &ref : varDecl->getConstraints()) 368 if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint)) 369 genConstraintCall(userCst, genLoc(ref.referenceLoc), values); 370 } 371 372 //===----------------------------------------------------------------------===// 373 // CodeGen: Expressions 374 //===----------------------------------------------------------------------===// 375 376 Value CodeGen::genSingleExpr(const ast::Expr *expr) { 377 return TypeSwitch<const ast::Expr *, Value>(expr) 378 .Case<const ast::AttributeExpr, const ast::MemberAccessExpr, 379 const ast::OperationExpr, const ast::TypeExpr>( 380 [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) 381 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>( 382 [&](auto derivedNode) { 383 SmallVector<Value> results = this->genExprImpl(derivedNode); 384 assert(results.size() == 1 && "expected single expression result"); 385 return results[0]; 386 }); 387 } 388 389 SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) { 390 return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr) 391 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>( 392 [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) 393 .Default([&](const ast::Expr *expr) -> SmallVector<Value> { 394 return {genSingleExpr(expr)}; 395 }); 396 } 397 398 Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) { 399 Attribute attr = parseAttribute(expr->getValue(), builder.getContext()); 400 assert(attr && "invalid MLIR attribute data"); 401 return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr); 402 } 403 404 SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) { 405 Location loc = genLoc(expr->getLoc()); 406 SmallVector<Value> arguments; 407 for (const ast::Expr *arg : expr->getArguments()) 408 arguments.push_back(genSingleExpr(arg)); 409 410 // Resolve the callable expression of this call. 411 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr()); 412 assert(callableExpr && "unhandled CallExpr callable"); 413 414 // Generate the PDL based on the type of callable. 415 const ast::Decl *callable = callableExpr->getDecl(); 416 if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable)) 417 return genConstraintCall(decl, loc, arguments); 418 if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable)) 419 return genRewriteCall(decl, loc, arguments); 420 llvm_unreachable("unhandled CallExpr callable"); 421 } 422 423 SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) { 424 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl())) 425 return genVar(varDecl); 426 llvm_unreachable("unknown decl reference expression"); 427 } 428 429 Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { 430 Location loc = genLoc(expr->getLoc()); 431 StringRef name = expr->getMemberName(); 432 SmallVector<Value> parentExprs = genExpr(expr->getParentExpr()); 433 ast::Type parentType = expr->getParentExpr()->getType(); 434 435 // Handle operation based member access. 436 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { 437 if (isa<ast::AllResultsMemberAccessExpr>(expr)) { 438 Type mlirType = genType(expr->getType()); 439 if (mlirType.isa<pdl::ValueType>()) 440 return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0], 441 builder.getI32IntegerAttr(0)); 442 return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]); 443 } 444 445 assert(opType.getName() && "expected valid operation name"); 446 const ods::Operation *odsOp = odsContext.lookupOperation(*opType.getName()); 447 assert(odsOp && "expected valid ODS operation information"); 448 449 // Find the result with the member name or by index. 450 ArrayRef<ods::OperandOrResult> results = odsOp->getResults(); 451 unsigned resultIndex = results.size(); 452 if (llvm::isDigit(name[0])) { 453 name.getAsInteger(/*Radix=*/10, resultIndex); 454 } else { 455 auto findFn = [&](const ods::OperandOrResult &result) { 456 return result.getName() == name; 457 }; 458 resultIndex = llvm::find_if(results, findFn) - results.begin(); 459 } 460 assert(resultIndex < results.size() && "invalid result index"); 461 462 // Generate the result access. 463 IntegerAttr index = builder.getI32IntegerAttr(resultIndex); 464 return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()), 465 parentExprs[0], index); 466 } 467 468 // Handle tuple based member access. 469 if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { 470 auto elementNames = tupleType.getElementNames(); 471 472 // The index is either a numeric index, or a name. 473 unsigned index = 0; 474 if (llvm::isDigit(name[0])) 475 name.getAsInteger(/*Radix=*/10, index); 476 else 477 index = llvm::find(elementNames, name) - elementNames.begin(); 478 479 assert(index < parentExprs.size() && "invalid result index"); 480 return parentExprs[index]; 481 } 482 483 llvm_unreachable("unhandled member access expression"); 484 } 485 486 Value CodeGen::genExprImpl(const ast::OperationExpr *expr) { 487 Location loc = genLoc(expr->getLoc()); 488 Optional<StringRef> opName = expr->getName(); 489 490 // Operands. 491 SmallVector<Value> operands; 492 for (const ast::Expr *operand : expr->getOperands()) 493 operands.push_back(genSingleExpr(operand)); 494 495 // Attributes. 496 SmallVector<StringRef> attrNames; 497 SmallVector<Value> attrValues; 498 for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) { 499 attrNames.push_back(attr->getName().getName()); 500 attrValues.push_back(genSingleExpr(attr->getValue())); 501 } 502 503 // Results. 504 SmallVector<Value> results; 505 for (const ast::Expr *result : expr->getResultTypes()) 506 results.push_back(genSingleExpr(result)); 507 508 return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames, 509 attrValues, results); 510 } 511 512 SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) { 513 SmallVector<Value> elements; 514 for (const ast::Expr *element : expr->getElements()) 515 elements.push_back(genSingleExpr(element)); 516 return elements; 517 } 518 519 Value CodeGen::genExprImpl(const ast::TypeExpr *expr) { 520 Type type = parseType(expr->getValue(), builder.getContext()); 521 assert(type && "invalid MLIR type data"); 522 return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()), 523 builder.getType<pdl::TypeType>(), 524 TypeAttr::get(type)); 525 } 526 527 SmallVector<Value> 528 CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc, 529 ValueRange inputs) { 530 // Apply any constraints defined on the arguments to the input values. 531 for (auto it : llvm::zip(decl->getInputs(), inputs)) 532 applyVarConstraints(std::get<0>(it), std::get<1>(it)); 533 534 // Generate the constraint call. 535 SmallVector<Value> results = 536 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc, 537 inputs); 538 539 // Apply any constraints defined on the results of the constraint. 540 for (auto it : llvm::zip(decl->getResults(), results)) 541 applyVarConstraints(std::get<0>(it), std::get<1>(it)); 542 return results; 543 } 544 545 SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl, 546 Location loc, ValueRange inputs) { 547 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc, 548 inputs); 549 } 550 551 template <typename PDLOpT, typename T> 552 SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl, 553 Location loc, 554 ValueRange inputs) { 555 const ast::CompoundStmt *cstBody = decl->getBody(); 556 557 // If the decl doesn't have a statement body, it is a native decl. 558 if (!cstBody) { 559 ast::Type declResultType = decl->getResultType(); 560 SmallVector<Type> resultTypes; 561 if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) { 562 for (ast::Type type : tupleType.getElementTypes()) 563 resultTypes.push_back(genType(type)); 564 } else { 565 resultTypes.push_back(genType(declResultType)); 566 } 567 568 // FIXME: We currently do not have a modeling for the "constant params" 569 // support PDL provides. We should either figure out a modeling for this, or 570 // refactor the support within PDL to be something a bit more reasonable for 571 // what we need as a frontend. 572 Operation *pdlOp = builder.create<PDLOpT>(loc, resultTypes, 573 decl->getName().getName(), inputs, 574 /*params=*/ArrayAttr()); 575 return pdlOp->getResults(); 576 } 577 578 // Otherwise, this is a PDLL decl. 579 VariableMapTy::ScopeTy varScope(variables); 580 581 // Map the inputs of the call to the decl arguments. 582 // Note: This is only valid because we do not support recursion, meaning 583 // we don't need to worry about conflicting mappings here. 584 for (auto it : llvm::zip(inputs, decl->getInputs())) 585 variables.insert(std::get<1>(it), {std::get<0>(it)}); 586 587 // Visit the body of the call as normal. 588 gen(cstBody); 589 590 // If the decl has no results, there is nothing to do. 591 if (cstBody->getChildren().empty()) 592 return SmallVector<Value>(); 593 auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back()); 594 if (!returnStmt) 595 return SmallVector<Value>(); 596 597 // Otherwise, grab the results from the return statement. 598 return genExpr(returnStmt->getResultExpr()); 599 } 600 601 //===----------------------------------------------------------------------===// 602 // MLIRGen 603 //===----------------------------------------------------------------------===// 604 605 OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR( 606 MLIRContext *mlirContext, const ast::Context &context, 607 const llvm::SourceMgr &sourceMgr, const ast::Module &module) { 608 CodeGen codegen(mlirContext, context, sourceMgr); 609 OwningOpRef<ModuleOp> mlirModule = codegen.generate(module); 610 if (failed(verify(*mlirModule))) 611 return nullptr; 612 return mlirModule; 613 } 614