1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 #include "../PassDetail.h" 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::pdl_to_pdl_interp; 25 26 //===----------------------------------------------------------------------===// 27 // PatternLowering 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 /// This class generators operations within the PDL Interpreter dialect from a 32 /// given module containing PDL pattern operations. 33 struct PatternLowering { 34 public: 35 PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule); 36 37 /// Generate code for matching and rewriting based on the pattern operations 38 /// within the module. 39 void lower(ModuleOp module); 40 41 private: 42 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 43 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 44 45 /// Generate interpreter operations for the tree rooted at the given matcher 46 /// node. 47 Block *generateMatcher(MatcherNode &node); 48 49 /// Get or create an access to the provided positional value within the 50 /// current block. 51 Value getValueAt(Block *cur, Position *pos); 52 53 /// Create an interpreter predicate operation, branching to the provided true 54 /// and false destinations. 55 void generatePredicate(Block *currentBlock, Qualifier *question, 56 Qualifier *answer, Value val, Block *trueDest, 57 Block *falseDest); 58 59 /// Create an interpreter switch predicate operation, with a provided default 60 /// and several case destinations. 61 void generateSwitch(SwitchNode *switchNode, Block *currentBlock, 62 Qualifier *question, Value val, Block *defaultDest); 63 64 /// Create the interpreter operations to record a successful pattern match. 65 void generateRecordMatch(Block *currentBlock, Block *nextBlock, 66 pdl::PatternOp pattern); 67 68 /// Generate a rewriter function for the given pattern operation, and returns 69 /// a reference to that function. 70 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 71 SmallVectorImpl<Position *> &usedMatchValues); 72 73 /// Generate the rewriter code for the given operation. 74 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 75 DenseMap<Value, Value> &rewriteValues, 76 function_ref<Value(Value)> mapRewriteValue); 77 void generateRewriter(pdl::AttributeOp attrOp, 78 DenseMap<Value, Value> &rewriteValues, 79 function_ref<Value(Value)> mapRewriteValue); 80 void generateRewriter(pdl::EraseOp eraseOp, 81 DenseMap<Value, Value> &rewriteValues, 82 function_ref<Value(Value)> mapRewriteValue); 83 void generateRewriter(pdl::OperationOp operationOp, 84 DenseMap<Value, Value> &rewriteValues, 85 function_ref<Value(Value)> mapRewriteValue); 86 void generateRewriter(pdl::ReplaceOp replaceOp, 87 DenseMap<Value, Value> &rewriteValues, 88 function_ref<Value(Value)> mapRewriteValue); 89 void generateRewriter(pdl::ResultOp resultOp, 90 DenseMap<Value, Value> &rewriteValues, 91 function_ref<Value(Value)> mapRewriteValue); 92 void generateRewriter(pdl::ResultsOp resultOp, 93 DenseMap<Value, Value> &rewriteValues, 94 function_ref<Value(Value)> mapRewriteValue); 95 void generateRewriter(pdl::TypeOp typeOp, 96 DenseMap<Value, Value> &rewriteValues, 97 function_ref<Value(Value)> mapRewriteValue); 98 void generateRewriter(pdl::TypesOp typeOp, 99 DenseMap<Value, Value> &rewriteValues, 100 function_ref<Value(Value)> mapRewriteValue); 101 102 /// Generate the values used for resolving the result types of an operation 103 /// created within a dag rewriter region. 104 void generateOperationResultTypeRewriter( 105 pdl::OperationOp op, SmallVectorImpl<Value> &types, 106 DenseMap<Value, Value> &rewriteValues, 107 function_ref<Value(Value)> mapRewriteValue); 108 109 /// A builder to use when generating interpreter operations. 110 OpBuilder builder; 111 112 /// The matcher function used for all match related logic within PDL patterns. 113 FuncOp matcherFunc; 114 115 /// The rewriter module containing the all rewrite related logic within PDL 116 /// patterns. 117 ModuleOp rewriterModule; 118 119 /// The symbol table of the rewriter module used for insertion. 120 SymbolTable rewriterSymbolTable; 121 122 /// A scoped map connecting a position with the corresponding interpreter 123 /// value. 124 ValueMap values; 125 126 /// A stack of blocks used as the failure destination for matcher nodes that 127 /// don't have an explicit failure path. 128 SmallVector<Block *, 8> failureBlockStack; 129 130 /// A mapping between values defined in a pattern match, and the corresponding 131 /// positional value. 132 DenseMap<Value, Position *> valueToPosition; 133 134 /// The set of operation values whose whose location will be used for newly 135 /// generated operations. 136 SetVector<Value> locOps; 137 }; 138 } // end anonymous namespace 139 140 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule) 141 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 142 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} 143 144 void PatternLowering::lower(ModuleOp module) { 145 PredicateUniquer predicateUniquer; 146 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 147 148 // Define top-level scope for the arguments to the matcher function. 149 ValueMapScope topLevelValueScope(values); 150 151 // Insert the root operation, i.e. argument to the matcher, at the root 152 // position. 153 Block *matcherEntryBlock = matcherFunc.addEntryBlock(); 154 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 155 156 // Generate a root matcher node from the provided PDL module. 157 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 158 module, predicateBuilder, valueToPosition); 159 Block *firstMatcherBlock = generateMatcher(*root); 160 161 // After generation, merged the first matched block into the entry. 162 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 163 firstMatcherBlock->getOperations()); 164 firstMatcherBlock->erase(); 165 } 166 167 Block *PatternLowering::generateMatcher(MatcherNode &node) { 168 // Push a new scope for the values used by this matcher. 169 Block *block = matcherFunc.addBlock(); 170 ValueMapScope scope(values); 171 172 // If this is the return node, simply insert the corresponding interpreter 173 // finalize. 174 if (isa<ExitNode>(node)) { 175 builder.setInsertionPointToEnd(block); 176 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 177 return block; 178 } 179 180 // If this node contains a position, get the corresponding value for this 181 // block. 182 Position *position = node.getPosition(); 183 Value val = position ? getValueAt(block, position) : Value(); 184 185 // Get the next block in the match sequence. 186 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 187 Block *nextBlock; 188 if (failureNode) { 189 nextBlock = generateMatcher(*failureNode); 190 failureBlockStack.push_back(nextBlock); 191 } else { 192 assert(!failureBlockStack.empty() && "expected valid failure block"); 193 nextBlock = failureBlockStack.back(); 194 } 195 196 // If this value corresponds to an operation, record that we are going to use 197 // its location as part of a fused location. 198 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 199 if (isOperationValue) 200 locOps.insert(val); 201 202 // Generate code for a boolean predicate node. 203 if (auto *boolNode = dyn_cast<BoolNode>(&node)) { 204 auto *child = generateMatcher(*boolNode->getSuccessNode()); 205 generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val, 206 child, nextBlock); 207 208 // Generate code for a switch node. 209 } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) { 210 generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock); 211 212 // Generate code for a success node. 213 } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) { 214 generateRecordMatch(block, nextBlock, successNode->getPattern()); 215 } 216 217 if (failureNode) 218 failureBlockStack.pop_back(); 219 if (isOperationValue) 220 locOps.remove(val); 221 return block; 222 } 223 224 Value PatternLowering::getValueAt(Block *cur, Position *pos) { 225 if (Value val = values.lookup(pos)) 226 return val; 227 228 // Get the value for the parent position. 229 Value parentVal = getValueAt(cur, pos->getParent()); 230 231 // TODO: Use a location from the position. 232 Location loc = parentVal.getLoc(); 233 builder.setInsertionPointToEnd(cur); 234 Value value; 235 switch (pos->getKind()) { 236 case Predicates::OperationPos: 237 value = builder.create<pdl_interp::GetDefiningOpOp>( 238 loc, builder.getType<pdl::OperationType>(), parentVal); 239 break; 240 case Predicates::OperandPos: { 241 auto *operandPos = cast<OperandPosition>(pos); 242 value = builder.create<pdl_interp::GetOperandOp>( 243 loc, builder.getType<pdl::ValueType>(), parentVal, 244 operandPos->getOperandNumber()); 245 break; 246 } 247 case Predicates::OperandGroupPos: { 248 auto *operandPos = cast<OperandGroupPosition>(pos); 249 Type valueTy = builder.getType<pdl::ValueType>(); 250 value = builder.create<pdl_interp::GetOperandsOp>( 251 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 252 parentVal, operandPos->getOperandGroupNumber()); 253 break; 254 } 255 case Predicates::AttributePos: { 256 auto *attrPos = cast<AttributePosition>(pos); 257 value = builder.create<pdl_interp::GetAttributeOp>( 258 loc, builder.getType<pdl::AttributeType>(), parentVal, 259 attrPos->getName().strref()); 260 break; 261 } 262 case Predicates::TypePos: { 263 if (parentVal.getType().isa<pdl::AttributeType>()) 264 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 265 else 266 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 267 break; 268 } 269 case Predicates::ResultPos: { 270 auto *resPos = cast<ResultPosition>(pos); 271 value = builder.create<pdl_interp::GetResultOp>( 272 loc, builder.getType<pdl::ValueType>(), parentVal, 273 resPos->getResultNumber()); 274 break; 275 } 276 case Predicates::ResultGroupPos: { 277 auto *resPos = cast<ResultGroupPosition>(pos); 278 Type valueTy = builder.getType<pdl::ValueType>(); 279 value = builder.create<pdl_interp::GetResultsOp>( 280 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 281 parentVal, resPos->getResultGroupNumber()); 282 break; 283 } 284 default: 285 llvm_unreachable("Generating unknown Position getter"); 286 break; 287 } 288 values.insert(pos, value); 289 return value; 290 } 291 292 void PatternLowering::generatePredicate(Block *currentBlock, 293 Qualifier *question, Qualifier *answer, 294 Value val, Block *trueDest, 295 Block *falseDest) { 296 builder.setInsertionPointToEnd(currentBlock); 297 Location loc = val.getLoc(); 298 Predicates::Kind kind = question->getKind(); 299 switch (kind) { 300 case Predicates::IsNotNullQuestion: 301 builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest); 302 break; 303 case Predicates::OperationNameQuestion: { 304 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 305 builder.create<pdl_interp::CheckOperationNameOp>( 306 loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest); 307 break; 308 } 309 case Predicates::TypeQuestion: { 310 auto *ans = cast<TypeAnswer>(answer); 311 if (val.getType().isa<pdl::RangeType>()) 312 builder.create<pdl_interp::CheckTypesOp>( 313 loc, val, ans->getValue().cast<ArrayAttr>(), trueDest, falseDest); 314 else 315 builder.create<pdl_interp::CheckTypeOp>( 316 loc, val, ans->getValue().cast<TypeAttr>(), trueDest, falseDest); 317 break; 318 } 319 case Predicates::AttributeQuestion: { 320 auto *ans = cast<AttributeAnswer>(answer); 321 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 322 trueDest, falseDest); 323 break; 324 } 325 case Predicates::OperandCountAtLeastQuestion: 326 case Predicates::OperandCountQuestion: 327 builder.create<pdl_interp::CheckOperandCountOp>( 328 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 329 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 330 trueDest, falseDest); 331 break; 332 case Predicates::ResultCountAtLeastQuestion: 333 case Predicates::ResultCountQuestion: 334 builder.create<pdl_interp::CheckResultCountOp>( 335 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 336 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 337 trueDest, falseDest); 338 break; 339 case Predicates::EqualToQuestion: { 340 auto *equalToQuestion = cast<EqualToQuestion>(question); 341 builder.create<pdl_interp::AreEqualOp>( 342 loc, val, getValueAt(currentBlock, equalToQuestion->getValue()), 343 trueDest, falseDest); 344 break; 345 } 346 case Predicates::ConstraintQuestion: { 347 auto *cstQuestion = cast<ConstraintQuestion>(question); 348 SmallVector<Value, 2> args; 349 for (Position *position : std::get<1>(cstQuestion->getValue())) 350 args.push_back(getValueAt(currentBlock, position)); 351 builder.create<pdl_interp::ApplyConstraintOp>( 352 loc, std::get<0>(cstQuestion->getValue()), args, 353 std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest, 354 falseDest); 355 break; 356 } 357 default: 358 llvm_unreachable("Generating unknown Predicate operation"); 359 } 360 } 361 362 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 363 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 364 llvm::MapVector<Qualifier *, Block *> &dests) { 365 std::vector<ValT> values; 366 std::vector<Block *> blocks; 367 values.reserve(dests.size()); 368 blocks.reserve(dests.size()); 369 for (const auto &it : dests) { 370 blocks.push_back(it.second); 371 values.push_back(cast<PredT>(it.first)->getValue()); 372 } 373 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 374 } 375 376 void PatternLowering::generateSwitch(SwitchNode *switchNode, 377 Block *currentBlock, Qualifier *question, 378 Value val, Block *defaultDest) { 379 // If the switch question is not an exact answer, i.e. for the `at_least` 380 // cases, we generate a special block sequence. 381 Predicates::Kind kind = question->getKind(); 382 if (kind == Predicates::OperandCountAtLeastQuestion || 383 kind == Predicates::ResultCountAtLeastQuestion) { 384 // Order the children such that the cases are in reverse numerical order. 385 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( 386 llvm::seq<unsigned>(0, switchNode->getChildren().size())); 387 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 388 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 389 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 390 }); 391 392 // Build the destination for each child using the next highest child as a 393 // a failure destination. This essentially creates the following control 394 // flow: 395 // 396 // if (operand_count < 1) 397 // goto failure 398 // if (child1.match()) 399 // ... 400 // 401 // if (operand_count < 2) 402 // goto failure 403 // if (child2.match()) 404 // ... 405 // 406 // failure: 407 // ... 408 // 409 failureBlockStack.push_back(defaultDest); 410 for (unsigned idx : sortedChildren) { 411 auto &child = switchNode->getChild(idx); 412 Block *childBlock = generateMatcher(*child.second); 413 Block *predicateBlock = builder.createBlock(childBlock); 414 generatePredicate(predicateBlock, question, child.first, val, childBlock, 415 defaultDest); 416 failureBlockStack.back() = predicateBlock; 417 } 418 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 419 currentBlock->getOperations().splice(currentBlock->end(), 420 firstPredicateBlock->getOperations()); 421 firstPredicateBlock->erase(); 422 return; 423 } 424 425 // Otherwise, generate each of the children and generate an interpreter 426 // switch. 427 llvm::MapVector<Qualifier *, Block *> children; 428 for (auto &it : switchNode->getChildren()) 429 children.insert({it.first, generateMatcher(*it.second)}); 430 builder.setInsertionPointToEnd(currentBlock); 431 432 switch (question->getKind()) { 433 case Predicates::OperandCountQuestion: 434 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 435 int32_t>(val, defaultDest, builder, children); 436 case Predicates::ResultCountQuestion: 437 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 438 int32_t>(val, defaultDest, builder, children); 439 case Predicates::OperationNameQuestion: 440 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 441 OperationNameAnswer>(val, defaultDest, builder, 442 children); 443 case Predicates::TypeQuestion: 444 if (val.getType().isa<pdl::RangeType>()) { 445 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 446 val, defaultDest, builder, children); 447 } 448 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 449 val, defaultDest, builder, children); 450 case Predicates::AttributeQuestion: 451 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 452 val, defaultDest, builder, children); 453 default: 454 llvm_unreachable("Generating unknown switch predicate."); 455 } 456 } 457 458 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock, 459 pdl::PatternOp pattern) { 460 // Generate a rewriter for the pattern this success node represents, and track 461 // any values used from the match region. 462 SmallVector<Position *, 8> usedMatchValues; 463 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 464 465 // Process any values used in the rewrite that are defined in the match. 466 std::vector<Value> mappedMatchValues; 467 mappedMatchValues.reserve(usedMatchValues.size()); 468 for (Position *position : usedMatchValues) 469 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 470 471 // Collect the set of operations generated by the rewriter. 472 SmallVector<StringRef, 4> generatedOps; 473 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>()) 474 generatedOps.push_back(*op.name()); 475 ArrayAttr generatedOpsAttr; 476 if (!generatedOps.empty()) 477 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 478 479 // Grab the root kind if present. 480 StringAttr rootKindAttr; 481 if (Optional<StringRef> rootKind = pattern.getRootKind()) 482 rootKindAttr = builder.getStringAttr(*rootKind); 483 484 builder.setInsertionPointToEnd(currentBlock); 485 builder.create<pdl_interp::RecordMatchOp>( 486 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 487 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), 488 nextBlock); 489 } 490 491 SymbolRefAttr PatternLowering::generateRewriter( 492 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 493 FuncOp rewriterFunc = 494 FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter", 495 builder.getFunctionType(llvm::None, llvm::None)); 496 rewriterSymbolTable.insert(rewriterFunc); 497 498 // Generate the rewriter function body. 499 builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock()); 500 501 // Map an input operand of the pattern to a generated interpreter value. 502 DenseMap<Value, Value> rewriteValues; 503 auto mapRewriteValue = [&](Value oldValue) { 504 Value &newValue = rewriteValues[oldValue]; 505 if (newValue) 506 return newValue; 507 508 // Prefer materializing constants directly when possible. 509 Operation *oldOp = oldValue.getDefiningOp(); 510 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 511 if (Attribute value = attrOp.valueAttr()) { 512 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 513 attrOp.getLoc(), value); 514 } 515 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 516 if (TypeAttr type = typeOp.typeAttr()) { 517 return newValue = builder.create<pdl_interp::CreateTypeOp>( 518 typeOp.getLoc(), type); 519 } 520 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 521 if (ArrayAttr type = typeOp.typesAttr()) { 522 return newValue = builder.create<pdl_interp::CreateTypesOp>( 523 typeOp.getLoc(), typeOp.getType(), type); 524 } 525 } 526 527 // Otherwise, add this as an input to the rewriter. 528 Position *inputPos = valueToPosition.lookup(oldValue); 529 assert(inputPos && "expected value to be a pattern input"); 530 usedMatchValues.push_back(inputPos); 531 return newValue = rewriterFunc.front().addArgument(oldValue.getType()); 532 }; 533 534 // If this is a custom rewriter, simply dispatch to the registered rewrite 535 // method. 536 pdl::RewriteOp rewriter = pattern.getRewriter(); 537 if (StringAttr rewriteName = rewriter.nameAttr()) { 538 auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); 539 SmallVector<Value, 4> args(1, mapRewriteValue(rewriter.root())); 540 args.append(mappedArgs.begin(), mappedArgs.end()); 541 builder.create<pdl_interp::ApplyRewriteOp>( 542 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, 543 rewriter.externalConstParamsAttr()); 544 } else { 545 // Otherwise this is a dag rewriter defined using PDL operations. 546 for (Operation &rewriteOp : *rewriter.getBody()) { 547 llvm::TypeSwitch<Operation *>(&rewriteOp) 548 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 549 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, 550 pdl::TypeOp, pdl::TypesOp>([&](auto op) { 551 this->generateRewriter(op, rewriteValues, mapRewriteValue); 552 }); 553 } 554 } 555 556 // Update the signature of the rewrite function. 557 rewriterFunc.setType(builder.getFunctionType( 558 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 559 /*results=*/llvm::None)); 560 561 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 562 return SymbolRefAttr::get( 563 builder.getContext(), 564 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 565 SymbolRefAttr::get(rewriterFunc)); 566 } 567 568 void PatternLowering::generateRewriter( 569 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 570 function_ref<Value(Value)> mapRewriteValue) { 571 SmallVector<Value, 2> arguments; 572 for (Value argument : rewriteOp.args()) 573 arguments.push_back(mapRewriteValue(argument)); 574 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 575 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), 576 arguments, rewriteOp.constParamsAttr()); 577 for (auto it : llvm::zip(rewriteOp.results(), interpOp.results())) 578 rewriteValues[std::get<0>(it)] = std::get<1>(it); 579 } 580 581 void PatternLowering::generateRewriter( 582 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 583 function_ref<Value(Value)> mapRewriteValue) { 584 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 585 attrOp.getLoc(), attrOp.valueAttr()); 586 rewriteValues[attrOp] = newAttr; 587 } 588 589 void PatternLowering::generateRewriter( 590 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 591 function_ref<Value(Value)> mapRewriteValue) { 592 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 593 mapRewriteValue(eraseOp.operation())); 594 } 595 596 void PatternLowering::generateRewriter( 597 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 598 function_ref<Value(Value)> mapRewriteValue) { 599 SmallVector<Value, 4> operands; 600 for (Value operand : operationOp.operands()) 601 operands.push_back(mapRewriteValue(operand)); 602 603 SmallVector<Value, 4> attributes; 604 for (Value attr : operationOp.attributes()) 605 attributes.push_back(mapRewriteValue(attr)); 606 607 SmallVector<Value, 2> types; 608 generateOperationResultTypeRewriter(operationOp, types, rewriteValues, 609 mapRewriteValue); 610 611 // Create the new operation. 612 Location loc = operationOp.getLoc(); 613 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 614 loc, *operationOp.name(), types, operands, attributes, 615 operationOp.attributeNames()); 616 rewriteValues[operationOp.op()] = createdOp; 617 618 // Generate accesses for any results that have their types constrained. 619 // Handle the case where there is a single range representing all of the 620 // result types. 621 OperandRange resultTys = operationOp.types(); 622 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { 623 Value &type = rewriteValues[resultTys[0]]; 624 if (!type) { 625 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 626 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 627 } 628 return; 629 } 630 631 // Otherwise, populate the individual results. 632 bool seenVariableLength = false; 633 Type valueTy = builder.getType<pdl::ValueType>(); 634 Type valueRangeTy = pdl::RangeType::get(valueTy); 635 for (auto it : llvm::enumerate(resultTys)) { 636 Value &type = rewriteValues[it.value()]; 637 if (type) 638 continue; 639 bool isVariadic = it.value().getType().isa<pdl::RangeType>(); 640 seenVariableLength |= isVariadic; 641 642 // After a variable length result has been seen, we need to use result 643 // groups because the exact index of the result is not statically known. 644 Value resultVal; 645 if (seenVariableLength) 646 resultVal = builder.create<pdl_interp::GetResultsOp>( 647 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 648 else 649 resultVal = builder.create<pdl_interp::GetResultOp>( 650 loc, valueTy, createdOp, it.index()); 651 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 652 } 653 } 654 655 void PatternLowering::generateRewriter( 656 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 657 function_ref<Value(Value)> mapRewriteValue) { 658 SmallVector<Value, 4> replOperands; 659 660 // If the replacement was another operation, get its results. `pdl` allows 661 // for using an operation for simplicitly, but the interpreter isn't as 662 // user facing. 663 if (Value replOp = replaceOp.replOperation()) { 664 // Don't use replace if we know the replaced operation has no results. 665 auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>(); 666 if (!opOp || !opOp.types().empty()) { 667 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 668 replOp.getLoc(), mapRewriteValue(replOp))); 669 } 670 } else { 671 for (Value operand : replaceOp.replValues()) 672 replOperands.push_back(mapRewriteValue(operand)); 673 } 674 675 // If there are no replacement values, just create an erase instead. 676 if (replOperands.empty()) { 677 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(), 678 mapRewriteValue(replaceOp.operation())); 679 return; 680 } 681 682 builder.create<pdl_interp::ReplaceOp>( 683 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); 684 } 685 686 void PatternLowering::generateRewriter( 687 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 688 function_ref<Value(Value)> mapRewriteValue) { 689 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 690 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 691 mapRewriteValue(resultOp.parent()), resultOp.index()); 692 } 693 694 void PatternLowering::generateRewriter( 695 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 696 function_ref<Value(Value)> mapRewriteValue) { 697 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 698 resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()), 699 resultOp.index()); 700 } 701 702 void PatternLowering::generateRewriter( 703 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 704 function_ref<Value(Value)> mapRewriteValue) { 705 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 706 // type. 707 if (TypeAttr typeAttr = typeOp.typeAttr()) { 708 rewriteValues[typeOp] = 709 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 710 } 711 } 712 713 void PatternLowering::generateRewriter( 714 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 715 function_ref<Value(Value)> mapRewriteValue) { 716 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 717 // type. 718 if (ArrayAttr typeAttr = typeOp.typesAttr()) { 719 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 720 typeOp.getLoc(), typeOp.getType(), typeAttr); 721 } 722 } 723 724 void PatternLowering::generateOperationResultTypeRewriter( 725 pdl::OperationOp op, SmallVectorImpl<Value> &types, 726 DenseMap<Value, Value> &rewriteValues, 727 function_ref<Value(Value)> mapRewriteValue) { 728 // Look for an operation that was replaced by `op`. The result types will be 729 // inferred from the results that were replaced. 730 Block *rewriterBlock = op->getBlock(); 731 Value replacedOp; 732 for (OpOperand &use : op.op().getUses()) { 733 // Check that the use corresponds to a ReplaceOp and that it is the 734 // replacement value, not the operation being replaced. 735 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 736 if (!replOpUser || use.getOperandNumber() == 0) 737 continue; 738 // Make sure the replaced operation was defined before this one. 739 Value replOpVal = replOpUser.operation(); 740 Operation *replacedOp = replOpVal.getDefiningOp(); 741 if (replacedOp->getBlock() == rewriterBlock && 742 !replacedOp->isBeforeInBlock(op)) 743 continue; 744 745 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 746 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 747 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 748 replacedOp->getLoc(), replacedOpResults)); 749 return; 750 } 751 752 // Check if the operation has type inference support. 753 if (op.hasTypeInference()) { 754 types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc())); 755 return; 756 } 757 758 // Otherwise, handle inference for each of the result types individually. 759 OperandRange resultTypeValues = op.types(); 760 types.reserve(resultTypeValues.size()); 761 for (auto it : llvm::enumerate(resultTypeValues)) { 762 Value resultType = it.value(); 763 764 // Check for an already translated value. 765 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 766 types.push_back(existingRewriteValue); 767 continue; 768 } 769 770 // Check for an input from the matcher. 771 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 772 types.push_back(mapRewriteValue(resultType)); 773 continue; 774 } 775 776 // The verifier asserts that the result types of each pdl.operation can be 777 // inferred. If we reach here, there is a bug either in the logic above or 778 // in the verifier for pdl.operation. 779 op->emitOpError() << "unable to infer result type for operation"; 780 llvm_unreachable("unable to infer result type for operation"); 781 } 782 } 783 784 //===----------------------------------------------------------------------===// 785 // Conversion Pass 786 //===----------------------------------------------------------------------===// 787 788 namespace { 789 struct PDLToPDLInterpPass 790 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 791 void runOnOperation() final; 792 }; 793 } // namespace 794 795 /// Convert the given module containing PDL pattern operations into a PDL 796 /// Interpreter operations. 797 void PDLToPDLInterpPass::runOnOperation() { 798 ModuleOp module = getOperation(); 799 800 // Create the main matcher function This function contains all of the match 801 // related functionality from patterns in the module. 802 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 803 FuncOp matcherFunc = builder.create<FuncOp>( 804 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 805 builder.getFunctionType(builder.getType<pdl::OperationType>(), 806 /*results=*/llvm::None), 807 /*attrs=*/llvm::None); 808 809 // Create a nested module to hold the functions invoked for rewriting the IR 810 // after a successful match. 811 ModuleOp rewriterModule = builder.create<ModuleOp>( 812 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 813 814 // Generate the code for the patterns within the module. 815 PatternLowering generator(matcherFunc, rewriterModule); 816 generator.lower(module); 817 818 // After generation, delete all of the pattern operations. 819 for (pdl::PatternOp pattern : 820 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) 821 pattern.erase(); 822 } 823 824 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 825 return std::make_unique<PDLToPDLInterpPass>(); 826 } 827