1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 #include "../PassDetail.h" 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/SetVector.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::pdl_to_pdl_interp; 23 24 //===----------------------------------------------------------------------===// 25 // PatternLowering 26 //===----------------------------------------------------------------------===// 27 28 namespace { 29 /// This class generators operations within the PDL Interpreter dialect from a 30 /// given module containing PDL pattern operations. 31 struct PatternLowering { 32 public: 33 PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule); 34 35 /// Generate code for matching and rewriting based on the pattern operations 36 /// within the module. 37 void lower(ModuleOp module); 38 39 private: 40 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 41 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 42 43 /// Generate interpreter operations for the tree rooted at the given matcher 44 /// node. 45 Block *generateMatcher(MatcherNode &node); 46 47 /// Get or create an access to the provided positional value within the 48 /// current block. 49 Value getValueAt(Block *cur, Position *pos); 50 51 /// Create an interpreter predicate operation, branching to the provided true 52 /// and false destinations. 53 void generatePredicate(Block *currentBlock, Qualifier *question, 54 Qualifier *answer, Value val, Block *trueDest, 55 Block *falseDest); 56 57 /// Create an interpreter switch predicate operation, with a provided default 58 /// and several case destinations. 59 void generateSwitch(SwitchNode *switchNode, Block *currentBlock, 60 Qualifier *question, Value val, Block *defaultDest); 61 62 /// Create the interpreter operations to record a successful pattern match. 63 void generateRecordMatch(Block *currentBlock, Block *nextBlock, 64 pdl::PatternOp pattern); 65 66 /// Generate a rewriter function for the given pattern operation, and returns 67 /// a reference to that function. 68 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 69 SmallVectorImpl<Position *> &usedMatchValues); 70 71 /// Generate the rewriter code for the given operation. 72 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, 73 DenseMap<Value, Value> &rewriteValues, 74 function_ref<Value(Value)> mapRewriteValue); 75 void generateRewriter(pdl::AttributeOp attrOp, 76 DenseMap<Value, Value> &rewriteValues, 77 function_ref<Value(Value)> mapRewriteValue); 78 void generateRewriter(pdl::EraseOp eraseOp, 79 DenseMap<Value, Value> &rewriteValues, 80 function_ref<Value(Value)> mapRewriteValue); 81 void generateRewriter(pdl::OperationOp operationOp, 82 DenseMap<Value, Value> &rewriteValues, 83 function_ref<Value(Value)> mapRewriteValue); 84 void generateRewriter(pdl::ReplaceOp replaceOp, 85 DenseMap<Value, Value> &rewriteValues, 86 function_ref<Value(Value)> mapRewriteValue); 87 void generateRewriter(pdl::ResultOp resultOp, 88 DenseMap<Value, Value> &rewriteValues, 89 function_ref<Value(Value)> mapRewriteValue); 90 void generateRewriter(pdl::ResultsOp resultOp, 91 DenseMap<Value, Value> &rewriteValues, 92 function_ref<Value(Value)> mapRewriteValue); 93 void generateRewriter(pdl::TypeOp typeOp, 94 DenseMap<Value, Value> &rewriteValues, 95 function_ref<Value(Value)> mapRewriteValue); 96 void generateRewriter(pdl::TypesOp typeOp, 97 DenseMap<Value, Value> &rewriteValues, 98 function_ref<Value(Value)> mapRewriteValue); 99 100 /// Generate the values used for resolving the result types of an operation 101 /// created within a dag rewriter region. 102 void generateOperationResultTypeRewriter( 103 pdl::OperationOp op, SmallVectorImpl<Value> &types, 104 DenseMap<Value, Value> &rewriteValues, 105 function_ref<Value(Value)> mapRewriteValue); 106 107 /// A builder to use when generating interpreter operations. 108 OpBuilder builder; 109 110 /// The matcher function used for all match related logic within PDL patterns. 111 FuncOp matcherFunc; 112 113 /// The rewriter module containing the all rewrite related logic within PDL 114 /// patterns. 115 ModuleOp rewriterModule; 116 117 /// The symbol table of the rewriter module used for insertion. 118 SymbolTable rewriterSymbolTable; 119 120 /// A scoped map connecting a position with the corresponding interpreter 121 /// value. 122 ValueMap values; 123 124 /// A stack of blocks used as the failure destination for matcher nodes that 125 /// don't have an explicit failure path. 126 SmallVector<Block *, 8> failureBlockStack; 127 128 /// A mapping between values defined in a pattern match, and the corresponding 129 /// positional value. 130 DenseMap<Value, Position *> valueToPosition; 131 132 /// The set of operation values whose whose location will be used for newly 133 /// generated operations. 134 SetVector<Value> locOps; 135 }; 136 } // end anonymous namespace 137 138 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule) 139 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 140 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} 141 142 void PatternLowering::lower(ModuleOp module) { 143 PredicateUniquer predicateUniquer; 144 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 145 146 // Define top-level scope for the arguments to the matcher function. 147 ValueMapScope topLevelValueScope(values); 148 149 // Insert the root operation, i.e. argument to the matcher, at the root 150 // position. 151 Block *matcherEntryBlock = matcherFunc.addEntryBlock(); 152 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 153 154 // Generate a root matcher node from the provided PDL module. 155 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 156 module, predicateBuilder, valueToPosition); 157 Block *firstMatcherBlock = generateMatcher(*root); 158 159 // After generation, merged the first matched block into the entry. 160 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 161 firstMatcherBlock->getOperations()); 162 firstMatcherBlock->erase(); 163 } 164 165 Block *PatternLowering::generateMatcher(MatcherNode &node) { 166 // Push a new scope for the values used by this matcher. 167 Block *block = matcherFunc.addBlock(); 168 ValueMapScope scope(values); 169 170 // If this is the return node, simply insert the corresponding interpreter 171 // finalize. 172 if (isa<ExitNode>(node)) { 173 builder.setInsertionPointToEnd(block); 174 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 175 return block; 176 } 177 178 // If this node contains a position, get the corresponding value for this 179 // block. 180 Position *position = node.getPosition(); 181 Value val = position ? getValueAt(block, position) : Value(); 182 183 // Get the next block in the match sequence. 184 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 185 Block *nextBlock; 186 if (failureNode) { 187 nextBlock = generateMatcher(*failureNode); 188 failureBlockStack.push_back(nextBlock); 189 } else { 190 assert(!failureBlockStack.empty() && "expected valid failure block"); 191 nextBlock = failureBlockStack.back(); 192 } 193 194 // If this value corresponds to an operation, record that we are going to use 195 // its location as part of a fused location. 196 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 197 if (isOperationValue) 198 locOps.insert(val); 199 200 // Generate code for a boolean predicate node. 201 if (auto *boolNode = dyn_cast<BoolNode>(&node)) { 202 auto *child = generateMatcher(*boolNode->getSuccessNode()); 203 generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val, 204 child, nextBlock); 205 206 // Generate code for a switch node. 207 } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) { 208 generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock); 209 210 // Generate code for a success node. 211 } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) { 212 generateRecordMatch(block, nextBlock, successNode->getPattern()); 213 } 214 215 if (failureNode) 216 failureBlockStack.pop_back(); 217 if (isOperationValue) 218 locOps.remove(val); 219 return block; 220 } 221 222 Value PatternLowering::getValueAt(Block *cur, Position *pos) { 223 if (Value val = values.lookup(pos)) 224 return val; 225 226 // Get the value for the parent position. 227 Value parentVal = getValueAt(cur, pos->getParent()); 228 229 // TODO: Use a location from the position. 230 Location loc = parentVal.getLoc(); 231 builder.setInsertionPointToEnd(cur); 232 Value value; 233 switch (pos->getKind()) { 234 case Predicates::OperationPos: 235 value = builder.create<pdl_interp::GetDefiningOpOp>( 236 loc, builder.getType<pdl::OperationType>(), parentVal); 237 break; 238 case Predicates::OperandPos: { 239 auto *operandPos = cast<OperandPosition>(pos); 240 value = builder.create<pdl_interp::GetOperandOp>( 241 loc, builder.getType<pdl::ValueType>(), parentVal, 242 operandPos->getOperandNumber()); 243 break; 244 } 245 case Predicates::OperandGroupPos: { 246 auto *operandPos = cast<OperandGroupPosition>(pos); 247 Type valueTy = builder.getType<pdl::ValueType>(); 248 value = builder.create<pdl_interp::GetOperandsOp>( 249 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 250 parentVal, operandPos->getOperandGroupNumber()); 251 break; 252 } 253 case Predicates::AttributePos: { 254 auto *attrPos = cast<AttributePosition>(pos); 255 value = builder.create<pdl_interp::GetAttributeOp>( 256 loc, builder.getType<pdl::AttributeType>(), parentVal, 257 attrPos->getName().strref()); 258 break; 259 } 260 case Predicates::TypePos: { 261 if (parentVal.getType().isa<pdl::AttributeType>()) 262 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 263 else 264 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 265 break; 266 } 267 case Predicates::ResultPos: { 268 auto *resPos = cast<ResultPosition>(pos); 269 value = builder.create<pdl_interp::GetResultOp>( 270 loc, builder.getType<pdl::ValueType>(), parentVal, 271 resPos->getResultNumber()); 272 break; 273 } 274 case Predicates::ResultGroupPos: { 275 auto *resPos = cast<ResultGroupPosition>(pos); 276 Type valueTy = builder.getType<pdl::ValueType>(); 277 value = builder.create<pdl_interp::GetResultsOp>( 278 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, 279 parentVal, resPos->getResultGroupNumber()); 280 break; 281 } 282 default: 283 llvm_unreachable("Generating unknown Position getter"); 284 break; 285 } 286 values.insert(pos, value); 287 return value; 288 } 289 290 void PatternLowering::generatePredicate(Block *currentBlock, 291 Qualifier *question, Qualifier *answer, 292 Value val, Block *trueDest, 293 Block *falseDest) { 294 builder.setInsertionPointToEnd(currentBlock); 295 Location loc = val.getLoc(); 296 Predicates::Kind kind = question->getKind(); 297 switch (kind) { 298 case Predicates::IsNotNullQuestion: 299 builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest); 300 break; 301 case Predicates::OperationNameQuestion: { 302 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 303 builder.create<pdl_interp::CheckOperationNameOp>( 304 loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest); 305 break; 306 } 307 case Predicates::TypeQuestion: { 308 auto *ans = cast<TypeAnswer>(answer); 309 if (val.getType().isa<pdl::RangeType>()) 310 builder.create<pdl_interp::CheckTypesOp>( 311 loc, val, ans->getValue().cast<ArrayAttr>(), trueDest, falseDest); 312 else 313 builder.create<pdl_interp::CheckTypeOp>( 314 loc, val, ans->getValue().cast<TypeAttr>(), trueDest, falseDest); 315 break; 316 } 317 case Predicates::AttributeQuestion: { 318 auto *ans = cast<AttributeAnswer>(answer); 319 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 320 trueDest, falseDest); 321 break; 322 } 323 case Predicates::OperandCountAtLeastQuestion: 324 case Predicates::OperandCountQuestion: 325 builder.create<pdl_interp::CheckOperandCountOp>( 326 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 327 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, 328 trueDest, falseDest); 329 break; 330 case Predicates::ResultCountAtLeastQuestion: 331 case Predicates::ResultCountQuestion: 332 builder.create<pdl_interp::CheckResultCountOp>( 333 loc, val, cast<UnsignedAnswer>(answer)->getValue(), 334 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, 335 trueDest, falseDest); 336 break; 337 case Predicates::EqualToQuestion: { 338 auto *equalToQuestion = cast<EqualToQuestion>(question); 339 builder.create<pdl_interp::AreEqualOp>( 340 loc, val, getValueAt(currentBlock, equalToQuestion->getValue()), 341 trueDest, falseDest); 342 break; 343 } 344 case Predicates::ConstraintQuestion: { 345 auto *cstQuestion = cast<ConstraintQuestion>(question); 346 SmallVector<Value, 2> args; 347 for (Position *position : std::get<1>(cstQuestion->getValue())) 348 args.push_back(getValueAt(currentBlock, position)); 349 builder.create<pdl_interp::ApplyConstraintOp>( 350 loc, std::get<0>(cstQuestion->getValue()), args, 351 std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest, 352 falseDest); 353 break; 354 } 355 default: 356 llvm_unreachable("Generating unknown Predicate operation"); 357 } 358 } 359 360 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 361 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 362 llvm::MapVector<Qualifier *, Block *> &dests) { 363 std::vector<ValT> values; 364 std::vector<Block *> blocks; 365 values.reserve(dests.size()); 366 blocks.reserve(dests.size()); 367 for (const auto &it : dests) { 368 blocks.push_back(it.second); 369 values.push_back(cast<PredT>(it.first)->getValue()); 370 } 371 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 372 } 373 374 void PatternLowering::generateSwitch(SwitchNode *switchNode, 375 Block *currentBlock, Qualifier *question, 376 Value val, Block *defaultDest) { 377 // If the switch question is not an exact answer, i.e. for the `at_least` 378 // cases, we generate a special block sequence. 379 Predicates::Kind kind = question->getKind(); 380 if (kind == Predicates::OperandCountAtLeastQuestion || 381 kind == Predicates::ResultCountAtLeastQuestion) { 382 // Order the children such that the cases are in reverse numerical order. 383 SmallVector<unsigned> sortedChildren = 384 llvm::seq<unsigned>(0, switchNode->getChildren().size()) 385 .asSmallVector(); 386 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { 387 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > 388 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); 389 }); 390 391 // Build the destination for each child using the next highest child as a 392 // a failure destination. This essentially creates the following control 393 // flow: 394 // 395 // if (operand_count < 1) 396 // goto failure 397 // if (child1.match()) 398 // ... 399 // 400 // if (operand_count < 2) 401 // goto failure 402 // if (child2.match()) 403 // ... 404 // 405 // failure: 406 // ... 407 // 408 failureBlockStack.push_back(defaultDest); 409 for (unsigned idx : sortedChildren) { 410 auto &child = switchNode->getChild(idx); 411 Block *childBlock = generateMatcher(*child.second); 412 Block *predicateBlock = builder.createBlock(childBlock); 413 generatePredicate(predicateBlock, question, child.first, val, childBlock, 414 defaultDest); 415 failureBlockStack.back() = predicateBlock; 416 } 417 Block *firstPredicateBlock = failureBlockStack.pop_back_val(); 418 currentBlock->getOperations().splice(currentBlock->end(), 419 firstPredicateBlock->getOperations()); 420 firstPredicateBlock->erase(); 421 return; 422 } 423 424 // Otherwise, generate each of the children and generate an interpreter 425 // switch. 426 llvm::MapVector<Qualifier *, Block *> children; 427 for (auto &it : switchNode->getChildren()) 428 children.insert({it.first, generateMatcher(*it.second)}); 429 builder.setInsertionPointToEnd(currentBlock); 430 431 switch (question->getKind()) { 432 case Predicates::OperandCountQuestion: 433 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 434 int32_t>(val, defaultDest, builder, children); 435 case Predicates::ResultCountQuestion: 436 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 437 int32_t>(val, defaultDest, builder, children); 438 case Predicates::OperationNameQuestion: 439 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 440 OperationNameAnswer>(val, defaultDest, builder, 441 children); 442 case Predicates::TypeQuestion: 443 if (val.getType().isa<pdl::RangeType>()) { 444 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( 445 val, defaultDest, builder, children); 446 } 447 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 448 val, defaultDest, builder, children); 449 case Predicates::AttributeQuestion: 450 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 451 val, defaultDest, builder, children); 452 default: 453 llvm_unreachable("Generating unknown switch predicate."); 454 } 455 } 456 457 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock, 458 pdl::PatternOp pattern) { 459 // Generate a rewriter for the pattern this success node represents, and track 460 // any values used from the match region. 461 SmallVector<Position *, 8> usedMatchValues; 462 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 463 464 // Process any values used in the rewrite that are defined in the match. 465 std::vector<Value> mappedMatchValues; 466 mappedMatchValues.reserve(usedMatchValues.size()); 467 for (Position *position : usedMatchValues) 468 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 469 470 // Collect the set of operations generated by the rewriter. 471 SmallVector<StringRef, 4> generatedOps; 472 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>()) 473 generatedOps.push_back(*op.name()); 474 ArrayAttr generatedOpsAttr; 475 if (!generatedOps.empty()) 476 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 477 478 // Grab the root kind if present. 479 StringAttr rootKindAttr; 480 if (Optional<StringRef> rootKind = pattern.getRootKind()) 481 rootKindAttr = builder.getStringAttr(*rootKind); 482 483 builder.setInsertionPointToEnd(currentBlock); 484 builder.create<pdl_interp::RecordMatchOp>( 485 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 486 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), 487 nextBlock); 488 } 489 490 SymbolRefAttr PatternLowering::generateRewriter( 491 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 492 FuncOp rewriterFunc = 493 FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter", 494 builder.getFunctionType(llvm::None, llvm::None)); 495 rewriterSymbolTable.insert(rewriterFunc); 496 497 // Generate the rewriter function body. 498 builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock()); 499 500 // Map an input operand of the pattern to a generated interpreter value. 501 DenseMap<Value, Value> rewriteValues; 502 auto mapRewriteValue = [&](Value oldValue) { 503 Value &newValue = rewriteValues[oldValue]; 504 if (newValue) 505 return newValue; 506 507 // Prefer materializing constants directly when possible. 508 Operation *oldOp = oldValue.getDefiningOp(); 509 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 510 if (Attribute value = attrOp.valueAttr()) { 511 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 512 attrOp.getLoc(), value); 513 } 514 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 515 if (TypeAttr type = typeOp.typeAttr()) { 516 return newValue = builder.create<pdl_interp::CreateTypeOp>( 517 typeOp.getLoc(), type); 518 } 519 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { 520 if (ArrayAttr type = typeOp.typesAttr()) { 521 return newValue = builder.create<pdl_interp::CreateTypesOp>( 522 typeOp.getLoc(), typeOp.getType(), type); 523 } 524 } 525 526 // Otherwise, add this as an input to the rewriter. 527 Position *inputPos = valueToPosition.lookup(oldValue); 528 assert(inputPos && "expected value to be a pattern input"); 529 usedMatchValues.push_back(inputPos); 530 return newValue = rewriterFunc.front().addArgument(oldValue.getType()); 531 }; 532 533 // If this is a custom rewriter, simply dispatch to the registered rewrite 534 // method. 535 pdl::RewriteOp rewriter = pattern.getRewriter(); 536 if (StringAttr rewriteName = rewriter.nameAttr()) { 537 auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); 538 SmallVector<Value, 4> args(1, mapRewriteValue(rewriter.root())); 539 args.append(mappedArgs.begin(), mappedArgs.end()); 540 builder.create<pdl_interp::ApplyRewriteOp>( 541 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, 542 rewriter.externalConstParamsAttr()); 543 } else { 544 // Otherwise this is a dag rewriter defined using PDL operations. 545 for (Operation &rewriteOp : *rewriter.getBody()) { 546 llvm::TypeSwitch<Operation *>(&rewriteOp) 547 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, 548 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, 549 pdl::TypeOp, pdl::TypesOp>([&](auto op) { 550 this->generateRewriter(op, rewriteValues, mapRewriteValue); 551 }); 552 } 553 } 554 555 // Update the signature of the rewrite function. 556 rewriterFunc.setType(builder.getFunctionType( 557 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 558 /*results=*/llvm::None)); 559 560 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 561 return builder.getSymbolRefAttr( 562 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 563 builder.getSymbolRefAttr(rewriterFunc)); 564 } 565 566 void PatternLowering::generateRewriter( 567 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, 568 function_ref<Value(Value)> mapRewriteValue) { 569 SmallVector<Value, 2> arguments; 570 for (Value argument : rewriteOp.args()) 571 arguments.push_back(mapRewriteValue(argument)); 572 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( 573 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), 574 arguments, rewriteOp.constParamsAttr()); 575 for (auto it : llvm::zip(rewriteOp.results(), interpOp.results())) 576 rewriteValues[std::get<0>(it)] = std::get<1>(it); 577 } 578 579 void PatternLowering::generateRewriter( 580 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 581 function_ref<Value(Value)> mapRewriteValue) { 582 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 583 attrOp.getLoc(), attrOp.valueAttr()); 584 rewriteValues[attrOp] = newAttr; 585 } 586 587 void PatternLowering::generateRewriter( 588 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 589 function_ref<Value(Value)> mapRewriteValue) { 590 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 591 mapRewriteValue(eraseOp.operation())); 592 } 593 594 void PatternLowering::generateRewriter( 595 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 596 function_ref<Value(Value)> mapRewriteValue) { 597 SmallVector<Value, 4> operands; 598 for (Value operand : operationOp.operands()) 599 operands.push_back(mapRewriteValue(operand)); 600 601 SmallVector<Value, 4> attributes; 602 for (Value attr : operationOp.attributes()) 603 attributes.push_back(mapRewriteValue(attr)); 604 605 SmallVector<Value, 2> types; 606 generateOperationResultTypeRewriter(operationOp, types, rewriteValues, 607 mapRewriteValue); 608 609 // Create the new operation. 610 Location loc = operationOp.getLoc(); 611 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 612 loc, *operationOp.name(), types, operands, attributes, 613 operationOp.attributeNames()); 614 rewriteValues[operationOp.op()] = createdOp; 615 616 // Generate accesses for any results that have their types constrained. 617 // Handle the case where there is a single range representing all of the 618 // result types. 619 OperandRange resultTys = operationOp.types(); 620 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { 621 Value &type = rewriteValues[resultTys[0]]; 622 if (!type) { 623 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); 624 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); 625 } 626 return; 627 } 628 629 // Otherwise, populate the individual results. 630 bool seenVariableLength = false; 631 Type valueTy = builder.getType<pdl::ValueType>(); 632 Type valueRangeTy = pdl::RangeType::get(valueTy); 633 for (auto it : llvm::enumerate(resultTys)) { 634 Value &type = rewriteValues[it.value()]; 635 if (type) 636 continue; 637 bool isVariadic = it.value().getType().isa<pdl::RangeType>(); 638 seenVariableLength |= isVariadic; 639 640 // After a variable length result has been seen, we need to use result 641 // groups because the exact index of the result is not statically known. 642 Value resultVal; 643 if (seenVariableLength) 644 resultVal = builder.create<pdl_interp::GetResultsOp>( 645 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); 646 else 647 resultVal = builder.create<pdl_interp::GetResultOp>( 648 loc, valueTy, createdOp, it.index()); 649 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); 650 } 651 } 652 653 void PatternLowering::generateRewriter( 654 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 655 function_ref<Value(Value)> mapRewriteValue) { 656 SmallVector<Value, 4> replOperands; 657 658 // If the replacement was another operation, get its results. `pdl` allows 659 // for using an operation for simplicitly, but the interpreter isn't as 660 // user facing. 661 if (Value replOp = replaceOp.replOperation()) { 662 // Don't use replace if we know the replaced operation has no results. 663 auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>(); 664 if (!opOp || !opOp.types().empty()) { 665 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( 666 replOp.getLoc(), mapRewriteValue(replOp))); 667 } 668 } else { 669 for (Value operand : replaceOp.replValues()) 670 replOperands.push_back(mapRewriteValue(operand)); 671 } 672 673 // If there are no replacement values, just create an erase instead. 674 if (replOperands.empty()) { 675 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(), 676 mapRewriteValue(replaceOp.operation())); 677 return; 678 } 679 680 builder.create<pdl_interp::ReplaceOp>( 681 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); 682 } 683 684 void PatternLowering::generateRewriter( 685 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, 686 function_ref<Value(Value)> mapRewriteValue) { 687 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( 688 resultOp.getLoc(), builder.getType<pdl::ValueType>(), 689 mapRewriteValue(resultOp.parent()), resultOp.index()); 690 } 691 692 void PatternLowering::generateRewriter( 693 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, 694 function_ref<Value(Value)> mapRewriteValue) { 695 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( 696 resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()), 697 resultOp.index()); 698 } 699 700 void PatternLowering::generateRewriter( 701 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 702 function_ref<Value(Value)> mapRewriteValue) { 703 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 704 // type. 705 if (TypeAttr typeAttr = typeOp.typeAttr()) { 706 rewriteValues[typeOp] = 707 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 708 } 709 } 710 711 void PatternLowering::generateRewriter( 712 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, 713 function_ref<Value(Value)> mapRewriteValue) { 714 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 715 // type. 716 if (ArrayAttr typeAttr = typeOp.typesAttr()) { 717 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( 718 typeOp.getLoc(), typeOp.getType(), typeAttr); 719 } 720 } 721 722 void PatternLowering::generateOperationResultTypeRewriter( 723 pdl::OperationOp op, SmallVectorImpl<Value> &types, 724 DenseMap<Value, Value> &rewriteValues, 725 function_ref<Value(Value)> mapRewriteValue) { 726 // Look for an operation that was replaced by `op`. The result types will be 727 // inferred from the results that were replaced. 728 Block *rewriterBlock = op->getBlock(); 729 Value replacedOp; 730 for (OpOperand &use : op.op().getUses()) { 731 // Check that the use corresponds to a ReplaceOp and that it is the 732 // replacement value, not the operation being replaced. 733 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 734 if (!replOpUser || use.getOperandNumber() == 0) 735 continue; 736 // Make sure the replaced operation was defined before this one. 737 Value replOpVal = replOpUser.operation(); 738 Operation *replacedOp = replOpVal.getDefiningOp(); 739 if (replacedOp->getBlock() == rewriterBlock && 740 !replacedOp->isBeforeInBlock(op)) 741 continue; 742 743 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( 744 replacedOp->getLoc(), mapRewriteValue(replOpVal)); 745 types.push_back(builder.create<pdl_interp::GetValueTypeOp>( 746 replacedOp->getLoc(), replacedOpResults)); 747 return; 748 } 749 750 // Check if the operation has type inference support. 751 if (op.hasTypeInference()) { 752 types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc())); 753 return; 754 } 755 756 // Otherwise, handle inference for each of the result types individually. 757 OperandRange resultTypeValues = op.types(); 758 types.reserve(resultTypeValues.size()); 759 for (auto it : llvm::enumerate(resultTypeValues)) { 760 Value resultType = it.value(); 761 762 // Check for an already translated value. 763 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 764 types.push_back(existingRewriteValue); 765 continue; 766 } 767 768 // Check for an input from the matcher. 769 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 770 types.push_back(mapRewriteValue(resultType)); 771 continue; 772 } 773 774 // The verifier asserts that the result types of each pdl.operation can be 775 // inferred. If we reach here, there is a bug either in the logic above or 776 // in the verifier for pdl.operation. 777 op->emitOpError() << "unable to infer result type for operation"; 778 llvm_unreachable("unable to infer result type for operation"); 779 } 780 } 781 782 //===----------------------------------------------------------------------===// 783 // Conversion Pass 784 //===----------------------------------------------------------------------===// 785 786 namespace { 787 struct PDLToPDLInterpPass 788 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 789 void runOnOperation() final; 790 }; 791 } // namespace 792 793 /// Convert the given module containing PDL pattern operations into a PDL 794 /// Interpreter operations. 795 void PDLToPDLInterpPass::runOnOperation() { 796 ModuleOp module = getOperation(); 797 798 // Create the main matcher function This function contains all of the match 799 // related functionality from patterns in the module. 800 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 801 FuncOp matcherFunc = builder.create<FuncOp>( 802 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 803 builder.getFunctionType(builder.getType<pdl::OperationType>(), 804 /*results=*/llvm::None), 805 /*attrs=*/llvm::None); 806 807 // Create a nested module to hold the functions invoked for rewriting the IR 808 // after a successful match. 809 ModuleOp rewriterModule = builder.create<ModuleOp>( 810 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 811 812 // Generate the code for the patterns within the module. 813 PatternLowering generator(matcherFunc, rewriterModule); 814 generator.lower(module); 815 816 // After generation, delete all of the pattern operations. 817 for (pdl::PatternOp pattern : 818 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) 819 pattern.erase(); 820 } 821 822 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 823 return std::make_unique<PDLToPDLInterpPass>(); 824 } 825