1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 #include "../PassDetail.h" 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::pdl_to_pdl_interp; 25 26 //===----------------------------------------------------------------------===// 27 // PatternLowering 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 /// This class generators operations within the PDL Interpreter dialect from a 32 /// given module containing PDL pattern operations. 33 struct PatternLowering { 34 public: 35 PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule); 36 37 /// Generate code for matching and rewriting based on the pattern operations 38 /// within the module. 39 void lower(ModuleOp module); 40 41 private: 42 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 43 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 44 45 /// Generate interpreter operations for the tree rooted at the given matcher 46 /// node. 47 Block *generateMatcher(MatcherNode &node); 48 49 /// Get or create an access to the provided positional value within the 50 /// current block. 51 Value getValueAt(Block *cur, Position *pos); 52 53 /// Create an interpreter predicate operation, branching to the provided true 54 /// and false destinations. 55 void generatePredicate(Block *currentBlock, Qualifier *question, 56 Qualifier *answer, Value val, Block *trueDest, 57 Block *falseDest); 58 59 /// Create an interpreter switch predicate operation, with a provided default 60 /// and several case destinations. 61 void generateSwitch(SwitchNode *switchNode, Block *currentBlock, 62 Qualifier *question, Value val, Block *defaultDest); 63 64 /// Create the interpreter operations to record a successful pattern match. 65 void generateRecordMatch(Block *currentBlock, Block *nextBlock, 66 pdl::PatternOp pattern); 67 68 /// Generate a rewriter function for the given pattern operation, and returns 69 /// a reference to that function. 70 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 71 SmallVectorImpl<Position *> &usedMatchValues); 72 73 /// Generate the rewriter code for the given operation. 74 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 75 DenseMap<Value, Value> &rewriteValues, 76 function_ref<Value(Value)> mapRewriteValue); 77 void generateRewriter(pdl::AttributeOp attrOp, 78 DenseMap<Value, Value> &rewriteValues, 79 function_ref<Value(Value)> mapRewriteValue); 80 void generateRewriter(pdl::EraseOp eraseOp, 81 DenseMap<Value, Value> &rewriteValues, 82 function_ref<Value(Value)> mapRewriteValue); 83 void generateRewriter(pdl::OperationOp operationOp, 84 DenseMap<Value, Value> &rewriteValues, 85 function_ref<Value(Value)> mapRewriteValue); 86 void generateRewriter(pdl::ReplaceOp replaceOp, 87 DenseMap<Value, Value> &rewriteValues, 88 function_ref<Value(Value)> mapRewriteValue); 89 void generateRewriter(pdl::ResultOp resultOp, 90 DenseMap<Value, Value> &rewriteValues, 91 function_ref<Value(Value)> mapRewriteValue); 92 void generateRewriter(pdl::ResultsOp resultOp, 93 DenseMap<Value, Value> &rewriteValues, 94 function_ref<Value(Value)> mapRewriteValue); 95 void generateRewriter(pdl::TypeOp typeOp, 96 DenseMap<Value, Value> &rewriteValues, 97 function_ref<Value(Value)> mapRewriteValue); 98 void generateRewriter(pdl::TypesOp typeOp, 99 DenseMap<Value, Value> &rewriteValues, 100 function_ref<Value(Value)> mapRewriteValue); 101 102 /// Generate the values used for resolving the result types of an operation 103 /// created within a dag rewriter region. 104 void generateOperationResultTypeRewriter( 105 pdl::OperationOp op, SmallVectorImpl<Value> &types, 106 DenseMap<Value, Value> &rewriteValues, 107 function_ref<Value(Value)> mapRewriteValue); 108 109 /// A builder to use when generating interpreter operations. 110 OpBuilder builder; 111 112 /// The matcher function used for all match related logic within PDL patterns. 113 FuncOp matcherFunc; 114 115 /// The rewriter module containing the all rewrite related logic within PDL 116 /// patterns. 117 ModuleOp rewriterModule; 118 119 /// The symbol table of the rewriter module used for insertion. 120 SymbolTable rewriterSymbolTable; 121 122 /// A scoped map connecting a position with the corresponding interpreter 123 /// value. 124 ValueMap values; 125 126 /// A stack of blocks used as the failure destination for matcher nodes that 127 /// don't have an explicit failure path. 128 SmallVector<Block *, 8> failureBlockStack; 129 130 /// A mapping between values defined in a pattern match, and the corresponding 131 /// positional value. 132 DenseMap<Value, Position *> valueToPosition; 133 134 /// The set of operation values whose whose location will be used for newly 135 /// generated operations. 136 SetVector<Value> locOps; 137 }; 138 } // end anonymous namespace 139 140 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule) 141 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 142 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} 143 144 void PatternLowering::lower(ModuleOp module) { 145 PredicateUniquer predicateUniquer; 146 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 147 148 // Define top-level scope for the arguments to the matcher function. 149 ValueMapScope topLevelValueScope(values); 150 151 // Insert the root operation, i.e. argument to the matcher, at the root 152 // position. 153 Block *matcherEntryBlock = matcherFunc.addEntryBlock(); 154 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 155 156 // Generate a root matcher node from the provided PDL module. 157 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 158 module, predicateBuilder, valueToPosition); 159 Block *firstMatcherBlock = generateMatcher(*root); 160 161 // After generation, merged the first matched block into the entry. 162 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 163 firstMatcherBlock->getOperations()); 164 firstMatcherBlock->erase(); 165 } 166 167 Block *PatternLowering::generateMatcher(MatcherNode &node) { 168 // Push a new scope for the values used by this matcher. 169 Block *block = matcherFunc.addBlock(); 170 ValueMapScope scope(values); 171 172 // If this is the return node, simply insert the corresponding interpreter 173 // finalize. 174 if (isa<ExitNode>(node)) { 175 builder.setInsertionPointToEnd(block); 176 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 177 return block; 178 } 179 180 // If this node contains a position, get the corresponding value for this 181 // block. 182 Position *position = node.getPosition(); 183 Value val = position ? getValueAt(block, position) : Value(); 184 185 // Get the next block in the match sequence. 186 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 187 Block *nextBlock; 188 if (failureNode) { 189 nextBlock = generateMatcher(*failureNode); 190 failureBlockStack.push_back(nextBlock); 191 } else { 192 assert(!failureBlockStack.empty() && "expected valid failure block"); 193 nextBlock = failureBlockStack.back(); 194 } 195 196 // If this value corresponds to an operation, record that we are going to use 197 // its location as part of a fused location. 198 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 199 if (isOperationValue) 200 locOps.insert(val); 201 202 // Generate code for a boolean predicate node. 203 if (auto *boolNode = dyn_cast<BoolNode>(&node)) { 204 auto *child = generateMatcher(*boolNode->getSuccessNode()); 205 generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val, 206 child, nextBlock); 207 208 // Generate code for a switch node. 209 } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) { 210 generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock); 211 212 // Generate code for a success node. 213 } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) { 214 generateRecordMatch(block, nextBlock, successNode->getPattern()); 215 } 216 217 if (failureNode) 218 failureBlockStack.pop_back(); 219 if (isOperationValue) 220 locOps.remove(val); 221 return block; 222 } 223 224 Value PatternLowering::getValueAt(Block *cur, Position *pos) { 225 if (Value val = values.lookup(pos)) 226 return val; 227 228 // Get the value for the parent position. 229 Value parentVal = getValueAt(cur, pos->getParent()); 230 231 // TODO: Use a location from the position. 232 Location loc = parentVal.getLoc(); 233 builder.setInsertionPointToEnd(cur); 234 Value value; 235 switch (pos->getKind()) { 236 case Predicates::OperationPos: 237 value = builder.create<pdl_interp::GetDefiningOpOp>( 238 loc, builder.getType<pdl::OperationType>(), parentVal); 239 break; 240 case Predicates::OperandPos: { 241 auto *operandPos = cast<OperandPosition>(pos); 242 value = builder.create<pdl_interp::GetOperandOp>( 243 loc, builder.getType<pdl::ValueType>(), parentVal, 244 operandPos->getOperandNumber()); 245 break; 246 } 247 case Predicates::OperandGroupPos: { 248 auto *operandPos = cast<OperandGroupPosition>(pos); 249 Type valueTy = builder.getType<pdl::ValueType>(); 250 value = builder.create<pdl_interp::GetOperandsOp>( 251 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 252 parentVal, operandPos->getOperandGroupNumber()); 253 break; 254 } 255 case Predicates::AttributePos: { 256 auto *attrPos = cast<AttributePosition>(pos); 257 value = builder.create<pdl_interp::GetAttributeOp>( 258 loc, builder.getType<pdl::AttributeType>(), parentVal, 259 attrPos->getName().strref()); 260 break; 261 } 262 case Predicates::TypePos: { 263 if (parentVal.getType().isa<pdl::AttributeType>()) 264 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 265 else 266 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 267 break; 268 } 269 case Predicates::ResultPos: { 270 auto *resPos = cast<ResultPosition>(pos); 271 value = builder.create<pdl_interp::GetResultOp>( 272 loc, builder.getType<pdl::ValueType>(), parentVal, 273 resPos->getResultNumber()); 274 break; 275 } 276 case Predicates::ResultGroupPos: { 277 auto *resPos = cast<ResultGroupPosition>(pos); 278 Type valueTy = builder.getType<pdl::ValueType>(); 279 value = builder.create<pdl_interp::GetResultsOp>( 280 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 281 parentVal, resPos->getResultGroupNumber()); 282 break; 283 } 284 default: 285 llvm_unreachable("Generating unknown Position getter"); 286 break; 287 } 288 values.insert(pos, value); 289 return value; 290 } 291 292 void PatternLowering::generatePredicate(Block *currentBlock, 293 Qualifier *question, Qualifier *answer, 294 Value val, Block *trueDest, 295 Block *falseDest) { 296 builder.setInsertionPointToEnd(currentBlock); 297 Location loc = val.getLoc(); 298 Predicates::Kind kind = question->getKind(); 299 switch (kind) { 300 case Predicates::IsNotNullQuestion: 301 builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest); 302 break; 303 case Predicates::OperationNameQuestion: { 304 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 305 builder.create<pdl_interp::CheckOperationNameOp>( 306 loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest); 307 break; 308 } 309 case Predicates::TypeQuestion: { 310 auto *ans = cast<TypeAnswer>(answer); 311 if (val.getType().isa<pdl::RangeType>()) 312 builder.create<pdl_interp::CheckTypesOp>( 313 loc, val, ans->getValue().cast<ArrayAttr>(), trueDest, falseDest); 314 else 315 builder.create<pdl_interp::CheckTypeOp>( 316 loc, val, ans->getValue().cast<TypeAttr>(), trueDest, falseDest); 317 break; 318 } 319 case Predicates::AttributeQuestion: { 320 auto *ans = cast<AttributeAnswer>(answer); 321 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 322 trueDest, falseDest); 323 break; 324 } 325 case Predicates::OperandCountAtLeastQuestion: 326 case Predicates::OperandCountQuestion: 327 builder.create<pdl_interp::CheckOperandCountOp>( 328 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 329 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 330 trueDest, falseDest); 331 break; 332 case Predicates::ResultCountAtLeastQuestion: 333 case Predicates::ResultCountQuestion: 334 builder.create<pdl_interp::CheckResultCountOp>( 335 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 336 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 337 trueDest, falseDest); 338 break; 339 case Predicates::EqualToQuestion: { 340 auto *equalToQuestion = cast<EqualToQuestion>(question); 341 builder.create<pdl_interp::AreEqualOp>( 342 loc, val, getValueAt(currentBlock, equalToQuestion->getValue()), 343 trueDest, falseDest); 344 break; 345 } 346 case Predicates::ConstraintQuestion: { 347 auto *cstQuestion = cast<ConstraintQuestion>(question); 348 SmallVector<Value, 2> args; 349 for (Position *position : std::get<1>(cstQuestion->getValue())) 350 args.push_back(getValueAt(currentBlock, position)); 351 builder.create<pdl_interp::ApplyConstraintOp>( 352 loc, std::get<0>(cstQuestion->getValue()), args, 353 std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest, 354 falseDest); 355 break; 356 } 357 default: 358 llvm_unreachable("Generating unknown Predicate operation"); 359 } 360 } 361 362 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 363 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 364 llvm::MapVector<Qualifier *, Block *> &dests) { 365 std::vector<ValT> values; 366 std::vector<Block *> blocks; 367 values.reserve(dests.size()); 368 blocks.reserve(dests.size()); 369 for (const auto &it : dests) { 370 blocks.push_back(it.second); 371 values.push_back(cast<PredT>(it.first)->getValue()); 372 } 373 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 374 } 375 376 void PatternLowering::generateSwitch(SwitchNode *switchNode, 377 Block *currentBlock, Qualifier *question, 378 Value val, Block *defaultDest) { 379 // If the switch question is not an exact answer, i.e. for the `at_least` 380 // cases, we generate a special block sequence. 381 Predicates::Kind kind = question->getKind(); 382 if (kind == Predicates::OperandCountAtLeastQuestion || 383 kind == Predicates::ResultCountAtLeastQuestion) { 384 // Order the children such that the cases are in reverse numerical order. 385 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( 386 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 387 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 388 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 389 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 390 }); 391 392 // Build the destination for each child using the next highest child as a 393 // a failure destination. This essentially creates the following control 394 // flow: 395 // 396 // if (operand_count < 1) 397 // goto failure 398 // if (child1.match()) 399 // ... 400 // 401 // if (operand_count < 2) 402 // goto failure 403 // if (child2.match()) 404 // ... 405 // 406 // failure: 407 // ... 408 // 409 failureBlockStack.push_back(defaultDest); 410 for (unsigned idx : sortedChildren) { 411 auto &child = switchNode->getChild(idx); 412 Block *childBlock = generateMatcher(*child.second); 413 Block *predicateBlock = builder.createBlock(childBlock); 414 generatePredicate(predicateBlock, question, child.first, val, childBlock, 415 defaultDest); 416 failureBlockStack.back() = predicateBlock; 417 } 418 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 419 currentBlock->getOperations().splice(currentBlock->end(), 420 firstPredicateBlock->getOperations()); 421 firstPredicateBlock->erase(); 422 return; 423 } 424 425 // Otherwise, generate each of the children and generate an interpreter 426 // switch. 427 llvm::MapVector<Qualifier *, Block *> children; 428 for (auto &it : switchNode->getChildren()) 429 children.insert({it.first, generateMatcher(*it.second)}); 430 builder.setInsertionPointToEnd(currentBlock); 431 432 switch (question->getKind()) { 433 case Predicates::OperandCountQuestion: 434 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 435 int32_t>(val, defaultDest, builder, children); 436 case Predicates::ResultCountQuestion: 437 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 438 int32_t>(val, defaultDest, builder, children); 439 case Predicates::OperationNameQuestion: 440 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 441 OperationNameAnswer>(val, defaultDest, builder, 442 children); 443 case Predicates::TypeQuestion: 444 if (val.getType().isa<pdl::RangeType>()) { 445 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 446 val, defaultDest, builder, children); 447 } 448 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 449 val, defaultDest, builder, children); 450 case Predicates::AttributeQuestion: 451 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 452 val, defaultDest, builder, children); 453 default: 454 llvm_unreachable("Generating unknown switch predicate."); 455 } 456 } 457 458 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock, 459 pdl::PatternOp pattern) { 460 // Generate a rewriter for the pattern this success node represents, and track 461 // any values used from the match region. 462 SmallVector<Position *, 8> usedMatchValues; 463 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 464 465 // Process any values used in the rewrite that are defined in the match. 466 std::vector<Value> mappedMatchValues; 467 mappedMatchValues.reserve(usedMatchValues.size()); 468 for (Position *position : usedMatchValues) 469 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 470 471 // Collect the set of operations generated by the rewriter. 472 SmallVector<StringRef, 4> generatedOps; 473 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>()) 474 generatedOps.push_back(*op.name()); 475 ArrayAttr generatedOpsAttr; 476 if (!generatedOps.empty()) 477 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 478 479 // Grab the root kind if present. 480 StringAttr rootKindAttr; 481 if (Optional<StringRef> rootKind = pattern.getRootKind()) 482 rootKindAttr = builder.getStringAttr(*rootKind); 483 484 builder.setInsertionPointToEnd(currentBlock); 485 builder.create<pdl_interp::RecordMatchOp>( 486 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 487 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), 488 nextBlock); 489 } 490 491 SymbolRefAttr PatternLowering::generateRewriter( 492 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 493 FuncOp rewriterFunc = 494 FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter", 495 builder.getFunctionType(llvm::None, llvm::None)); 496 rewriterSymbolTable.insert(rewriterFunc); 497 498 // Generate the rewriter function body. 499 builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock()); 500 501 // Map an input operand of the pattern to a generated interpreter value. 502 DenseMap<Value, Value> rewriteValues; 503 auto mapRewriteValue = [&](Value oldValue) { 504 Value &newValue = rewriteValues[oldValue]; 505 if (newValue) 506 return newValue; 507 508 // Prefer materializing constants directly when possible. 509 Operation *oldOp = oldValue.getDefiningOp(); 510 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 511 if (Attribute value = attrOp.valueAttr()) { 512 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 513 attrOp.getLoc(), value); 514 } 515 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 516 if (TypeAttr type = typeOp.typeAttr()) { 517 return newValue = builder.create<pdl_interp::CreateTypeOp>( 518 typeOp.getLoc(), type); 519 } 520 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 521 if (ArrayAttr type = typeOp.typesAttr()) { 522 return newValue = builder.create<pdl_interp::CreateTypesOp>( 523 typeOp.getLoc(), typeOp.getType(), type); 524 } 525 } 526 527 // Otherwise, add this as an input to the rewriter. 528 Position *inputPos = valueToPosition.lookup(oldValue); 529 assert(inputPos && "expected value to be a pattern input"); 530 usedMatchValues.push_back(inputPos); 531 return newValue = rewriterFunc.front().addArgument(oldValue.getType()); 532 }; 533 534 // If this is a custom rewriter, simply dispatch to the registered rewrite 535 // method. 536 pdl::RewriteOp rewriter = pattern.getRewriter(); 537 if (StringAttr rewriteName = rewriter.nameAttr()) { 538 auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); 539 SmallVector<Value, 4> args(1, mapRewriteValue(rewriter.root())); 540 args.append(mappedArgs.begin(), mappedArgs.end()); 541 builder.create<pdl_interp::ApplyRewriteOp>( 542 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, 543 rewriter.externalConstParamsAttr()); 544 } else { 545 // Otherwise this is a dag rewriter defined using PDL operations. 546 for (Operation &rewriteOp : *rewriter.getBody()) { 547 llvm::TypeSwitch<Operation *>(&rewriteOp) 548 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 549 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, 550 pdl::TypeOp, pdl::TypesOp>([&](auto op) { 551 this->generateRewriter(op, rewriteValues, mapRewriteValue); 552 }); 553 } 554 } 555 556 // Update the signature of the rewrite function. 557 rewriterFunc.setType(builder.getFunctionType( 558 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 559 /*results=*/llvm::None)); 560 561 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 562 return builder.getSymbolRefAttr( 563 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 564 builder.getSymbolRefAttr(rewriterFunc)); 565 } 566 567 void PatternLowering::generateRewriter( 568 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 569 function_ref<Value(Value)> mapRewriteValue) { 570 SmallVector<Value, 2> arguments; 571 for (Value argument : rewriteOp.args()) 572 arguments.push_back(mapRewriteValue(argument)); 573 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 574 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), 575 arguments, rewriteOp.constParamsAttr()); 576 for (auto it : llvm::zip(rewriteOp.results(), interpOp.results())) 577 rewriteValues[std::get<0>(it)] = std::get<1>(it); 578 } 579 580 void PatternLowering::generateRewriter( 581 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 582 function_ref<Value(Value)> mapRewriteValue) { 583 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 584 attrOp.getLoc(), attrOp.valueAttr()); 585 rewriteValues[attrOp] = newAttr; 586 } 587 588 void PatternLowering::generateRewriter( 589 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 590 function_ref<Value(Value)> mapRewriteValue) { 591 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 592 mapRewriteValue(eraseOp.operation())); 593 } 594 595 void PatternLowering::generateRewriter( 596 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 597 function_ref<Value(Value)> mapRewriteValue) { 598 SmallVector<Value, 4> operands; 599 for (Value operand : operationOp.operands()) 600 operands.push_back(mapRewriteValue(operand)); 601 602 SmallVector<Value, 4> attributes; 603 for (Value attr : operationOp.attributes()) 604 attributes.push_back(mapRewriteValue(attr)); 605 606 SmallVector<Value, 2> types; 607 generateOperationResultTypeRewriter(operationOp, types, rewriteValues, 608 mapRewriteValue); 609 610 // Create the new operation. 611 Location loc = operationOp.getLoc(); 612 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 613 loc, *operationOp.name(), types, operands, attributes, 614 operationOp.attributeNames()); 615 rewriteValues[operationOp.op()] = createdOp; 616 617 // Generate accesses for any results that have their types constrained. 618 // Handle the case where there is a single range representing all of the 619 // result types. 620 OperandRange resultTys = operationOp.types(); 621 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { 622 Value &type = rewriteValues[resultTys[0]]; 623 if (!type) { 624 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 625 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 626 } 627 return; 628 } 629 630 // Otherwise, populate the individual results. 631 bool seenVariableLength = false; 632 Type valueTy = builder.getType<pdl::ValueType>(); 633 Type valueRangeTy = pdl::RangeType::get(valueTy); 634 for (auto it : llvm::enumerate(resultTys)) { 635 Value &type = rewriteValues[it.value()]; 636 if (type) 637 continue; 638 bool isVariadic = it.value().getType().isa<pdl::RangeType>(); 639 seenVariableLength |= isVariadic; 640 641 // After a variable length result has been seen, we need to use result 642 // groups because the exact index of the result is not statically known. 643 Value resultVal; 644 if (seenVariableLength) 645 resultVal = builder.create<pdl_interp::GetResultsOp>( 646 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 647 else 648 resultVal = builder.create<pdl_interp::GetResultOp>( 649 loc, valueTy, createdOp, it.index()); 650 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 651 } 652 } 653 654 void PatternLowering::generateRewriter( 655 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 656 function_ref<Value(Value)> mapRewriteValue) { 657 SmallVector<Value, 4> replOperands; 658 659 // If the replacement was another operation, get its results. `pdl` allows 660 // for using an operation for simplicitly, but the interpreter isn't as 661 // user facing. 662 if (Value replOp = replaceOp.replOperation()) { 663 // Don't use replace if we know the replaced operation has no results. 664 auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>(); 665 if (!opOp || !opOp.types().empty()) { 666 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 667 replOp.getLoc(), mapRewriteValue(replOp))); 668 } 669 } else { 670 for (Value operand : replaceOp.replValues()) 671 replOperands.push_back(mapRewriteValue(operand)); 672 } 673 674 // If there are no replacement values, just create an erase instead. 675 if (replOperands.empty()) { 676 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(), 677 mapRewriteValue(replaceOp.operation())); 678 return; 679 } 680 681 builder.create<pdl_interp::ReplaceOp>( 682 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); 683 } 684 685 void PatternLowering::generateRewriter( 686 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 687 function_ref<Value(Value)> mapRewriteValue) { 688 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 689 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 690 mapRewriteValue(resultOp.parent()), resultOp.index()); 691 } 692 693 void PatternLowering::generateRewriter( 694 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 695 function_ref<Value(Value)> mapRewriteValue) { 696 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 697 resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()), 698 resultOp.index()); 699 } 700 701 void PatternLowering::generateRewriter( 702 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 703 function_ref<Value(Value)> mapRewriteValue) { 704 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 705 // type. 706 if (TypeAttr typeAttr = typeOp.typeAttr()) { 707 rewriteValues[typeOp] = 708 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 709 } 710 } 711 712 void PatternLowering::generateRewriter( 713 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 714 function_ref<Value(Value)> mapRewriteValue) { 715 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 716 // type. 717 if (ArrayAttr typeAttr = typeOp.typesAttr()) { 718 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 719 typeOp.getLoc(), typeOp.getType(), typeAttr); 720 } 721 } 722 723 void PatternLowering::generateOperationResultTypeRewriter( 724 pdl::OperationOp op, SmallVectorImpl<Value> &types, 725 DenseMap<Value, Value> &rewriteValues, 726 function_ref<Value(Value)> mapRewriteValue) { 727 // Look for an operation that was replaced by `op`. The result types will be 728 // inferred from the results that were replaced. 729 Block *rewriterBlock = op->getBlock(); 730 Value replacedOp; 731 for (OpOperand &use : op.op().getUses()) { 732 // Check that the use corresponds to a ReplaceOp and that it is the 733 // replacement value, not the operation being replaced. 734 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 735 if (!replOpUser || use.getOperandNumber() == 0) 736 continue; 737 // Make sure the replaced operation was defined before this one. 738 Value replOpVal = replOpUser.operation(); 739 Operation *replacedOp = replOpVal.getDefiningOp(); 740 if (replacedOp->getBlock() == rewriterBlock && 741 !replacedOp->isBeforeInBlock(op)) 742 continue; 743 744 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 745 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 746 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 747 replacedOp->getLoc(), replacedOpResults)); 748 return; 749 } 750 751 // Check if the operation has type inference support. 752 if (op.hasTypeInference()) { 753 types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc())); 754 return; 755 } 756 757 // Otherwise, handle inference for each of the result types individually. 758 OperandRange resultTypeValues = op.types(); 759 types.reserve(resultTypeValues.size()); 760 for (auto it : llvm::enumerate(resultTypeValues)) { 761 Value resultType = it.value(); 762 763 // Check for an already translated value. 764 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 765 types.push_back(existingRewriteValue); 766 continue; 767 } 768 769 // Check for an input from the matcher. 770 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 771 types.push_back(mapRewriteValue(resultType)); 772 continue; 773 } 774 775 // The verifier asserts that the result types of each pdl.operation can be 776 // inferred. If we reach here, there is a bug either in the logic above or 777 // in the verifier for pdl.operation. 778 op->emitOpError() << "unable to infer result type for operation"; 779 llvm_unreachable("unable to infer result type for operation"); 780 } 781 } 782 783 //===----------------------------------------------------------------------===// 784 // Conversion Pass 785 //===----------------------------------------------------------------------===// 786 787 namespace { 788 struct PDLToPDLInterpPass 789 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 790 void runOnOperation() final; 791 }; 792 } // namespace 793 794 /// Convert the given module containing PDL pattern operations into a PDL 795 /// Interpreter operations. 796 void PDLToPDLInterpPass::runOnOperation() { 797 ModuleOp module = getOperation(); 798 799 // Create the main matcher function This function contains all of the match 800 // related functionality from patterns in the module. 801 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 802 FuncOp matcherFunc = builder.create<FuncOp>( 803 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 804 builder.getFunctionType(builder.getType<pdl::OperationType>(), 805 /*results=*/llvm::None), 806 /*attrs=*/llvm::None); 807 808 // Create a nested module to hold the functions invoked for rewriting the IR 809 // after a successful match. 810 ModuleOp rewriterModule = builder.create<ModuleOp>( 811 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 812 813 // Generate the code for the patterns within the module. 814 PatternLowering generator(matcherFunc, rewriterModule); 815 generator.lower(module); 816 817 // After generation, delete all of the pattern operations. 818 for (pdl::PatternOp pattern : 819 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) 820 pattern.erase(); 821 } 822 823 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 824 return std::make_unique<PDLToPDLInterpPass>(); 825 } 826