1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 #include "../PassDetail.h" 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::pdl_to_pdl_interp; 25 26 //===----------------------------------------------------------------------===// 27 // PatternLowering 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 /// This class generators operations within the PDL Interpreter dialect from a 32 /// given module containing PDL pattern operations. 33 struct PatternLowering { 34 public: 35 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule); 36 37 /// Generate code for matching and rewriting based on the pattern operations 38 /// within the module. 39 void lower(ModuleOp module); 40 41 private: 42 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 43 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 44 45 /// Generate interpreter operations for the tree rooted at the given matcher 46 /// node, in the specified region. 47 Block *generateMatcher(MatcherNode &node, Region ®ion); 48 49 /// Get or create an access to the provided positional value in the current 50 /// block. This operation may mutate the provided block pointer if nested 51 /// regions (i.e., pdl_interp.iterate) are required. 52 Value getValueAt(Block *¤tBlock, Position *pos); 53 54 /// Create the interpreter predicate operations. This operation may mutate the 55 /// provided current block pointer if nested regions (iterates) are required. 56 void generate(BoolNode *boolNode, Block *¤tBlock, Value val); 57 58 /// Create the interpreter switch / predicate operations, with several case 59 /// destinations. This operation never mutates the provided current block 60 /// pointer, because the switch operation does not need Values beyond `val`. 61 void generate(SwitchNode *switchNode, Block *currentBlock, Value val); 62 63 /// Create the interpreter operations to record a successful pattern match 64 /// using the contained root operation. This operation may mutate the current 65 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. 66 void generate(SuccessNode *successNode, Block *¤tBlock); 67 68 /// Generate a rewriter function for the given pattern operation, and returns 69 /// a reference to that function. 70 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 71 SmallVectorImpl<Position *> &usedMatchValues); 72 73 /// Generate the rewriter code for the given operation. 74 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 75 DenseMap<Value, Value> &rewriteValues, 76 function_ref<Value(Value)> mapRewriteValue); 77 void generateRewriter(pdl::AttributeOp attrOp, 78 DenseMap<Value, Value> &rewriteValues, 79 function_ref<Value(Value)> mapRewriteValue); 80 void generateRewriter(pdl::EraseOp eraseOp, 81 DenseMap<Value, Value> &rewriteValues, 82 function_ref<Value(Value)> mapRewriteValue); 83 void generateRewriter(pdl::OperationOp operationOp, 84 DenseMap<Value, Value> &rewriteValues, 85 function_ref<Value(Value)> mapRewriteValue); 86 void generateRewriter(pdl::ReplaceOp replaceOp, 87 DenseMap<Value, Value> &rewriteValues, 88 function_ref<Value(Value)> mapRewriteValue); 89 void generateRewriter(pdl::ResultOp resultOp, 90 DenseMap<Value, Value> &rewriteValues, 91 function_ref<Value(Value)> mapRewriteValue); 92 void generateRewriter(pdl::ResultsOp resultOp, 93 DenseMap<Value, Value> &rewriteValues, 94 function_ref<Value(Value)> mapRewriteValue); 95 void generateRewriter(pdl::TypeOp typeOp, 96 DenseMap<Value, Value> &rewriteValues, 97 function_ref<Value(Value)> mapRewriteValue); 98 void generateRewriter(pdl::TypesOp typeOp, 99 DenseMap<Value, Value> &rewriteValues, 100 function_ref<Value(Value)> mapRewriteValue); 101 102 /// Generate the values used for resolving the result types of an operation 103 /// created within a dag rewriter region. 104 void generateOperationResultTypeRewriter( 105 pdl::OperationOp op, SmallVectorImpl<Value> &types, 106 DenseMap<Value, Value> &rewriteValues, 107 function_ref<Value(Value)> mapRewriteValue); 108 109 /// A builder to use when generating interpreter operations. 110 OpBuilder builder; 111 112 /// The matcher function used for all match related logic within PDL patterns. 113 pdl_interp::FuncOp matcherFunc; 114 115 /// The rewriter module containing the all rewrite related logic within PDL 116 /// patterns. 117 ModuleOp rewriterModule; 118 119 /// The symbol table of the rewriter module used for insertion. 120 SymbolTable rewriterSymbolTable; 121 122 /// A scoped map connecting a position with the corresponding interpreter 123 /// value. 124 ValueMap values; 125 126 /// A stack of blocks used as the failure destination for matcher nodes that 127 /// don't have an explicit failure path. 128 SmallVector<Block *, 8> failureBlockStack; 129 130 /// A mapping between values defined in a pattern match, and the corresponding 131 /// positional value. 132 DenseMap<Value, Position *> valueToPosition; 133 134 /// The set of operation values whose whose location will be used for newly 135 /// generated operations. 136 SetVector<Value> locOps; 137 }; 138 } // namespace 139 140 PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc, 141 ModuleOp rewriterModule) 142 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 143 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} 144 145 void PatternLowering::lower(ModuleOp module) { 146 PredicateUniquer predicateUniquer; 147 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 148 149 // Define top-level scope for the arguments to the matcher function. 150 ValueMapScope topLevelValueScope(values); 151 152 // Insert the root operation, i.e. argument to the matcher, at the root 153 // position. 154 Block *matcherEntryBlock = &matcherFunc.front(); 155 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 156 157 // Generate a root matcher node from the provided PDL module. 158 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 159 module, predicateBuilder, valueToPosition); 160 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); 161 assert(failureBlockStack.empty() && "failed to empty the stack"); 162 163 // After generation, merged the first matched block into the entry. 164 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 165 firstMatcherBlock->getOperations()); 166 firstMatcherBlock->erase(); 167 } 168 169 Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { 170 // Push a new scope for the values used by this matcher. 171 Block *block = ®ion.emplaceBlock(); 172 ValueMapScope scope(values); 173 174 // If this is the return node, simply insert the corresponding interpreter 175 // finalize. 176 if (isa<ExitNode>(node)) { 177 builder.setInsertionPointToEnd(block); 178 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 179 return block; 180 } 181 182 // Get the next block in the match sequence. 183 // This is intentionally executed first, before we get the value for the 184 // position associated with the node, so that we preserve an "there exist" 185 // semantics: if getting a value requires an upward traversal (going from a 186 // value to its consumers), we want to perform the check on all the consumers 187 // before we pass control to the failure node. 188 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 189 Block *failureBlock; 190 if (failureNode) { 191 failureBlock = generateMatcher(*failureNode, region); 192 failureBlockStack.push_back(failureBlock); 193 } else { 194 assert(!failureBlockStack.empty() && "expected valid failure block"); 195 failureBlock = failureBlockStack.back(); 196 } 197 198 // If this node contains a position, get the corresponding value for this 199 // block. 200 Block *currentBlock = block; 201 Position *position = node.getPosition(); 202 Value val = position ? getValueAt(currentBlock, position) : Value(); 203 204 // If this value corresponds to an operation, record that we are going to use 205 // its location as part of a fused location. 206 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 207 if (isOperationValue) 208 locOps.insert(val); 209 210 // Dispatch to the correct method based on derived node type. 211 TypeSwitch<MatcherNode *>(&node) 212 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) { 213 this->generate(derivedNode, currentBlock, val); 214 }) 215 .Case([&](SuccessNode *successNode) { 216 generate(successNode, currentBlock); 217 }); 218 219 // Pop all the failure blocks that were inserted due to nesting of 220 // pdl_interp.iterate. 221 while (failureBlockStack.back() != failureBlock) { 222 failureBlockStack.pop_back(); 223 assert(!failureBlockStack.empty() && "unable to locate failure block"); 224 } 225 226 // Pop the new failure block. 227 if (failureNode) 228 failureBlockStack.pop_back(); 229 230 if (isOperationValue) 231 locOps.remove(val); 232 233 return block; 234 } 235 236 Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { 237 if (Value val = values.lookup(pos)) 238 return val; 239 240 // Get the value for the parent position. 241 Value parentVal; 242 if (Position *parent = pos->getParent()) 243 parentVal = getValueAt(currentBlock, parent); 244 245 // TODO: Use a location from the position. 246 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc(); 247 builder.setInsertionPointToEnd(currentBlock); 248 Value value; 249 switch (pos->getKind()) { 250 case Predicates::OperationPos: { 251 auto *operationPos = cast<OperationPosition>(pos); 252 if (operationPos->isOperandDefiningOp()) 253 // Standard (downward) traversal which directly follows the defining op. 254 value = builder.create<pdl_interp::GetDefiningOpOp>( 255 loc, builder.getType<pdl::OperationType>(), parentVal); 256 else 257 // A passthrough operation position. 258 value = parentVal; 259 break; 260 } 261 case Predicates::UsersPos: { 262 auto *usersPos = cast<UsersPosition>(pos); 263 264 // The first operation retrieves the representative value of a range. 265 // This applies only when the parent is a range of values and we were 266 // requested to use a representative value (e.g., upward traversal). 267 if (parentVal.getType().isa<pdl::RangeType>() && 268 usersPos->useRepresentative()) 269 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); 270 else 271 value = parentVal; 272 273 // The second operation retrieves the users. 274 value = builder.create<pdl_interp::GetUsersOp>(loc, value); 275 break; 276 } 277 case Predicates::ForEachPos: { 278 assert(!failureBlockStack.empty() && "expected valid failure block"); 279 auto foreach = builder.create<pdl_interp::ForEachOp>( 280 loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); 281 value = foreach.getLoopVariable(); 282 283 // Create the continuation block. 284 Block *continueBlock = builder.createBlock(&foreach.getRegion()); 285 builder.create<pdl_interp::ContinueOp>(loc); 286 failureBlockStack.push_back(continueBlock); 287 288 currentBlock = &foreach.getRegion().front(); 289 break; 290 } 291 case Predicates::OperandPos: { 292 auto *operandPos = cast<OperandPosition>(pos); 293 value = builder.create<pdl_interp::GetOperandOp>( 294 loc, builder.getType<pdl::ValueType>(), parentVal, 295 operandPos->getOperandNumber()); 296 break; 297 } 298 case Predicates::OperandGroupPos: { 299 auto *operandPos = cast<OperandGroupPosition>(pos); 300 Type valueTy = builder.getType<pdl::ValueType>(); 301 value = builder.create<pdl_interp::GetOperandsOp>( 302 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 303 parentVal, operandPos->getOperandGroupNumber()); 304 break; 305 } 306 case Predicates::AttributePos: { 307 auto *attrPos = cast<AttributePosition>(pos); 308 value = builder.create<pdl_interp::GetAttributeOp>( 309 loc, builder.getType<pdl::AttributeType>(), parentVal, 310 attrPos->getName().strref()); 311 break; 312 } 313 case Predicates::TypePos: { 314 if (parentVal.getType().isa<pdl::AttributeType>()) 315 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 316 else 317 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 318 break; 319 } 320 case Predicates::ResultPos: { 321 auto *resPos = cast<ResultPosition>(pos); 322 value = builder.create<pdl_interp::GetResultOp>( 323 loc, builder.getType<pdl::ValueType>(), parentVal, 324 resPos->getResultNumber()); 325 break; 326 } 327 case Predicates::ResultGroupPos: { 328 auto *resPos = cast<ResultGroupPosition>(pos); 329 Type valueTy = builder.getType<pdl::ValueType>(); 330 value = builder.create<pdl_interp::GetResultsOp>( 331 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 332 parentVal, resPos->getResultGroupNumber()); 333 break; 334 } 335 case Predicates::AttributeLiteralPos: { 336 auto *attrPos = cast<AttributeLiteralPosition>(pos); 337 value = 338 builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue()); 339 break; 340 } 341 case Predicates::TypeLiteralPos: { 342 auto *typePos = cast<TypeLiteralPosition>(pos); 343 Attribute rawTypeAttr = typePos->getValue(); 344 if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>()) 345 value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr); 346 else 347 value = builder.create<pdl_interp::CreateTypesOp>( 348 loc, rawTypeAttr.cast<ArrayAttr>()); 349 break; 350 } 351 default: 352 llvm_unreachable("Generating unknown Position getter"); 353 break; 354 } 355 356 values.insert(pos, value); 357 return value; 358 } 359 360 void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, 361 Value val) { 362 Location loc = val.getLoc(); 363 Qualifier *question = boolNode->getQuestion(); 364 Qualifier *answer = boolNode->getAnswer(); 365 Region *region = currentBlock->getParent(); 366 367 // Execute the getValue queries first, so that we create success 368 // matcher in the correct (possibly nested) region. 369 SmallVector<Value> args; 370 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) { 371 args = {getValueAt(currentBlock, equalToQuestion->getValue())}; 372 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) { 373 for (Position *position : cstQuestion->getArgs()) 374 args.push_back(getValueAt(currentBlock, position)); 375 } 376 377 // Generate the matcher in the current (potentially nested) region 378 // and get the failure successor. 379 Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); 380 Block *failure = failureBlockStack.back(); 381 382 // Finally, create the predicate. 383 builder.setInsertionPointToEnd(currentBlock); 384 Predicates::Kind kind = question->getKind(); 385 switch (kind) { 386 case Predicates::IsNotNullQuestion: 387 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); 388 break; 389 case Predicates::OperationNameQuestion: { 390 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 391 builder.create<pdl_interp::CheckOperationNameOp>( 392 loc, val, opNameAnswer->getValue().getStringRef(), success, failure); 393 break; 394 } 395 case Predicates::TypeQuestion: { 396 auto *ans = cast<TypeAnswer>(answer); 397 if (val.getType().isa<pdl::RangeType>()) 398 builder.create<pdl_interp::CheckTypesOp>( 399 loc, val, ans->getValue().cast<ArrayAttr>(), success, failure); 400 else 401 builder.create<pdl_interp::CheckTypeOp>( 402 loc, val, ans->getValue().cast<TypeAttr>(), success, failure); 403 break; 404 } 405 case Predicates::AttributeQuestion: { 406 auto *ans = cast<AttributeAnswer>(answer); 407 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 408 success, failure); 409 break; 410 } 411 case Predicates::OperandCountAtLeastQuestion: 412 case Predicates::OperandCountQuestion: 413 builder.create<pdl_interp::CheckOperandCountOp>( 414 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 415 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 416 success, failure); 417 break; 418 case Predicates::ResultCountAtLeastQuestion: 419 case Predicates::ResultCountQuestion: 420 builder.create<pdl_interp::CheckResultCountOp>( 421 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 422 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 423 success, failure); 424 break; 425 case Predicates::EqualToQuestion: { 426 bool trueAnswer = isa<TrueAnswer>(answer); 427 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), 428 trueAnswer ? success : failure, 429 trueAnswer ? failure : success); 430 break; 431 } 432 case Predicates::ConstraintQuestion: { 433 auto *cstQuestion = cast<ConstraintQuestion>(question); 434 builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(), 435 args, success, failure); 436 break; 437 } 438 default: 439 llvm_unreachable("Generating unknown Predicate operation"); 440 } 441 } 442 443 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 444 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 445 llvm::MapVector<Qualifier *, Block *> &dests) { 446 std::vector<ValT> values; 447 std::vector<Block *> blocks; 448 values.reserve(dests.size()); 449 blocks.reserve(dests.size()); 450 for (const auto &it : dests) { 451 blocks.push_back(it.second); 452 values.push_back(cast<PredT>(it.first)->getValue()); 453 } 454 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 455 } 456 457 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, 458 Value val) { 459 Qualifier *question = switchNode->getQuestion(); 460 Region *region = currentBlock->getParent(); 461 Block *defaultDest = failureBlockStack.back(); 462 463 // If the switch question is not an exact answer, i.e. for the `at_least` 464 // cases, we generate a special block sequence. 465 Predicates::Kind kind = question->getKind(); 466 if (kind == Predicates::OperandCountAtLeastQuestion || 467 kind == Predicates::ResultCountAtLeastQuestion) { 468 // Order the children such that the cases are in reverse numerical order. 469 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( 470 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 471 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 472 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 473 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 474 }); 475 476 // Build the destination for each child using the next highest child as a 477 // a failure destination. This essentially creates the following control 478 // flow: 479 // 480 // if (operand_count < 1) 481 // goto failure 482 // if (child1.match()) 483 // ... 484 // 485 // if (operand_count < 2) 486 // goto failure 487 // if (child2.match()) 488 // ... 489 // 490 // failure: 491 // ... 492 // 493 failureBlockStack.push_back(defaultDest); 494 Location loc = val.getLoc(); 495 for (unsigned idx : sortedChildren) { 496 auto &child = switchNode->getChild(idx); 497 Block *childBlock = generateMatcher(*child.second, *region); 498 Block *predicateBlock = builder.createBlock(childBlock); 499 builder.setInsertionPointToEnd(predicateBlock); 500 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue(); 501 switch (kind) { 502 case Predicates::OperandCountAtLeastQuestion: 503 builder.create<pdl_interp::CheckOperandCountOp>( 504 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 505 break; 506 case Predicates::ResultCountAtLeastQuestion: 507 builder.create<pdl_interp::CheckResultCountOp>( 508 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); 509 break; 510 default: 511 llvm_unreachable("Generating invalid AtLeast operation"); 512 } 513 failureBlockStack.back() = predicateBlock; 514 } 515 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 516 currentBlock->getOperations().splice(currentBlock->end(), 517 firstPredicateBlock->getOperations()); 518 firstPredicateBlock->erase(); 519 return; 520 } 521 522 // Otherwise, generate each of the children and generate an interpreter 523 // switch. 524 llvm::MapVector<Qualifier *, Block *> children; 525 for (auto &it : switchNode->getChildren()) 526 children.insert({it.first, generateMatcher(*it.second, *region)}); 527 builder.setInsertionPointToEnd(currentBlock); 528 529 switch (question->getKind()) { 530 case Predicates::OperandCountQuestion: 531 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 532 int32_t>(val, defaultDest, builder, children); 533 case Predicates::ResultCountQuestion: 534 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 535 int32_t>(val, defaultDest, builder, children); 536 case Predicates::OperationNameQuestion: 537 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 538 OperationNameAnswer>(val, defaultDest, builder, 539 children); 540 case Predicates::TypeQuestion: 541 if (val.getType().isa<pdl::RangeType>()) { 542 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 543 val, defaultDest, builder, children); 544 } 545 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 546 val, defaultDest, builder, children); 547 case Predicates::AttributeQuestion: 548 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 549 val, defaultDest, builder, children); 550 default: 551 llvm_unreachable("Generating unknown switch predicate."); 552 } 553 } 554 555 void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { 556 pdl::PatternOp pattern = successNode->getPattern(); 557 Value root = successNode->getRoot(); 558 559 // Generate a rewriter for the pattern this success node represents, and track 560 // any values used from the match region. 561 SmallVector<Position *, 8> usedMatchValues; 562 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 563 564 // Process any values used in the rewrite that are defined in the match. 565 std::vector<Value> mappedMatchValues; 566 mappedMatchValues.reserve(usedMatchValues.size()); 567 for (Position *position : usedMatchValues) 568 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 569 570 // Collect the set of operations generated by the rewriter. 571 SmallVector<StringRef, 4> generatedOps; 572 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>()) 573 generatedOps.push_back(*op.name()); 574 ArrayAttr generatedOpsAttr; 575 if (!generatedOps.empty()) 576 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 577 578 // Grab the root kind if present. 579 StringAttr rootKindAttr; 580 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>()) 581 if (Optional<StringRef> rootKind = rootOp.name()) 582 rootKindAttr = builder.getStringAttr(*rootKind); 583 584 builder.setInsertionPointToEnd(currentBlock); 585 builder.create<pdl_interp::RecordMatchOp>( 586 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 587 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), 588 failureBlockStack.back()); 589 } 590 591 SymbolRefAttr PatternLowering::generateRewriter( 592 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 593 builder.setInsertionPointToEnd(rewriterModule.getBody()); 594 auto rewriterFunc = builder.create<pdl_interp::FuncOp>( 595 pattern.getLoc(), "pdl_generated_rewriter", 596 builder.getFunctionType(llvm::None, llvm::None)); 597 rewriterSymbolTable.insert(rewriterFunc); 598 599 // Generate the rewriter function body. 600 builder.setInsertionPointToEnd(&rewriterFunc.front()); 601 602 // Map an input operand of the pattern to a generated interpreter value. 603 DenseMap<Value, Value> rewriteValues; 604 auto mapRewriteValue = [&](Value oldValue) { 605 Value &newValue = rewriteValues[oldValue]; 606 if (newValue) 607 return newValue; 608 609 // Prefer materializing constants directly when possible. 610 Operation *oldOp = oldValue.getDefiningOp(); 611 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 612 if (Attribute value = attrOp.valueAttr()) { 613 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 614 attrOp.getLoc(), value); 615 } 616 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 617 if (TypeAttr type = typeOp.typeAttr()) { 618 return newValue = builder.create<pdl_interp::CreateTypeOp>( 619 typeOp.getLoc(), type); 620 } 621 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 622 if (ArrayAttr type = typeOp.typesAttr()) { 623 return newValue = builder.create<pdl_interp::CreateTypesOp>( 624 typeOp.getLoc(), typeOp.getType(), type); 625 } 626 } 627 628 // Otherwise, add this as an input to the rewriter. 629 Position *inputPos = valueToPosition.lookup(oldValue); 630 assert(inputPos && "expected value to be a pattern input"); 631 usedMatchValues.push_back(inputPos); 632 return newValue = rewriterFunc.front().addArgument(oldValue.getType(), 633 oldValue.getLoc()); 634 }; 635 636 // If this is a custom rewriter, simply dispatch to the registered rewrite 637 // method. 638 pdl::RewriteOp rewriter = pattern.getRewriter(); 639 if (StringAttr rewriteName = rewriter.nameAttr()) { 640 SmallVector<Value> args; 641 if (rewriter.root()) 642 args.push_back(mapRewriteValue(rewriter.root())); 643 auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); 644 args.append(mappedArgs.begin(), mappedArgs.end()); 645 builder.create<pdl_interp::ApplyRewriteOp>( 646 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); 647 } else { 648 // Otherwise this is a dag rewriter defined using PDL operations. 649 for (Operation &rewriteOp : *rewriter.getBody()) { 650 llvm::TypeSwitch<Operation *>(&rewriteOp) 651 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 652 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, 653 pdl::TypeOp, pdl::TypesOp>([&](auto op) { 654 this->generateRewriter(op, rewriteValues, mapRewriteValue); 655 }); 656 } 657 } 658 659 // Update the signature of the rewrite function. 660 rewriterFunc.setType(builder.getFunctionType( 661 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 662 /*results=*/llvm::None)); 663 664 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 665 return SymbolRefAttr::get( 666 builder.getContext(), 667 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 668 SymbolRefAttr::get(rewriterFunc)); 669 } 670 671 void PatternLowering::generateRewriter( 672 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 673 function_ref<Value(Value)> mapRewriteValue) { 674 SmallVector<Value, 2> arguments; 675 for (Value argument : rewriteOp.args()) 676 arguments.push_back(mapRewriteValue(argument)); 677 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 678 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), 679 arguments); 680 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) 681 rewriteValues[std::get<0>(it)] = std::get<1>(it); 682 } 683 684 void PatternLowering::generateRewriter( 685 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 686 function_ref<Value(Value)> mapRewriteValue) { 687 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 688 attrOp.getLoc(), attrOp.valueAttr()); 689 rewriteValues[attrOp] = newAttr; 690 } 691 692 void PatternLowering::generateRewriter( 693 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 694 function_ref<Value(Value)> mapRewriteValue) { 695 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 696 mapRewriteValue(eraseOp.operation())); 697 } 698 699 void PatternLowering::generateRewriter( 700 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 701 function_ref<Value(Value)> mapRewriteValue) { 702 SmallVector<Value, 4> operands; 703 for (Value operand : operationOp.operands()) 704 operands.push_back(mapRewriteValue(operand)); 705 706 SmallVector<Value, 4> attributes; 707 for (Value attr : operationOp.attributes()) 708 attributes.push_back(mapRewriteValue(attr)); 709 710 SmallVector<Value, 2> types; 711 generateOperationResultTypeRewriter(operationOp, types, rewriteValues, 712 mapRewriteValue); 713 714 // Create the new operation. 715 Location loc = operationOp.getLoc(); 716 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 717 loc, *operationOp.name(), types, operands, attributes, 718 operationOp.attributeNames()); 719 rewriteValues[operationOp.op()] = createdOp; 720 721 // Generate accesses for any results that have their types constrained. 722 // Handle the case where there is a single range representing all of the 723 // result types. 724 OperandRange resultTys = operationOp.types(); 725 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { 726 Value &type = rewriteValues[resultTys[0]]; 727 if (!type) { 728 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 729 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 730 } 731 return; 732 } 733 734 // Otherwise, populate the individual results. 735 bool seenVariableLength = false; 736 Type valueTy = builder.getType<pdl::ValueType>(); 737 Type valueRangeTy = pdl::RangeType::get(valueTy); 738 for (const auto &it : llvm::enumerate(resultTys)) { 739 Value &type = rewriteValues[it.value()]; 740 if (type) 741 continue; 742 bool isVariadic = it.value().getType().isa<pdl::RangeType>(); 743 seenVariableLength |= isVariadic; 744 745 // After a variable length result has been seen, we need to use result 746 // groups because the exact index of the result is not statically known. 747 Value resultVal; 748 if (seenVariableLength) 749 resultVal = builder.create<pdl_interp::GetResultsOp>( 750 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 751 else 752 resultVal = builder.create<pdl_interp::GetResultOp>( 753 loc, valueTy, createdOp, it.index()); 754 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 755 } 756 } 757 758 void PatternLowering::generateRewriter( 759 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 760 function_ref<Value(Value)> mapRewriteValue) { 761 SmallVector<Value, 4> replOperands; 762 763 // If the replacement was another operation, get its results. `pdl` allows 764 // for using an operation for simplicitly, but the interpreter isn't as 765 // user facing. 766 if (Value replOp = replaceOp.replOperation()) { 767 // Don't use replace if we know the replaced operation has no results. 768 auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>(); 769 if (!opOp || !opOp.types().empty()) { 770 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 771 replOp.getLoc(), mapRewriteValue(replOp))); 772 } 773 } else { 774 for (Value operand : replaceOp.replValues()) 775 replOperands.push_back(mapRewriteValue(operand)); 776 } 777 778 // If there are no replacement values, just create an erase instead. 779 if (replOperands.empty()) { 780 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(), 781 mapRewriteValue(replaceOp.operation())); 782 return; 783 } 784 785 builder.create<pdl_interp::ReplaceOp>( 786 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); 787 } 788 789 void PatternLowering::generateRewriter( 790 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 791 function_ref<Value(Value)> mapRewriteValue) { 792 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 793 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 794 mapRewriteValue(resultOp.parent()), resultOp.index()); 795 } 796 797 void PatternLowering::generateRewriter( 798 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 799 function_ref<Value(Value)> mapRewriteValue) { 800 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 801 resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()), 802 resultOp.index()); 803 } 804 805 void PatternLowering::generateRewriter( 806 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 807 function_ref<Value(Value)> mapRewriteValue) { 808 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 809 // type. 810 if (TypeAttr typeAttr = typeOp.typeAttr()) { 811 rewriteValues[typeOp] = 812 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 813 } 814 } 815 816 void PatternLowering::generateRewriter( 817 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 818 function_ref<Value(Value)> mapRewriteValue) { 819 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 820 // type. 821 if (ArrayAttr typeAttr = typeOp.typesAttr()) { 822 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 823 typeOp.getLoc(), typeOp.getType(), typeAttr); 824 } 825 } 826 827 void PatternLowering::generateOperationResultTypeRewriter( 828 pdl::OperationOp op, SmallVectorImpl<Value> &types, 829 DenseMap<Value, Value> &rewriteValues, 830 function_ref<Value(Value)> mapRewriteValue) { 831 // Look for an operation that was replaced by `op`. The result types will be 832 // inferred from the results that were replaced. 833 Block *rewriterBlock = op->getBlock(); 834 for (OpOperand &use : op.op().getUses()) { 835 // Check that the use corresponds to a ReplaceOp and that it is the 836 // replacement value, not the operation being replaced. 837 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 838 if (!replOpUser || use.getOperandNumber() == 0) 839 continue; 840 // Make sure the replaced operation was defined before this one. 841 Value replOpVal = replOpUser.operation(); 842 Operation *replacedOp = replOpVal.getDefiningOp(); 843 if (replacedOp->getBlock() == rewriterBlock && 844 !replacedOp->isBeforeInBlock(op)) 845 continue; 846 847 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 848 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 849 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 850 replacedOp->getLoc(), replacedOpResults)); 851 return; 852 } 853 854 // Check if the operation has type inference support. 855 if (op.hasTypeInference()) { 856 types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc())); 857 return; 858 } 859 860 // Otherwise, handle inference for each of the result types individually. 861 OperandRange resultTypeValues = op.types(); 862 types.reserve(resultTypeValues.size()); 863 for (const auto &it : llvm::enumerate(resultTypeValues)) { 864 Value resultType = it.value(); 865 866 // Check for an already translated value. 867 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 868 types.push_back(existingRewriteValue); 869 continue; 870 } 871 872 // Check for an input from the matcher. 873 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 874 types.push_back(mapRewriteValue(resultType)); 875 continue; 876 } 877 878 // The verifier asserts that the result types of each pdl.operation can be 879 // inferred. If we reach here, there is a bug either in the logic above or 880 // in the verifier for pdl.operation. 881 op->emitOpError() << "unable to infer result type for operation"; 882 llvm_unreachable("unable to infer result type for operation"); 883 } 884 } 885 886 //===----------------------------------------------------------------------===// 887 // Conversion Pass 888 //===----------------------------------------------------------------------===// 889 890 namespace { 891 struct PDLToPDLInterpPass 892 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 893 void runOnOperation() final; 894 }; 895 } // namespace 896 897 /// Convert the given module containing PDL pattern operations into a PDL 898 /// Interpreter operations. 899 void PDLToPDLInterpPass::runOnOperation() { 900 ModuleOp module = getOperation(); 901 902 // Create the main matcher function This function contains all of the match 903 // related functionality from patterns in the module. 904 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 905 auto matcherFunc = builder.create<pdl_interp::FuncOp>( 906 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 907 builder.getFunctionType(builder.getType<pdl::OperationType>(), 908 /*results=*/llvm::None), 909 /*attrs=*/llvm::None); 910 911 // Create a nested module to hold the functions invoked for rewriting the IR 912 // after a successful match. 913 ModuleOp rewriterModule = builder.create<ModuleOp>( 914 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 915 916 // Generate the code for the patterns within the module. 917 PatternLowering generator(matcherFunc, rewriterModule); 918 generator.lower(module); 919 920 // After generation, delete all of the pattern operations. 921 for (pdl::PatternOp pattern : 922 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) 923 pattern.erase(); 924 } 925 926 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 927 return std::make_unique<PDLToPDLInterpPass>(); 928 } 929