1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace test; 23 24 // Native function for testing NativeCodeCall 25 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 26 return choice.getValue() ? input1 : input2; 27 } 28 29 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 30 rewriter.create<OpI>(loc, input); 31 } 32 33 static void handleNoResultOp(PatternRewriter &rewriter, 34 OpSymbolBindingNoResult op) { 35 // Turn the no result op to a one-result op. 36 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 37 op.getOperand()); 38 } 39 40 static bool getFirstI32Result(Operation *op, Value &value) { 41 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 42 return false; 43 value = op->getResult(0); 44 return true; 45 } 46 47 static Value bindNativeCodeCallResult(Value value) { return value; } 48 49 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 50 Value input2) { 51 return SmallVector<Value, 2>({input2, input1}); 52 } 53 54 // Test that natives calls are only called once during rewrites. 55 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 56 // This let us check the number of times OpM_Test was called by inspecting 57 // the returned value in the MLIR output. 58 static int64_t opMIncreasingValue = 314159265; 59 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 60 int64_t i = opMIncreasingValue++; 61 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 62 } 63 64 namespace { 65 #include "TestPatterns.inc" 66 } // namespace 67 68 //===----------------------------------------------------------------------===// 69 // Test Reduce Pattern Interface 70 //===----------------------------------------------------------------------===// 71 72 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 73 populateWithGenerated(patterns); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Canonicalizer Driver. 78 //===----------------------------------------------------------------------===// 79 80 namespace { 81 struct FoldingPattern : public RewritePattern { 82 public: 83 FoldingPattern(MLIRContext *context) 84 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 85 /*benefit=*/1, context) {} 86 87 LogicalResult matchAndRewrite(Operation *op, 88 PatternRewriter &rewriter) const override { 89 // Exercise OperationFolder API for a single-result operation that is folded 90 // upon construction. The operation being created through the folder has an 91 // in-place folder, and it should be still present in the output. 92 // Furthermore, the folder should not crash when attempting to recover the 93 // (unchanged) operation result. 94 OperationFolder folder(op->getContext()); 95 Value result = folder.create<TestOpInPlaceFold>( 96 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 97 rewriter.getI32IntegerAttr(0)); 98 assert(result); 99 rewriter.replaceOp(op, result); 100 return success(); 101 } 102 }; 103 104 /// This pattern creates a foldable operation at the entry point of the block. 105 /// This tests the situation where the operation folder will need to replace an 106 /// operation with a previously created constant that does not initially 107 /// dominate the operation to replace. 108 struct FolderInsertBeforePreviouslyFoldedConstantPattern 109 : public OpRewritePattern<TestCastOp> { 110 public: 111 using OpRewritePattern<TestCastOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(TestCastOp op, 114 PatternRewriter &rewriter) const override { 115 if (!op->hasAttr("test_fold_before_previously_folded_op")) 116 return failure(); 117 rewriter.setInsertionPointToStart(op->getBlock()); 118 119 auto constOp = rewriter.create<arith::ConstantOp>( 120 op.getLoc(), rewriter.getBoolAttr(true)); 121 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 122 Value(constOp)); 123 return success(); 124 } 125 }; 126 127 /// This pattern matches test.op_commutative2 with the first operand being 128 /// another test.op_commutative2 with a constant on the right side and fold it 129 /// away by propagating it as its result. This is intend to check that patterns 130 /// are applied after the commutative property moves constant to the right. 131 struct FolderCommutativeOp2WithConstant 132 : public OpRewritePattern<TestCommutative2Op> { 133 public: 134 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 135 136 LogicalResult matchAndRewrite(TestCommutative2Op op, 137 PatternRewriter &rewriter) const override { 138 auto operand = 139 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 140 if (!operand) 141 return failure(); 142 Attribute constInput; 143 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 144 return failure(); 145 rewriter.replaceOp(op, operand->getOperand(1)); 146 return success(); 147 } 148 }; 149 150 struct TestPatternDriver 151 : public PassWrapper<TestPatternDriver, OperationPass<FuncOp>> { 152 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) 153 154 StringRef getArgument() const final { return "test-patterns"; } 155 StringRef getDescription() const final { return "Run test dialect patterns"; } 156 void runOnOperation() override { 157 mlir::RewritePatternSet patterns(&getContext()); 158 populateWithGenerated(patterns); 159 160 // Verify named pattern is generated with expected name. 161 patterns.add<FoldingPattern, TestNamedPatternRule, 162 FolderInsertBeforePreviouslyFoldedConstantPattern, 163 FolderCommutativeOp2WithConstant>(&getContext()); 164 165 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 166 } 167 }; 168 } // namespace 169 170 //===----------------------------------------------------------------------===// 171 // ReturnType Driver. 172 //===----------------------------------------------------------------------===// 173 174 namespace { 175 // Generate ops for each instance where the type can be successfully inferred. 176 template <typename OpTy> 177 static void invokeCreateWithInferredReturnType(Operation *op) { 178 auto *context = op->getContext(); 179 auto fop = op->getParentOfType<FuncOp>(); 180 auto location = UnknownLoc::get(context); 181 OpBuilder b(op); 182 b.setInsertionPointAfter(op); 183 184 // Use permutations of 2 args as operands. 185 assert(fop.getNumArguments() >= 2); 186 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 187 for (int j = 0; j < e; ++j) { 188 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 189 SmallVector<Type, 2> inferredReturnTypes; 190 if (succeeded(OpTy::inferReturnTypes( 191 context, llvm::None, values, op->getAttrDictionary(), 192 op->getRegions(), inferredReturnTypes))) { 193 OperationState state(location, OpTy::getOperationName()); 194 // TODO: Expand to regions. 195 OpTy::build(b, state, values, op->getAttrs()); 196 (void)b.create(state); 197 } 198 } 199 } 200 } 201 202 static void reifyReturnShape(Operation *op) { 203 OpBuilder b(op); 204 205 // Use permutations of 2 args as operands. 206 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 207 SmallVector<Value, 2> shapes; 208 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 209 !llvm::hasSingleElement(shapes)) 210 return; 211 for (const auto &it : llvm::enumerate(shapes)) { 212 op->emitRemark() << "value " << it.index() << ": " 213 << it.value().getDefiningOp(); 214 } 215 } 216 217 struct TestReturnTypeDriver 218 : public PassWrapper<TestReturnTypeDriver, OperationPass<FuncOp>> { 219 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 220 221 void getDependentDialects(DialectRegistry ®istry) const override { 222 registry.insert<tensor::TensorDialect>(); 223 } 224 StringRef getArgument() const final { return "test-return-type"; } 225 StringRef getDescription() const final { return "Run return type functions"; } 226 227 void runOnOperation() override { 228 if (getOperation().getName() == "testCreateFunctions") { 229 std::vector<Operation *> ops; 230 // Collect ops to avoid triggering on inserted ops. 231 for (auto &op : getOperation().getBody().front()) 232 ops.push_back(&op); 233 // Generate test patterns for each, but skip terminator. 234 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 235 // Test create method of each of the Op classes below. The resultant 236 // output would be in reverse order underneath `op` from which 237 // the attributes and regions are used. 238 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 239 invokeCreateWithInferredReturnType< 240 OpWithShapedTypeInferTypeInterfaceOp>(op); 241 }; 242 return; 243 } 244 if (getOperation().getName() == "testReifyFunctions") { 245 std::vector<Operation *> ops; 246 // Collect ops to avoid triggering on inserted ops. 247 for (auto &op : getOperation().getBody().front()) 248 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 249 ops.push_back(&op); 250 // Generate test patterns for each, but skip terminator. 251 for (auto *op : ops) 252 reifyReturnShape(op); 253 } 254 } 255 }; 256 } // namespace 257 258 namespace { 259 struct TestDerivedAttributeDriver 260 : public PassWrapper<TestDerivedAttributeDriver, OperationPass<FuncOp>> { 261 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 262 263 StringRef getArgument() const final { return "test-derived-attr"; } 264 StringRef getDescription() const final { 265 return "Run test derived attributes"; 266 } 267 void runOnOperation() override; 268 }; 269 } // namespace 270 271 void TestDerivedAttributeDriver::runOnOperation() { 272 getOperation().walk([](DerivedAttributeOpInterface dOp) { 273 auto dAttr = dOp.materializeDerivedAttributes(); 274 if (!dAttr) 275 return; 276 for (auto d : dAttr) 277 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 278 }); 279 } 280 281 //===----------------------------------------------------------------------===// 282 // Legalization Driver. 283 //===----------------------------------------------------------------------===// 284 285 namespace { 286 //===----------------------------------------------------------------------===// 287 // Region-Block Rewrite Testing 288 289 /// This pattern is a simple pattern that inlines the first region of a given 290 /// operation into the parent region. 291 struct TestRegionRewriteBlockMovement : public ConversionPattern { 292 TestRegionRewriteBlockMovement(MLIRContext *ctx) 293 : ConversionPattern("test.region", 1, ctx) {} 294 295 LogicalResult 296 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 297 ConversionPatternRewriter &rewriter) const final { 298 // Inline this region into the parent region. 299 auto &parentRegion = *op->getParentRegion(); 300 auto &opRegion = op->getRegion(0); 301 if (op->getAttr("legalizer.should_clone")) 302 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 303 else 304 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 305 306 if (op->getAttr("legalizer.erase_old_blocks")) { 307 while (!opRegion.empty()) 308 rewriter.eraseBlock(&opRegion.front()); 309 } 310 311 // Drop this operation. 312 rewriter.eraseOp(op); 313 return success(); 314 } 315 }; 316 /// This pattern is a simple pattern that generates a region containing an 317 /// illegal operation. 318 struct TestRegionRewriteUndo : public RewritePattern { 319 TestRegionRewriteUndo(MLIRContext *ctx) 320 : RewritePattern("test.region_builder", 1, ctx) {} 321 322 LogicalResult matchAndRewrite(Operation *op, 323 PatternRewriter &rewriter) const final { 324 // Create the region operation with an entry block containing arguments. 325 OperationState newRegion(op->getLoc(), "test.region"); 326 newRegion.addRegion(); 327 auto *regionOp = rewriter.create(newRegion); 328 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 329 entryBlock->addArgument(rewriter.getIntegerType(64), 330 rewriter.getUnknownLoc()); 331 332 // Add an explicitly illegal operation to ensure the conversion fails. 333 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 334 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 335 336 // Drop this operation. 337 rewriter.eraseOp(op); 338 return success(); 339 } 340 }; 341 /// A simple pattern that creates a block at the end of the parent region of the 342 /// matched operation. 343 struct TestCreateBlock : public RewritePattern { 344 TestCreateBlock(MLIRContext *ctx) 345 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 346 347 LogicalResult matchAndRewrite(Operation *op, 348 PatternRewriter &rewriter) const final { 349 Region ®ion = *op->getParentRegion(); 350 Type i32Type = rewriter.getIntegerType(32); 351 Location loc = op->getLoc(); 352 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 353 rewriter.create<TerminatorOp>(loc); 354 rewriter.replaceOp(op, {}); 355 return success(); 356 } 357 }; 358 359 /// A simple pattern that creates a block containing an invalid operation in 360 /// order to trigger the block creation undo mechanism. 361 struct TestCreateIllegalBlock : public RewritePattern { 362 TestCreateIllegalBlock(MLIRContext *ctx) 363 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 364 365 LogicalResult matchAndRewrite(Operation *op, 366 PatternRewriter &rewriter) const final { 367 Region ®ion = *op->getParentRegion(); 368 Type i32Type = rewriter.getIntegerType(32); 369 Location loc = op->getLoc(); 370 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 371 // Create an illegal op to ensure the conversion fails. 372 rewriter.create<ILLegalOpF>(loc, i32Type); 373 rewriter.create<TerminatorOp>(loc); 374 rewriter.replaceOp(op, {}); 375 return success(); 376 } 377 }; 378 379 /// A simple pattern that tests the undo mechanism when replacing the uses of a 380 /// block argument. 381 struct TestUndoBlockArgReplace : public ConversionPattern { 382 TestUndoBlockArgReplace(MLIRContext *ctx) 383 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 384 385 LogicalResult 386 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 387 ConversionPatternRewriter &rewriter) const final { 388 auto illegalOp = 389 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 390 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 391 illegalOp); 392 rewriter.updateRootInPlace(op, [] {}); 393 return success(); 394 } 395 }; 396 397 /// A rewrite pattern that tests the undo mechanism when erasing a block. 398 struct TestUndoBlockErase : public ConversionPattern { 399 TestUndoBlockErase(MLIRContext *ctx) 400 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 401 402 LogicalResult 403 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 404 ConversionPatternRewriter &rewriter) const final { 405 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 406 rewriter.setInsertionPointToStart(secondBlock); 407 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 408 rewriter.eraseBlock(secondBlock); 409 rewriter.updateRootInPlace(op, [] {}); 410 return success(); 411 } 412 }; 413 414 //===----------------------------------------------------------------------===// 415 // Type-Conversion Rewrite Testing 416 417 /// This patterns erases a region operation that has had a type conversion. 418 struct TestDropOpSignatureConversion : public ConversionPattern { 419 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 420 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 421 LogicalResult 422 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 423 ConversionPatternRewriter &rewriter) const override { 424 Region ®ion = op->getRegion(0); 425 Block *entry = ®ion.front(); 426 427 // Convert the original entry arguments. 428 TypeConverter &converter = *getTypeConverter(); 429 TypeConverter::SignatureConversion result(entry->getNumArguments()); 430 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 431 result)) || 432 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 433 return failure(); 434 435 // Convert the region signature and just drop the operation. 436 rewriter.eraseOp(op); 437 return success(); 438 } 439 }; 440 /// This pattern simply updates the operands of the given operation. 441 struct TestPassthroughInvalidOp : public ConversionPattern { 442 TestPassthroughInvalidOp(MLIRContext *ctx) 443 : ConversionPattern("test.invalid", 1, ctx) {} 444 LogicalResult 445 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 446 ConversionPatternRewriter &rewriter) const final { 447 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 448 llvm::None); 449 return success(); 450 } 451 }; 452 /// This pattern handles the case of a split return value. 453 struct TestSplitReturnType : public ConversionPattern { 454 TestSplitReturnType(MLIRContext *ctx) 455 : ConversionPattern("test.return", 1, ctx) {} 456 LogicalResult 457 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 458 ConversionPatternRewriter &rewriter) const final { 459 // Check for a return of F32. 460 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 461 return failure(); 462 463 // Check if the first operation is a cast operation, if it is we use the 464 // results directly. 465 auto *defOp = operands[0].getDefiningOp(); 466 if (auto packerOp = 467 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 468 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 469 return success(); 470 } 471 472 // Otherwise, fail to match. 473 return failure(); 474 } 475 }; 476 477 //===----------------------------------------------------------------------===// 478 // Multi-Level Type-Conversion Rewrite Testing 479 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 480 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 481 : ConversionPattern("test.type_producer", 1, ctx) {} 482 LogicalResult 483 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 484 ConversionPatternRewriter &rewriter) const final { 485 // If the type is I32, change the type to F32. 486 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 487 return failure(); 488 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 489 return success(); 490 } 491 }; 492 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 493 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 494 : ConversionPattern("test.type_producer", 1, ctx) {} 495 LogicalResult 496 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 497 ConversionPatternRewriter &rewriter) const final { 498 // If the type is F32, change the type to F64. 499 if (!Type(*op->result_type_begin()).isF32()) 500 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 501 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 502 return success(); 503 } 504 }; 505 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 506 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 507 : ConversionPattern("test.type_producer", 10, ctx) {} 508 LogicalResult 509 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 510 ConversionPatternRewriter &rewriter) const final { 511 // Always convert to B16, even though it is not a legal type. This tests 512 // that values are unmapped correctly. 513 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 514 return success(); 515 } 516 }; 517 struct TestUpdateConsumerType : public ConversionPattern { 518 TestUpdateConsumerType(MLIRContext *ctx) 519 : ConversionPattern("test.type_consumer", 1, ctx) {} 520 LogicalResult 521 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 522 ConversionPatternRewriter &rewriter) const final { 523 // Verify that the incoming operand has been successfully remapped to F64. 524 if (!operands[0].getType().isF64()) 525 return failure(); 526 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 527 return success(); 528 } 529 }; 530 531 //===----------------------------------------------------------------------===// 532 // Non-Root Replacement Rewrite Testing 533 /// This pattern generates an invalid operation, but replaces it before the 534 /// pattern is finished. This checks that we don't need to legalize the 535 /// temporary op. 536 struct TestNonRootReplacement : public RewritePattern { 537 TestNonRootReplacement(MLIRContext *ctx) 538 : RewritePattern("test.replace_non_root", 1, ctx) {} 539 540 LogicalResult matchAndRewrite(Operation *op, 541 PatternRewriter &rewriter) const final { 542 auto resultType = *op->result_type_begin(); 543 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 544 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 545 546 rewriter.replaceOp(illegalOp, {legalOp}); 547 rewriter.replaceOp(op, {illegalOp}); 548 return success(); 549 } 550 }; 551 552 //===----------------------------------------------------------------------===// 553 // Recursive Rewrite Testing 554 /// This pattern is applied to the same operation multiple times, but has a 555 /// bounded recursion. 556 struct TestBoundedRecursiveRewrite 557 : public OpRewritePattern<TestRecursiveRewriteOp> { 558 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 559 560 void initialize() { 561 // The conversion target handles bounding the recursion of this pattern. 562 setHasBoundedRewriteRecursion(); 563 } 564 565 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 566 PatternRewriter &rewriter) const final { 567 // Decrement the depth of the op in-place. 568 rewriter.updateRootInPlace(op, [&] { 569 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 570 }); 571 return success(); 572 } 573 }; 574 575 struct TestNestedOpCreationUndoRewrite 576 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 577 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 578 579 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 580 PatternRewriter &rewriter) const final { 581 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 582 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 583 return success(); 584 }; 585 }; 586 587 // This pattern matches `test.blackhole` and delete this op and its producer. 588 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 589 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 590 591 LogicalResult matchAndRewrite(BlackHoleOp op, 592 PatternRewriter &rewriter) const final { 593 Operation *producer = op.getOperand().getDefiningOp(); 594 // Always erase the user before the producer, the framework should handle 595 // this correctly. 596 rewriter.eraseOp(op); 597 rewriter.eraseOp(producer); 598 return success(); 599 }; 600 }; 601 602 // This pattern replaces explicitly illegal op with explicitly legal op, 603 // but in addition creates unregistered operation. 604 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 605 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 606 607 LogicalResult matchAndRewrite(ILLegalOpG op, 608 PatternRewriter &rewriter) const final { 609 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 610 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 611 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 612 return success(); 613 }; 614 }; 615 } // namespace 616 617 namespace { 618 struct TestTypeConverter : public TypeConverter { 619 using TypeConverter::TypeConverter; 620 TestTypeConverter() { 621 addConversion(convertType); 622 addArgumentMaterialization(materializeCast); 623 addSourceMaterialization(materializeCast); 624 } 625 626 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 627 // Drop I16 types. 628 if (t.isSignlessInteger(16)) 629 return success(); 630 631 // Convert I64 to F64. 632 if (t.isSignlessInteger(64)) { 633 results.push_back(FloatType::getF64(t.getContext())); 634 return success(); 635 } 636 637 // Convert I42 to I43. 638 if (t.isInteger(42)) { 639 results.push_back(IntegerType::get(t.getContext(), 43)); 640 return success(); 641 } 642 643 // Split F32 into F16,F16. 644 if (t.isF32()) { 645 results.assign(2, FloatType::getF16(t.getContext())); 646 return success(); 647 } 648 649 // Otherwise, convert the type directly. 650 results.push_back(t); 651 return success(); 652 } 653 654 /// Hook for materializing a conversion. This is necessary because we generate 655 /// 1->N type mappings. 656 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 657 ValueRange inputs, Location loc) { 658 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 659 } 660 }; 661 662 struct TestLegalizePatternDriver 663 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 664 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 665 666 StringRef getArgument() const final { return "test-legalize-patterns"; } 667 StringRef getDescription() const final { 668 return "Run test dialect legalization patterns"; 669 } 670 /// The mode of conversion to use with the driver. 671 enum class ConversionMode { Analysis, Full, Partial }; 672 673 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 674 675 void getDependentDialects(DialectRegistry ®istry) const override { 676 registry.insert<func::FuncDialect>(); 677 } 678 679 void runOnOperation() override { 680 TestTypeConverter converter; 681 mlir::RewritePatternSet patterns(&getContext()); 682 populateWithGenerated(patterns); 683 patterns 684 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 685 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 686 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 687 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 688 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 689 TestNonRootReplacement, TestBoundedRecursiveRewrite, 690 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 691 TestCreateUnregisteredOp>(&getContext()); 692 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 693 mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, 694 converter); 695 mlir::populateCallOpTypeConversionPattern(patterns, converter); 696 697 // Define the conversion target used for the test. 698 ConversionTarget target(getContext()); 699 target.addLegalOp<ModuleOp>(); 700 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 701 TerminatorOp>(); 702 target 703 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 704 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 705 // Don't allow F32 operands. 706 return llvm::none_of(op.getOperandTypes(), 707 [](Type type) { return type.isF32(); }); 708 }); 709 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 710 return converter.isSignatureLegal(op.getFunctionType()) && 711 converter.isLegal(&op.getBody()); 712 }); 713 target.addDynamicallyLegalOp<func::CallOp>( 714 [&](func::CallOp op) { return converter.isLegal(op); }); 715 716 // TestCreateUnregisteredOp creates `arith.constant` operation, 717 // which was not added to target intentionally to test 718 // correct error code from conversion driver. 719 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 720 721 // Expect the type_producer/type_consumer operations to only operate on f64. 722 target.addDynamicallyLegalOp<TestTypeProducerOp>( 723 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 724 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 725 return op.getOperand().getType().isF64(); 726 }); 727 728 // Check support for marking certain operations as recursively legal. 729 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 730 return static_cast<bool>( 731 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 732 }); 733 734 // Mark the bound recursion operation as dynamically legal. 735 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 736 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 737 738 // Handle a partial conversion. 739 if (mode == ConversionMode::Partial) { 740 DenseSet<Operation *> unlegalizedOps; 741 if (failed(applyPartialConversion( 742 getOperation(), target, std::move(patterns), &unlegalizedOps))) { 743 getOperation()->emitRemark() << "applyPartialConversion failed"; 744 } 745 // Emit remarks for each legalizable operation. 746 for (auto *op : unlegalizedOps) 747 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 748 return; 749 } 750 751 // Handle a full conversion. 752 if (mode == ConversionMode::Full) { 753 // Check support for marking unknown operations as dynamically legal. 754 target.markUnknownOpDynamicallyLegal([](Operation *op) { 755 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 756 }); 757 758 if (failed(applyFullConversion(getOperation(), target, 759 std::move(patterns)))) { 760 getOperation()->emitRemark() << "applyFullConversion failed"; 761 } 762 return; 763 } 764 765 // Otherwise, handle an analysis conversion. 766 assert(mode == ConversionMode::Analysis); 767 768 // Analyze the convertible operations. 769 DenseSet<Operation *> legalizedOps; 770 if (failed(applyAnalysisConversion(getOperation(), target, 771 std::move(patterns), legalizedOps))) 772 return signalPassFailure(); 773 774 // Emit remarks for each legalizable operation. 775 for (auto *op : legalizedOps) 776 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 777 } 778 779 /// The mode of conversion to use. 780 ConversionMode mode; 781 }; 782 } // namespace 783 784 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 785 legalizerConversionMode( 786 "test-legalize-mode", 787 llvm::cl::desc("The legalization mode to use with the test driver"), 788 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 789 llvm::cl::values( 790 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 791 "analysis", "Perform an analysis conversion"), 792 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 793 "Perform a full conversion"), 794 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 795 "partial", "Perform a partial conversion"))); 796 797 //===----------------------------------------------------------------------===// 798 // ConversionPatternRewriter::getRemappedValue testing. This method is used 799 // to get the remapped value of an original value that was replaced using 800 // ConversionPatternRewriter. 801 namespace { 802 struct TestRemapValueTypeConverter : public TypeConverter { 803 using TypeConverter::TypeConverter; 804 805 TestRemapValueTypeConverter() { 806 addConversion( 807 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 808 addConversion([](Type type) { return type; }); 809 } 810 }; 811 812 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 813 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 814 /// operand twice. 815 /// 816 /// Example: 817 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 818 /// is replaced with: 819 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 820 struct OneVResOneVOperandOp1Converter 821 : public OpConversionPattern<OneVResOneVOperandOp1> { 822 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 823 824 LogicalResult 825 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 826 ConversionPatternRewriter &rewriter) const override { 827 auto origOps = op.getOperands(); 828 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 829 "One operand expected"); 830 Value origOp = *origOps.begin(); 831 SmallVector<Value, 2> remappedOperands; 832 // Replicate the remapped original operand twice. Note that we don't used 833 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 834 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 835 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 836 837 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 838 remappedOperands); 839 return success(); 840 } 841 }; 842 843 /// A rewriter pattern that tests that blocks can be merged. 844 struct TestRemapValueInRegion 845 : public OpConversionPattern<TestRemappedValueRegionOp> { 846 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 847 848 LogicalResult 849 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 850 ConversionPatternRewriter &rewriter) const final { 851 Block &block = op.getBody().front(); 852 Operation *terminator = block.getTerminator(); 853 854 // Merge the block into the parent region. 855 Block *parentBlock = op->getBlock(); 856 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 857 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 858 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 859 860 // Replace the results of this operation with the remapped terminator 861 // values. 862 SmallVector<Value> terminatorOperands; 863 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 864 terminatorOperands))) 865 return failure(); 866 867 rewriter.eraseOp(terminator); 868 rewriter.replaceOp(op, terminatorOperands); 869 return success(); 870 } 871 }; 872 873 struct TestRemappedValue 874 : public mlir::PassWrapper<TestRemappedValue, OperationPass<FuncOp>> { 875 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 876 877 StringRef getArgument() const final { return "test-remapped-value"; } 878 StringRef getDescription() const final { 879 return "Test public remapped value mechanism in ConversionPatternRewriter"; 880 } 881 void runOnOperation() override { 882 TestRemapValueTypeConverter typeConverter; 883 884 mlir::RewritePatternSet patterns(&getContext()); 885 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 886 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 887 &getContext()); 888 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 889 890 mlir::ConversionTarget target(getContext()); 891 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 892 893 // Expect the type_producer/type_consumer operations to only operate on f64. 894 target.addDynamicallyLegalOp<TestTypeProducerOp>( 895 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 896 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 897 return op.getOperand().getType().isF64(); 898 }); 899 900 // We make OneVResOneVOperandOp1 legal only when it has more that one 901 // operand. This will trigger the conversion that will replace one-operand 902 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 903 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 904 [](Operation *op) { return op->getNumOperands() > 1; }); 905 906 if (failed(mlir::applyFullConversion(getOperation(), target, 907 std::move(patterns)))) { 908 signalPassFailure(); 909 } 910 } 911 }; 912 } // namespace 913 914 //===----------------------------------------------------------------------===// 915 // Test patterns without a specific root operation kind 916 //===----------------------------------------------------------------------===// 917 918 namespace { 919 /// This pattern matches and removes any operation in the test dialect. 920 struct RemoveTestDialectOps : public RewritePattern { 921 RemoveTestDialectOps(MLIRContext *context) 922 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 923 924 LogicalResult matchAndRewrite(Operation *op, 925 PatternRewriter &rewriter) const override { 926 if (!isa<TestDialect>(op->getDialect())) 927 return failure(); 928 rewriter.eraseOp(op); 929 return success(); 930 } 931 }; 932 933 struct TestUnknownRootOpDriver 934 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<FuncOp>> { 935 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 936 937 StringRef getArgument() const final { 938 return "test-legalize-unknown-root-patterns"; 939 } 940 StringRef getDescription() const final { 941 return "Test public remapped value mechanism in ConversionPatternRewriter"; 942 } 943 void runOnOperation() override { 944 mlir::RewritePatternSet patterns(&getContext()); 945 patterns.add<RemoveTestDialectOps>(&getContext()); 946 947 mlir::ConversionTarget target(getContext()); 948 target.addIllegalDialect<TestDialect>(); 949 if (failed(applyPartialConversion(getOperation(), target, 950 std::move(patterns)))) 951 signalPassFailure(); 952 } 953 }; 954 } // namespace 955 956 //===----------------------------------------------------------------------===// 957 // Test type conversions 958 //===----------------------------------------------------------------------===// 959 960 namespace { 961 struct TestTypeConversionProducer 962 : public OpConversionPattern<TestTypeProducerOp> { 963 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 964 LogicalResult 965 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 966 ConversionPatternRewriter &rewriter) const final { 967 Type resultType = op.getType(); 968 Type convertedType = getTypeConverter() 969 ? getTypeConverter()->convertType(resultType) 970 : resultType; 971 if (resultType.isa<FloatType>()) 972 resultType = rewriter.getF64Type(); 973 else if (resultType.isInteger(16)) 974 resultType = rewriter.getIntegerType(64); 975 else if (resultType.isa<test::TestRecursiveType>() && 976 convertedType != resultType) 977 resultType = convertedType; 978 else 979 return failure(); 980 981 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 982 return success(); 983 } 984 }; 985 986 /// Call signature conversion and then fail the rewrite to trigger the undo 987 /// mechanism. 988 struct TestSignatureConversionUndo 989 : public OpConversionPattern<TestSignatureConversionUndoOp> { 990 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 991 992 LogicalResult 993 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 994 ConversionPatternRewriter &rewriter) const final { 995 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 996 return failure(); 997 } 998 }; 999 1000 /// Call signature conversion without providing a type converter to handle 1001 /// materializations. 1002 struct TestTestSignatureConversionNoConverter 1003 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1004 TestTestSignatureConversionNoConverter(TypeConverter &converter, 1005 MLIRContext *context) 1006 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1007 converter(converter) {} 1008 1009 LogicalResult 1010 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1011 ConversionPatternRewriter &rewriter) const final { 1012 Region ®ion = op->getRegion(0); 1013 Block *entry = ®ion.front(); 1014 1015 // Convert the original entry arguments. 1016 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1017 if (failed( 1018 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1019 return failure(); 1020 rewriter.updateRootInPlace( 1021 op, [&] { rewriter.applySignatureConversion(®ion, result); }); 1022 return success(); 1023 } 1024 1025 TypeConverter &converter; 1026 }; 1027 1028 /// Just forward the operands to the root op. This is essentially a no-op 1029 /// pattern that is used to trigger target materialization. 1030 struct TestTypeConsumerForward 1031 : public OpConversionPattern<TestTypeConsumerOp> { 1032 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1033 1034 LogicalResult 1035 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1036 ConversionPatternRewriter &rewriter) const final { 1037 rewriter.updateRootInPlace(op, 1038 [&] { op->setOperands(adaptor.getOperands()); }); 1039 return success(); 1040 } 1041 }; 1042 1043 struct TestTypeConversionAnotherProducer 1044 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1045 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1046 1047 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1048 PatternRewriter &rewriter) const final { 1049 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1050 return success(); 1051 } 1052 }; 1053 1054 struct TestTypeConversionDriver 1055 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 1056 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1057 1058 void getDependentDialects(DialectRegistry ®istry) const override { 1059 registry.insert<TestDialect>(); 1060 } 1061 StringRef getArgument() const final { 1062 return "test-legalize-type-conversion"; 1063 } 1064 StringRef getDescription() const final { 1065 return "Test various type conversion functionalities in DialectConversion"; 1066 } 1067 1068 void runOnOperation() override { 1069 // Initialize the type converter. 1070 TypeConverter converter; 1071 1072 /// Add the legal set of type conversions. 1073 converter.addConversion([](Type type) -> Type { 1074 // Treat F64 as legal. 1075 if (type.isF64()) 1076 return type; 1077 // Allow converting BF16/F16/F32 to F64. 1078 if (type.isBF16() || type.isF16() || type.isF32()) 1079 return FloatType::getF64(type.getContext()); 1080 // Otherwise, the type is illegal. 1081 return nullptr; 1082 }); 1083 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1084 // Drop all integer types. 1085 return success(); 1086 }); 1087 converter.addConversion( 1088 // Convert a recursive self-referring type into a non-self-referring 1089 // type named "outer_converted_type" that contains a SimpleAType. 1090 [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results, 1091 ArrayRef<Type> callStack) -> Optional<LogicalResult> { 1092 // If the type is already converted, return it to indicate that it is 1093 // legal. 1094 if (type.getName() == "outer_converted_type") { 1095 results.push_back(type); 1096 return success(); 1097 } 1098 1099 // If the type is on the call stack more than once (it is there at 1100 // least once because of the _current_ call, which is always the last 1101 // element on the stack), we've hit the recursive case. Just return 1102 // SimpleAType here to create a non-recursive type as a result. 1103 if (llvm::is_contained(callStack.drop_back(), type)) { 1104 results.push_back(test::SimpleAType::get(type.getContext())); 1105 return success(); 1106 } 1107 1108 // Convert the body recursively. 1109 auto result = test::TestRecursiveType::get(type.getContext(), 1110 "outer_converted_type"); 1111 if (failed(result.setBody(converter.convertType(type.getBody())))) 1112 return failure(); 1113 results.push_back(result); 1114 return success(); 1115 }); 1116 1117 /// Add the legal set of type materializations. 1118 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1119 ValueRange inputs, 1120 Location loc) -> Value { 1121 // Allow casting from F64 back to F32. 1122 if (!resultType.isF16() && inputs.size() == 1 && 1123 inputs[0].getType().isF64()) 1124 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1125 // Allow producing an i32 or i64 from nothing. 1126 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1127 inputs.empty()) 1128 return builder.create<TestTypeProducerOp>(loc, resultType); 1129 // Allow producing an i64 from an integer. 1130 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 1131 inputs[0].getType().isa<IntegerType>()) 1132 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1133 // Otherwise, fail. 1134 return nullptr; 1135 }); 1136 1137 // Initialize the conversion target. 1138 mlir::ConversionTarget target(getContext()); 1139 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1140 auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>(); 1141 return op.getType().isF64() || op.getType().isInteger(64) || 1142 (recursiveType && 1143 recursiveType.getName() == "outer_converted_type"); 1144 }); 1145 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 1146 return converter.isSignatureLegal(op.getFunctionType()) && 1147 converter.isLegal(&op.getBody()); 1148 }); 1149 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1150 // Allow casts from F64 to F32. 1151 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1152 }); 1153 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1154 [&](TestSignatureConversionNoConverterOp op) { 1155 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1156 }); 1157 1158 // Initialize the set of rewrite patterns. 1159 RewritePatternSet patterns(&getContext()); 1160 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1161 TestSignatureConversionUndo, 1162 TestTestSignatureConversionNoConverter>(converter, 1163 &getContext()); 1164 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1165 mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, 1166 converter); 1167 1168 if (failed(applyPartialConversion(getOperation(), target, 1169 std::move(patterns)))) 1170 signalPassFailure(); 1171 } 1172 }; 1173 } // namespace 1174 1175 //===----------------------------------------------------------------------===// 1176 // Test Target Materialization With No Uses 1177 //===----------------------------------------------------------------------===// 1178 1179 namespace { 1180 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1181 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1182 1183 LogicalResult 1184 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1185 ConversionPatternRewriter &rewriter) const final { 1186 rewriter.replaceOp(op, adaptor.getOperands()); 1187 return success(); 1188 } 1189 }; 1190 1191 struct TestTargetMaterializationWithNoUses 1192 : public PassWrapper<TestTargetMaterializationWithNoUses, 1193 OperationPass<ModuleOp>> { 1194 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1195 TestTargetMaterializationWithNoUses) 1196 1197 StringRef getArgument() const final { 1198 return "test-target-materialization-with-no-uses"; 1199 } 1200 StringRef getDescription() const final { 1201 return "Test a special case of target materialization in DialectConversion"; 1202 } 1203 1204 void runOnOperation() override { 1205 TypeConverter converter; 1206 converter.addConversion([](Type t) { return t; }); 1207 converter.addConversion([](IntegerType intTy) -> Type { 1208 if (intTy.getWidth() == 16) 1209 return IntegerType::get(intTy.getContext(), 64); 1210 return intTy; 1211 }); 1212 converter.addTargetMaterialization( 1213 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1214 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1215 }); 1216 1217 ConversionTarget target(getContext()); 1218 target.addIllegalOp<TestTypeChangerOp>(); 1219 1220 RewritePatternSet patterns(&getContext()); 1221 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1222 1223 if (failed(applyPartialConversion(getOperation(), target, 1224 std::move(patterns)))) 1225 signalPassFailure(); 1226 } 1227 }; 1228 } // namespace 1229 1230 //===----------------------------------------------------------------------===// 1231 // Test Block Merging 1232 //===----------------------------------------------------------------------===// 1233 1234 namespace { 1235 /// A rewriter pattern that tests that blocks can be merged. 1236 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1237 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1238 1239 LogicalResult 1240 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1241 ConversionPatternRewriter &rewriter) const final { 1242 Block &firstBlock = op.getBody().front(); 1243 Operation *branchOp = firstBlock.getTerminator(); 1244 Block *secondBlock = &*(std::next(op.getBody().begin())); 1245 auto succOperands = branchOp->getOperands(); 1246 SmallVector<Value, 2> replacements(succOperands); 1247 rewriter.eraseOp(branchOp); 1248 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1249 rewriter.updateRootInPlace(op, [] {}); 1250 return success(); 1251 } 1252 }; 1253 1254 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1255 struct TestUndoBlocksMerge : public ConversionPattern { 1256 TestUndoBlocksMerge(MLIRContext *ctx) 1257 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1258 LogicalResult 1259 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1260 ConversionPatternRewriter &rewriter) const final { 1261 Block &firstBlock = op->getRegion(0).front(); 1262 Operation *branchOp = firstBlock.getTerminator(); 1263 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1264 rewriter.setInsertionPointToStart(secondBlock); 1265 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1266 auto succOperands = branchOp->getOperands(); 1267 SmallVector<Value, 2> replacements(succOperands); 1268 rewriter.eraseOp(branchOp); 1269 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1270 rewriter.updateRootInPlace(op, [] {}); 1271 return success(); 1272 } 1273 }; 1274 1275 /// A rewrite mechanism to inline the body of the op into its parent, when both 1276 /// ops can have a single block. 1277 struct TestMergeSingleBlockOps 1278 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1279 using OpConversionPattern< 1280 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1281 1282 LogicalResult 1283 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1284 ConversionPatternRewriter &rewriter) const final { 1285 SingleBlockImplicitTerminatorOp parentOp = 1286 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1287 if (!parentOp) 1288 return failure(); 1289 Block &innerBlock = op.getRegion().front(); 1290 TerminatorOp innerTerminator = 1291 cast<TerminatorOp>(innerBlock.getTerminator()); 1292 rewriter.mergeBlockBefore(&innerBlock, op); 1293 rewriter.eraseOp(innerTerminator); 1294 rewriter.eraseOp(op); 1295 rewriter.updateRootInPlace(op, [] {}); 1296 return success(); 1297 } 1298 }; 1299 1300 struct TestMergeBlocksPatternDriver 1301 : public PassWrapper<TestMergeBlocksPatternDriver, 1302 OperationPass<ModuleOp>> { 1303 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 1304 1305 StringRef getArgument() const final { return "test-merge-blocks"; } 1306 StringRef getDescription() const final { 1307 return "Test Merging operation in ConversionPatternRewriter"; 1308 } 1309 void runOnOperation() override { 1310 MLIRContext *context = &getContext(); 1311 mlir::RewritePatternSet patterns(context); 1312 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1313 context); 1314 ConversionTarget target(*context); 1315 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1316 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1317 target.addIllegalOp<ILLegalOpF>(); 1318 1319 /// Expect the op to have a single block after legalization. 1320 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1321 [&](TestMergeBlocksOp op) -> bool { 1322 return llvm::hasSingleElement(op.getBody()); 1323 }); 1324 1325 /// Only allow `test.br` within test.merge_blocks op. 1326 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1327 return op->getParentOfType<TestMergeBlocksOp>(); 1328 }); 1329 1330 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1331 /// inlined. 1332 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1333 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1334 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1335 }); 1336 1337 DenseSet<Operation *> unlegalizedOps; 1338 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1339 &unlegalizedOps); 1340 for (auto *op : unlegalizedOps) 1341 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1342 } 1343 }; 1344 } // namespace 1345 1346 //===----------------------------------------------------------------------===// 1347 // Test Selective Replacement 1348 //===----------------------------------------------------------------------===// 1349 1350 namespace { 1351 /// A rewrite mechanism to inline the body of the op into its parent, when both 1352 /// ops can have a single block. 1353 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1354 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1355 1356 LogicalResult matchAndRewrite(TestCastOp op, 1357 PatternRewriter &rewriter) const final { 1358 if (op.getNumOperands() != 2) 1359 return failure(); 1360 OperandRange operands = op.getOperands(); 1361 1362 // Replace non-terminator uses with the first operand. 1363 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1364 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1365 }); 1366 // Replace everything else with the second operand if the operation isn't 1367 // dead. 1368 rewriter.replaceOp(op, op.getOperand(1)); 1369 return success(); 1370 } 1371 }; 1372 1373 struct TestSelectiveReplacementPatternDriver 1374 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1375 OperationPass<>> { 1376 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1377 TestSelectiveReplacementPatternDriver) 1378 1379 StringRef getArgument() const final { 1380 return "test-pattern-selective-replacement"; 1381 } 1382 StringRef getDescription() const final { 1383 return "Test selective replacement in the PatternRewriter"; 1384 } 1385 void runOnOperation() override { 1386 MLIRContext *context = &getContext(); 1387 mlir::RewritePatternSet patterns(context); 1388 patterns.add<TestSelectiveOpReplacementPattern>(context); 1389 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1390 std::move(patterns)); 1391 } 1392 }; 1393 } // namespace 1394 1395 //===----------------------------------------------------------------------===// 1396 // PassRegistration 1397 //===----------------------------------------------------------------------===// 1398 1399 namespace mlir { 1400 namespace test { 1401 void registerPatternsTestPass() { 1402 PassRegistration<TestReturnTypeDriver>(); 1403 1404 PassRegistration<TestDerivedAttributeDriver>(); 1405 1406 PassRegistration<TestPatternDriver>(); 1407 1408 PassRegistration<TestLegalizePatternDriver>([] { 1409 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1410 }); 1411 1412 PassRegistration<TestRemappedValue>(); 1413 1414 PassRegistration<TestUnknownRootOpDriver>(); 1415 1416 PassRegistration<TestTypeConversionDriver>(); 1417 PassRegistration<TestTargetMaterializationWithNoUses>(); 1418 1419 PassRegistration<TestMergeBlocksPatternDriver>(); 1420 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1421 } 1422 } // namespace test 1423 } // namespace mlir 1424