1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace test; 23 24 // Native function for testing NativeCodeCall 25 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 26 return choice.getValue() ? input1 : input2; 27 } 28 29 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 30 rewriter.create<OpI>(loc, input); 31 } 32 33 static void handleNoResultOp(PatternRewriter &rewriter, 34 OpSymbolBindingNoResult op) { 35 // Turn the no result op to a one-result op. 36 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 37 op.getOperand()); 38 } 39 40 static bool getFirstI32Result(Operation *op, Value &value) { 41 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 42 return false; 43 value = op->getResult(0); 44 return true; 45 } 46 47 static Value bindNativeCodeCallResult(Value value) { return value; } 48 49 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 50 Value input2) { 51 return SmallVector<Value, 2>({input2, input1}); 52 } 53 54 // Test that natives calls are only called once during rewrites. 55 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 56 // This let us check the number of times OpM_Test was called by inspecting 57 // the returned value in the MLIR output. 58 static int64_t opMIncreasingValue = 314159265; 59 static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 60 int64_t i = opMIncreasingValue++; 61 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 62 } 63 64 namespace { 65 #include "TestPatterns.inc" 66 } // end anonymous namespace 67 68 //===----------------------------------------------------------------------===// 69 // Test Reduce Pattern Interface 70 //===----------------------------------------------------------------------===// 71 72 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 73 populateWithGenerated(patterns); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Canonicalizer Driver. 78 //===----------------------------------------------------------------------===// 79 80 namespace { 81 struct FoldingPattern : public RewritePattern { 82 public: 83 FoldingPattern(MLIRContext *context) 84 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 85 /*benefit=*/1, context) {} 86 87 LogicalResult matchAndRewrite(Operation *op, 88 PatternRewriter &rewriter) const override { 89 // Exercise OperationFolder API for a single-result operation that is folded 90 // upon construction. The operation being created through the folder has an 91 // in-place folder, and it should be still present in the output. 92 // Furthermore, the folder should not crash when attempting to recover the 93 // (unchanged) operation result. 94 OperationFolder folder(op->getContext()); 95 Value result = folder.create<TestOpInPlaceFold>( 96 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 97 rewriter.getI32IntegerAttr(0)); 98 assert(result); 99 rewriter.replaceOp(op, result); 100 return success(); 101 } 102 }; 103 104 /// This pattern creates a foldable operation at the entry point of the block. 105 /// This tests the situation where the operation folder will need to replace an 106 /// operation with a previously created constant that does not initially 107 /// dominate the operation to replace. 108 struct FolderInsertBeforePreviouslyFoldedConstantPattern 109 : public OpRewritePattern<TestCastOp> { 110 public: 111 using OpRewritePattern<TestCastOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(TestCastOp op, 114 PatternRewriter &rewriter) const override { 115 if (!op->hasAttr("test_fold_before_previously_folded_op")) 116 return failure(); 117 rewriter.setInsertionPointToStart(op->getBlock()); 118 119 auto constOp = rewriter.create<arith::ConstantOp>( 120 op.getLoc(), rewriter.getBoolAttr(true)); 121 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 122 Value(constOp)); 123 return success(); 124 } 125 }; 126 127 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 128 StringRef getArgument() const final { return "test-patterns"; } 129 StringRef getDescription() const final { return "Run test dialect patterns"; } 130 void runOnFunction() override { 131 mlir::RewritePatternSet patterns(&getContext()); 132 populateWithGenerated(patterns); 133 134 // Verify named pattern is generated with expected name. 135 patterns.add<FoldingPattern, TestNamedPatternRule, 136 FolderInsertBeforePreviouslyFoldedConstantPattern>( 137 &getContext()); 138 139 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 140 } 141 }; 142 } // end anonymous namespace 143 144 //===----------------------------------------------------------------------===// 145 // ReturnType Driver. 146 //===----------------------------------------------------------------------===// 147 148 namespace { 149 // Generate ops for each instance where the type can be successfully inferred. 150 template <typename OpTy> 151 static void invokeCreateWithInferredReturnType(Operation *op) { 152 auto *context = op->getContext(); 153 auto fop = op->getParentOfType<FuncOp>(); 154 auto location = UnknownLoc::get(context); 155 OpBuilder b(op); 156 b.setInsertionPointAfter(op); 157 158 // Use permutations of 2 args as operands. 159 assert(fop.getNumArguments() >= 2); 160 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 161 for (int j = 0; j < e; ++j) { 162 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 163 SmallVector<Type, 2> inferredReturnTypes; 164 if (succeeded(OpTy::inferReturnTypes( 165 context, llvm::None, values, op->getAttrDictionary(), 166 op->getRegions(), inferredReturnTypes))) { 167 OperationState state(location, OpTy::getOperationName()); 168 // TODO: Expand to regions. 169 OpTy::build(b, state, values, op->getAttrs()); 170 (void)b.createOperation(state); 171 } 172 } 173 } 174 } 175 176 static void reifyReturnShape(Operation *op) { 177 OpBuilder b(op); 178 179 // Use permutations of 2 args as operands. 180 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 181 SmallVector<Value, 2> shapes; 182 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 183 !llvm::hasSingleElement(shapes)) 184 return; 185 for (auto it : llvm::enumerate(shapes)) { 186 op->emitRemark() << "value " << it.index() << ": " 187 << it.value().getDefiningOp(); 188 } 189 } 190 191 struct TestReturnTypeDriver 192 : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 193 void getDependentDialects(DialectRegistry ®istry) const override { 194 registry.insert<tensor::TensorDialect>(); 195 } 196 StringRef getArgument() const final { return "test-return-type"; } 197 StringRef getDescription() const final { return "Run return type functions"; } 198 199 void runOnFunction() override { 200 if (getFunction().getName() == "testCreateFunctions") { 201 std::vector<Operation *> ops; 202 // Collect ops to avoid triggering on inserted ops. 203 for (auto &op : getFunction().getBody().front()) 204 ops.push_back(&op); 205 // Generate test patterns for each, but skip terminator. 206 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 207 // Test create method of each of the Op classes below. The resultant 208 // output would be in reverse order underneath `op` from which 209 // the attributes and regions are used. 210 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 211 invokeCreateWithInferredReturnType< 212 OpWithShapedTypeInferTypeInterfaceOp>(op); 213 }; 214 return; 215 } 216 if (getFunction().getName() == "testReifyFunctions") { 217 std::vector<Operation *> ops; 218 // Collect ops to avoid triggering on inserted ops. 219 for (auto &op : getFunction().getBody().front()) 220 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 221 ops.push_back(&op); 222 // Generate test patterns for each, but skip terminator. 223 for (auto *op : ops) 224 reifyReturnShape(op); 225 } 226 } 227 }; 228 } // end anonymous namespace 229 230 namespace { 231 struct TestDerivedAttributeDriver 232 : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 233 StringRef getArgument() const final { return "test-derived-attr"; } 234 StringRef getDescription() const final { 235 return "Run test derived attributes"; 236 } 237 void runOnFunction() override; 238 }; 239 } // end anonymous namespace 240 241 void TestDerivedAttributeDriver::runOnFunction() { 242 getFunction().walk([](DerivedAttributeOpInterface dOp) { 243 auto dAttr = dOp.materializeDerivedAttributes(); 244 if (!dAttr) 245 return; 246 for (auto d : dAttr) 247 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 248 }); 249 } 250 251 //===----------------------------------------------------------------------===// 252 // Legalization Driver. 253 //===----------------------------------------------------------------------===// 254 255 namespace { 256 //===----------------------------------------------------------------------===// 257 // Region-Block Rewrite Testing 258 259 /// This pattern is a simple pattern that inlines the first region of a given 260 /// operation into the parent region. 261 struct TestRegionRewriteBlockMovement : public ConversionPattern { 262 TestRegionRewriteBlockMovement(MLIRContext *ctx) 263 : ConversionPattern("test.region", 1, ctx) {} 264 265 LogicalResult 266 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 267 ConversionPatternRewriter &rewriter) const final { 268 // Inline this region into the parent region. 269 auto &parentRegion = *op->getParentRegion(); 270 auto &opRegion = op->getRegion(0); 271 if (op->getAttr("legalizer.should_clone")) 272 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 273 else 274 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 275 276 if (op->getAttr("legalizer.erase_old_blocks")) { 277 while (!opRegion.empty()) 278 rewriter.eraseBlock(&opRegion.front()); 279 } 280 281 // Drop this operation. 282 rewriter.eraseOp(op); 283 return success(); 284 } 285 }; 286 /// This pattern is a simple pattern that generates a region containing an 287 /// illegal operation. 288 struct TestRegionRewriteUndo : public RewritePattern { 289 TestRegionRewriteUndo(MLIRContext *ctx) 290 : RewritePattern("test.region_builder", 1, ctx) {} 291 292 LogicalResult matchAndRewrite(Operation *op, 293 PatternRewriter &rewriter) const final { 294 // Create the region operation with an entry block containing arguments. 295 OperationState newRegion(op->getLoc(), "test.region"); 296 newRegion.addRegion(); 297 auto *regionOp = rewriter.createOperation(newRegion); 298 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 299 entryBlock->addArgument(rewriter.getIntegerType(64)); 300 301 // Add an explicitly illegal operation to ensure the conversion fails. 302 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 303 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 304 305 // Drop this operation. 306 rewriter.eraseOp(op); 307 return success(); 308 } 309 }; 310 /// A simple pattern that creates a block at the end of the parent region of the 311 /// matched operation. 312 struct TestCreateBlock : public RewritePattern { 313 TestCreateBlock(MLIRContext *ctx) 314 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 315 316 LogicalResult matchAndRewrite(Operation *op, 317 PatternRewriter &rewriter) const final { 318 Region ®ion = *op->getParentRegion(); 319 Type i32Type = rewriter.getIntegerType(32); 320 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 321 rewriter.create<TerminatorOp>(op->getLoc()); 322 rewriter.replaceOp(op, {}); 323 return success(); 324 } 325 }; 326 327 /// A simple pattern that creates a block containing an invalid operation in 328 /// order to trigger the block creation undo mechanism. 329 struct TestCreateIllegalBlock : public RewritePattern { 330 TestCreateIllegalBlock(MLIRContext *ctx) 331 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 332 333 LogicalResult matchAndRewrite(Operation *op, 334 PatternRewriter &rewriter) const final { 335 Region ®ion = *op->getParentRegion(); 336 Type i32Type = rewriter.getIntegerType(32); 337 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 338 // Create an illegal op to ensure the conversion fails. 339 rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 340 rewriter.create<TerminatorOp>(op->getLoc()); 341 rewriter.replaceOp(op, {}); 342 return success(); 343 } 344 }; 345 346 /// A simple pattern that tests the undo mechanism when replacing the uses of a 347 /// block argument. 348 struct TestUndoBlockArgReplace : public ConversionPattern { 349 TestUndoBlockArgReplace(MLIRContext *ctx) 350 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 351 352 LogicalResult 353 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 354 ConversionPatternRewriter &rewriter) const final { 355 auto illegalOp = 356 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 357 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 358 illegalOp); 359 rewriter.updateRootInPlace(op, [] {}); 360 return success(); 361 } 362 }; 363 364 /// A rewrite pattern that tests the undo mechanism when erasing a block. 365 struct TestUndoBlockErase : public ConversionPattern { 366 TestUndoBlockErase(MLIRContext *ctx) 367 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 368 369 LogicalResult 370 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 371 ConversionPatternRewriter &rewriter) const final { 372 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 373 rewriter.setInsertionPointToStart(secondBlock); 374 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 375 rewriter.eraseBlock(secondBlock); 376 rewriter.updateRootInPlace(op, [] {}); 377 return success(); 378 } 379 }; 380 381 //===----------------------------------------------------------------------===// 382 // Type-Conversion Rewrite Testing 383 384 /// This patterns erases a region operation that has had a type conversion. 385 struct TestDropOpSignatureConversion : public ConversionPattern { 386 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 387 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 388 LogicalResult 389 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 390 ConversionPatternRewriter &rewriter) const override { 391 Region ®ion = op->getRegion(0); 392 Block *entry = ®ion.front(); 393 394 // Convert the original entry arguments. 395 TypeConverter &converter = *getTypeConverter(); 396 TypeConverter::SignatureConversion result(entry->getNumArguments()); 397 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 398 result)) || 399 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 400 return failure(); 401 402 // Convert the region signature and just drop the operation. 403 rewriter.eraseOp(op); 404 return success(); 405 } 406 }; 407 /// This pattern simply updates the operands of the given operation. 408 struct TestPassthroughInvalidOp : public ConversionPattern { 409 TestPassthroughInvalidOp(MLIRContext *ctx) 410 : ConversionPattern("test.invalid", 1, ctx) {} 411 LogicalResult 412 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 413 ConversionPatternRewriter &rewriter) const final { 414 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 415 llvm::None); 416 return success(); 417 } 418 }; 419 /// This pattern handles the case of a split return value. 420 struct TestSplitReturnType : public ConversionPattern { 421 TestSplitReturnType(MLIRContext *ctx) 422 : ConversionPattern("test.return", 1, ctx) {} 423 LogicalResult 424 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 425 ConversionPatternRewriter &rewriter) const final { 426 // Check for a return of F32. 427 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 428 return failure(); 429 430 // Check if the first operation is a cast operation, if it is we use the 431 // results directly. 432 auto *defOp = operands[0].getDefiningOp(); 433 if (auto packerOp = 434 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 435 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 436 return success(); 437 } 438 439 // Otherwise, fail to match. 440 return failure(); 441 } 442 }; 443 444 //===----------------------------------------------------------------------===// 445 // Multi-Level Type-Conversion Rewrite Testing 446 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 447 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 448 : ConversionPattern("test.type_producer", 1, ctx) {} 449 LogicalResult 450 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 451 ConversionPatternRewriter &rewriter) const final { 452 // If the type is I32, change the type to F32. 453 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 454 return failure(); 455 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 456 return success(); 457 } 458 }; 459 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 460 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 461 : ConversionPattern("test.type_producer", 1, ctx) {} 462 LogicalResult 463 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 464 ConversionPatternRewriter &rewriter) const final { 465 // If the type is F32, change the type to F64. 466 if (!Type(*op->result_type_begin()).isF32()) 467 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 468 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 469 return success(); 470 } 471 }; 472 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 473 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 474 : ConversionPattern("test.type_producer", 10, ctx) {} 475 LogicalResult 476 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 477 ConversionPatternRewriter &rewriter) const final { 478 // Always convert to B16, even though it is not a legal type. This tests 479 // that values are unmapped correctly. 480 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 481 return success(); 482 } 483 }; 484 struct TestUpdateConsumerType : public ConversionPattern { 485 TestUpdateConsumerType(MLIRContext *ctx) 486 : ConversionPattern("test.type_consumer", 1, ctx) {} 487 LogicalResult 488 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 489 ConversionPatternRewriter &rewriter) const final { 490 // Verify that the incoming operand has been successfully remapped to F64. 491 if (!operands[0].getType().isF64()) 492 return failure(); 493 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 494 return success(); 495 } 496 }; 497 498 //===----------------------------------------------------------------------===// 499 // Non-Root Replacement Rewrite Testing 500 /// This pattern generates an invalid operation, but replaces it before the 501 /// pattern is finished. This checks that we don't need to legalize the 502 /// temporary op. 503 struct TestNonRootReplacement : public RewritePattern { 504 TestNonRootReplacement(MLIRContext *ctx) 505 : RewritePattern("test.replace_non_root", 1, ctx) {} 506 507 LogicalResult matchAndRewrite(Operation *op, 508 PatternRewriter &rewriter) const final { 509 auto resultType = *op->result_type_begin(); 510 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 511 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 512 513 rewriter.replaceOp(illegalOp, {legalOp}); 514 rewriter.replaceOp(op, {illegalOp}); 515 return success(); 516 } 517 }; 518 519 //===----------------------------------------------------------------------===// 520 // Recursive Rewrite Testing 521 /// This pattern is applied to the same operation multiple times, but has a 522 /// bounded recursion. 523 struct TestBoundedRecursiveRewrite 524 : public OpRewritePattern<TestRecursiveRewriteOp> { 525 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 526 527 void initialize() { 528 // The conversion target handles bounding the recursion of this pattern. 529 setHasBoundedRewriteRecursion(); 530 } 531 532 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 533 PatternRewriter &rewriter) const final { 534 // Decrement the depth of the op in-place. 535 rewriter.updateRootInPlace(op, [&] { 536 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 537 }); 538 return success(); 539 } 540 }; 541 542 struct TestNestedOpCreationUndoRewrite 543 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 544 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 545 546 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 547 PatternRewriter &rewriter) const final { 548 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 549 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 550 return success(); 551 }; 552 }; 553 554 // This pattern matches `test.blackhole` and delete this op and its producer. 555 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 556 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 557 558 LogicalResult matchAndRewrite(BlackHoleOp op, 559 PatternRewriter &rewriter) const final { 560 Operation *producer = op.getOperand().getDefiningOp(); 561 // Always erase the user before the producer, the framework should handle 562 // this correctly. 563 rewriter.eraseOp(op); 564 rewriter.eraseOp(producer); 565 return success(); 566 }; 567 }; 568 569 // This pattern replaces explicitly illegal op with explicitly legal op, 570 // but in addition creates unregistered operation. 571 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 572 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 573 574 LogicalResult matchAndRewrite(ILLegalOpG op, 575 PatternRewriter &rewriter) const final { 576 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 577 Value val = rewriter.create<ConstantOp>(op->getLoc(), attr); 578 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 579 return success(); 580 }; 581 }; 582 } // namespace 583 584 namespace { 585 struct TestTypeConverter : public TypeConverter { 586 using TypeConverter::TypeConverter; 587 TestTypeConverter() { 588 addConversion(convertType); 589 addArgumentMaterialization(materializeCast); 590 addSourceMaterialization(materializeCast); 591 } 592 593 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 594 // Drop I16 types. 595 if (t.isSignlessInteger(16)) 596 return success(); 597 598 // Convert I64 to F64. 599 if (t.isSignlessInteger(64)) { 600 results.push_back(FloatType::getF64(t.getContext())); 601 return success(); 602 } 603 604 // Convert I42 to I43. 605 if (t.isInteger(42)) { 606 results.push_back(IntegerType::get(t.getContext(), 43)); 607 return success(); 608 } 609 610 // Split F32 into F16,F16. 611 if (t.isF32()) { 612 results.assign(2, FloatType::getF16(t.getContext())); 613 return success(); 614 } 615 616 // Otherwise, convert the type directly. 617 results.push_back(t); 618 return success(); 619 } 620 621 /// Hook for materializing a conversion. This is necessary because we generate 622 /// 1->N type mappings. 623 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 624 ValueRange inputs, Location loc) { 625 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 626 } 627 }; 628 629 struct TestLegalizePatternDriver 630 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 631 StringRef getArgument() const final { return "test-legalize-patterns"; } 632 StringRef getDescription() const final { 633 return "Run test dialect legalization patterns"; 634 } 635 /// The mode of conversion to use with the driver. 636 enum class ConversionMode { Analysis, Full, Partial }; 637 638 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 639 640 void getDependentDialects(DialectRegistry ®istry) const override { 641 registry.insert<StandardOpsDialect>(); 642 } 643 644 void runOnOperation() override { 645 TestTypeConverter converter; 646 mlir::RewritePatternSet patterns(&getContext()); 647 populateWithGenerated(patterns); 648 patterns 649 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 650 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 651 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 652 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 653 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 654 TestNonRootReplacement, TestBoundedRecursiveRewrite, 655 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 656 TestCreateUnregisteredOp>(&getContext()); 657 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 658 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 659 mlir::populateCallOpTypeConversionPattern(patterns, converter); 660 661 // Define the conversion target used for the test. 662 ConversionTarget target(getContext()); 663 target.addLegalOp<ModuleOp>(); 664 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 665 TerminatorOp>(); 666 target 667 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 668 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 669 // Don't allow F32 operands. 670 return llvm::none_of(op.getOperandTypes(), 671 [](Type type) { return type.isF32(); }); 672 }); 673 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 674 return converter.isSignatureLegal(op.getType()) && 675 converter.isLegal(&op.getBody()); 676 }); 677 target.addDynamicallyLegalOp<CallOp>( 678 [&](CallOp op) { return converter.isLegal(op); }); 679 680 // TestCreateUnregisteredOp creates `arith.constant` operation, 681 // which was not added to target intentionally to test 682 // correct error code from conversion driver. 683 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 684 685 // Expect the type_producer/type_consumer operations to only operate on f64. 686 target.addDynamicallyLegalOp<TestTypeProducerOp>( 687 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 688 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 689 return op.getOperand().getType().isF64(); 690 }); 691 692 // Check support for marking certain operations as recursively legal. 693 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 694 return static_cast<bool>( 695 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 696 }); 697 698 // Mark the bound recursion operation as dynamically legal. 699 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 700 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 701 702 // Handle a partial conversion. 703 if (mode == ConversionMode::Partial) { 704 DenseSet<Operation *> unlegalizedOps; 705 if (failed(applyPartialConversion( 706 getOperation(), target, std::move(patterns), &unlegalizedOps))) { 707 getOperation()->emitRemark() << "applyPartialConversion failed"; 708 } 709 // Emit remarks for each legalizable operation. 710 for (auto *op : unlegalizedOps) 711 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 712 return; 713 } 714 715 // Handle a full conversion. 716 if (mode == ConversionMode::Full) { 717 // Check support for marking unknown operations as dynamically legal. 718 target.markUnknownOpDynamicallyLegal([](Operation *op) { 719 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 720 }); 721 722 if (failed(applyFullConversion(getOperation(), target, 723 std::move(patterns)))) { 724 getOperation()->emitRemark() << "applyFullConversion failed"; 725 } 726 return; 727 } 728 729 // Otherwise, handle an analysis conversion. 730 assert(mode == ConversionMode::Analysis); 731 732 // Analyze the convertible operations. 733 DenseSet<Operation *> legalizedOps; 734 if (failed(applyAnalysisConversion(getOperation(), target, 735 std::move(patterns), legalizedOps))) 736 return signalPassFailure(); 737 738 // Emit remarks for each legalizable operation. 739 for (auto *op : legalizedOps) 740 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 741 } 742 743 /// The mode of conversion to use. 744 ConversionMode mode; 745 }; 746 } // end anonymous namespace 747 748 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 749 legalizerConversionMode( 750 "test-legalize-mode", 751 llvm::cl::desc("The legalization mode to use with the test driver"), 752 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 753 llvm::cl::values( 754 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 755 "analysis", "Perform an analysis conversion"), 756 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 757 "Perform a full conversion"), 758 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 759 "partial", "Perform a partial conversion"))); 760 761 //===----------------------------------------------------------------------===// 762 // ConversionPatternRewriter::getRemappedValue testing. This method is used 763 // to get the remapped value of an original value that was replaced using 764 // ConversionPatternRewriter. 765 namespace { 766 struct TestRemapValueTypeConverter : public TypeConverter { 767 using TypeConverter::TypeConverter; 768 769 TestRemapValueTypeConverter() { 770 addConversion( 771 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 772 addConversion([](Type type) { return type; }); 773 } 774 }; 775 776 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 777 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 778 /// operand twice. 779 /// 780 /// Example: 781 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 782 /// is replaced with: 783 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 784 struct OneVResOneVOperandOp1Converter 785 : public OpConversionPattern<OneVResOneVOperandOp1> { 786 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 787 788 LogicalResult 789 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 790 ConversionPatternRewriter &rewriter) const override { 791 auto origOps = op.getOperands(); 792 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 793 "One operand expected"); 794 Value origOp = *origOps.begin(); 795 SmallVector<Value, 2> remappedOperands; 796 // Replicate the remapped original operand twice. Note that we don't used 797 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 798 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 799 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 800 801 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 802 remappedOperands); 803 return success(); 804 } 805 }; 806 807 /// A rewriter pattern that tests that blocks can be merged. 808 struct TestRemapValueInRegion 809 : public OpConversionPattern<TestRemappedValueRegionOp> { 810 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 811 812 LogicalResult 813 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 814 ConversionPatternRewriter &rewriter) const final { 815 Block &block = op.getBody().front(); 816 Operation *terminator = block.getTerminator(); 817 818 // Merge the block into the parent region. 819 Block *parentBlock = op->getBlock(); 820 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 821 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 822 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 823 824 // Replace the results of this operation with the remapped terminator 825 // values. 826 SmallVector<Value> terminatorOperands; 827 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 828 terminatorOperands))) 829 return failure(); 830 831 rewriter.eraseOp(terminator); 832 rewriter.replaceOp(op, terminatorOperands); 833 return success(); 834 } 835 }; 836 837 struct TestRemappedValue 838 : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 839 StringRef getArgument() const final { return "test-remapped-value"; } 840 StringRef getDescription() const final { 841 return "Test public remapped value mechanism in ConversionPatternRewriter"; 842 } 843 void runOnFunction() override { 844 TestRemapValueTypeConverter typeConverter; 845 846 mlir::RewritePatternSet patterns(&getContext()); 847 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 848 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 849 &getContext()); 850 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 851 852 mlir::ConversionTarget target(getContext()); 853 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 854 855 // Expect the type_producer/type_consumer operations to only operate on f64. 856 target.addDynamicallyLegalOp<TestTypeProducerOp>( 857 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 858 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 859 return op.getOperand().getType().isF64(); 860 }); 861 862 // We make OneVResOneVOperandOp1 legal only when it has more that one 863 // operand. This will trigger the conversion that will replace one-operand 864 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 865 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 866 [](Operation *op) { return op->getNumOperands() > 1; }); 867 868 if (failed(mlir::applyFullConversion(getFunction(), target, 869 std::move(patterns)))) { 870 signalPassFailure(); 871 } 872 } 873 }; 874 } // end anonymous namespace 875 876 //===----------------------------------------------------------------------===// 877 // Test patterns without a specific root operation kind 878 //===----------------------------------------------------------------------===// 879 880 namespace { 881 /// This pattern matches and removes any operation in the test dialect. 882 struct RemoveTestDialectOps : public RewritePattern { 883 RemoveTestDialectOps(MLIRContext *context) 884 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 885 886 LogicalResult matchAndRewrite(Operation *op, 887 PatternRewriter &rewriter) const override { 888 if (!isa<TestDialect>(op->getDialect())) 889 return failure(); 890 rewriter.eraseOp(op); 891 return success(); 892 } 893 }; 894 895 struct TestUnknownRootOpDriver 896 : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 897 StringRef getArgument() const final { 898 return "test-legalize-unknown-root-patterns"; 899 } 900 StringRef getDescription() const final { 901 return "Test public remapped value mechanism in ConversionPatternRewriter"; 902 } 903 void runOnFunction() override { 904 mlir::RewritePatternSet patterns(&getContext()); 905 patterns.add<RemoveTestDialectOps>(&getContext()); 906 907 mlir::ConversionTarget target(getContext()); 908 target.addIllegalDialect<TestDialect>(); 909 if (failed( 910 applyPartialConversion(getFunction(), target, std::move(patterns)))) 911 signalPassFailure(); 912 } 913 }; 914 } // end anonymous namespace 915 916 //===----------------------------------------------------------------------===// 917 // Test type conversions 918 //===----------------------------------------------------------------------===// 919 920 namespace { 921 struct TestTypeConversionProducer 922 : public OpConversionPattern<TestTypeProducerOp> { 923 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 924 LogicalResult 925 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 926 ConversionPatternRewriter &rewriter) const final { 927 Type resultType = op.getType(); 928 Type convertedType = getTypeConverter() 929 ? getTypeConverter()->convertType(resultType) 930 : resultType; 931 if (resultType.isa<FloatType>()) 932 resultType = rewriter.getF64Type(); 933 else if (resultType.isInteger(16)) 934 resultType = rewriter.getIntegerType(64); 935 else if (resultType.isa<test::TestRecursiveType>() && 936 convertedType != resultType) 937 resultType = convertedType; 938 else 939 return failure(); 940 941 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 942 return success(); 943 } 944 }; 945 946 /// Call signature conversion and then fail the rewrite to trigger the undo 947 /// mechanism. 948 struct TestSignatureConversionUndo 949 : public OpConversionPattern<TestSignatureConversionUndoOp> { 950 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 951 952 LogicalResult 953 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 954 ConversionPatternRewriter &rewriter) const final { 955 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 956 return failure(); 957 } 958 }; 959 960 /// Call signature conversion without providing a type converter to handle 961 /// materializations. 962 struct TestTestSignatureConversionNoConverter 963 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 964 TestTestSignatureConversionNoConverter(TypeConverter &converter, 965 MLIRContext *context) 966 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 967 converter(converter) {} 968 969 LogicalResult 970 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 971 ConversionPatternRewriter &rewriter) const final { 972 Region ®ion = op->getRegion(0); 973 Block *entry = ®ion.front(); 974 975 // Convert the original entry arguments. 976 TypeConverter::SignatureConversion result(entry->getNumArguments()); 977 if (failed( 978 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 979 return failure(); 980 rewriter.updateRootInPlace( 981 op, [&] { rewriter.applySignatureConversion(®ion, result); }); 982 return success(); 983 } 984 985 TypeConverter &converter; 986 }; 987 988 /// Just forward the operands to the root op. This is essentially a no-op 989 /// pattern that is used to trigger target materialization. 990 struct TestTypeConsumerForward 991 : public OpConversionPattern<TestTypeConsumerOp> { 992 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 993 994 LogicalResult 995 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 996 ConversionPatternRewriter &rewriter) const final { 997 rewriter.updateRootInPlace(op, 998 [&] { op->setOperands(adaptor.getOperands()); }); 999 return success(); 1000 } 1001 }; 1002 1003 struct TestTypeConversionAnotherProducer 1004 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1005 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1006 1007 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1008 PatternRewriter &rewriter) const final { 1009 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1010 return success(); 1011 } 1012 }; 1013 1014 struct TestTypeConversionDriver 1015 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 1016 void getDependentDialects(DialectRegistry ®istry) const override { 1017 registry.insert<TestDialect>(); 1018 } 1019 StringRef getArgument() const final { 1020 return "test-legalize-type-conversion"; 1021 } 1022 StringRef getDescription() const final { 1023 return "Test various type conversion functionalities in DialectConversion"; 1024 } 1025 1026 void runOnOperation() override { 1027 // Initialize the type converter. 1028 TypeConverter converter; 1029 1030 /// Add the legal set of type conversions. 1031 converter.addConversion([](Type type) -> Type { 1032 // Treat F64 as legal. 1033 if (type.isF64()) 1034 return type; 1035 // Allow converting BF16/F16/F32 to F64. 1036 if (type.isBF16() || type.isF16() || type.isF32()) 1037 return FloatType::getF64(type.getContext()); 1038 // Otherwise, the type is illegal. 1039 return nullptr; 1040 }); 1041 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1042 // Drop all integer types. 1043 return success(); 1044 }); 1045 converter.addConversion( 1046 // Convert a recursive self-referring type into a non-self-referring 1047 // type named "outer_converted_type" that contains a SimpleAType. 1048 [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results, 1049 ArrayRef<Type> callStack) -> Optional<LogicalResult> { 1050 // If the type is already converted, return it to indicate that it is 1051 // legal. 1052 if (type.getName() == "outer_converted_type") { 1053 results.push_back(type); 1054 return success(); 1055 } 1056 1057 // If the type is on the call stack more than once (it is there at 1058 // least once because of the _current_ call, which is always the last 1059 // element on the stack), we've hit the recursive case. Just return 1060 // SimpleAType here to create a non-recursive type as a result. 1061 if (llvm::is_contained(callStack.drop_back(), type)) { 1062 results.push_back(test::SimpleAType::get(type.getContext())); 1063 return success(); 1064 } 1065 1066 // Convert the body recursively. 1067 auto result = test::TestRecursiveType::get(type.getContext(), 1068 "outer_converted_type"); 1069 if (failed(result.setBody(converter.convertType(type.getBody())))) 1070 return failure(); 1071 results.push_back(result); 1072 return success(); 1073 }); 1074 1075 /// Add the legal set of type materializations. 1076 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1077 ValueRange inputs, 1078 Location loc) -> Value { 1079 // Allow casting from F64 back to F32. 1080 if (!resultType.isF16() && inputs.size() == 1 && 1081 inputs[0].getType().isF64()) 1082 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1083 // Allow producing an i32 or i64 from nothing. 1084 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1085 inputs.empty()) 1086 return builder.create<TestTypeProducerOp>(loc, resultType); 1087 // Allow producing an i64 from an integer. 1088 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 1089 inputs[0].getType().isa<IntegerType>()) 1090 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1091 // Otherwise, fail. 1092 return nullptr; 1093 }); 1094 1095 // Initialize the conversion target. 1096 mlir::ConversionTarget target(getContext()); 1097 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1098 auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>(); 1099 return op.getType().isF64() || op.getType().isInteger(64) || 1100 (recursiveType && 1101 recursiveType.getName() == "outer_converted_type"); 1102 }); 1103 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 1104 return converter.isSignatureLegal(op.getType()) && 1105 converter.isLegal(&op.getBody()); 1106 }); 1107 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1108 // Allow casts from F64 to F32. 1109 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1110 }); 1111 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1112 [&](TestSignatureConversionNoConverterOp op) { 1113 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1114 }); 1115 1116 // Initialize the set of rewrite patterns. 1117 RewritePatternSet patterns(&getContext()); 1118 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1119 TestSignatureConversionUndo, 1120 TestTestSignatureConversionNoConverter>(converter, 1121 &getContext()); 1122 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1123 mlir::populateFuncOpTypeConversionPattern(patterns, converter); 1124 1125 if (failed(applyPartialConversion(getOperation(), target, 1126 std::move(patterns)))) 1127 signalPassFailure(); 1128 } 1129 }; 1130 } // end anonymous namespace 1131 1132 //===----------------------------------------------------------------------===// 1133 // Test Block Merging 1134 //===----------------------------------------------------------------------===// 1135 1136 namespace { 1137 /// A rewriter pattern that tests that blocks can be merged. 1138 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1139 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1140 1141 LogicalResult 1142 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1143 ConversionPatternRewriter &rewriter) const final { 1144 Block &firstBlock = op.getBody().front(); 1145 Operation *branchOp = firstBlock.getTerminator(); 1146 Block *secondBlock = &*(std::next(op.getBody().begin())); 1147 auto succOperands = branchOp->getOperands(); 1148 SmallVector<Value, 2> replacements(succOperands); 1149 rewriter.eraseOp(branchOp); 1150 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1151 rewriter.updateRootInPlace(op, [] {}); 1152 return success(); 1153 } 1154 }; 1155 1156 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1157 struct TestUndoBlocksMerge : public ConversionPattern { 1158 TestUndoBlocksMerge(MLIRContext *ctx) 1159 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1160 LogicalResult 1161 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1162 ConversionPatternRewriter &rewriter) const final { 1163 Block &firstBlock = op->getRegion(0).front(); 1164 Operation *branchOp = firstBlock.getTerminator(); 1165 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1166 rewriter.setInsertionPointToStart(secondBlock); 1167 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1168 auto succOperands = branchOp->getOperands(); 1169 SmallVector<Value, 2> replacements(succOperands); 1170 rewriter.eraseOp(branchOp); 1171 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1172 rewriter.updateRootInPlace(op, [] {}); 1173 return success(); 1174 } 1175 }; 1176 1177 /// A rewrite mechanism to inline the body of the op into its parent, when both 1178 /// ops can have a single block. 1179 struct TestMergeSingleBlockOps 1180 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1181 using OpConversionPattern< 1182 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1183 1184 LogicalResult 1185 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1186 ConversionPatternRewriter &rewriter) const final { 1187 SingleBlockImplicitTerminatorOp parentOp = 1188 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1189 if (!parentOp) 1190 return failure(); 1191 Block &innerBlock = op.getRegion().front(); 1192 TerminatorOp innerTerminator = 1193 cast<TerminatorOp>(innerBlock.getTerminator()); 1194 rewriter.mergeBlockBefore(&innerBlock, op); 1195 rewriter.eraseOp(innerTerminator); 1196 rewriter.eraseOp(op); 1197 rewriter.updateRootInPlace(op, [] {}); 1198 return success(); 1199 } 1200 }; 1201 1202 struct TestMergeBlocksPatternDriver 1203 : public PassWrapper<TestMergeBlocksPatternDriver, 1204 OperationPass<ModuleOp>> { 1205 StringRef getArgument() const final { return "test-merge-blocks"; } 1206 StringRef getDescription() const final { 1207 return "Test Merging operation in ConversionPatternRewriter"; 1208 } 1209 void runOnOperation() override { 1210 MLIRContext *context = &getContext(); 1211 mlir::RewritePatternSet patterns(context); 1212 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1213 context); 1214 ConversionTarget target(*context); 1215 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1216 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1217 target.addIllegalOp<ILLegalOpF>(); 1218 1219 /// Expect the op to have a single block after legalization. 1220 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1221 [&](TestMergeBlocksOp op) -> bool { 1222 return llvm::hasSingleElement(op.getBody()); 1223 }); 1224 1225 /// Only allow `test.br` within test.merge_blocks op. 1226 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1227 return op->getParentOfType<TestMergeBlocksOp>(); 1228 }); 1229 1230 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1231 /// inlined. 1232 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1233 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1234 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1235 }); 1236 1237 DenseSet<Operation *> unlegalizedOps; 1238 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1239 &unlegalizedOps); 1240 for (auto *op : unlegalizedOps) 1241 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1242 } 1243 }; 1244 } // namespace 1245 1246 //===----------------------------------------------------------------------===// 1247 // Test Selective Replacement 1248 //===----------------------------------------------------------------------===// 1249 1250 namespace { 1251 /// A rewrite mechanism to inline the body of the op into its parent, when both 1252 /// ops can have a single block. 1253 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1254 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1255 1256 LogicalResult matchAndRewrite(TestCastOp op, 1257 PatternRewriter &rewriter) const final { 1258 if (op.getNumOperands() != 2) 1259 return failure(); 1260 OperandRange operands = op.getOperands(); 1261 1262 // Replace non-terminator uses with the first operand. 1263 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1264 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1265 }); 1266 // Replace everything else with the second operand if the operation isn't 1267 // dead. 1268 rewriter.replaceOp(op, op.getOperand(1)); 1269 return success(); 1270 } 1271 }; 1272 1273 struct TestSelectiveReplacementPatternDriver 1274 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1275 OperationPass<>> { 1276 StringRef getArgument() const final { 1277 return "test-pattern-selective-replacement"; 1278 } 1279 StringRef getDescription() const final { 1280 return "Test selective replacement in the PatternRewriter"; 1281 } 1282 void runOnOperation() override { 1283 MLIRContext *context = &getContext(); 1284 mlir::RewritePatternSet patterns(context); 1285 patterns.add<TestSelectiveOpReplacementPattern>(context); 1286 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1287 std::move(patterns)); 1288 } 1289 }; 1290 } // namespace 1291 1292 //===----------------------------------------------------------------------===// 1293 // PassRegistration 1294 //===----------------------------------------------------------------------===// 1295 1296 namespace mlir { 1297 namespace test { 1298 void registerPatternsTestPass() { 1299 PassRegistration<TestReturnTypeDriver>(); 1300 1301 PassRegistration<TestDerivedAttributeDriver>(); 1302 1303 PassRegistration<TestPatternDriver>(); 1304 1305 PassRegistration<TestLegalizePatternDriver>([] { 1306 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1307 }); 1308 1309 PassRegistration<TestRemappedValue>(); 1310 1311 PassRegistration<TestUnknownRootOpDriver>(); 1312 1313 PassRegistration<TestTypeConversionDriver>(); 1314 1315 PassRegistration<TestMergeBlocksPatternDriver>(); 1316 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1317 } 1318 } // namespace test 1319 } // namespace mlir 1320