1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 #include "../PassDetail.h" 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::pdl_to_pdl_interp; 25 26 //===----------------------------------------------------------------------===// 27 // PatternLowering 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 /// This class generators operations within the PDL Interpreter dialect from a 32 /// given module containing PDL pattern operations. 33 struct PatternLowering { 34 public: 35 PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule); 36 37 /// Generate code for matching and rewriting based on the pattern operations 38 /// within the module. 39 void lower(ModuleOp module); 40 41 private: 42 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 43 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 44 45 /// Generate interpreter operations for the tree rooted at the given matcher 46 /// node, in the specified region. 47 Block *generateMatcher(MatcherNode &node, Region ®ion); 48 49 /// Get or create an access to the provided positional value in the current 50 /// block. This operation may mutate the provided block pointer if nested 51 /// regions (i.e., pdl_interp.iterate) are required. 52 Value getValueAt(Block *¤tBlock, Position *pos); 53 54 /// Create the interpreter predicate operations. This operation may mutate the 55 /// provided current block pointer if nested regions (iterates) are required. 56 void generate(BoolNode *boolNode, Block *¤tBlock, Value val); 57 58 /// Create the interpreter switch / predicate operations, with several case 59 /// destinations. This operation never mutates the provided current block 60 /// pointer, because the switch operation does not need Values beyond `val`. 61 void generate(SwitchNode *switchNode, Block *currentBlock, Value val); 62 63 /// Create the interpreter operations to record a successful pattern match 64 /// using the contained root operation. This operation may mutate the current 65 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. 66 void generate(SuccessNode *successNode, Block *¤tBlock); 67 68 /// Generate a rewriter function for the given pattern operation, and returns 69 /// a reference to that function. 70 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 71 SmallVectorImpl<Position *> &usedMatchValues); 72 73 /// Generate the rewriter code for the given operation. 74 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 75 DenseMap<Value, Value> &rewriteValues, 76 function_ref<Value(Value)> mapRewriteValue); 77 void generateRewriter(pdl::AttributeOp attrOp, 78 DenseMap<Value, Value> &rewriteValues, 79 function_ref<Value(Value)> mapRewriteValue); 80 void generateRewriter(pdl::EraseOp eraseOp, 81 DenseMap<Value, Value> &rewriteValues, 82 function_ref<Value(Value)> mapRewriteValue); 83 void generateRewriter(pdl::OperationOp operationOp, 84 DenseMap<Value, Value> &rewriteValues, 85 function_ref<Value(Value)> mapRewriteValue); 86 void generateRewriter(pdl::ReplaceOp replaceOp, 87 DenseMap<Value, Value> &rewriteValues, 88 function_ref<Value(Value)> mapRewriteValue); 89 void generateRewriter(pdl::ResultOp resultOp, 90 DenseMap<Value, Value> &rewriteValues, 91 function_ref<Value(Value)> mapRewriteValue); 92 void generateRewriter(pdl::ResultsOp resultOp, 93 DenseMap<Value, Value> &rewriteValues, 94 function_ref<Value(Value)> mapRewriteValue); 95 void generateRewriter(pdl::TypeOp typeOp, 96 DenseMap<Value, Value> &rewriteValues, 97 function_ref<Value(Value)> mapRewriteValue); 98 void generateRewriter(pdl::TypesOp typeOp, 99 DenseMap<Value, Value> &rewriteValues, 100 function_ref<Value(Value)> mapRewriteValue); 101 102 /// Generate the values used for resolving the result types of an operation 103 /// created within a dag rewriter region. 104 void generateOperationResultTypeRewriter( 105 pdl::OperationOp op, SmallVectorImpl<Value> &types, 106 DenseMap<Value, Value> &rewriteValues, 107 function_ref<Value(Value)> mapRewriteValue); 108 109 /// A builder to use when generating interpreter operations. 110 OpBuilder builder; 111 112 /// The matcher function used for all match related logic within PDL patterns. 113 FuncOp matcherFunc; 114 115 /// The rewriter module containing the all rewrite related logic within PDL 116 /// patterns. 117 ModuleOp rewriterModule; 118 119 /// The symbol table of the rewriter module used for insertion. 120 SymbolTable rewriterSymbolTable; 121 122 /// A scoped map connecting a position with the corresponding interpreter 123 /// value. 124 ValueMap values; 125 126 /// A stack of blocks used as the failure destination for matcher nodes that 127 /// don't have an explicit failure path. 128 SmallVector<Block *, 8> failureBlockStack; 129 130 /// A mapping between values defined in a pattern match, and the corresponding 131 /// positional value. 132 DenseMap<Value, Position *> valueToPosition; 133 134 /// The set of operation values whose whose location will be used for newly 135 /// generated operations. 136 SetVector<Value> locOps; 137 }; 138 } // end anonymous namespace 139 140 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule) 141 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 142 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} 143 144 void PatternLowering::lower(ModuleOp module) { 145 PredicateUniquer predicateUniquer; 146 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 147 148 // Define top-level scope for the arguments to the matcher function. 149 ValueMapScope topLevelValueScope(values); 150 151 // Insert the root operation, i.e. argument to the matcher, at the root 152 // position. 153 Block *matcherEntryBlock = matcherFunc.addEntryBlock(); 154 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 155 156 // Generate a root matcher node from the provided PDL module. 157 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 158 module, predicateBuilder, valueToPosition); 159 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); 160 assert(failureBlockStack.empty() && "failed to empty the stack"); 161 162 // After generation, merged the first matched block into the entry. 163 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 164 firstMatcherBlock->getOperations()); 165 firstMatcherBlock->erase(); 166 } 167 168 Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { 169 // Push a new scope for the values used by this matcher. 170 Block *block = ®ion.emplaceBlock(); 171 ValueMapScope scope(values); 172 173 // If this is the return node, simply insert the corresponding interpreter 174 // finalize. 175 if (isa<ExitNode>(node)) { 176 builder.setInsertionPointToEnd(block); 177 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 178 return block; 179 } 180 181 // Get the next block in the match sequence. 182 // This is intentionally executed first, before we get the value for the 183 // position associated with the node, so that we preserve an "there exist" 184 // semantics: if getting a value requires an upward traversal (going from a 185 // value to its consumers), we want to perform the check on all the consumers 186 // before we pass control to the failure node. 187 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 188 Block *failureBlock; 189 if (failureNode) { 190 failureBlock = generateMatcher(*failureNode, region); 191 failureBlockStack.push_back(failureBlock); 192 } else { 193 assert(!failureBlockStack.empty() && "expected valid failure block"); 194 failureBlock = failureBlockStack.back(); 195 } 196 197 // If this node contains a position, get the corresponding value for this 198 // block. 199 Block *currentBlock = block; 200 Position *position = node.getPosition(); 201 Value val = position ? getValueAt(currentBlock, position) : Value(); 202 203 // If this value corresponds to an operation, record that we are going to use 204 // its location as part of a fused location. 205 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 206 if (isOperationValue) 207 locOps.insert(val); 208 209 // Dispatch to the correct method based on derived node type. 210 TypeSwitch<MatcherNode *>(&node) 211 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) { 212 this->generate(derivedNode, currentBlock, val); 213 }) 214 .Case([&](SuccessNode *successNode) { 215 generate(successNode, currentBlock); 216 }); 217 218 // Pop all the failure blocks that were inserted due to nesting of 219 // pdl_interp.iterate. 220 while (failureBlockStack.back() != failureBlock) { 221 failureBlockStack.pop_back(); 222 assert(!failureBlockStack.empty() && "unable to locate failure block"); 223 } 224 225 // Pop the new failure block. 226 if (failureNode) 227 failureBlockStack.pop_back(); 228 229 if (isOperationValue) 230 locOps.remove(val); 231 232 return block; 233 } 234 235 Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { 236 if (Value val = values.lookup(pos)) 237 return val; 238 239 // Get the value for the parent position. 240 Value parentVal = getValueAt(currentBlock, pos->getParent()); 241 242 // TODO: Use a location from the position. 243 Location loc = parentVal.getLoc(); 244 builder.setInsertionPointToEnd(currentBlock); 245 Value value; 246 switch (pos->getKind()) { 247 case Predicates::OperationPos: { 248 auto *operationPos = cast<OperationPosition>(pos); 249 if (!operationPos->isUpward()) { 250 // Standard (downward) traversal which directly follows the defining op. 251 value = builder.create<pdl_interp::GetDefiningOpOp>( 252 loc, builder.getType<pdl::OperationType>(), parentVal); 253 break; 254 } 255 256 // The first operation retrieves the representative value of a range. 257 // This applies only when the parent is a range of values. 258 if (parentVal.getType().isa<pdl::RangeType>()) 259 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); 260 else 261 value = parentVal; 262 263 // The second operation retrieves the users. 264 value = builder.create<pdl_interp::GetUsersOp>(loc, value); 265 266 // The third operation iterates over them. 267 assert(!failureBlockStack.empty() && "expected valid failure block"); 268 auto foreach = builder.create<pdl_interp::ForEachOp>( 269 loc, value, failureBlockStack.back(), /*initLoop=*/true); 270 value = foreach.getLoopVariable(); 271 272 // Create the success and continuation blocks. 273 Block *successBlock = builder.createBlock(&foreach.region()); 274 Block *continueBlock = builder.createBlock(successBlock); 275 builder.create<pdl_interp::ContinueOp>(loc); 276 failureBlockStack.push_back(continueBlock); 277 278 // The fourth operation extracts the operand(s) of the user at the specified 279 // index (which can be None, indicating all operands). 280 builder.setInsertionPointToStart(&foreach.region().front()); 281 Value operands = builder.create<pdl_interp::GetOperandsOp>( 282 loc, parentVal.getType(), value, operationPos->getIndex()); 283 284 // The fifth operation compares the operands to the parent value / range. 285 builder.create<pdl_interp::AreEqualOp>(loc, parentVal, operands, 286 successBlock, continueBlock); 287 currentBlock = successBlock; 288 break; 289 } 290 case Predicates::OperandPos: { 291 auto *operandPos = cast<OperandPosition>(pos); 292 value = builder.create<pdl_interp::GetOperandOp>( 293 loc, builder.getType<pdl::ValueType>(), parentVal, 294 operandPos->getOperandNumber()); 295 break; 296 } 297 case Predicates::OperandGroupPos: { 298 auto *operandPos = cast<OperandGroupPosition>(pos); 299 Type valueTy = builder.getType<pdl::ValueType>(); 300 value = builder.create<pdl_interp::GetOperandsOp>( 301 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 302 parentVal, operandPos->getOperandGroupNumber()); 303 break; 304 } 305 case Predicates::AttributePos: { 306 auto *attrPos = cast<AttributePosition>(pos); 307 value = builder.create<pdl_interp::GetAttributeOp>( 308 loc, builder.getType<pdl::AttributeType>(), parentVal, 309 attrPos->getName().strref()); 310 break; 311 } 312 case Predicates::TypePos: { 313 if (parentVal.getType().isa<pdl::AttributeType>()) 314 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 315 else 316 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 317 break; 318 } 319 case Predicates::ResultPos: { 320 auto *resPos = cast<ResultPosition>(pos); 321 value = builder.create<pdl_interp::GetResultOp>( 322 loc, builder.getType<pdl::ValueType>(), parentVal, 323 resPos->getResultNumber()); 324 break; 325 } 326 case Predicates::ResultGroupPos: { 327 auto *resPos = cast<ResultGroupPosition>(pos); 328 Type valueTy = builder.getType<pdl::ValueType>(); 329 value = builder.create<pdl_interp::GetResultsOp>( 330 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 331 parentVal, resPos->getResultGroupNumber()); 332 break; 333 } 334 default: 335 llvm_unreachable("Generating unknown Position getter"); 336 break; 337 } 338 339 values.insert(pos, value); 340 return value; 341 } 342 343 void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, 344 Value val) { 345 Location loc = val.getLoc(); 346 Qualifier *question = boolNode->getQuestion(); 347 Qualifier *answer = boolNode->getAnswer(); 348 Region *region = currentBlock->getParent(); 349 350 // Execute the getValue queries first, so that we create success 351 // matcher in the correct (possibly nested) region. 352 SmallVector<Value> args; 353 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) { 354 args = {getValueAt(currentBlock, equalToQuestion->getValue())}; 355 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) { 356 for (Position *position : std::get<1>(cstQuestion->getValue())) 357 args.push_back(getValueAt(currentBlock, position)); 358 } 359 360 // Generate the matcher in the current (potentially nested) region 361 // and get the failure successor. 362 Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); 363 Block *failure = failureBlockStack.back(); 364 365 // Finally, create the predicate. 366 builder.setInsertionPointToEnd(currentBlock); 367 Predicates::Kind kind = question->getKind(); 368 switch (kind) { 369 case Predicates::IsNotNullQuestion: 370 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); 371 break; 372 case Predicates::OperationNameQuestion: { 373 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 374 builder.create<pdl_interp::CheckOperationNameOp>( 375 loc, val, opNameAnswer->getValue().getStringRef(), success, failure); 376 break; 377 } 378 case Predicates::TypeQuestion: { 379 auto *ans = cast<TypeAnswer>(answer); 380 if (val.getType().isa<pdl::RangeType>()) 381 builder.create<pdl_interp::CheckTypesOp>( 382 loc, val, ans->getValue().cast<ArrayAttr>(), success, failure); 383 else 384 builder.create<pdl_interp::CheckTypeOp>( 385 loc, val, ans->getValue().cast<TypeAttr>(), success, failure); 386 break; 387 } 388 case Predicates::AttributeQuestion: { 389 auto *ans = cast<AttributeAnswer>(answer); 390 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 391 success, failure); 392 break; 393 } 394 case Predicates::OperandCountAtLeastQuestion: 395 case Predicates::OperandCountQuestion: 396 builder.create<pdl_interp::CheckOperandCountOp>( 397 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 398 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 399 success, failure); 400 break; 401 case Predicates::ResultCountAtLeastQuestion: 402 case Predicates::ResultCountQuestion: 403 builder.create<pdl_interp::CheckResultCountOp>( 404 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 405 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 406 success, failure); 407 break; 408 case Predicates::EqualToQuestion: { 409 bool trueAnswer = isa<TrueAnswer>(answer); 410 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), 411 trueAnswer ? success : failure, 412 trueAnswer ? failure : success); 413 break; 414 } 415 case Predicates::ConstraintQuestion: { 416 auto value = cast<ConstraintQuestion>(question)->getValue(); 417 builder.create<pdl_interp::ApplyConstraintOp>( 418 loc, std::get<0>(value), args, std::get<2>(value).cast<ArrayAttr>(), 419 success, failure); 420 break; 421 } 422 default: 423 llvm_unreachable("Generating unknown Predicate operation"); 424 } 425 } 426 427 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 428 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 429 llvm::MapVector<Qualifier *, Block *> &dests) { 430 std::vector<ValT> values; 431 std::vector<Block *> blocks; 432 values.reserve(dests.size()); 433 blocks.reserve(dests.size()); 434 for (const auto &it : dests) { 435 blocks.push_back(it.second); 436 values.push_back(cast<PredT>(it.first)->getValue()); 437 } 438 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 439 } 440 441 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, 442 Value val) { 443 Qualifier *question = switchNode->getQuestion(); 444 Region *region = currentBlock->getParent(); 445 Block *defaultDest = failureBlockStack.back(); 446 447 // If the switch question is not an exact answer, i.e. for the `at_least` 448 // cases, we generate a special block sequence. 449 Predicates::Kind kind = question->getKind(); 450 if (kind == Predicates::OperandCountAtLeastQuestion || 451 kind == Predicates::ResultCountAtLeastQuestion) { 452 // Order the children such that the cases are in reverse numerical order. 453 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( 454 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 455 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 456 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 457 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 458 }); 459 460 // Build the destination for each child using the next highest child as a 461 // a failure destination. This essentially creates the following control 462 // flow: 463 // 464 // if (operand_count < 1) 465 // goto failure 466 // if (child1.match()) 467 // ... 468 // 469 // if (operand_count < 2) 470 // goto failure 471 // if (child2.match()) 472 // ... 473 // 474 // failure: 475 // ... 476 // 477 failureBlockStack.push_back(defaultDest); 478 Location loc = val.getLoc(); 479 for (unsigned idx : sortedChildren) { 480 auto &child = switchNode->getChild(idx); 481 Block *childBlock = generateMatcher(*child.second, *region); 482 Block *predicateBlock = builder.createBlock(childBlock); 483 builder.setInsertionPointToEnd(predicateBlock); 484 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue(); 485 switch (kind) { 486 case Predicates::OperandCountAtLeastQuestion: 487 builder.create<pdl_interp::CheckOperandCountOp>( 488 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 489 break; 490 case Predicates::ResultCountAtLeastQuestion: 491 builder.create<pdl_interp::CheckResultCountOp>( 492 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 493 break; 494 default: 495 llvm_unreachable("Generating invalid AtLeast operation"); 496 } 497 failureBlockStack.back() = predicateBlock; 498 } 499 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 500 currentBlock->getOperations().splice(currentBlock->end(), 501 firstPredicateBlock->getOperations()); 502 firstPredicateBlock->erase(); 503 return; 504 } 505 506 // Otherwise, generate each of the children and generate an interpreter 507 // switch. 508 llvm::MapVector<Qualifier *, Block *> children; 509 for (auto &it : switchNode->getChildren()) 510 children.insert({it.first, generateMatcher(*it.second, *region)}); 511 builder.setInsertionPointToEnd(currentBlock); 512 513 switch (question->getKind()) { 514 case Predicates::OperandCountQuestion: 515 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 516 int32_t>(val, defaultDest, builder, children); 517 case Predicates::ResultCountQuestion: 518 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 519 int32_t>(val, defaultDest, builder, children); 520 case Predicates::OperationNameQuestion: 521 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 522 OperationNameAnswer>(val, defaultDest, builder, 523 children); 524 case Predicates::TypeQuestion: 525 if (val.getType().isa<pdl::RangeType>()) { 526 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 527 val, defaultDest, builder, children); 528 } 529 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 530 val, defaultDest, builder, children); 531 case Predicates::AttributeQuestion: 532 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 533 val, defaultDest, builder, children); 534 default: 535 llvm_unreachable("Generating unknown switch predicate."); 536 } 537 } 538 539 void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { 540 pdl::PatternOp pattern = successNode->getPattern(); 541 Value root = successNode->getRoot(); 542 543 // Generate a rewriter for the pattern this success node represents, and track 544 // any values used from the match region. 545 SmallVector<Position *, 8> usedMatchValues; 546 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 547 548 // Process any values used in the rewrite that are defined in the match. 549 std::vector<Value> mappedMatchValues; 550 mappedMatchValues.reserve(usedMatchValues.size()); 551 for (Position *position : usedMatchValues) 552 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 553 554 // Collect the set of operations generated by the rewriter. 555 SmallVector<StringRef, 4> generatedOps; 556 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>()) 557 generatedOps.push_back(*op.name()); 558 ArrayAttr generatedOpsAttr; 559 if (!generatedOps.empty()) 560 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 561 562 // Grab the root kind if present. 563 StringAttr rootKindAttr; 564 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>()) 565 if (Optional<StringRef> rootKind = rootOp.name()) 566 rootKindAttr = builder.getStringAttr(*rootKind); 567 568 builder.setInsertionPointToEnd(currentBlock); 569 builder.create<pdl_interp::RecordMatchOp>( 570 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 571 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), 572 failureBlockStack.back()); 573 } 574 575 SymbolRefAttr PatternLowering::generateRewriter( 576 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 577 FuncOp rewriterFunc = 578 FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter", 579 builder.getFunctionType(llvm::None, llvm::None)); 580 rewriterSymbolTable.insert(rewriterFunc); 581 582 // Generate the rewriter function body. 583 builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock()); 584 585 // Map an input operand of the pattern to a generated interpreter value. 586 DenseMap<Value, Value> rewriteValues; 587 auto mapRewriteValue = [&](Value oldValue) { 588 Value &newValue = rewriteValues[oldValue]; 589 if (newValue) 590 return newValue; 591 592 // Prefer materializing constants directly when possible. 593 Operation *oldOp = oldValue.getDefiningOp(); 594 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 595 if (Attribute value = attrOp.valueAttr()) { 596 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 597 attrOp.getLoc(), value); 598 } 599 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 600 if (TypeAttr type = typeOp.typeAttr()) { 601 return newValue = builder.create<pdl_interp::CreateTypeOp>( 602 typeOp.getLoc(), type); 603 } 604 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 605 if (ArrayAttr type = typeOp.typesAttr()) { 606 return newValue = builder.create<pdl_interp::CreateTypesOp>( 607 typeOp.getLoc(), typeOp.getType(), type); 608 } 609 } 610 611 // Otherwise, add this as an input to the rewriter. 612 Position *inputPos = valueToPosition.lookup(oldValue); 613 assert(inputPos && "expected value to be a pattern input"); 614 usedMatchValues.push_back(inputPos); 615 return newValue = rewriterFunc.front().addArgument(oldValue.getType()); 616 }; 617 618 // If this is a custom rewriter, simply dispatch to the registered rewrite 619 // method. 620 pdl::RewriteOp rewriter = pattern.getRewriter(); 621 if (StringAttr rewriteName = rewriter.nameAttr()) { 622 SmallVector<Value> args; 623 if (rewriter.root()) 624 args.push_back(mapRewriteValue(rewriter.root())); 625 auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); 626 args.append(mappedArgs.begin(), mappedArgs.end()); 627 builder.create<pdl_interp::ApplyRewriteOp>( 628 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, 629 rewriter.externalConstParamsAttr()); 630 } else { 631 // Otherwise this is a dag rewriter defined using PDL operations. 632 for (Operation &rewriteOp : *rewriter.getBody()) { 633 llvm::TypeSwitch<Operation *>(&rewriteOp) 634 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 635 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, 636 pdl::TypeOp, pdl::TypesOp>([&](auto op) { 637 this->generateRewriter(op, rewriteValues, mapRewriteValue); 638 }); 639 } 640 } 641 642 // Update the signature of the rewrite function. 643 rewriterFunc.setType(builder.getFunctionType( 644 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 645 /*results=*/llvm::None)); 646 647 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 648 return SymbolRefAttr::get( 649 builder.getContext(), 650 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 651 SymbolRefAttr::get(rewriterFunc)); 652 } 653 654 void PatternLowering::generateRewriter( 655 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 656 function_ref<Value(Value)> mapRewriteValue) { 657 SmallVector<Value, 2> arguments; 658 for (Value argument : rewriteOp.args()) 659 arguments.push_back(mapRewriteValue(argument)); 660 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 661 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), 662 arguments, rewriteOp.constParamsAttr()); 663 for (auto it : llvm::zip(rewriteOp.results(), interpOp.results())) 664 rewriteValues[std::get<0>(it)] = std::get<1>(it); 665 } 666 667 void PatternLowering::generateRewriter( 668 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 669 function_ref<Value(Value)> mapRewriteValue) { 670 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 671 attrOp.getLoc(), attrOp.valueAttr()); 672 rewriteValues[attrOp] = newAttr; 673 } 674 675 void PatternLowering::generateRewriter( 676 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 677 function_ref<Value(Value)> mapRewriteValue) { 678 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 679 mapRewriteValue(eraseOp.operation())); 680 } 681 682 void PatternLowering::generateRewriter( 683 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 684 function_ref<Value(Value)> mapRewriteValue) { 685 SmallVector<Value, 4> operands; 686 for (Value operand : operationOp.operands()) 687 operands.push_back(mapRewriteValue(operand)); 688 689 SmallVector<Value, 4> attributes; 690 for (Value attr : operationOp.attributes()) 691 attributes.push_back(mapRewriteValue(attr)); 692 693 SmallVector<Value, 2> types; 694 generateOperationResultTypeRewriter(operationOp, types, rewriteValues, 695 mapRewriteValue); 696 697 // Create the new operation. 698 Location loc = operationOp.getLoc(); 699 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 700 loc, *operationOp.name(), types, operands, attributes, 701 operationOp.attributeNames()); 702 rewriteValues[operationOp.op()] = createdOp; 703 704 // Generate accesses for any results that have their types constrained. 705 // Handle the case where there is a single range representing all of the 706 // result types. 707 OperandRange resultTys = operationOp.types(); 708 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { 709 Value &type = rewriteValues[resultTys[0]]; 710 if (!type) { 711 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 712 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 713 } 714 return; 715 } 716 717 // Otherwise, populate the individual results. 718 bool seenVariableLength = false; 719 Type valueTy = builder.getType<pdl::ValueType>(); 720 Type valueRangeTy = pdl::RangeType::get(valueTy); 721 for (auto it : llvm::enumerate(resultTys)) { 722 Value &type = rewriteValues[it.value()]; 723 if (type) 724 continue; 725 bool isVariadic = it.value().getType().isa<pdl::RangeType>(); 726 seenVariableLength |= isVariadic; 727 728 // After a variable length result has been seen, we need to use result 729 // groups because the exact index of the result is not statically known. 730 Value resultVal; 731 if (seenVariableLength) 732 resultVal = builder.create<pdl_interp::GetResultsOp>( 733 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 734 else 735 resultVal = builder.create<pdl_interp::GetResultOp>( 736 loc, valueTy, createdOp, it.index()); 737 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 738 } 739 } 740 741 void PatternLowering::generateRewriter( 742 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 743 function_ref<Value(Value)> mapRewriteValue) { 744 SmallVector<Value, 4> replOperands; 745 746 // If the replacement was another operation, get its results. `pdl` allows 747 // for using an operation for simplicitly, but the interpreter isn't as 748 // user facing. 749 if (Value replOp = replaceOp.replOperation()) { 750 // Don't use replace if we know the replaced operation has no results. 751 auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>(); 752 if (!opOp || !opOp.types().empty()) { 753 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 754 replOp.getLoc(), mapRewriteValue(replOp))); 755 } 756 } else { 757 for (Value operand : replaceOp.replValues()) 758 replOperands.push_back(mapRewriteValue(operand)); 759 } 760 761 // If there are no replacement values, just create an erase instead. 762 if (replOperands.empty()) { 763 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(), 764 mapRewriteValue(replaceOp.operation())); 765 return; 766 } 767 768 builder.create<pdl_interp::ReplaceOp>( 769 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); 770 } 771 772 void PatternLowering::generateRewriter( 773 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 774 function_ref<Value(Value)> mapRewriteValue) { 775 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 776 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 777 mapRewriteValue(resultOp.parent()), resultOp.index()); 778 } 779 780 void PatternLowering::generateRewriter( 781 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 782 function_ref<Value(Value)> mapRewriteValue) { 783 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 784 resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()), 785 resultOp.index()); 786 } 787 788 void PatternLowering::generateRewriter( 789 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 790 function_ref<Value(Value)> mapRewriteValue) { 791 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 792 // type. 793 if (TypeAttr typeAttr = typeOp.typeAttr()) { 794 rewriteValues[typeOp] = 795 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 796 } 797 } 798 799 void PatternLowering::generateRewriter( 800 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 801 function_ref<Value(Value)> mapRewriteValue) { 802 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 803 // type. 804 if (ArrayAttr typeAttr = typeOp.typesAttr()) { 805 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 806 typeOp.getLoc(), typeOp.getType(), typeAttr); 807 } 808 } 809 810 void PatternLowering::generateOperationResultTypeRewriter( 811 pdl::OperationOp op, SmallVectorImpl<Value> &types, 812 DenseMap<Value, Value> &rewriteValues, 813 function_ref<Value(Value)> mapRewriteValue) { 814 // Look for an operation that was replaced by `op`. The result types will be 815 // inferred from the results that were replaced. 816 Block *rewriterBlock = op->getBlock(); 817 Value replacedOp; 818 for (OpOperand &use : op.op().getUses()) { 819 // Check that the use corresponds to a ReplaceOp and that it is the 820 // replacement value, not the operation being replaced. 821 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 822 if (!replOpUser || use.getOperandNumber() == 0) 823 continue; 824 // Make sure the replaced operation was defined before this one. 825 Value replOpVal = replOpUser.operation(); 826 Operation *replacedOp = replOpVal.getDefiningOp(); 827 if (replacedOp->getBlock() == rewriterBlock && 828 !replacedOp->isBeforeInBlock(op)) 829 continue; 830 831 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 832 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 833 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 834 replacedOp->getLoc(), replacedOpResults)); 835 return; 836 } 837 838 // Check if the operation has type inference support. 839 if (op.hasTypeInference()) { 840 types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc())); 841 return; 842 } 843 844 // Otherwise, handle inference for each of the result types individually. 845 OperandRange resultTypeValues = op.types(); 846 types.reserve(resultTypeValues.size()); 847 for (auto it : llvm::enumerate(resultTypeValues)) { 848 Value resultType = it.value(); 849 850 // Check for an already translated value. 851 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 852 types.push_back(existingRewriteValue); 853 continue; 854 } 855 856 // Check for an input from the matcher. 857 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 858 types.push_back(mapRewriteValue(resultType)); 859 continue; 860 } 861 862 // The verifier asserts that the result types of each pdl.operation can be 863 // inferred. If we reach here, there is a bug either in the logic above or 864 // in the verifier for pdl.operation. 865 op->emitOpError() << "unable to infer result type for operation"; 866 llvm_unreachable("unable to infer result type for operation"); 867 } 868 } 869 870 //===----------------------------------------------------------------------===// 871 // Conversion Pass 872 //===----------------------------------------------------------------------===// 873 874 namespace { 875 struct PDLToPDLInterpPass 876 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 877 void runOnOperation() final; 878 }; 879 } // namespace 880 881 /// Convert the given module containing PDL pattern operations into a PDL 882 /// Interpreter operations. 883 void PDLToPDLInterpPass::runOnOperation() { 884 ModuleOp module = getOperation(); 885 886 // Create the main matcher function This function contains all of the match 887 // related functionality from patterns in the module. 888 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 889 FuncOp matcherFunc = builder.create<FuncOp>( 890 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 891 builder.getFunctionType(builder.getType<pdl::OperationType>(), 892 /*results=*/llvm::None), 893 /*attrs=*/llvm::None); 894 895 // Create a nested module to hold the functions invoked for rewriting the IR 896 // after a successful match. 897 ModuleOp rewriterModule = builder.create<ModuleOp>( 898 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 899 900 // Generate the code for the patterns within the module. 901 PatternLowering generator(matcherFunc, rewriterModule); 902 generator.lower(module); 903 904 // After generation, delete all of the pattern operations. 905 for (pdl::PatternOp pattern : 906 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) 907 pattern.erase(); 908 } 909 910 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 911 return std::make_unique<PDLToPDLInterpPass>(); 912 } 913