1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 using namespace mlir; 20 using namespace test; 21 22 // Native function for testing NativeCodeCall 23 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 24 return choice.getValue() ? input1 : input2; 25 } 26 27 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 28 rewriter.create<OpI>(loc, input); 29 } 30 31 static void handleNoResultOp(PatternRewriter &rewriter, 32 OpSymbolBindingNoResult op) { 33 // Turn the no result op to a one-result op. 34 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 35 op.operand()); 36 } 37 38 static bool getFirstI32Result(Operation *op, Value &value) { 39 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 40 return false; 41 value = op->getResult(0); 42 return true; 43 } 44 45 static Value bindNativeCodeCallResult(Value value) { return value; } 46 47 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 48 Value input2) { 49 return SmallVector<Value, 2>({input2, input1}); 50 } 51 52 // Test that natives calls are only called once during rewrites. 53 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 54 // This let us check the number of times OpM_Test was called by inspecting 55 // the returned value in the MLIR output. 56 static int64_t opMIncreasingValue = 314159265; 57 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 58 int64_t i = opMIncreasingValue++; 59 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 60 } 61 62 namespace { 63 #include "TestPatterns.inc" 64 } // end anonymous namespace 65 66 //===----------------------------------------------------------------------===// 67 // Test Reduce Pattern Interface 68 //===----------------------------------------------------------------------===// 69 70 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 71 populateWithGenerated(patterns); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Canonicalizer Driver. 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 struct FoldingPattern : public RewritePattern { 80 public: 81 FoldingPattern(MLIRContext *context) 82 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 83 /*benefit=*/1, context) {} 84 85 LogicalResult matchAndRewrite(Operation *op, 86 PatternRewriter &rewriter) const override { 87 // Exercise OperationFolder API for a single-result operation that is folded 88 // upon construction. The operation being created through the folder has an 89 // in-place folder, and it should be still present in the output. 90 // Furthermore, the folder should not crash when attempting to recover the 91 // (unchanged) operation result. 92 OperationFolder folder(op->getContext()); 93 Value result = folder.create<TestOpInPlaceFold>( 94 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 95 rewriter.getI32IntegerAttr(0)); 96 assert(result); 97 rewriter.replaceOp(op, result); 98 return success(); 99 } 100 }; 101 102 /// This pattern creates a foldable operation at the entry point of the block. 103 /// This tests the situation where the operation folder will need to replace an 104 /// operation with a previously created constant that does not initially 105 /// dominate the operation to replace. 106 struct FolderInsertBeforePreviouslyFoldedConstantPattern 107 : public OpRewritePattern<TestCastOp> { 108 public: 109 using OpRewritePattern<TestCastOp>::OpRewritePattern; 110 111 LogicalResult matchAndRewrite(TestCastOp op, 112 PatternRewriter &rewriter) const override { 113 if (!op->hasAttr("test_fold_before_previously_folded_op")) 114 return failure(); 115 rewriter.setInsertionPointToStart(op->getBlock()); 116 117 auto constOp = 118 rewriter.create<ConstantOp>(op.getLoc(), rewriter.getBoolAttr(true)); 119 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 120 Value(constOp)); 121 return success(); 122 } 123 }; 124 125 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 126 StringRef getArgument() const final { return "test-patterns"; } 127 StringRef getDescription() const final { return "Run test dialect patterns"; } 128 void runOnFunction() override { 129 mlir::RewritePatternSet patterns(&getContext()); 130 populateWithGenerated(patterns); 131 132 // Verify named pattern is generated with expected name. 133 patterns.add<FoldingPattern, TestNamedPatternRule, 134 FolderInsertBeforePreviouslyFoldedConstantPattern>( 135 &getContext()); 136 137 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 138 } 139 }; 140 } // end anonymous namespace 141 142 //===----------------------------------------------------------------------===// 143 // ReturnType Driver. 144 //===----------------------------------------------------------------------===// 145 146 namespace { 147 // Generate ops for each instance where the type can be successfully inferred. 148 template <typename OpTy> 149 static void invokeCreateWithInferredReturnType(Operation *op) { 150 auto *context = op->getContext(); 151 auto fop = op->getParentOfType<FuncOp>(); 152 auto location = UnknownLoc::get(context); 153 OpBuilder b(op); 154 b.setInsertionPointAfter(op); 155 156 // Use permutations of 2 args as operands. 157 assert(fop.getNumArguments() >= 2); 158 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 159 for (int j = 0; j < e; ++j) { 160 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 161 SmallVector<Type, 2> inferredReturnTypes; 162 if (succeeded(OpTy::inferReturnTypes( 163 context, llvm::None, values, op->getAttrDictionary(), 164 op->getRegions(), inferredReturnTypes))) { 165 OperationState state(location, OpTy::getOperationName()); 166 // TODO: Expand to regions. 167 OpTy::build(b, state, values, op->getAttrs()); 168 (void)b.createOperation(state); 169 } 170 } 171 } 172 } 173 174 static void reifyReturnShape(Operation *op) { 175 OpBuilder b(op); 176 177 // Use permutations of 2 args as operands. 178 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 179 SmallVector<Value, 2> shapes; 180 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 181 !llvm::hasSingleElement(shapes)) 182 return; 183 for (auto it : llvm::enumerate(shapes)) { 184 op->emitRemark() << "value " << it.index() << ": " 185 << it.value().getDefiningOp(); 186 } 187 } 188 189 struct TestReturnTypeDriver 190 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 191 void getDependentDialects(DialectRegistry ®istry) const override { 192 registry.insert<tensor::TensorDialect>(); 193 } 194 StringRef getArgument() const final { return "test-return-type"; } 195 StringRef getDescription() const final { return "Run return type functions"; } 196 197 void runOnFunction() override { 198 if (getFunction().getName() == "testCreateFunctions") { 199 std::vector<Operation *> ops; 200 // Collect ops to avoid triggering on inserted ops. 201 for (auto &op : getFunction().getBody().front()) 202 ops.push_back(&op); 203 // Generate test patterns for each, but skip terminator. 204 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 205 // Test create method of each of the Op classes below. The resultant 206 // output would be in reverse order underneath `op` from which 207 // the attributes and regions are used. 208 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 209 invokeCreateWithInferredReturnType< 210 OpWithShapedTypeInferTypeInterfaceOp>(op); 211 }; 212 return; 213 } 214 if (getFunction().getName() == "testReifyFunctions") { 215 std::vector<Operation *> ops; 216 // Collect ops to avoid triggering on inserted ops. 217 for (auto &op : getFunction().getBody().front()) 218 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 219 ops.push_back(&op); 220 // Generate test patterns for each, but skip terminator. 221 for (auto *op : ops) 222 reifyReturnShape(op); 223 } 224 } 225 }; 226 } // end anonymous namespace 227 228 namespace { 229 struct TestDerivedAttributeDriver 230 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 231 StringRef getArgument() const final { return "test-derived-attr"; } 232 StringRef getDescription() const final { 233 return "Run test derived attributes"; 234 } 235 void runOnFunction() override; 236 }; 237 } // end anonymous namespace 238 239 void TestDerivedAttributeDriver::runOnFunction() { 240 getFunction().walk([](DerivedAttributeOpInterface dOp) { 241 auto dAttr = dOp.materializeDerivedAttributes(); 242 if (!dAttr) 243 return; 244 for (auto d : dAttr) 245 dOp.emitRemark() << d.first << " = " << d.second; 246 }); 247 } 248 249 //===----------------------------------------------------------------------===// 250 // Legalization Driver. 251 //===----------------------------------------------------------------------===// 252 253 namespace { 254 //===----------------------------------------------------------------------===// 255 // Region-Block Rewrite Testing 256 257 /// This pattern is a simple pattern that inlines the first region of a given 258 /// operation into the parent region. 259 struct TestRegionRewriteBlockMovement : public ConversionPattern { 260 TestRegionRewriteBlockMovement(MLIRContext *ctx) 261 : ConversionPattern("test.region", 1, ctx) {} 262 263 LogicalResult 264 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 265 ConversionPatternRewriter &rewriter) const final { 266 // Inline this region into the parent region. 267 auto &parentRegion = *op->getParentRegion(); 268 auto &opRegion = op->getRegion(0); 269 if (op->getAttr("legalizer.should_clone")) 270 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 271 else 272 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 273 274 if (op->getAttr("legalizer.erase_old_blocks")) { 275 while (!opRegion.empty()) 276 rewriter.eraseBlock(&opRegion.front()); 277 } 278 279 // Drop this operation. 280 rewriter.eraseOp(op); 281 return success(); 282 } 283 }; 284 /// This pattern is a simple pattern that generates a region containing an 285 /// illegal operation. 286 struct TestRegionRewriteUndo : public RewritePattern { 287 TestRegionRewriteUndo(MLIRContext *ctx) 288 : RewritePattern("test.region_builder", 1, ctx) {} 289 290 LogicalResult matchAndRewrite(Operation *op, 291 PatternRewriter &rewriter) const final { 292 // Create the region operation with an entry block containing arguments. 293 OperationState newRegion(op->getLoc(), "test.region"); 294 newRegion.addRegion(); 295 auto *regionOp = rewriter.createOperation(newRegion); 296 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 297 entryBlock->addArgument(rewriter.getIntegerType(64)); 298 299 // Add an explicitly illegal operation to ensure the conversion fails. 300 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 301 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 302 303 // Drop this operation. 304 rewriter.eraseOp(op); 305 return success(); 306 } 307 }; 308 /// A simple pattern that creates a block at the end of the parent region of the 309 /// matched operation. 310 struct TestCreateBlock : public RewritePattern { 311 TestCreateBlock(MLIRContext *ctx) 312 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 313 314 LogicalResult matchAndRewrite(Operation *op, 315 PatternRewriter &rewriter) const final { 316 Region ®ion = *op->getParentRegion(); 317 Type i32Type = rewriter.getIntegerType(32); 318 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 319 rewriter.create<TerminatorOp>(op->getLoc()); 320 rewriter.replaceOp(op, {}); 321 return success(); 322 } 323 }; 324 325 /// A simple pattern that creates a block containing an invalid operation in 326 /// order to trigger the block creation undo mechanism. 327 struct TestCreateIllegalBlock : public RewritePattern { 328 TestCreateIllegalBlock(MLIRContext *ctx) 329 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 330 331 LogicalResult matchAndRewrite(Operation *op, 332 PatternRewriter &rewriter) const final { 333 Region ®ion = *op->getParentRegion(); 334 Type i32Type = rewriter.getIntegerType(32); 335 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 336 // Create an illegal op to ensure the conversion fails. 337 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 338 rewriter.create<TerminatorOp>(op->getLoc()); 339 rewriter.replaceOp(op, {}); 340 return success(); 341 } 342 }; 343 344 /// A simple pattern that tests the undo mechanism when replacing the uses of a 345 /// block argument. 346 struct TestUndoBlockArgReplace : public ConversionPattern { 347 TestUndoBlockArgReplace(MLIRContext *ctx) 348 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 349 350 LogicalResult 351 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 352 ConversionPatternRewriter &rewriter) const final { 353 auto illegalOp = 354 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 355 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 356 illegalOp); 357 rewriter.updateRootInPlace(op, [] {}); 358 return success(); 359 } 360 }; 361 362 /// A rewrite pattern that tests the undo mechanism when erasing a block. 363 struct TestUndoBlockErase : public ConversionPattern { 364 TestUndoBlockErase(MLIRContext *ctx) 365 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 366 367 LogicalResult 368 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 369 ConversionPatternRewriter &rewriter) const final { 370 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 371 rewriter.setInsertionPointToStart(secondBlock); 372 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 373 rewriter.eraseBlock(secondBlock); 374 rewriter.updateRootInPlace(op, [] {}); 375 return success(); 376 } 377 }; 378 379 //===----------------------------------------------------------------------===// 380 // Type-Conversion Rewrite Testing 381 382 /// This patterns erases a region operation that has had a type conversion. 383 struct TestDropOpSignatureConversion : public ConversionPattern { 384 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 385 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 386 LogicalResult 387 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 388 ConversionPatternRewriter &rewriter) const override { 389 Region ®ion = op->getRegion(0); 390 Block *entry = ®ion.front(); 391 392 // Convert the original entry arguments. 393 TypeConverter &converter = *getTypeConverter(); 394 TypeConverter::SignatureConversion result(entry->getNumArguments()); 395 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 396 result)) || 397 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 398 return failure(); 399 400 // Convert the region signature and just drop the operation. 401 rewriter.eraseOp(op); 402 return success(); 403 } 404 }; 405 /// This pattern simply updates the operands of the given operation. 406 struct TestPassthroughInvalidOp : public ConversionPattern { 407 TestPassthroughInvalidOp(MLIRContext *ctx) 408 : ConversionPattern("test.invalid", 1, ctx) {} 409 LogicalResult 410 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 411 ConversionPatternRewriter &rewriter) const final { 412 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 413 llvm::None); 414 return success(); 415 } 416 }; 417 /// This pattern handles the case of a split return value. 418 struct TestSplitReturnType : public ConversionPattern { 419 TestSplitReturnType(MLIRContext *ctx) 420 : ConversionPattern("test.return", 1, ctx) {} 421 LogicalResult 422 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 423 ConversionPatternRewriter &rewriter) const final { 424 // Check for a return of F32. 425 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 426 return failure(); 427 428 // Check if the first operation is a cast operation, if it is we use the 429 // results directly. 430 auto *defOp = operands[0].getDefiningOp(); 431 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 432 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 433 return success(); 434 } 435 436 // Otherwise, fail to match. 437 return failure(); 438 } 439 }; 440 441 //===----------------------------------------------------------------------===// 442 // Multi-Level Type-Conversion Rewrite Testing 443 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 444 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 445 : ConversionPattern("test.type_producer", 1, ctx) {} 446 LogicalResult 447 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 448 ConversionPatternRewriter &rewriter) const final { 449 // If the type is I32, change the type to F32. 450 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 451 return failure(); 452 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 453 return success(); 454 } 455 }; 456 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 457 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 458 : ConversionPattern("test.type_producer", 1, ctx) {} 459 LogicalResult 460 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 461 ConversionPatternRewriter &rewriter) const final { 462 // If the type is F32, change the type to F64. 463 if (!Type(*op->result_type_begin()).isF32()) 464 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 465 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 466 return success(); 467 } 468 }; 469 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 470 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 471 : ConversionPattern("test.type_producer", 10, ctx) {} 472 LogicalResult 473 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 474 ConversionPatternRewriter &rewriter) const final { 475 // Always convert to B16, even though it is not a legal type. This tests 476 // that values are unmapped correctly. 477 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 478 return success(); 479 } 480 }; 481 struct TestUpdateConsumerType : public ConversionPattern { 482 TestUpdateConsumerType(MLIRContext *ctx) 483 : ConversionPattern("test.type_consumer", 1, ctx) {} 484 LogicalResult 485 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 486 ConversionPatternRewriter &rewriter) const final { 487 // Verify that the incoming operand has been successfully remapped to F64. 488 if (!operands[0].getType().isF64()) 489 return failure(); 490 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 491 return success(); 492 } 493 }; 494 495 //===----------------------------------------------------------------------===// 496 // Non-Root Replacement Rewrite Testing 497 /// This pattern generates an invalid operation, but replaces it before the 498 /// pattern is finished. This checks that we don't need to legalize the 499 /// temporary op. 500 struct TestNonRootReplacement : public RewritePattern { 501 TestNonRootReplacement(MLIRContext *ctx) 502 : RewritePattern("test.replace_non_root", 1, ctx) {} 503 504 LogicalResult matchAndRewrite(Operation *op, 505 PatternRewriter &rewriter) const final { 506 auto resultType = *op->result_type_begin(); 507 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 508 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 509 510 rewriter.replaceOp(illegalOp, {legalOp}); 511 rewriter.replaceOp(op, {illegalOp}); 512 return success(); 513 } 514 }; 515 516 //===----------------------------------------------------------------------===// 517 // Recursive Rewrite Testing 518 /// This pattern is applied to the same operation multiple times, but has a 519 /// bounded recursion. 520 struct TestBoundedRecursiveRewrite 521 : public OpRewritePattern<TestRecursiveRewriteOp> { 522 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 523 524 void initialize() { 525 // The conversion target handles bounding the recursion of this pattern. 526 setHasBoundedRewriteRecursion(); 527 } 528 529 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 530 PatternRewriter &rewriter) const final { 531 // Decrement the depth of the op in-place. 532 rewriter.updateRootInPlace(op, [&] { 533 op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 534 }); 535 return success(); 536 } 537 }; 538 539 struct TestNestedOpCreationUndoRewrite 540 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 541 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 542 543 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 544 PatternRewriter &rewriter) const final { 545 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 546 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 547 return success(); 548 }; 549 }; 550 551 // This pattern matches `test.blackhole` and delete this op and its producer. 552 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 553 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 554 555 LogicalResult matchAndRewrite(BlackHoleOp op, 556 PatternRewriter &rewriter) const final { 557 Operation *producer = op.getOperand().getDefiningOp(); 558 // Always erase the user before the producer, the framework should handle 559 // this correctly. 560 rewriter.eraseOp(op); 561 rewriter.eraseOp(producer); 562 return success(); 563 }; 564 }; 565 566 // This pattern replaces explicitly illegal op with explicitly legal op, 567 // but in addition creates unregistered operation. 568 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 569 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 570 571 LogicalResult matchAndRewrite(ILLegalOpG op, 572 PatternRewriter &rewriter) const final { 573 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 574 Value val = rewriter.create<ConstantOp>(op->getLoc(), attr); 575 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 576 return success(); 577 }; 578 }; 579 } // namespace 580 581 namespace { 582 struct TestTypeConverter : public TypeConverter { 583 using TypeConverter::TypeConverter; 584 TestTypeConverter() { 585 addConversion(convertType); 586 addArgumentMaterialization(materializeCast); 587 addSourceMaterialization(materializeCast); 588 589 /// Materialize the cast for one-to-one conversion from i64 to f64. 590 const auto materializeOneToOneCast = 591 [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 592 Location loc) -> Optional<Value> { 593 if (resultType.getWidth() == 42 && inputs.size() == 1) 594 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 595 return llvm::None; 596 }; 597 addArgumentMaterialization(materializeOneToOneCast); 598 } 599 600 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 601 // Drop I16 types. 602 if (t.isSignlessInteger(16)) 603 return success(); 604 605 // Convert I64 to F64. 606 if (t.isSignlessInteger(64)) { 607 results.push_back(FloatType::getF64(t.getContext())); 608 return success(); 609 } 610 611 // Convert I42 to I43. 612 if (t.isInteger(42)) { 613 results.push_back(IntegerType::get(t.getContext(), 43)); 614 return success(); 615 } 616 617 // Split F32 into F16,F16. 618 if (t.isF32()) { 619 results.assign(2, FloatType::getF16(t.getContext())); 620 return success(); 621 } 622 623 // Otherwise, convert the type directly. 624 results.push_back(t); 625 return success(); 626 } 627 628 /// Hook for materializing a conversion. This is necessary because we generate 629 /// 1->N type mappings. 630 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 631 ValueRange inputs, Location loc) { 632 if (inputs.size() == 1) 633 return inputs[0]; 634 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 635 } 636 }; 637 638 struct TestLegalizePatternDriver 639 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 640 StringRef getArgument() const final { return "test-legalize-patterns"; } 641 StringRef getDescription() const final { 642 return "Run test dialect legalization patterns"; 643 } 644 /// The mode of conversion to use with the driver. 645 enum class ConversionMode { Analysis, Full, Partial }; 646 647 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 648 649 void getDependentDialects(DialectRegistry ®istry) const override { 650 registry.insert<StandardOpsDialect>(); 651 } 652 653 void runOnOperation() override { 654 TestTypeConverter converter; 655 mlir::RewritePatternSet patterns(&getContext()); 656 populateWithGenerated(patterns); 657 patterns 658 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 659 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 660 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 661 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 662 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 663 TestNonRootReplacement, TestBoundedRecursiveRewrite, 664 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 665 TestCreateUnregisteredOp>(&getContext()); 666 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 667 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 668 mlir::populateCallOpTypeConversionPattern(patterns, converter); 669 670 // Define the conversion target used for the test. 671 ConversionTarget target(getContext()); 672 target.addLegalOp<ModuleOp>(); 673 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 674 TerminatorOp>(); 675 target 676 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 677 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 678 // Don't allow F32 operands. 679 return llvm::none_of(op.getOperandTypes(), 680 [](Type type) { return type.isF32(); }); 681 }); 682 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 683 return converter.isSignatureLegal(op.getType()) && 684 converter.isLegal(&op.getBody()); 685 }); 686 687 // TestCreateUnregisteredOp creates `std.constant` operation, 688 // which was not added to target intentionally to test 689 // correct error code from conversion driver. 690 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 691 692 // Expect the type_producer/type_consumer operations to only operate on f64. 693 target.addDynamicallyLegalOp<TestTypeProducerOp>( 694 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 695 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 696 return op.getOperand().getType().isF64(); 697 }); 698 699 // Check support for marking certain operations as recursively legal. 700 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 701 return static_cast<bool>( 702 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 703 }); 704 705 // Mark the bound recursion operation as dynamically legal. 706 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 707 [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 708 709 // Handle a partial conversion. 710 if (mode == ConversionMode::Partial) { 711 DenseSet<Operation *> unlegalizedOps; 712 if (failed(applyPartialConversion( 713 getOperation(), target, std::move(patterns), &unlegalizedOps))) { 714 getOperation()->emitRemark() << "applyPartialConversion failed"; 715 } 716 // Emit remarks for each legalizable operation. 717 for (auto *op : unlegalizedOps) 718 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 719 return; 720 } 721 722 // Handle a full conversion. 723 if (mode == ConversionMode::Full) { 724 // Check support for marking unknown operations as dynamically legal. 725 target.markUnknownOpDynamicallyLegal([](Operation *op) { 726 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 727 }); 728 729 if (failed(applyFullConversion(getOperation(), target, 730 std::move(patterns)))) { 731 getOperation()->emitRemark() << "applyFullConversion failed"; 732 } 733 return; 734 } 735 736 // Otherwise, handle an analysis conversion. 737 assert(mode == ConversionMode::Analysis); 738 739 // Analyze the convertible operations. 740 DenseSet<Operation *> legalizedOps; 741 if (failed(applyAnalysisConversion(getOperation(), target, 742 std::move(patterns), legalizedOps))) 743 return signalPassFailure(); 744 745 // Emit remarks for each legalizable operation. 746 for (auto *op : legalizedOps) 747 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 748 } 749 750 /// The mode of conversion to use. 751 ConversionMode mode; 752 }; 753 } // end anonymous namespace 754 755 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 756 legalizerConversionMode( 757 "test-legalize-mode", 758 llvm::cl::desc("The legalization mode to use with the test driver"), 759 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 760 llvm::cl::values( 761 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 762 "analysis", "Perform an analysis conversion"), 763 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 764 "Perform a full conversion"), 765 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 766 "partial", "Perform a partial conversion"))); 767 768 //===----------------------------------------------------------------------===// 769 // ConversionPatternRewriter::getRemappedValue testing. This method is used 770 // to get the remapped value of an original value that was replaced using 771 // ConversionPatternRewriter. 772 namespace { 773 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 774 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 775 /// operand twice. 776 /// 777 /// Example: 778 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 779 /// is replaced with: 780 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 781 struct OneVResOneVOperandOp1Converter 782 : public OpConversionPattern<OneVResOneVOperandOp1> { 783 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 784 785 LogicalResult 786 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 787 ConversionPatternRewriter &rewriter) const override { 788 auto origOps = op.getOperands(); 789 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 790 "One operand expected"); 791 Value origOp = *origOps.begin(); 792 SmallVector<Value, 2> remappedOperands; 793 // Replicate the remapped original operand twice. Note that we don't used 794 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 795 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 796 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 797 798 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 799 remappedOperands); 800 return success(); 801 } 802 }; 803 804 struct TestRemappedValue 805 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 806 StringRef getArgument() const final { return "test-remapped-value"; } 807 StringRef getDescription() const final { 808 return "Test public remapped value mechanism in ConversionPatternRewriter"; 809 } 810 void runOnFunction() override { 811 mlir::RewritePatternSet patterns(&getContext()); 812 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 813 814 mlir::ConversionTarget target(getContext()); 815 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 816 // We make OneVResOneVOperandOp1 legal only when it has more that one 817 // operand. This will trigger the conversion that will replace one-operand 818 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 819 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 820 [](Operation *op) -> bool { 821 return std::distance(op->operand_begin(), op->operand_end()) > 1; 822 }); 823 824 if (failed(mlir::applyFullConversion(getFunction(), target, 825 std::move(patterns)))) { 826 signalPassFailure(); 827 } 828 } 829 }; 830 } // end anonymous namespace 831 832 //===----------------------------------------------------------------------===// 833 // Test patterns without a specific root operation kind 834 //===----------------------------------------------------------------------===// 835 836 namespace { 837 /// This pattern matches and removes any operation in the test dialect. 838 struct RemoveTestDialectOps : public RewritePattern { 839 RemoveTestDialectOps(MLIRContext *context) 840 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 841 842 LogicalResult matchAndRewrite(Operation *op, 843 PatternRewriter &rewriter) const override { 844 if (!isa<TestDialect>(op->getDialect())) 845 return failure(); 846 rewriter.eraseOp(op); 847 return success(); 848 } 849 }; 850 851 struct TestUnknownRootOpDriver 852 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 853 StringRef getArgument() const final { 854 return "test-legalize-unknown-root-patterns"; 855 } 856 StringRef getDescription() const final { 857 return "Test public remapped value mechanism in ConversionPatternRewriter"; 858 } 859 void runOnFunction() override { 860 mlir::RewritePatternSet patterns(&getContext()); 861 patterns.add<RemoveTestDialectOps>(&getContext()); 862 863 mlir::ConversionTarget target(getContext()); 864 target.addIllegalDialect<TestDialect>(); 865 if (failed( 866 applyPartialConversion(getFunction(), target, std::move(patterns)))) 867 signalPassFailure(); 868 } 869 }; 870 } // end anonymous namespace 871 872 //===----------------------------------------------------------------------===// 873 // Test type conversions 874 //===----------------------------------------------------------------------===// 875 876 namespace { 877 struct TestTypeConversionProducer 878 : public OpConversionPattern<TestTypeProducerOp> { 879 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 880 LogicalResult 881 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 882 ConversionPatternRewriter &rewriter) const final { 883 Type resultType = op.getType(); 884 if (resultType.isa<FloatType>()) 885 resultType = rewriter.getF64Type(); 886 else if (resultType.isInteger(16)) 887 resultType = rewriter.getIntegerType(64); 888 else 889 return failure(); 890 891 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 892 return success(); 893 } 894 }; 895 896 /// Call signature conversion and then fail the rewrite to trigger the undo 897 /// mechanism. 898 struct TestSignatureConversionUndo 899 : public OpConversionPattern<TestSignatureConversionUndoOp> { 900 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 901 902 LogicalResult 903 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 904 ConversionPatternRewriter &rewriter) const final { 905 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 906 return failure(); 907 } 908 }; 909 910 /// Just forward the operands to the root op. This is essentially a no-op 911 /// pattern that is used to trigger target materialization. 912 struct TestTypeConsumerForward 913 : public OpConversionPattern<TestTypeConsumerOp> { 914 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 915 916 LogicalResult 917 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 918 ConversionPatternRewriter &rewriter) const final { 919 rewriter.updateRootInPlace(op, 920 [&] { op->setOperands(adaptor.getOperands()); }); 921 return success(); 922 } 923 }; 924 925 struct TestTypeConversionAnotherProducer 926 : public OpRewritePattern<TestAnotherTypeProducerOp> { 927 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 928 929 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 930 PatternRewriter &rewriter) const final { 931 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 932 return success(); 933 } 934 }; 935 936 struct TestTypeConversionDriver 937 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 938 void getDependentDialects(DialectRegistry ®istry) const override { 939 registry.insert<TestDialect>(); 940 } 941 StringRef getArgument() const final { 942 return "test-legalize-type-conversion"; 943 } 944 StringRef getDescription() const final { 945 return "Test various type conversion functionalities in DialectConversion"; 946 } 947 948 void runOnOperation() override { 949 // Initialize the type converter. 950 TypeConverter converter; 951 952 /// Add the legal set of type conversions. 953 converter.addConversion([](Type type) -> Type { 954 // Treat F64 as legal. 955 if (type.isF64()) 956 return type; 957 // Allow converting BF16/F16/F32 to F64. 958 if (type.isBF16() || type.isF16() || type.isF32()) 959 return FloatType::getF64(type.getContext()); 960 // Otherwise, the type is illegal. 961 return nullptr; 962 }); 963 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 964 // Drop all integer types. 965 return success(); 966 }); 967 968 /// Add the legal set of type materializations. 969 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 970 ValueRange inputs, 971 Location loc) -> Value { 972 // Allow casting from F64 back to F32. 973 if (!resultType.isF16() && inputs.size() == 1 && 974 inputs[0].getType().isF64()) 975 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 976 // Allow producing an i32 or i64 from nothing. 977 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 978 inputs.empty()) 979 return builder.create<TestTypeProducerOp>(loc, resultType); 980 // Allow producing an i64 from an integer. 981 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 982 inputs[0].getType().isa<IntegerType>()) 983 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 984 // Otherwise, fail. 985 return nullptr; 986 }); 987 988 // Initialize the conversion target. 989 mlir::ConversionTarget target(getContext()); 990 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 991 return op.getType().isF64() || op.getType().isInteger(64); 992 }); 993 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 994 return converter.isSignatureLegal(op.getType()) && 995 converter.isLegal(&op.getBody()); 996 }); 997 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 998 // Allow casts from F64 to F32. 999 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1000 }); 1001 1002 // Initialize the set of rewrite patterns. 1003 RewritePatternSet patterns(&getContext()); 1004 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1005 TestSignatureConversionUndo>(converter, &getContext()); 1006 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1007 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 1008 1009 if (failed(applyPartialConversion(getOperation(), target, 1010 std::move(patterns)))) 1011 signalPassFailure(); 1012 } 1013 }; 1014 } // end anonymous namespace 1015 1016 //===----------------------------------------------------------------------===// 1017 // Test Block Merging 1018 //===----------------------------------------------------------------------===// 1019 1020 namespace { 1021 /// A rewriter pattern that tests that blocks can be merged. 1022 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1023 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1024 1025 LogicalResult 1026 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1027 ConversionPatternRewriter &rewriter) const final { 1028 Block &firstBlock = op.body().front(); 1029 Operation *branchOp = firstBlock.getTerminator(); 1030 Block *secondBlock = &*(std::next(op.body().begin())); 1031 auto succOperands = branchOp->getOperands(); 1032 SmallVector<Value, 2> replacements(succOperands); 1033 rewriter.eraseOp(branchOp); 1034 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1035 rewriter.updateRootInPlace(op, [] {}); 1036 return success(); 1037 } 1038 }; 1039 1040 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1041 struct TestUndoBlocksMerge : public ConversionPattern { 1042 TestUndoBlocksMerge(MLIRContext *ctx) 1043 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1044 LogicalResult 1045 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1046 ConversionPatternRewriter &rewriter) const final { 1047 Block &firstBlock = op->getRegion(0).front(); 1048 Operation *branchOp = firstBlock.getTerminator(); 1049 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1050 rewriter.setInsertionPointToStart(secondBlock); 1051 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1052 auto succOperands = branchOp->getOperands(); 1053 SmallVector<Value, 2> replacements(succOperands); 1054 rewriter.eraseOp(branchOp); 1055 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1056 rewriter.updateRootInPlace(op, [] {}); 1057 return success(); 1058 } 1059 }; 1060 1061 /// A rewrite mechanism to inline the body of the op into its parent, when both 1062 /// ops can have a single block. 1063 struct TestMergeSingleBlockOps 1064 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1065 using OpConversionPattern< 1066 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1067 1068 LogicalResult 1069 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1070 ConversionPatternRewriter &rewriter) const final { 1071 SingleBlockImplicitTerminatorOp parentOp = 1072 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1073 if (!parentOp) 1074 return failure(); 1075 Block &innerBlock = op.region().front(); 1076 TerminatorOp innerTerminator = 1077 cast<TerminatorOp>(innerBlock.getTerminator()); 1078 rewriter.mergeBlockBefore(&innerBlock, op); 1079 rewriter.eraseOp(innerTerminator); 1080 rewriter.eraseOp(op); 1081 rewriter.updateRootInPlace(op, [] {}); 1082 return success(); 1083 } 1084 }; 1085 1086 struct TestMergeBlocksPatternDriver 1087 : public PassWrapper<TestMergeBlocksPatternDriver, 1088 OperationPass<ModuleOp>> { 1089 StringRef getArgument() const final { return "test-merge-blocks"; } 1090 StringRef getDescription() const final { 1091 return "Test Merging operation in ConversionPatternRewriter"; 1092 } 1093 void runOnOperation() override { 1094 MLIRContext *context = &getContext(); 1095 mlir::RewritePatternSet patterns(context); 1096 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1097 context); 1098 ConversionTarget target(*context); 1099 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1100 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1101 target.addIllegalOp<ILLegalOpF>(); 1102 1103 /// Expect the op to have a single block after legalization. 1104 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1105 [&](TestMergeBlocksOp op) -> bool { 1106 return llvm::hasSingleElement(op.body()); 1107 }); 1108 1109 /// Only allow `test.br` within test.merge_blocks op. 1110 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1111 return op->getParentOfType<TestMergeBlocksOp>(); 1112 }); 1113 1114 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1115 /// inlined. 1116 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1117 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1118 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1119 }); 1120 1121 DenseSet<Operation *> unlegalizedOps; 1122 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1123 &unlegalizedOps); 1124 for (auto *op : unlegalizedOps) 1125 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1126 } 1127 }; 1128 } // namespace 1129 1130 //===----------------------------------------------------------------------===// 1131 // Test Selective Replacement 1132 //===----------------------------------------------------------------------===// 1133 1134 namespace { 1135 /// A rewrite mechanism to inline the body of the op into its parent, when both 1136 /// ops can have a single block. 1137 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1138 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1139 1140 LogicalResult matchAndRewrite(TestCastOp op, 1141 PatternRewriter &rewriter) const final { 1142 if (op.getNumOperands() != 2) 1143 return failure(); 1144 OperandRange operands = op.getOperands(); 1145 1146 // Replace non-terminator uses with the first operand. 1147 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1148 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1149 }); 1150 // Replace everything else with the second operand if the operation isn't 1151 // dead. 1152 rewriter.replaceOp(op, op.getOperand(1)); 1153 return success(); 1154 } 1155 }; 1156 1157 struct TestSelectiveReplacementPatternDriver 1158 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1159 OperationPass<>> { 1160 StringRef getArgument() const final { 1161 return "test-pattern-selective-replacement"; 1162 } 1163 StringRef getDescription() const final { 1164 return "Test selective replacement in the PatternRewriter"; 1165 } 1166 void runOnOperation() override { 1167 MLIRContext *context = &getContext(); 1168 mlir::RewritePatternSet patterns(context); 1169 patterns.add<TestSelectiveOpReplacementPattern>(context); 1170 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1171 std::move(patterns)); 1172 } 1173 }; 1174 } // namespace 1175 1176 //===----------------------------------------------------------------------===// 1177 // PassRegistration 1178 //===----------------------------------------------------------------------===// 1179 1180 namespace mlir { 1181 namespace test { 1182 void registerPatternsTestPass() { 1183 PassRegistration<TestReturnTypeDriver>(); 1184 1185 PassRegistration<TestDerivedAttributeDriver>(); 1186 1187 PassRegistration<TestPatternDriver>(); 1188 1189 PassRegistration<TestLegalizePatternDriver>([] { 1190 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1191 }); 1192 1193 PassRegistration<TestRemappedValue>(); 1194 1195 PassRegistration<TestUnknownRootOpDriver>(); 1196 1197 PassRegistration<TestTypeConversionDriver>(); 1198 1199 PassRegistration<TestMergeBlocksPatternDriver>(); 1200 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1201 } 1202 } // namespace test 1203 } // namespace mlir 1204