1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 #include "../PassDetail.h" 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/SetVector.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::pdl_to_pdl_interp; 23 24 //===----------------------------------------------------------------------===// 25 // PatternLowering 26 //===----------------------------------------------------------------------===// 27 28 namespace { 29 /// This class generators operations within the PDL Interpreter dialect from a 30 /// given module containing PDL pattern operations. 31 struct PatternLowering { 32 public: 33 PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule); 34 35 /// Generate code for matching and rewriting based on the pattern operations 36 /// within the module. 37 void lower(ModuleOp module); 38 39 private: 40 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 41 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 42 43 /// Generate interpreter operations for the tree rooted at the given matcher 44 /// node. 45 Block *generateMatcher(MatcherNode &node); 46 47 /// Get or create an access to the provided positional value within the 48 /// current block. 49 Value getValueAt(Block *cur, Position *pos); 50 51 /// Create an interpreter predicate operation, branching to the provided true 52 /// and false destinations. 53 void generatePredicate(Block *currentBlock, Qualifier *question, 54 Qualifier *answer, Value val, Block *trueDest, 55 Block *falseDest); 56 57 /// Create an interpreter switch predicate operation, with a provided default 58 /// and several case destinations. 59 void generateSwitch(SwitchNode *switchNode, Block *currentBlock, 60 Qualifier *question, Value val, Block *defaultDest); 61 62 /// Create the interpreter operations to record a successful pattern match. 63 void generateRecordMatch(Block *currentBlock, Block *nextBlock, 64 pdl::PatternOp pattern); 65 66 /// Generate a rewriter function for the given pattern operation, and returns 67 /// a reference to that function. 68 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 69 SmallVectorImpl<Position *> &usedMatchValues); 70 71 /// Generate the rewriter code for the given operation. 72 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 73 DenseMap<Value, Value> &rewriteValues, 74 function_ref<Value(Value)> mapRewriteValue); 75 void generateRewriter(pdl::AttributeOp attrOp, 76 DenseMap<Value, Value> &rewriteValues, 77 function_ref<Value(Value)> mapRewriteValue); 78 void generateRewriter(pdl::EraseOp eraseOp, 79 DenseMap<Value, Value> &rewriteValues, 80 function_ref<Value(Value)> mapRewriteValue); 81 void generateRewriter(pdl::OperationOp operationOp, 82 DenseMap<Value, Value> &rewriteValues, 83 function_ref<Value(Value)> mapRewriteValue); 84 void generateRewriter(pdl::ReplaceOp replaceOp, 85 DenseMap<Value, Value> &rewriteValues, 86 function_ref<Value(Value)> mapRewriteValue); 87 void generateRewriter(pdl::ResultOp resultOp, 88 DenseMap<Value, Value> &rewriteValues, 89 function_ref<Value(Value)> mapRewriteValue); 90 void generateRewriter(pdl::ResultsOp resultOp, 91 DenseMap<Value, Value> &rewriteValues, 92 function_ref<Value(Value)> mapRewriteValue); 93 void generateRewriter(pdl::TypeOp typeOp, 94 DenseMap<Value, Value> &rewriteValues, 95 function_ref<Value(Value)> mapRewriteValue); 96 void generateRewriter(pdl::TypesOp typeOp, 97 DenseMap<Value, Value> &rewriteValues, 98 function_ref<Value(Value)> mapRewriteValue); 99 100 /// Generate the values used for resolving the result types of an operation 101 /// created within a dag rewriter region. 102 void generateOperationResultTypeRewriter( 103 pdl::OperationOp op, SmallVectorImpl<Value> &types, 104 DenseMap<Value, Value> &rewriteValues, 105 function_ref<Value(Value)> mapRewriteValue); 106 107 /// A builder to use when generating interpreter operations. 108 OpBuilder builder; 109 110 /// The matcher function used for all match related logic within PDL patterns. 111 FuncOp matcherFunc; 112 113 /// The rewriter module containing the all rewrite related logic within PDL 114 /// patterns. 115 ModuleOp rewriterModule; 116 117 /// The symbol table of the rewriter module used for insertion. 118 SymbolTable rewriterSymbolTable; 119 120 /// A scoped map connecting a position with the corresponding interpreter 121 /// value. 122 ValueMap values; 123 124 /// A stack of blocks used as the failure destination for matcher nodes that 125 /// don't have an explicit failure path. 126 SmallVector<Block *, 8> failureBlockStack; 127 128 /// A mapping between values defined in a pattern match, and the corresponding 129 /// positional value. 130 DenseMap<Value, Position *> valueToPosition; 131 132 /// The set of operation values whose whose location will be used for newly 133 /// generated operations. 134 SetVector<Value> locOps; 135 }; 136 } // end anonymous namespace 137 138 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule) 139 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 140 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} 141 142 void PatternLowering::lower(ModuleOp module) { 143 PredicateUniquer predicateUniquer; 144 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 145 146 // Define top-level scope for the arguments to the matcher function. 147 ValueMapScope topLevelValueScope(values); 148 149 // Insert the root operation, i.e. argument to the matcher, at the root 150 // position. 151 Block *matcherEntryBlock = matcherFunc.addEntryBlock(); 152 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 153 154 // Generate a root matcher node from the provided PDL module. 155 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 156 module, predicateBuilder, valueToPosition); 157 Block *firstMatcherBlock = generateMatcher(*root); 158 159 // After generation, merged the first matched block into the entry. 160 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 161 firstMatcherBlock->getOperations()); 162 firstMatcherBlock->erase(); 163 } 164 165 Block *PatternLowering::generateMatcher(MatcherNode &node) { 166 // Push a new scope for the values used by this matcher. 167 Block *block = matcherFunc.addBlock(); 168 ValueMapScope scope(values); 169 170 // If this is the return node, simply insert the corresponding interpreter 171 // finalize. 172 if (isa<ExitNode>(node)) { 173 builder.setInsertionPointToEnd(block); 174 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 175 return block; 176 } 177 178 // If this node contains a position, get the corresponding value for this 179 // block. 180 Position *position = node.getPosition(); 181 Value val = position ? getValueAt(block, position) : Value(); 182 183 // Get the next block in the match sequence. 184 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 185 Block *nextBlock; 186 if (failureNode) { 187 nextBlock = generateMatcher(*failureNode); 188 failureBlockStack.push_back(nextBlock); 189 } else { 190 assert(!failureBlockStack.empty() && "expected valid failure block"); 191 nextBlock = failureBlockStack.back(); 192 } 193 194 // If this value corresponds to an operation, record that we are going to use 195 // its location as part of a fused location. 196 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 197 if (isOperationValue) 198 locOps.insert(val); 199 200 // Generate code for a boolean predicate node. 201 if (auto *boolNode = dyn_cast<BoolNode>(&node)) { 202 auto *child = generateMatcher(*boolNode->getSuccessNode()); 203 generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val, 204 child, nextBlock); 205 206 // Generate code for a switch node. 207 } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) { 208 generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock); 209 210 // Generate code for a success node. 211 } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) { 212 generateRecordMatch(block, nextBlock, successNode->getPattern()); 213 } 214 215 if (failureNode) 216 failureBlockStack.pop_back(); 217 if (isOperationValue) 218 locOps.remove(val); 219 return block; 220 } 221 222 Value PatternLowering::getValueAt(Block *cur, Position *pos) { 223 if (Value val = values.lookup(pos)) 224 return val; 225 226 // Get the value for the parent position. 227 Value parentVal = getValueAt(cur, pos->getParent()); 228 229 // TODO: Use a location from the position. 230 Location loc = parentVal.getLoc(); 231 builder.setInsertionPointToEnd(cur); 232 Value value; 233 switch (pos->getKind()) { 234 case Predicates::OperationPos: 235 value = builder.create<pdl_interp::GetDefiningOpOp>( 236 loc, builder.getType<pdl::OperationType>(), parentVal); 237 break; 238 case Predicates::OperandPos: { 239 auto *operandPos = cast<OperandPosition>(pos); 240 value = builder.create<pdl_interp::GetOperandOp>( 241 loc, builder.getType<pdl::ValueType>(), parentVal, 242 operandPos->getOperandNumber()); 243 break; 244 } 245 case Predicates::OperandGroupPos: { 246 auto *operandPos = cast<OperandGroupPosition>(pos); 247 Type valueTy = builder.getType<pdl::ValueType>(); 248 value = builder.create<pdl_interp::GetOperandsOp>( 249 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 250 parentVal, operandPos->getOperandGroupNumber()); 251 break; 252 } 253 case Predicates::AttributePos: { 254 auto *attrPos = cast<AttributePosition>(pos); 255 value = builder.create<pdl_interp::GetAttributeOp>( 256 loc, builder.getType<pdl::AttributeType>(), parentVal, 257 attrPos->getName().strref()); 258 break; 259 } 260 case Predicates::TypePos: { 261 if (parentVal.getType().isa<pdl::AttributeType>()) 262 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 263 else 264 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 265 break; 266 } 267 case Predicates::ResultPos: { 268 auto *resPos = cast<ResultPosition>(pos); 269 value = builder.create<pdl_interp::GetResultOp>( 270 loc, builder.getType<pdl::ValueType>(), parentVal, 271 resPos->getResultNumber()); 272 break; 273 } 274 case Predicates::ResultGroupPos: { 275 auto *resPos = cast<ResultGroupPosition>(pos); 276 Type valueTy = builder.getType<pdl::ValueType>(); 277 value = builder.create<pdl_interp::GetResultsOp>( 278 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 279 parentVal, resPos->getResultGroupNumber()); 280 break; 281 } 282 default: 283 llvm_unreachable("Generating unknown Position getter"); 284 break; 285 } 286 values.insert(pos, value); 287 return value; 288 } 289 290 void PatternLowering::generatePredicate(Block *currentBlock, 291 Qualifier *question, Qualifier *answer, 292 Value val, Block *trueDest, 293 Block *falseDest) { 294 builder.setInsertionPointToEnd(currentBlock); 295 Location loc = val.getLoc(); 296 Predicates::Kind kind = question->getKind(); 297 switch (kind) { 298 case Predicates::IsNotNullQuestion: 299 builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest); 300 break; 301 case Predicates::OperationNameQuestion: { 302 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 303 builder.create<pdl_interp::CheckOperationNameOp>( 304 loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest); 305 break; 306 } 307 case Predicates::TypeQuestion: { 308 auto *ans = cast<TypeAnswer>(answer); 309 if (val.getType().isa<pdl::RangeType>()) 310 builder.create<pdl_interp::CheckTypesOp>( 311 loc, val, ans->getValue().cast<ArrayAttr>(), trueDest, falseDest); 312 else 313 builder.create<pdl_interp::CheckTypeOp>( 314 loc, val, ans->getValue().cast<TypeAttr>(), trueDest, falseDest); 315 break; 316 } 317 case Predicates::AttributeQuestion: { 318 auto *ans = cast<AttributeAnswer>(answer); 319 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 320 trueDest, falseDest); 321 break; 322 } 323 case Predicates::OperandCountAtLeastQuestion: 324 case Predicates::OperandCountQuestion: 325 builder.create<pdl_interp::CheckOperandCountOp>( 326 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 327 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 328 trueDest, falseDest); 329 break; 330 case Predicates::ResultCountAtLeastQuestion: 331 case Predicates::ResultCountQuestion: 332 builder.create<pdl_interp::CheckResultCountOp>( 333 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 334 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 335 trueDest, falseDest); 336 break; 337 case Predicates::EqualToQuestion: { 338 auto *equalToQuestion = cast<EqualToQuestion>(question); 339 builder.create<pdl_interp::AreEqualOp>( 340 loc, val, getValueAt(currentBlock, equalToQuestion->getValue()), 341 trueDest, falseDest); 342 break; 343 } 344 case Predicates::ConstraintQuestion: { 345 auto *cstQuestion = cast<ConstraintQuestion>(question); 346 SmallVector<Value, 2> args; 347 for (Position *position : std::get<1>(cstQuestion->getValue())) 348 args.push_back(getValueAt(currentBlock, position)); 349 builder.create<pdl_interp::ApplyConstraintOp>( 350 loc, std::get<0>(cstQuestion->getValue()), args, 351 std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest, 352 falseDest); 353 break; 354 } 355 default: 356 llvm_unreachable("Generating unknown Predicate operation"); 357 } 358 } 359 360 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 361 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 362 llvm::MapVector<Qualifier *, Block *> &dests) { 363 std::vector<ValT> values; 364 std::vector<Block *> blocks; 365 values.reserve(dests.size()); 366 blocks.reserve(dests.size()); 367 for (const auto &it : dests) { 368 blocks.push_back(it.second); 369 values.push_back(cast<PredT>(it.first)->getValue()); 370 } 371 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 372 } 373 374 void PatternLowering::generateSwitch(SwitchNode *switchNode, 375 Block *currentBlock, Qualifier *question, 376 Value val, Block *defaultDest) { 377 // If the switch question is not an exact answer, i.e. for the `at_least` 378 // cases, we generate a special block sequence. 379 Predicates::Kind kind = question->getKind(); 380 if (kind == Predicates::OperandCountAtLeastQuestion || 381 kind == Predicates::ResultCountAtLeastQuestion) { 382 // Order the children such that the cases are in reverse numerical order. 383 SmallVector<unsigned> sortedChildren( 384 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 385 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 386 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 387 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 388 }); 389 390 // Build the destination for each child using the next highest child as a 391 // a failure destination. This essentially creates the following control 392 // flow: 393 // 394 // if (operand_count < 1) 395 // goto failure 396 // if (child1.match()) 397 // ... 398 // 399 // if (operand_count < 2) 400 // goto failure 401 // if (child2.match()) 402 // ... 403 // 404 // failure: 405 // ... 406 // 407 failureBlockStack.push_back(defaultDest); 408 for (unsigned idx : sortedChildren) { 409 auto &child = switchNode->getChild(idx); 410 Block *childBlock = generateMatcher(*child.second); 411 Block *predicateBlock = builder.createBlock(childBlock); 412 generatePredicate(predicateBlock, question, child.first, val, childBlock, 413 defaultDest); 414 failureBlockStack.back() = predicateBlock; 415 } 416 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 417 currentBlock->getOperations().splice(currentBlock->end(), 418 firstPredicateBlock->getOperations()); 419 firstPredicateBlock->erase(); 420 return; 421 } 422 423 // Otherwise, generate each of the children and generate an interpreter 424 // switch. 425 llvm::MapVector<Qualifier *, Block *> children; 426 for (auto &it : switchNode->getChildren()) 427 children.insert({it.first, generateMatcher(*it.second)}); 428 builder.setInsertionPointToEnd(currentBlock); 429 430 switch (question->getKind()) { 431 case Predicates::OperandCountQuestion: 432 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 433 int32_t>(val, defaultDest, builder, children); 434 case Predicates::ResultCountQuestion: 435 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 436 int32_t>(val, defaultDest, builder, children); 437 case Predicates::OperationNameQuestion: 438 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 439 OperationNameAnswer>(val, defaultDest, builder, 440 children); 441 case Predicates::TypeQuestion: 442 if (val.getType().isa<pdl::RangeType>()) { 443 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 444 val, defaultDest, builder, children); 445 } 446 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 447 val, defaultDest, builder, children); 448 case Predicates::AttributeQuestion: 449 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 450 val, defaultDest, builder, children); 451 default: 452 llvm_unreachable("Generating unknown switch predicate."); 453 } 454 } 455 456 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock, 457 pdl::PatternOp pattern) { 458 // Generate a rewriter for the pattern this success node represents, and track 459 // any values used from the match region. 460 SmallVector<Position *, 8> usedMatchValues; 461 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 462 463 // Process any values used in the rewrite that are defined in the match. 464 std::vector<Value> mappedMatchValues; 465 mappedMatchValues.reserve(usedMatchValues.size()); 466 for (Position *position : usedMatchValues) 467 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 468 469 // Collect the set of operations generated by the rewriter. 470 SmallVector<StringRef, 4> generatedOps; 471 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>()) 472 generatedOps.push_back(*op.name()); 473 ArrayAttr generatedOpsAttr; 474 if (!generatedOps.empty()) 475 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 476 477 // Grab the root kind if present. 478 StringAttr rootKindAttr; 479 if (Optional<StringRef> rootKind = pattern.getRootKind()) 480 rootKindAttr = builder.getStringAttr(*rootKind); 481 482 builder.setInsertionPointToEnd(currentBlock); 483 builder.create<pdl_interp::RecordMatchOp>( 484 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 485 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), 486 nextBlock); 487 } 488 489 SymbolRefAttr PatternLowering::generateRewriter( 490 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 491 FuncOp rewriterFunc = 492 FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter", 493 builder.getFunctionType(llvm::None, llvm::None)); 494 rewriterSymbolTable.insert(rewriterFunc); 495 496 // Generate the rewriter function body. 497 builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock()); 498 499 // Map an input operand of the pattern to a generated interpreter value. 500 DenseMap<Value, Value> rewriteValues; 501 auto mapRewriteValue = [&](Value oldValue) { 502 Value &newValue = rewriteValues[oldValue]; 503 if (newValue) 504 return newValue; 505 506 // Prefer materializing constants directly when possible. 507 Operation *oldOp = oldValue.getDefiningOp(); 508 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 509 if (Attribute value = attrOp.valueAttr()) { 510 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 511 attrOp.getLoc(), value); 512 } 513 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 514 if (TypeAttr type = typeOp.typeAttr()) { 515 return newValue = builder.create<pdl_interp::CreateTypeOp>( 516 typeOp.getLoc(), type); 517 } 518 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 519 if (ArrayAttr type = typeOp.typesAttr()) { 520 return newValue = builder.create<pdl_interp::CreateTypesOp>( 521 typeOp.getLoc(), typeOp.getType(), type); 522 } 523 } 524 525 // Otherwise, add this as an input to the rewriter. 526 Position *inputPos = valueToPosition.lookup(oldValue); 527 assert(inputPos && "expected value to be a pattern input"); 528 usedMatchValues.push_back(inputPos); 529 return newValue = rewriterFunc.front().addArgument(oldValue.getType()); 530 }; 531 532 // If this is a custom rewriter, simply dispatch to the registered rewrite 533 // method. 534 pdl::RewriteOp rewriter = pattern.getRewriter(); 535 if (StringAttr rewriteName = rewriter.nameAttr()) { 536 auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); 537 SmallVector<Value, 4> args(1, mapRewriteValue(rewriter.root())); 538 args.append(mappedArgs.begin(), mappedArgs.end()); 539 builder.create<pdl_interp::ApplyRewriteOp>( 540 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, 541 rewriter.externalConstParamsAttr()); 542 } else { 543 // Otherwise this is a dag rewriter defined using PDL operations. 544 for (Operation &rewriteOp : *rewriter.getBody()) { 545 llvm::TypeSwitch<Operation *>(&rewriteOp) 546 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 547 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, 548 pdl::TypeOp, pdl::TypesOp>([&](auto op) { 549 this->generateRewriter(op, rewriteValues, mapRewriteValue); 550 }); 551 } 552 } 553 554 // Update the signature of the rewrite function. 555 rewriterFunc.setType(builder.getFunctionType( 556 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 557 /*results=*/llvm::None)); 558 559 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 560 return builder.getSymbolRefAttr( 561 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 562 builder.getSymbolRefAttr(rewriterFunc)); 563 } 564 565 void PatternLowering::generateRewriter( 566 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 567 function_ref<Value(Value)> mapRewriteValue) { 568 SmallVector<Value, 2> arguments; 569 for (Value argument : rewriteOp.args()) 570 arguments.push_back(mapRewriteValue(argument)); 571 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 572 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), 573 arguments, rewriteOp.constParamsAttr()); 574 for (auto it : llvm::zip(rewriteOp.results(), interpOp.results())) 575 rewriteValues[std::get<0>(it)] = std::get<1>(it); 576 } 577 578 void PatternLowering::generateRewriter( 579 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 580 function_ref<Value(Value)> mapRewriteValue) { 581 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 582 attrOp.getLoc(), attrOp.valueAttr()); 583 rewriteValues[attrOp] = newAttr; 584 } 585 586 void PatternLowering::generateRewriter( 587 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 588 function_ref<Value(Value)> mapRewriteValue) { 589 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 590 mapRewriteValue(eraseOp.operation())); 591 } 592 593 void PatternLowering::generateRewriter( 594 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 595 function_ref<Value(Value)> mapRewriteValue) { 596 SmallVector<Value, 4> operands; 597 for (Value operand : operationOp.operands()) 598 operands.push_back(mapRewriteValue(operand)); 599 600 SmallVector<Value, 4> attributes; 601 for (Value attr : operationOp.attributes()) 602 attributes.push_back(mapRewriteValue(attr)); 603 604 SmallVector<Value, 2> types; 605 generateOperationResultTypeRewriter(operationOp, types, rewriteValues, 606 mapRewriteValue); 607 608 // Create the new operation. 609 Location loc = operationOp.getLoc(); 610 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 611 loc, *operationOp.name(), types, operands, attributes, 612 operationOp.attributeNames()); 613 rewriteValues[operationOp.op()] = createdOp; 614 615 // Generate accesses for any results that have their types constrained. 616 // Handle the case where there is a single range representing all of the 617 // result types. 618 OperandRange resultTys = operationOp.types(); 619 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { 620 Value &type = rewriteValues[resultTys[0]]; 621 if (!type) { 622 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 623 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 624 } 625 return; 626 } 627 628 // Otherwise, populate the individual results. 629 bool seenVariableLength = false; 630 Type valueTy = builder.getType<pdl::ValueType>(); 631 Type valueRangeTy = pdl::RangeType::get(valueTy); 632 for (auto it : llvm::enumerate(resultTys)) { 633 Value &type = rewriteValues[it.value()]; 634 if (type) 635 continue; 636 bool isVariadic = it.value().getType().isa<pdl::RangeType>(); 637 seenVariableLength |= isVariadic; 638 639 // After a variable length result has been seen, we need to use result 640 // groups because the exact index of the result is not statically known. 641 Value resultVal; 642 if (seenVariableLength) 643 resultVal = builder.create<pdl_interp::GetResultsOp>( 644 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 645 else 646 resultVal = builder.create<pdl_interp::GetResultOp>( 647 loc, valueTy, createdOp, it.index()); 648 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 649 } 650 } 651 652 void PatternLowering::generateRewriter( 653 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 654 function_ref<Value(Value)> mapRewriteValue) { 655 SmallVector<Value, 4> replOperands; 656 657 // If the replacement was another operation, get its results. `pdl` allows 658 // for using an operation for simplicitly, but the interpreter isn't as 659 // user facing. 660 if (Value replOp = replaceOp.replOperation()) { 661 // Don't use replace if we know the replaced operation has no results. 662 auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>(); 663 if (!opOp || !opOp.types().empty()) { 664 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 665 replOp.getLoc(), mapRewriteValue(replOp))); 666 } 667 } else { 668 for (Value operand : replaceOp.replValues()) 669 replOperands.push_back(mapRewriteValue(operand)); 670 } 671 672 // If there are no replacement values, just create an erase instead. 673 if (replOperands.empty()) { 674 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(), 675 mapRewriteValue(replaceOp.operation())); 676 return; 677 } 678 679 builder.create<pdl_interp::ReplaceOp>( 680 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); 681 } 682 683 void PatternLowering::generateRewriter( 684 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 685 function_ref<Value(Value)> mapRewriteValue) { 686 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 687 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 688 mapRewriteValue(resultOp.parent()), resultOp.index()); 689 } 690 691 void PatternLowering::generateRewriter( 692 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 693 function_ref<Value(Value)> mapRewriteValue) { 694 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 695 resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()), 696 resultOp.index()); 697 } 698 699 void PatternLowering::generateRewriter( 700 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 701 function_ref<Value(Value)> mapRewriteValue) { 702 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 703 // type. 704 if (TypeAttr typeAttr = typeOp.typeAttr()) { 705 rewriteValues[typeOp] = 706 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 707 } 708 } 709 710 void PatternLowering::generateRewriter( 711 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 712 function_ref<Value(Value)> mapRewriteValue) { 713 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 714 // type. 715 if (ArrayAttr typeAttr = typeOp.typesAttr()) { 716 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 717 typeOp.getLoc(), typeOp.getType(), typeAttr); 718 } 719 } 720 721 void PatternLowering::generateOperationResultTypeRewriter( 722 pdl::OperationOp op, SmallVectorImpl<Value> &types, 723 DenseMap<Value, Value> &rewriteValues, 724 function_ref<Value(Value)> mapRewriteValue) { 725 // Look for an operation that was replaced by `op`. The result types will be 726 // inferred from the results that were replaced. 727 Block *rewriterBlock = op->getBlock(); 728 Value replacedOp; 729 for (OpOperand &use : op.op().getUses()) { 730 // Check that the use corresponds to a ReplaceOp and that it is the 731 // replacement value, not the operation being replaced. 732 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 733 if (!replOpUser || use.getOperandNumber() == 0) 734 continue; 735 // Make sure the replaced operation was defined before this one. 736 Value replOpVal = replOpUser.operation(); 737 Operation *replacedOp = replOpVal.getDefiningOp(); 738 if (replacedOp->getBlock() == rewriterBlock && 739 !replacedOp->isBeforeInBlock(op)) 740 continue; 741 742 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 743 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 744 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 745 replacedOp->getLoc(), replacedOpResults)); 746 return; 747 } 748 749 // Check if the operation has type inference support. 750 if (op.hasTypeInference()) { 751 types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc())); 752 return; 753 } 754 755 // Otherwise, handle inference for each of the result types individually. 756 OperandRange resultTypeValues = op.types(); 757 types.reserve(resultTypeValues.size()); 758 for (auto it : llvm::enumerate(resultTypeValues)) { 759 Value resultType = it.value(); 760 761 // Check for an already translated value. 762 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 763 types.push_back(existingRewriteValue); 764 continue; 765 } 766 767 // Check for an input from the matcher. 768 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 769 types.push_back(mapRewriteValue(resultType)); 770 continue; 771 } 772 773 // The verifier asserts that the result types of each pdl.operation can be 774 // inferred. If we reach here, there is a bug either in the logic above or 775 // in the verifier for pdl.operation. 776 op->emitOpError() << "unable to infer result type for operation"; 777 llvm_unreachable("unable to infer result type for operation"); 778 } 779 } 780 781 //===----------------------------------------------------------------------===// 782 // Conversion Pass 783 //===----------------------------------------------------------------------===// 784 785 namespace { 786 struct PDLToPDLInterpPass 787 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 788 void runOnOperation() final; 789 }; 790 } // namespace 791 792 /// Convert the given module containing PDL pattern operations into a PDL 793 /// Interpreter operations. 794 void PDLToPDLInterpPass::runOnOperation() { 795 ModuleOp module = getOperation(); 796 797 // Create the main matcher function This function contains all of the match 798 // related functionality from patterns in the module. 799 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 800 FuncOp matcherFunc = builder.create<FuncOp>( 801 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 802 builder.getFunctionType(builder.getType<pdl::OperationType>(), 803 /*results=*/llvm::None), 804 /*attrs=*/llvm::None); 805 806 // Create a nested module to hold the functions invoked for rewriting the IR 807 // after a successful match. 808 ModuleOp rewriterModule = builder.create<ModuleOp>( 809 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 810 811 // Generate the code for the patterns within the module. 812 PatternLowering generator(matcherFunc, rewriterModule); 813 generator.lower(module); 814 815 // After generation, delete all of the pattern operations. 816 for (pdl::PatternOp pattern : 817 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) 818 pattern.erase(); 819 } 820 821 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 822 return std::make_unique<PDLToPDLInterpPass>(); 823 } 824