1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 10 #include "../PassDetail.h" 11 #include "PredicateTree.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 15 #include "mlir/Pass/Pass.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/ScopedHashTable.h" 18 #include "llvm/ADT/SetVector.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::pdl_to_pdl_interp; 23 24 //===----------------------------------------------------------------------===// 25 // PatternLowering 26 //===----------------------------------------------------------------------===// 27 28 namespace { 29 /// This class generators operations within the PDL Interpreter dialect from a 30 /// given module containing PDL pattern operations. 31 struct PatternLowering { 32 public: 33 PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule); 34 35 /// Generate code for matching and rewriting based on the pattern operations 36 /// within the module. 37 void lower(ModuleOp module); 38 39 private: 40 using ValueMap = llvm::ScopedHashTable<Position *, Value>; 41 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; 42 43 /// Generate interpreter operations for the tree rooted at the given matcher 44 /// node. 45 Block *generateMatcher(MatcherNode &node); 46 47 /// Get or create an access to the provided positional value within the 48 /// current block. 49 Value getValueAt(Block *cur, Position *pos); 50 51 /// Create an interpreter predicate operation, branching to the provided true 52 /// and false destinations. 53 void generatePredicate(Block *currentBlock, Qualifier *question, 54 Qualifier *answer, Value val, Block *trueDest, 55 Block *falseDest); 56 57 /// Create an interpreter switch predicate operation, with a provided default 58 /// and several case destinations. 59 void generateSwitch(Block *currentBlock, Qualifier *question, Value val, 60 Block *defaultDest, 61 ArrayRef<std::pair<Qualifier *, Block *>> dests); 62 63 /// Create the interpreter operations to record a successful pattern match. 64 void generateRecordMatch(Block *currentBlock, Block *nextBlock, 65 pdl::PatternOp pattern); 66 67 /// Generate a rewriter function for the given pattern operation, and returns 68 /// a reference to that function. 69 SymbolRefAttr generateRewriter(pdl::PatternOp pattern, 70 SmallVectorImpl<Position *> &usedMatchValues); 71 72 /// Generate the rewriter code for the given operation. 73 void generateRewriter(pdl::AttributeOp attrOp, 74 DenseMap<Value, Value> &rewriteValues, 75 function_ref<Value(Value)> mapRewriteValue); 76 void generateRewriter(pdl::EraseOp eraseOp, 77 DenseMap<Value, Value> &rewriteValues, 78 function_ref<Value(Value)> mapRewriteValue); 79 void generateRewriter(pdl::OperationOp operationOp, 80 DenseMap<Value, Value> &rewriteValues, 81 function_ref<Value(Value)> mapRewriteValue); 82 void generateRewriter(pdl::CreateNativeOp createNativeOp, 83 DenseMap<Value, Value> &rewriteValues, 84 function_ref<Value(Value)> mapRewriteValue); 85 void generateRewriter(pdl::ReplaceOp replaceOp, 86 DenseMap<Value, Value> &rewriteValues, 87 function_ref<Value(Value)> mapRewriteValue); 88 void generateRewriter(pdl::TypeOp typeOp, 89 DenseMap<Value, Value> &rewriteValues, 90 function_ref<Value(Value)> mapRewriteValue); 91 92 /// Generate the values used for resolving the result types of an operation 93 /// created within a dag rewriter region. 94 void generateOperationResultTypeRewriter( 95 pdl::OperationOp op, SmallVectorImpl<Value> &types, 96 DenseMap<Value, Value> &rewriteValues, 97 function_ref<Value(Value)> mapRewriteValue); 98 99 /// A builder to use when generating interpreter operations. 100 OpBuilder builder; 101 102 /// The matcher function used for all match related logic within PDL patterns. 103 FuncOp matcherFunc; 104 105 /// The rewriter module containing the all rewrite related logic within PDL 106 /// patterns. 107 ModuleOp rewriterModule; 108 109 /// The symbol table of the rewriter module used for insertion. 110 SymbolTable rewriterSymbolTable; 111 112 /// A scoped map connecting a position with the corresponding interpreter 113 /// value. 114 ValueMap values; 115 116 /// A stack of blocks used as the failure destination for matcher nodes that 117 /// don't have an explicit failure path. 118 SmallVector<Block *, 8> failureBlockStack; 119 120 /// A mapping between values defined in a pattern match, and the corresponding 121 /// positional value. 122 DenseMap<Value, Position *> valueToPosition; 123 124 /// The set of operation values whose whose location will be used for newly 125 /// generated operations. 126 llvm::SetVector<Value> locOps; 127 }; 128 } // end anonymous namespace 129 130 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule) 131 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), 132 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} 133 134 void PatternLowering::lower(ModuleOp module) { 135 PredicateUniquer predicateUniquer; 136 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); 137 138 // Define top-level scope for the arguments to the matcher function. 139 ValueMapScope topLevelValueScope(values); 140 141 // Insert the root operation, i.e. argument to the matcher, at the root 142 // position. 143 Block *matcherEntryBlock = matcherFunc.addEntryBlock(); 144 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); 145 146 // Generate a root matcher node from the provided PDL module. 147 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( 148 module, predicateBuilder, valueToPosition); 149 Block *firstMatcherBlock = generateMatcher(*root); 150 151 // After generation, merged the first matched block into the entry. 152 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), 153 firstMatcherBlock->getOperations()); 154 firstMatcherBlock->erase(); 155 } 156 157 Block *PatternLowering::generateMatcher(MatcherNode &node) { 158 // Push a new scope for the values used by this matcher. 159 Block *block = matcherFunc.addBlock(); 160 ValueMapScope scope(values); 161 162 // If this is the return node, simply insert the corresponding interpreter 163 // finalize. 164 if (isa<ExitNode>(node)) { 165 builder.setInsertionPointToEnd(block); 166 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); 167 return block; 168 } 169 170 // If this node contains a position, get the corresponding value for this 171 // block. 172 Position *position = node.getPosition(); 173 Value val = position ? getValueAt(block, position) : Value(); 174 175 // Get the next block in the match sequence. 176 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); 177 Block *nextBlock; 178 if (failureNode) { 179 nextBlock = generateMatcher(*failureNode); 180 failureBlockStack.push_back(nextBlock); 181 } else { 182 assert(!failureBlockStack.empty() && "expected valid failure block"); 183 nextBlock = failureBlockStack.back(); 184 } 185 186 // If this value corresponds to an operation, record that we are going to use 187 // its location as part of a fused location. 188 bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); 189 if (isOperationValue) 190 locOps.insert(val); 191 192 // Generate code for a boolean predicate node. 193 if (auto *boolNode = dyn_cast<BoolNode>(&node)) { 194 auto *child = generateMatcher(*boolNode->getSuccessNode()); 195 generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val, 196 child, nextBlock); 197 198 // Generate code for a switch node. 199 } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) { 200 // Collect the next blocks for all of the children and generate a switch. 201 llvm::MapVector<Qualifier *, Block *> children; 202 for (auto &it : switchNode->getChildren()) 203 children.insert({it.first, generateMatcher(*it.second)}); 204 generateSwitch(block, node.getQuestion(), val, nextBlock, 205 children.takeVector()); 206 207 // Generate code for a success node. 208 } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) { 209 generateRecordMatch(block, nextBlock, successNode->getPattern()); 210 } 211 212 if (failureNode) 213 failureBlockStack.pop_back(); 214 if (isOperationValue) 215 locOps.remove(val); 216 return block; 217 } 218 219 Value PatternLowering::getValueAt(Block *cur, Position *pos) { 220 if (Value val = values.lookup(pos)) 221 return val; 222 223 // Get the value for the parent position. 224 Value parentVal = getValueAt(cur, pos->getParent()); 225 226 // TODO: Use a location from the position. 227 Location loc = parentVal.getLoc(); 228 builder.setInsertionPointToEnd(cur); 229 Value value; 230 switch (pos->getKind()) { 231 case Predicates::OperationPos: 232 value = builder.create<pdl_interp::GetDefiningOpOp>( 233 loc, builder.getType<pdl::OperationType>(), parentVal); 234 break; 235 case Predicates::OperandPos: { 236 auto *operandPos = cast<OperandPosition>(pos); 237 value = builder.create<pdl_interp::GetOperandOp>( 238 loc, builder.getType<pdl::ValueType>(), parentVal, 239 operandPos->getOperandNumber()); 240 break; 241 } 242 case Predicates::AttributePos: { 243 auto *attrPos = cast<AttributePosition>(pos); 244 value = builder.create<pdl_interp::GetAttributeOp>( 245 loc, builder.getType<pdl::AttributeType>(), parentVal, 246 attrPos->getName().strref()); 247 break; 248 } 249 case Predicates::TypePos: { 250 if (parentVal.getType().isa<pdl::ValueType>()) 251 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); 252 else 253 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); 254 break; 255 } 256 case Predicates::ResultPos: { 257 auto *resPos = cast<ResultPosition>(pos); 258 value = builder.create<pdl_interp::GetResultOp>( 259 loc, builder.getType<pdl::ValueType>(), parentVal, 260 resPos->getResultNumber()); 261 break; 262 } 263 default: 264 llvm_unreachable("Generating unknown Position getter"); 265 break; 266 } 267 values.insert(pos, value); 268 return value; 269 } 270 271 void PatternLowering::generatePredicate(Block *currentBlock, 272 Qualifier *question, Qualifier *answer, 273 Value val, Block *trueDest, 274 Block *falseDest) { 275 builder.setInsertionPointToEnd(currentBlock); 276 Location loc = val.getLoc(); 277 switch (question->getKind()) { 278 case Predicates::IsNotNullQuestion: 279 builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest); 280 break; 281 case Predicates::OperationNameQuestion: { 282 auto *opNameAnswer = cast<OperationNameAnswer>(answer); 283 builder.create<pdl_interp::CheckOperationNameOp>( 284 loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest); 285 break; 286 } 287 case Predicates::TypeQuestion: { 288 auto *ans = cast<TypeAnswer>(answer); 289 builder.create<pdl_interp::CheckTypeOp>( 290 loc, val, TypeAttr::get(ans->getValue()), trueDest, falseDest); 291 break; 292 } 293 case Predicates::AttributeQuestion: { 294 auto *ans = cast<AttributeAnswer>(answer); 295 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), 296 trueDest, falseDest); 297 break; 298 } 299 case Predicates::OperandCountQuestion: { 300 auto *unsignedAnswer = cast<UnsignedAnswer>(answer); 301 builder.create<pdl_interp::CheckOperandCountOp>( 302 loc, val, unsignedAnswer->getValue(), trueDest, falseDest); 303 break; 304 } 305 case Predicates::ResultCountQuestion: { 306 auto *unsignedAnswer = cast<UnsignedAnswer>(answer); 307 builder.create<pdl_interp::CheckResultCountOp>( 308 loc, val, unsignedAnswer->getValue(), trueDest, falseDest); 309 break; 310 } 311 case Predicates::EqualToQuestion: { 312 auto *equalToQuestion = cast<EqualToQuestion>(question); 313 builder.create<pdl_interp::AreEqualOp>( 314 loc, val, getValueAt(currentBlock, equalToQuestion->getValue()), 315 trueDest, falseDest); 316 break; 317 } 318 case Predicates::ConstraintQuestion: { 319 auto *cstQuestion = cast<ConstraintQuestion>(question); 320 SmallVector<Value, 2> args; 321 for (Position *position : std::get<1>(cstQuestion->getValue())) 322 args.push_back(getValueAt(currentBlock, position)); 323 builder.create<pdl_interp::ApplyConstraintOp>( 324 loc, std::get<0>(cstQuestion->getValue()), args, 325 std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest, 326 falseDest); 327 break; 328 } 329 default: 330 llvm_unreachable("Generating unknown Predicate operation"); 331 } 332 } 333 334 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> 335 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, 336 ArrayRef<std::pair<Qualifier *, Block *>> dests) { 337 std::vector<ValT> values; 338 std::vector<Block *> blocks; 339 values.reserve(dests.size()); 340 blocks.reserve(dests.size()); 341 for (const auto &it : dests) { 342 blocks.push_back(it.second); 343 values.push_back(cast<PredT>(it.first)->getValue()); 344 } 345 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); 346 } 347 348 void PatternLowering::generateSwitch( 349 Block *currentBlock, Qualifier *question, Value val, Block *defaultDest, 350 ArrayRef<std::pair<Qualifier *, Block *>> dests) { 351 builder.setInsertionPointToEnd(currentBlock); 352 switch (question->getKind()) { 353 case Predicates::OperandCountQuestion: 354 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, 355 int32_t>(val, defaultDest, builder, dests); 356 case Predicates::ResultCountQuestion: 357 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, 358 int32_t>(val, defaultDest, builder, dests); 359 case Predicates::OperationNameQuestion: 360 return createSwitchOp<pdl_interp::SwitchOperationNameOp, 361 OperationNameAnswer>(val, defaultDest, builder, 362 dests); 363 case Predicates::TypeQuestion: 364 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( 365 val, defaultDest, builder, dests); 366 case Predicates::AttributeQuestion: 367 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( 368 val, defaultDest, builder, dests); 369 default: 370 llvm_unreachable("Generating unknown switch predicate."); 371 } 372 } 373 374 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock, 375 pdl::PatternOp pattern) { 376 // Generate a rewriter for the pattern this success node represents, and track 377 // any values used from the match region. 378 SmallVector<Position *, 8> usedMatchValues; 379 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); 380 381 // Process any values used in the rewrite that are defined in the match. 382 std::vector<Value> mappedMatchValues; 383 mappedMatchValues.reserve(usedMatchValues.size()); 384 for (Position *position : usedMatchValues) 385 mappedMatchValues.push_back(getValueAt(currentBlock, position)); 386 387 // Collect the set of operations generated by the rewriter. 388 SmallVector<StringRef, 4> generatedOps; 389 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>()) 390 generatedOps.push_back(*op.name()); 391 ArrayAttr generatedOpsAttr; 392 if (!generatedOps.empty()) 393 generatedOpsAttr = builder.getStrArrayAttr(generatedOps); 394 395 // Grab the root kind if present. 396 StringAttr rootKindAttr; 397 if (Optional<StringRef> rootKind = pattern.getRootKind()) 398 rootKindAttr = builder.getStringAttr(*rootKind); 399 400 builder.setInsertionPointToEnd(currentBlock); 401 builder.create<pdl_interp::RecordMatchOp>( 402 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), 403 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), 404 nextBlock); 405 } 406 407 SymbolRefAttr PatternLowering::generateRewriter( 408 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { 409 FuncOp rewriterFunc = 410 FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter", 411 builder.getFunctionType(llvm::None, llvm::None)); 412 rewriterSymbolTable.insert(rewriterFunc); 413 414 // Generate the rewriter function body. 415 builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock()); 416 417 // Map an input operand of the pattern to a generated interpreter value. 418 DenseMap<Value, Value> rewriteValues; 419 auto mapRewriteValue = [&](Value oldValue) { 420 Value &newValue = rewriteValues[oldValue]; 421 if (newValue) 422 return newValue; 423 424 // Prefer materializing constants directly when possible. 425 Operation *oldOp = oldValue.getDefiningOp(); 426 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { 427 if (Attribute value = attrOp.valueAttr()) { 428 return newValue = builder.create<pdl_interp::CreateAttributeOp>( 429 attrOp.getLoc(), value); 430 } 431 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { 432 if (TypeAttr type = typeOp.typeAttr()) { 433 return newValue = builder.create<pdl_interp::CreateTypeOp>( 434 typeOp.getLoc(), type); 435 } 436 } 437 438 // Otherwise, add this as an input to the rewriter. 439 Position *inputPos = valueToPosition.lookup(oldValue); 440 assert(inputPos && "expected value to be a pattern input"); 441 usedMatchValues.push_back(inputPos); 442 return newValue = rewriterFunc.front().addArgument(oldValue.getType()); 443 }; 444 445 // If this is a custom rewriter, simply dispatch to the registered rewrite 446 // method. 447 pdl::RewriteOp rewriter = pattern.getRewriter(); 448 if (StringAttr rewriteName = rewriter.nameAttr()) { 449 Value root = mapRewriteValue(rewriter.root()); 450 SmallVector<Value, 4> args = llvm::to_vector<4>( 451 llvm::map_range(rewriter.externalArgs(), mapRewriteValue)); 452 builder.create<pdl_interp::ApplyRewriteOp>( 453 rewriter.getLoc(), rewriteName, root, args, 454 rewriter.externalConstParamsAttr()); 455 } else { 456 // Otherwise this is a dag rewriter defined using PDL operations. 457 for (Operation &rewriteOp : *rewriter.getBody()) { 458 llvm::TypeSwitch<Operation *>(&rewriteOp) 459 .Case<pdl::AttributeOp, pdl::CreateNativeOp, pdl::EraseOp, 460 pdl::OperationOp, pdl::ReplaceOp, pdl::TypeOp>([&](auto op) { 461 this->generateRewriter(op, rewriteValues, mapRewriteValue); 462 }); 463 } 464 } 465 466 // Update the signature of the rewrite function. 467 rewriterFunc.setType(builder.getFunctionType( 468 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), 469 /*results=*/llvm::None)); 470 471 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); 472 return builder.getSymbolRefAttr( 473 pdl_interp::PDLInterpDialect::getRewriterModuleName(), 474 builder.getSymbolRefAttr(rewriterFunc)); 475 } 476 477 void PatternLowering::generateRewriter( 478 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, 479 function_ref<Value(Value)> mapRewriteValue) { 480 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( 481 attrOp.getLoc(), attrOp.valueAttr()); 482 rewriteValues[attrOp] = newAttr; 483 } 484 485 void PatternLowering::generateRewriter( 486 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, 487 function_ref<Value(Value)> mapRewriteValue) { 488 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), 489 mapRewriteValue(eraseOp.operation())); 490 } 491 492 void PatternLowering::generateRewriter( 493 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, 494 function_ref<Value(Value)> mapRewriteValue) { 495 SmallVector<Value, 4> operands; 496 for (Value operand : operationOp.operands()) 497 operands.push_back(mapRewriteValue(operand)); 498 499 SmallVector<Value, 4> attributes; 500 for (Value attr : operationOp.attributes()) 501 attributes.push_back(mapRewriteValue(attr)); 502 503 SmallVector<Value, 2> types; 504 generateOperationResultTypeRewriter(operationOp, types, rewriteValues, 505 mapRewriteValue); 506 507 // Create the new operation. 508 Location loc = operationOp.getLoc(); 509 Value createdOp = builder.create<pdl_interp::CreateOperationOp>( 510 loc, *operationOp.name(), types, operands, attributes, 511 operationOp.attributeNames()); 512 rewriteValues[operationOp.op()] = createdOp; 513 514 // Make all of the new operation results available. 515 OperandRange resultTypes = operationOp.types(); 516 for (auto it : llvm::enumerate(operationOp.results())) { 517 Value getResultVal = builder.create<pdl_interp::GetResultOp>( 518 loc, builder.getType<pdl::ValueType>(), createdOp, it.index()); 519 rewriteValues[it.value()] = getResultVal; 520 521 // If any of the types have not been resolved, make those available as well. 522 Value &type = rewriteValues[resultTypes[it.index()]]; 523 if (!type) 524 type = builder.create<pdl_interp::GetValueTypeOp>(loc, getResultVal); 525 } 526 } 527 528 void PatternLowering::generateRewriter( 529 pdl::CreateNativeOp createNativeOp, DenseMap<Value, Value> &rewriteValues, 530 function_ref<Value(Value)> mapRewriteValue) { 531 SmallVector<Value, 2> arguments; 532 for (Value argument : createNativeOp.args()) 533 arguments.push_back(mapRewriteValue(argument)); 534 Value result = builder.create<pdl_interp::CreateNativeOp>( 535 createNativeOp.getLoc(), createNativeOp.result().getType(), 536 createNativeOp.nameAttr(), arguments, createNativeOp.constParamsAttr()); 537 rewriteValues[createNativeOp] = result; 538 } 539 540 void PatternLowering::generateRewriter( 541 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, 542 function_ref<Value(Value)> mapRewriteValue) { 543 // If the replacement was another operation, get its results. `pdl` allows 544 // for using an operation for simplicitly, but the interpreter isn't as 545 // user facing. 546 ValueRange origOperands; 547 if (Value replOp = replaceOp.replOperation()) 548 origOperands = cast<pdl::OperationOp>(replOp.getDefiningOp()).results(); 549 else 550 origOperands = replaceOp.replValues(); 551 552 // If there are no replacement values, just create an erase instead. 553 if (origOperands.empty()) { 554 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(), 555 mapRewriteValue(replaceOp.operation())); 556 return; 557 } 558 559 SmallVector<Value, 4> replOperands; 560 for (Value operand : origOperands) 561 replOperands.push_back(mapRewriteValue(operand)); 562 builder.create<pdl_interp::ReplaceOp>( 563 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); 564 } 565 566 void PatternLowering::generateRewriter( 567 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, 568 function_ref<Value(Value)> mapRewriteValue) { 569 // If the type isn't constant, the users (e.g. OperationOp) will resolve this 570 // type. 571 if (TypeAttr typeAttr = typeOp.typeAttr()) { 572 Value newType = 573 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); 574 rewriteValues[typeOp] = newType; 575 } 576 } 577 578 void PatternLowering::generateOperationResultTypeRewriter( 579 pdl::OperationOp op, SmallVectorImpl<Value> &types, 580 DenseMap<Value, Value> &rewriteValues, 581 function_ref<Value(Value)> mapRewriteValue) { 582 // Functor that returns if the given use can be used to infer a type. 583 Block *rewriterBlock = op.getOperation()->getBlock(); 584 auto getReplacedOperationFrom = [&](OpOperand &use) -> Operation * { 585 // Check that the use corresponds to a ReplaceOp and that it is the 586 // replacement value, not the operation being replaced. 587 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); 588 if (!replOpUser || use.getOperandNumber() == 0) 589 return nullptr; 590 // Make sure the replaced operation was defined before this one. 591 Operation *replacedOp = replOpUser.operation().getDefiningOp(); 592 if (replacedOp->getBlock() != rewriterBlock || 593 replacedOp->isBeforeInBlock(op)) 594 return replacedOp; 595 return nullptr; 596 }; 597 598 // If non-None/non-Null, this is an operation that is replaced by `op`. 599 // If Null, there is no full replacement operation for `op`. 600 // If None, a replacement operation hasn't been searched for. 601 Optional<Operation *> fullReplacedOperation; 602 bool hasTypeInference = op.hasTypeInference(); 603 auto resultTypeValues = op.types(); 604 types.reserve(resultTypeValues.size()); 605 for (auto it : llvm::enumerate(op.results())) { 606 Value result = it.value(), resultType = resultTypeValues[it.index()]; 607 608 // Check for an already translated value. 609 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { 610 types.push_back(existingRewriteValue); 611 continue; 612 } 613 614 // Check for an input from the matcher. 615 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { 616 types.push_back(mapRewriteValue(resultType)); 617 continue; 618 } 619 620 // Check if the operation has type inference support. 621 if (hasTypeInference) { 622 types.push_back(builder.create<pdl_interp::InferredTypeOp>(op.getLoc())); 623 continue; 624 } 625 626 // Look for an operation that was replaced by `op`. The result type will be 627 // inferred from the result that was replaced. There is guaranteed to be a 628 // replacement for either the op, or this specific result. Note that this is 629 // guaranteed by the verifier of `pdl::OperationOp`. 630 Operation *replacedOp = nullptr; 631 if (!fullReplacedOperation.hasValue()) { 632 for (OpOperand &use : op.op().getUses()) 633 if ((replacedOp = getReplacedOperationFrom(use))) 634 break; 635 fullReplacedOperation = replacedOp; 636 } else { 637 replacedOp = fullReplacedOperation.getValue(); 638 } 639 // Infer from the result, as there was no fully replaced op. 640 if (!replacedOp) { 641 for (OpOperand &use : result.getUses()) 642 if ((replacedOp = getReplacedOperationFrom(use))) 643 break; 644 assert(replacedOp && "expected replaced op to infer a result type from"); 645 } 646 647 auto replOpOp = cast<pdl::OperationOp>(replacedOp); 648 types.push_back(mapRewriteValue(replOpOp.types()[it.index()])); 649 } 650 } 651 652 //===----------------------------------------------------------------------===// 653 // Conversion Pass 654 //===----------------------------------------------------------------------===// 655 656 namespace { 657 struct PDLToPDLInterpPass 658 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { 659 void runOnOperation() final; 660 }; 661 } // namespace 662 663 /// Convert the given module containing PDL pattern operations into a PDL 664 /// Interpreter operations. 665 void PDLToPDLInterpPass::runOnOperation() { 666 ModuleOp module = getOperation(); 667 668 // Create the main matcher function This function contains all of the match 669 // related functionality from patterns in the module. 670 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); 671 FuncOp matcherFunc = builder.create<FuncOp>( 672 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), 673 builder.getFunctionType(builder.getType<pdl::OperationType>(), 674 /*results=*/llvm::None), 675 /*attrs=*/llvm::None); 676 677 // Create a nested module to hold the functions invoked for rewriting the IR 678 // after a successful match. 679 ModuleOp rewriterModule = builder.create<ModuleOp>( 680 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); 681 682 // Generate the code for the patterns within the module. 683 PatternLowering generator(matcherFunc, rewriterModule); 684 generator.lower(module); 685 686 // After generation, delete all of the pattern operations. 687 for (pdl::PatternOp pattern : 688 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) 689 pattern.erase(); 690 } 691 692 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { 693 return std::make_unique<PDLToPDLInterpPass>(); 694 } 695