1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 #include "../PassDetail.h" 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::pdl_to_pdl_interp; 25 26 //===----------------------------------------------------------------------===// 27 // PatternLowering 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 /// This class generators operations within the PDL Interpreter dialect from a 32 /// given module containing PDL pattern operations. 33 struct PatternLowering { 34 public: 35 PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule); 36 37 /// Generate code for matching and rewriting based on the pattern operations 38 /// within the module. 39 void lower(ModuleOp module); 40 41 private: 42 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 43 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 44 45 /// Generate interpreter operations for the tree rooted at the given matcher 46 /// node, in the specified region. 47 Block *generateMatcher(MatcherNode &node, Region ®ion); 48 49 /// Get or create an access to the provided positional value in the current 50 /// block. This operation may mutate the provided block pointer if nested 51 /// regions (i.e., pdl_interp.iterate) are required. 52 Value getValueAt(Block *¤tBlock, Position *pos); 53 54 /// Create the interpreter predicate operations. This operation may mutate the 55 /// provided current block pointer if nested regions (iterates) are required. 56 void generate(BoolNode *boolNode, Block *¤tBlock, Value val); 57 58 /// Create the interpreter switch / predicate operations, with several case 59 /// destinations. This operation never mutates the provided current block 60 /// pointer, because the switch operation does not need Values beyond `val`. 61 void generate(SwitchNode *switchNode, Block *currentBlock, Value val); 62 63 /// Create the interpreter operations to record a successful pattern match 64 /// using the contained root operation. This operation may mutate the current 65 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. 66 void generate(SuccessNode *successNode, Block *¤tBlock); 67 68 /// Generate a rewriter function for the given pattern operation, and returns 69 /// a reference to that function. 70 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 71 SmallVectorImpl<Position *> &usedMatchValues); 72 73 /// Generate the rewriter code for the given operation. 74 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 75 DenseMap<Value, Value> &rewriteValues, 76 function_ref<Value(Value)> mapRewriteValue); 77 void generateRewriter(pdl::AttributeOp attrOp, 78 DenseMap<Value, Value> &rewriteValues, 79 function_ref<Value(Value)> mapRewriteValue); 80 void generateRewriter(pdl::EraseOp eraseOp, 81 DenseMap<Value, Value> &rewriteValues, 82 function_ref<Value(Value)> mapRewriteValue); 83 void generateRewriter(pdl::OperationOp operationOp, 84 DenseMap<Value, Value> &rewriteValues, 85 function_ref<Value(Value)> mapRewriteValue); 86 void generateRewriter(pdl::ReplaceOp replaceOp, 87 DenseMap<Value, Value> &rewriteValues, 88 function_ref<Value(Value)> mapRewriteValue); 89 void generateRewriter(pdl::ResultOp resultOp, 90 DenseMap<Value, Value> &rewriteValues, 91 function_ref<Value(Value)> mapRewriteValue); 92 void generateRewriter(pdl::ResultsOp resultOp, 93 DenseMap<Value, Value> &rewriteValues, 94 function_ref<Value(Value)> mapRewriteValue); 95 void generateRewriter(pdl::TypeOp typeOp, 96 DenseMap<Value, Value> &rewriteValues, 97 function_ref<Value(Value)> mapRewriteValue); 98 void generateRewriter(pdl::TypesOp typeOp, 99 DenseMap<Value, Value> &rewriteValues, 100 function_ref<Value(Value)> mapRewriteValue); 101 102 /// Generate the values used for resolving the result types of an operation 103 /// created within a dag rewriter region. 104 void generateOperationResultTypeRewriter( 105 pdl::OperationOp op, SmallVectorImpl<Value> &types, 106 DenseMap<Value, Value> &rewriteValues, 107 function_ref<Value(Value)> mapRewriteValue); 108 109 /// A builder to use when generating interpreter operations. 110 OpBuilder builder; 111 112 /// The matcher function used for all match related logic within PDL patterns. 113 FuncOp matcherFunc; 114 115 /// The rewriter module containing the all rewrite related logic within PDL 116 /// patterns. 117 ModuleOp rewriterModule; 118 119 /// The symbol table of the rewriter module used for insertion. 120 SymbolTable rewriterSymbolTable; 121 122 /// A scoped map connecting a position with the corresponding interpreter 123 /// value. 124 ValueMap values; 125 126 /// A stack of blocks used as the failure destination for matcher nodes that 127 /// don't have an explicit failure path. 128 SmallVector<Block *, 8> failureBlockStack; 129 130 /// A mapping between values defined in a pattern match, and the corresponding 131 /// positional value. 132 DenseMap<Value, Position *> valueToPosition; 133 134 /// The set of operation values whose whose location will be used for newly 135 /// generated operations. 136 SetVector<Value> locOps; 137 }; 138 } // end anonymous namespace 139 140 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule) 141 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 142 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} 143 144 void PatternLowering::lower(ModuleOp module) { 145 PredicateUniquer predicateUniquer; 146 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 147 148 // Define top-level scope for the arguments to the matcher function. 149 ValueMapScope topLevelValueScope(values); 150 151 // Insert the root operation, i.e. argument to the matcher, at the root 152 // position. 153 Block *matcherEntryBlock = matcherFunc.addEntryBlock(); 154 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 155 156 // Generate a root matcher node from the provided PDL module. 157 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 158 module, predicateBuilder, valueToPosition); 159 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); 160 assert(failureBlockStack.empty() && "failed to empty the stack"); 161 162 // After generation, merged the first matched block into the entry. 163 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 164 firstMatcherBlock->getOperations()); 165 firstMatcherBlock->erase(); 166 } 167 168 Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { 169 // Push a new scope for the values used by this matcher. 170 Block *block = ®ion.emplaceBlock(); 171 ValueMapScope scope(values); 172 173 // If this is the return node, simply insert the corresponding interpreter 174 // finalize. 175 if (isa<ExitNode>(node)) { 176 builder.setInsertionPointToEnd(block); 177 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 178 return block; 179 } 180 181 // Get the next block in the match sequence. 182 // This is intentionally executed first, before we get the value for the 183 // position associated with the node, so that we preserve an "there exist" 184 // semantics: if getting a value requires an upward traversal (going from a 185 // value to its consumers), we want to perform the check on all the consumers 186 // before we pass control to the failure node. 187 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 188 Block *failureBlock; 189 if (failureNode) { 190 failureBlock = generateMatcher(*failureNode, region); 191 failureBlockStack.push_back(failureBlock); 192 } else { 193 assert(!failureBlockStack.empty() && "expected valid failure block"); 194 failureBlock = failureBlockStack.back(); 195 } 196 197 // If this node contains a position, get the corresponding value for this 198 // block. 199 Block *currentBlock = block; 200 Position *position = node.getPosition(); 201 Value val = position ? getValueAt(currentBlock, position) : Value(); 202 203 // If this value corresponds to an operation, record that we are going to use 204 // its location as part of a fused location. 205 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 206 if (isOperationValue) 207 locOps.insert(val); 208 209 // Dispatch to the correct method based on derived node type. 210 TypeSwitch<MatcherNode *>(&node) 211 .Case<BoolNode, SwitchNode>( 212 [&](auto *derivedNode) { generate(derivedNode, currentBlock, val); }) 213 .Case([&](SuccessNode *successNode) { 214 generate(successNode, currentBlock); 215 }); 216 217 // Pop all the failure blocks that were inserted due to nesting of 218 // pdl_interp.iterate. 219 while (failureBlockStack.back() != failureBlock) { 220 failureBlockStack.pop_back(); 221 assert(!failureBlockStack.empty() && "unable to locate failure block"); 222 } 223 224 // Pop the new failure block. 225 if (failureNode) 226 failureBlockStack.pop_back(); 227 228 if (isOperationValue) 229 locOps.remove(val); 230 231 return block; 232 } 233 234 Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { 235 if (Value val = values.lookup(pos)) 236 return val; 237 238 // Get the value for the parent position. 239 Value parentVal = getValueAt(currentBlock, pos->getParent()); 240 241 // TODO: Use a location from the position. 242 Location loc = parentVal.getLoc(); 243 builder.setInsertionPointToEnd(currentBlock); 244 Value value; 245 switch (pos->getKind()) { 246 case Predicates::OperationPos: { 247 auto *operationPos = cast<OperationPosition>(pos); 248 if (!operationPos->isUpward()) { 249 // Standard (downward) traversal which directly follows the defining op. 250 value = builder.create<pdl_interp::GetDefiningOpOp>( 251 loc, builder.getType<pdl::OperationType>(), parentVal); 252 break; 253 } 254 255 // The first operation retrieves the representative value of a range. 256 // This applies only when the parent is a range of values. 257 if (parentVal.getType().isa<pdl::RangeType>()) 258 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); 259 else 260 value = parentVal; 261 262 // The second operation retrieves the users. 263 value = builder.create<pdl_interp::GetUsersOp>(loc, value); 264 265 // The third operation iterates over them. 266 assert(!failureBlockStack.empty() && "expected valid failure block"); 267 auto foreach = builder.create<pdl_interp::ForEachOp>( 268 loc, value, failureBlockStack.back(), /*initLoop=*/true); 269 value = foreach.getLoopVariable(); 270 271 // Create the success and continuation blocks. 272 Block *successBlock = builder.createBlock(&foreach.region()); 273 Block *continueBlock = builder.createBlock(successBlock); 274 builder.create<pdl_interp::ContinueOp>(loc); 275 failureBlockStack.push_back(continueBlock); 276 277 // The fourth operation extracts the operand(s) of the user at the specified 278 // index (which can be None, indicating all operands). 279 builder.setInsertionPointToStart(&foreach.region().front()); 280 Value operands = builder.create<pdl_interp::GetOperandsOp>( 281 loc, parentVal.getType(), value, operationPos->getIndex()); 282 283 // The fifth operation compares the operands to the parent value / range. 284 builder.create<pdl_interp::AreEqualOp>(loc, parentVal, operands, 285 successBlock, continueBlock); 286 currentBlock = successBlock; 287 break; 288 } 289 case Predicates::OperandPos: { 290 auto *operandPos = cast<OperandPosition>(pos); 291 value = builder.create<pdl_interp::GetOperandOp>( 292 loc, builder.getType<pdl::ValueType>(), parentVal, 293 operandPos->getOperandNumber()); 294 break; 295 } 296 case Predicates::OperandGroupPos: { 297 auto *operandPos = cast<OperandGroupPosition>(pos); 298 Type valueTy = builder.getType<pdl::ValueType>(); 299 value = builder.create<pdl_interp::GetOperandsOp>( 300 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 301 parentVal, operandPos->getOperandGroupNumber()); 302 break; 303 } 304 case Predicates::AttributePos: { 305 auto *attrPos = cast<AttributePosition>(pos); 306 value = builder.create<pdl_interp::GetAttributeOp>( 307 loc, builder.getType<pdl::AttributeType>(), parentVal, 308 attrPos->getName().strref()); 309 break; 310 } 311 case Predicates::TypePos: { 312 if (parentVal.getType().isa<pdl::AttributeType>()) 313 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 314 else 315 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 316 break; 317 } 318 case Predicates::ResultPos: { 319 auto *resPos = cast<ResultPosition>(pos); 320 value = builder.create<pdl_interp::GetResultOp>( 321 loc, builder.getType<pdl::ValueType>(), parentVal, 322 resPos->getResultNumber()); 323 break; 324 } 325 case Predicates::ResultGroupPos: { 326 auto *resPos = cast<ResultGroupPosition>(pos); 327 Type valueTy = builder.getType<pdl::ValueType>(); 328 value = builder.create<pdl_interp::GetResultsOp>( 329 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 330 parentVal, resPos->getResultGroupNumber()); 331 break; 332 } 333 default: 334 llvm_unreachable("Generating unknown Position getter"); 335 break; 336 } 337 338 values.insert(pos, value); 339 return value; 340 } 341 342 void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, 343 Value val) { 344 Location loc = val.getLoc(); 345 Qualifier *question = boolNode->getQuestion(); 346 Qualifier *answer = boolNode->getAnswer(); 347 Region *region = currentBlock->getParent(); 348 349 // Execute the getValue queries first, so that we create success 350 // matcher in the correct (possibly nested) region. 351 SmallVector<Value> args; 352 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) { 353 args = {getValueAt(currentBlock, equalToQuestion->getValue())}; 354 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) { 355 for (Position *position : std::get<1>(cstQuestion->getValue())) 356 args.push_back(getValueAt(currentBlock, position)); 357 } 358 359 // Generate the matcher in the current (potentially nested) region 360 // and get the failure successor. 361 Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); 362 Block *failure = failureBlockStack.back(); 363 364 // Finally, create the predicate. 365 builder.setInsertionPointToEnd(currentBlock); 366 Predicates::Kind kind = question->getKind(); 367 switch (kind) { 368 case Predicates::IsNotNullQuestion: 369 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); 370 break; 371 case Predicates::OperationNameQuestion: { 372 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 373 builder.create<pdl_interp::CheckOperationNameOp>( 374 loc, val, opNameAnswer->getValue().getStringRef(), success, failure); 375 break; 376 } 377 case Predicates::TypeQuestion: { 378 auto *ans = cast<TypeAnswer>(answer); 379 if (val.getType().isa<pdl::RangeType>()) 380 builder.create<pdl_interp::CheckTypesOp>( 381 loc, val, ans->getValue().cast<ArrayAttr>(), success, failure); 382 else 383 builder.create<pdl_interp::CheckTypeOp>( 384 loc, val, ans->getValue().cast<TypeAttr>(), success, failure); 385 break; 386 } 387 case Predicates::AttributeQuestion: { 388 auto *ans = cast<AttributeAnswer>(answer); 389 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 390 success, failure); 391 break; 392 } 393 case Predicates::OperandCountAtLeastQuestion: 394 case Predicates::OperandCountQuestion: 395 builder.create<pdl_interp::CheckOperandCountOp>( 396 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 397 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 398 success, failure); 399 break; 400 case Predicates::ResultCountAtLeastQuestion: 401 case Predicates::ResultCountQuestion: 402 builder.create<pdl_interp::CheckResultCountOp>( 403 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 404 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 405 success, failure); 406 break; 407 case Predicates::EqualToQuestion: { 408 bool trueAnswer = isa<TrueAnswer>(answer); 409 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), 410 trueAnswer ? success : failure, 411 trueAnswer ? failure : success); 412 break; 413 } 414 case Predicates::ConstraintQuestion: { 415 auto value = cast<ConstraintQuestion>(question)->getValue(); 416 builder.create<pdl_interp::ApplyConstraintOp>( 417 loc, std::get<0>(value), args, std::get<2>(value).cast<ArrayAttr>(), 418 success, failure); 419 break; 420 } 421 default: 422 llvm_unreachable("Generating unknown Predicate operation"); 423 } 424 } 425 426 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 427 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 428 llvm::MapVector<Qualifier *, Block *> &dests) { 429 std::vector<ValT> values; 430 std::vector<Block *> blocks; 431 values.reserve(dests.size()); 432 blocks.reserve(dests.size()); 433 for (const auto &it : dests) { 434 blocks.push_back(it.second); 435 values.push_back(cast<PredT>(it.first)->getValue()); 436 } 437 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 438 } 439 440 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, 441 Value val) { 442 Qualifier *question = switchNode->getQuestion(); 443 Region *region = currentBlock->getParent(); 444 Block *defaultDest = failureBlockStack.back(); 445 446 // If the switch question is not an exact answer, i.e. for the `at_least` 447 // cases, we generate a special block sequence. 448 Predicates::Kind kind = question->getKind(); 449 if (kind == Predicates::OperandCountAtLeastQuestion || 450 kind == Predicates::ResultCountAtLeastQuestion) { 451 // Order the children such that the cases are in reverse numerical order. 452 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( 453 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 454 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 455 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 456 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 457 }); 458 459 // Build the destination for each child using the next highest child as a 460 // a failure destination. This essentially creates the following control 461 // flow: 462 // 463 // if (operand_count < 1) 464 // goto failure 465 // if (child1.match()) 466 // ... 467 // 468 // if (operand_count < 2) 469 // goto failure 470 // if (child2.match()) 471 // ... 472 // 473 // failure: 474 // ... 475 // 476 failureBlockStack.push_back(defaultDest); 477 Location loc = val.getLoc(); 478 for (unsigned idx : sortedChildren) { 479 auto &child = switchNode->getChild(idx); 480 Block *childBlock = generateMatcher(*child.second, *region); 481 Block *predicateBlock = builder.createBlock(childBlock); 482 builder.setInsertionPointToEnd(predicateBlock); 483 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue(); 484 switch (kind) { 485 case Predicates::OperandCountAtLeastQuestion: 486 builder.create<pdl_interp::CheckOperandCountOp>( 487 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 488 break; 489 case Predicates::ResultCountAtLeastQuestion: 490 builder.create<pdl_interp::CheckResultCountOp>( 491 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 492 break; 493 default: 494 llvm_unreachable("Generating invalid AtLeast operation"); 495 } 496 failureBlockStack.back() = predicateBlock; 497 } 498 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 499 currentBlock->getOperations().splice(currentBlock->end(), 500 firstPredicateBlock->getOperations()); 501 firstPredicateBlock->erase(); 502 return; 503 } 504 505 // Otherwise, generate each of the children and generate an interpreter 506 // switch. 507 llvm::MapVector<Qualifier *, Block *> children; 508 for (auto &it : switchNode->getChildren()) 509 children.insert({it.first, generateMatcher(*it.second, *region)}); 510 builder.setInsertionPointToEnd(currentBlock); 511 512 switch (question->getKind()) { 513 case Predicates::OperandCountQuestion: 514 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 515 int32_t>(val, defaultDest, builder, children); 516 case Predicates::ResultCountQuestion: 517 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 518 int32_t>(val, defaultDest, builder, children); 519 case Predicates::OperationNameQuestion: 520 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 521 OperationNameAnswer>(val, defaultDest, builder, 522 children); 523 case Predicates::TypeQuestion: 524 if (val.getType().isa<pdl::RangeType>()) { 525 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 526 val, defaultDest, builder, children); 527 } 528 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 529 val, defaultDest, builder, children); 530 case Predicates::AttributeQuestion: 531 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 532 val, defaultDest, builder, children); 533 default: 534 llvm_unreachable("Generating unknown switch predicate."); 535 } 536 } 537 538 void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { 539 pdl::PatternOp pattern = successNode->getPattern(); 540 Value root = successNode->getRoot(); 541 542 // Generate a rewriter for the pattern this success node represents, and track 543 // any values used from the match region. 544 SmallVector<Position *, 8> usedMatchValues; 545 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 546 547 // Process any values used in the rewrite that are defined in the match. 548 std::vector<Value> mappedMatchValues; 549 mappedMatchValues.reserve(usedMatchValues.size()); 550 for (Position *position : usedMatchValues) 551 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 552 553 // Collect the set of operations generated by the rewriter. 554 SmallVector<StringRef, 4> generatedOps; 555 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>()) 556 generatedOps.push_back(*op.name()); 557 ArrayAttr generatedOpsAttr; 558 if (!generatedOps.empty()) 559 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 560 561 // Grab the root kind if present. 562 StringAttr rootKindAttr; 563 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>()) 564 if (Optional<StringRef> rootKind = rootOp.name()) 565 rootKindAttr = builder.getStringAttr(*rootKind); 566 567 builder.setInsertionPointToEnd(currentBlock); 568 builder.create<pdl_interp::RecordMatchOp>( 569 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 570 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), 571 failureBlockStack.back()); 572 } 573 574 SymbolRefAttr PatternLowering::generateRewriter( 575 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 576 FuncOp rewriterFunc = 577 FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter", 578 builder.getFunctionType(llvm::None, llvm::None)); 579 rewriterSymbolTable.insert(rewriterFunc); 580 581 // Generate the rewriter function body. 582 builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock()); 583 584 // Map an input operand of the pattern to a generated interpreter value. 585 DenseMap<Value, Value> rewriteValues; 586 auto mapRewriteValue = [&](Value oldValue) { 587 Value &newValue = rewriteValues[oldValue]; 588 if (newValue) 589 return newValue; 590 591 // Prefer materializing constants directly when possible. 592 Operation *oldOp = oldValue.getDefiningOp(); 593 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 594 if (Attribute value = attrOp.valueAttr()) { 595 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 596 attrOp.getLoc(), value); 597 } 598 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 599 if (TypeAttr type = typeOp.typeAttr()) { 600 return newValue = builder.create<pdl_interp::CreateTypeOp>( 601 typeOp.getLoc(), type); 602 } 603 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 604 if (ArrayAttr type = typeOp.typesAttr()) { 605 return newValue = builder.create<pdl_interp::CreateTypesOp>( 606 typeOp.getLoc(), typeOp.getType(), type); 607 } 608 } 609 610 // Otherwise, add this as an input to the rewriter. 611 Position *inputPos = valueToPosition.lookup(oldValue); 612 assert(inputPos && "expected value to be a pattern input"); 613 usedMatchValues.push_back(inputPos); 614 return newValue = rewriterFunc.front().addArgument(oldValue.getType()); 615 }; 616 617 // If this is a custom rewriter, simply dispatch to the registered rewrite 618 // method. 619 pdl::RewriteOp rewriter = pattern.getRewriter(); 620 if (StringAttr rewriteName = rewriter.nameAttr()) { 621 SmallVector<Value> args; 622 if (rewriter.root()) 623 args.push_back(mapRewriteValue(rewriter.root())); 624 auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); 625 args.append(mappedArgs.begin(), mappedArgs.end()); 626 builder.create<pdl_interp::ApplyRewriteOp>( 627 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, 628 rewriter.externalConstParamsAttr()); 629 } else { 630 // Otherwise this is a dag rewriter defined using PDL operations. 631 for (Operation &rewriteOp : *rewriter.getBody()) { 632 llvm::TypeSwitch<Operation *>(&rewriteOp) 633 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 634 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, 635 pdl::TypeOp, pdl::TypesOp>([&](auto op) { 636 this->generateRewriter(op, rewriteValues, mapRewriteValue); 637 }); 638 } 639 } 640 641 // Update the signature of the rewrite function. 642 rewriterFunc.setType(builder.getFunctionType( 643 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 644 /*results=*/llvm::None)); 645 646 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 647 return SymbolRefAttr::get( 648 builder.getContext(), 649 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 650 SymbolRefAttr::get(rewriterFunc)); 651 } 652 653 void PatternLowering::generateRewriter( 654 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 655 function_ref<Value(Value)> mapRewriteValue) { 656 SmallVector<Value, 2> arguments; 657 for (Value argument : rewriteOp.args()) 658 arguments.push_back(mapRewriteValue(argument)); 659 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 660 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), 661 arguments, rewriteOp.constParamsAttr()); 662 for (auto it : llvm::zip(rewriteOp.results(), interpOp.results())) 663 rewriteValues[std::get<0>(it)] = std::get<1>(it); 664 } 665 666 void PatternLowering::generateRewriter( 667 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 668 function_ref<Value(Value)> mapRewriteValue) { 669 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 670 attrOp.getLoc(), attrOp.valueAttr()); 671 rewriteValues[attrOp] = newAttr; 672 } 673 674 void PatternLowering::generateRewriter( 675 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 676 function_ref<Value(Value)> mapRewriteValue) { 677 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 678 mapRewriteValue(eraseOp.operation())); 679 } 680 681 void PatternLowering::generateRewriter( 682 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 683 function_ref<Value(Value)> mapRewriteValue) { 684 SmallVector<Value, 4> operands; 685 for (Value operand : operationOp.operands()) 686 operands.push_back(mapRewriteValue(operand)); 687 688 SmallVector<Value, 4> attributes; 689 for (Value attr : operationOp.attributes()) 690 attributes.push_back(mapRewriteValue(attr)); 691 692 SmallVector<Value, 2> types; 693 generateOperationResultTypeRewriter(operationOp, types, rewriteValues, 694 mapRewriteValue); 695 696 // Create the new operation. 697 Location loc = operationOp.getLoc(); 698 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 699 loc, *operationOp.name(), types, operands, attributes, 700 operationOp.attributeNames()); 701 rewriteValues[operationOp.op()] = createdOp; 702 703 // Generate accesses for any results that have their types constrained. 704 // Handle the case where there is a single range representing all of the 705 // result types. 706 OperandRange resultTys = operationOp.types(); 707 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { 708 Value &type = rewriteValues[resultTys[0]]; 709 if (!type) { 710 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 711 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 712 } 713 return; 714 } 715 716 // Otherwise, populate the individual results. 717 bool seenVariableLength = false; 718 Type valueTy = builder.getType<pdl::ValueType>(); 719 Type valueRangeTy = pdl::RangeType::get(valueTy); 720 for (auto it : llvm::enumerate(resultTys)) { 721 Value &type = rewriteValues[it.value()]; 722 if (type) 723 continue; 724 bool isVariadic = it.value().getType().isa<pdl::RangeType>(); 725 seenVariableLength |= isVariadic; 726 727 // After a variable length result has been seen, we need to use result 728 // groups because the exact index of the result is not statically known. 729 Value resultVal; 730 if (seenVariableLength) 731 resultVal = builder.create<pdl_interp::GetResultsOp>( 732 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 733 else 734 resultVal = builder.create<pdl_interp::GetResultOp>( 735 loc, valueTy, createdOp, it.index()); 736 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 737 } 738 } 739 740 void PatternLowering::generateRewriter( 741 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 742 function_ref<Value(Value)> mapRewriteValue) { 743 SmallVector<Value, 4> replOperands; 744 745 // If the replacement was another operation, get its results. `pdl` allows 746 // for using an operation for simplicitly, but the interpreter isn't as 747 // user facing. 748 if (Value replOp = replaceOp.replOperation()) { 749 // Don't use replace if we know the replaced operation has no results. 750 auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>(); 751 if (!opOp || !opOp.types().empty()) { 752 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 753 replOp.getLoc(), mapRewriteValue(replOp))); 754 } 755 } else { 756 for (Value operand : replaceOp.replValues()) 757 replOperands.push_back(mapRewriteValue(operand)); 758 } 759 760 // If there are no replacement values, just create an erase instead. 761 if (replOperands.empty()) { 762 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(), 763 mapRewriteValue(replaceOp.operation())); 764 return; 765 } 766 767 builder.create<pdl_interp::ReplaceOp>( 768 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); 769 } 770 771 void PatternLowering::generateRewriter( 772 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 773 function_ref<Value(Value)> mapRewriteValue) { 774 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 775 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 776 mapRewriteValue(resultOp.parent()), resultOp.index()); 777 } 778 779 void PatternLowering::generateRewriter( 780 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 781 function_ref<Value(Value)> mapRewriteValue) { 782 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 783 resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()), 784 resultOp.index()); 785 } 786 787 void PatternLowering::generateRewriter( 788 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 789 function_ref<Value(Value)> mapRewriteValue) { 790 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 791 // type. 792 if (TypeAttr typeAttr = typeOp.typeAttr()) { 793 rewriteValues[typeOp] = 794 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 795 } 796 } 797 798 void PatternLowering::generateRewriter( 799 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 800 function_ref<Value(Value)> mapRewriteValue) { 801 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 802 // type. 803 if (ArrayAttr typeAttr = typeOp.typesAttr()) { 804 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 805 typeOp.getLoc(), typeOp.getType(), typeAttr); 806 } 807 } 808 809 void PatternLowering::generateOperationResultTypeRewriter( 810 pdl::OperationOp op, SmallVectorImpl<Value> &types, 811 DenseMap<Value, Value> &rewriteValues, 812 function_ref<Value(Value)> mapRewriteValue) { 813 // Look for an operation that was replaced by `op`. The result types will be 814 // inferred from the results that were replaced. 815 Block *rewriterBlock = op->getBlock(); 816 Value replacedOp; 817 for (OpOperand &use : op.op().getUses()) { 818 // Check that the use corresponds to a ReplaceOp and that it is the 819 // replacement value, not the operation being replaced. 820 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 821 if (!replOpUser || use.getOperandNumber() == 0) 822 continue; 823 // Make sure the replaced operation was defined before this one. 824 Value replOpVal = replOpUser.operation(); 825 Operation *replacedOp = replOpVal.getDefiningOp(); 826 if (replacedOp->getBlock() == rewriterBlock && 827 !replacedOp->isBeforeInBlock(op)) 828 continue; 829 830 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 831 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 832 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 833 replacedOp->getLoc(), replacedOpResults)); 834 return; 835 } 836 837 // Check if the operation has type inference support. 838 if (op.hasTypeInference()) { 839 types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc())); 840 return; 841 } 842 843 // Otherwise, handle inference for each of the result types individually. 844 OperandRange resultTypeValues = op.types(); 845 types.reserve(resultTypeValues.size()); 846 for (auto it : llvm::enumerate(resultTypeValues)) { 847 Value resultType = it.value(); 848 849 // Check for an already translated value. 850 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 851 types.push_back(existingRewriteValue); 852 continue; 853 } 854 855 // Check for an input from the matcher. 856 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 857 types.push_back(mapRewriteValue(resultType)); 858 continue; 859 } 860 861 // The verifier asserts that the result types of each pdl.operation can be 862 // inferred. If we reach here, there is a bug either in the logic above or 863 // in the verifier for pdl.operation. 864 op->emitOpError() << "unable to infer result type for operation"; 865 llvm_unreachable("unable to infer result type for operation"); 866 } 867 } 868 869 //===----------------------------------------------------------------------===// 870 // Conversion Pass 871 //===----------------------------------------------------------------------===// 872 873 namespace { 874 struct PDLToPDLInterpPass 875 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 876 void runOnOperation() final; 877 }; 878 } // namespace 879 880 /// Convert the given module containing PDL pattern operations into a PDL 881 /// Interpreter operations. 882 void PDLToPDLInterpPass::runOnOperation() { 883 ModuleOp module = getOperation(); 884 885 // Create the main matcher function This function contains all of the match 886 // related functionality from patterns in the module. 887 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 888 FuncOp matcherFunc = builder.create<FuncOp>( 889 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 890 builder.getFunctionType(builder.getType<pdl::OperationType>(), 891 /*results=*/llvm::None), 892 /*attrs=*/llvm::None); 893 894 // Create a nested module to hold the functions invoked for rewriting the IR 895 // after a successful match. 896 ModuleOp rewriterModule = builder.create<ModuleOp>( 897 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 898 899 // Generate the code for the patterns within the module. 900 PatternLowering generator(matcherFunc, rewriterModule); 901 generator.lower(module); 902 903 // After generation, delete all of the pattern operations. 904 for (pdl::PatternOp pattern : 905 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) 906 pattern.erase(); 907 } 908 909 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 910 return std::make_unique<PDLToPDLInterpPass>(); 911 } 912