1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace test; 23 24 // Native function for testing NativeCodeCall 25 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 26 return choice.getValue() ? input1 : input2; 27 } 28 29 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 30 rewriter.create<OpI>(loc, input); 31 } 32 33 static void handleNoResultOp(PatternRewriter &rewriter, 34 OpSymbolBindingNoResult op) { 35 // Turn the no result op to a one-result op. 36 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 37 op.getOperand()); 38 } 39 40 static bool getFirstI32Result(Operation *op, Value &value) { 41 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 42 return false; 43 value = op->getResult(0); 44 return true; 45 } 46 47 static Value bindNativeCodeCallResult(Value value) { return value; } 48 49 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 50 Value input2) { 51 return SmallVector<Value, 2>({input2, input1}); 52 } 53 54 // Test that natives calls are only called once during rewrites. 55 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 56 // This let us check the number of times OpM_Test was called by inspecting 57 // the returned value in the MLIR output. 58 static int64_t opMIncreasingValue = 314159265; 59 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 60 int64_t i = opMIncreasingValue++; 61 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 62 } 63 64 namespace { 65 #include "TestPatterns.inc" 66 } // namespace 67 68 //===----------------------------------------------------------------------===// 69 // Test Reduce Pattern Interface 70 //===----------------------------------------------------------------------===// 71 72 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 73 populateWithGenerated(patterns); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Canonicalizer Driver. 78 //===----------------------------------------------------------------------===// 79 80 namespace { 81 struct FoldingPattern : public RewritePattern { 82 public: 83 FoldingPattern(MLIRContext *context) 84 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 85 /*benefit=*/1, context) {} 86 87 LogicalResult matchAndRewrite(Operation *op, 88 PatternRewriter &rewriter) const override { 89 // Exercise OperationFolder API for a single-result operation that is folded 90 // upon construction. The operation being created through the folder has an 91 // in-place folder, and it should be still present in the output. 92 // Furthermore, the folder should not crash when attempting to recover the 93 // (unchanged) operation result. 94 OperationFolder folder(op->getContext()); 95 Value result = folder.create<TestOpInPlaceFold>( 96 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 97 rewriter.getI32IntegerAttr(0)); 98 assert(result); 99 rewriter.replaceOp(op, result); 100 return success(); 101 } 102 }; 103 104 /// This pattern creates a foldable operation at the entry point of the block. 105 /// This tests the situation where the operation folder will need to replace an 106 /// operation with a previously created constant that does not initially 107 /// dominate the operation to replace. 108 struct FolderInsertBeforePreviouslyFoldedConstantPattern 109 : public OpRewritePattern<TestCastOp> { 110 public: 111 using OpRewritePattern<TestCastOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(TestCastOp op, 114 PatternRewriter &rewriter) const override { 115 if (!op->hasAttr("test_fold_before_previously_folded_op")) 116 return failure(); 117 rewriter.setInsertionPointToStart(op->getBlock()); 118 119 auto constOp = rewriter.create<arith::ConstantOp>( 120 op.getLoc(), rewriter.getBoolAttr(true)); 121 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 122 Value(constOp)); 123 return success(); 124 } 125 }; 126 127 /// This pattern matches test.op_commutative2 with the first operand being 128 /// another test.op_commutative2 with a constant on the right side and fold it 129 /// away by propagating it as its result. This is intend to check that patterns 130 /// are applied after the commutative property moves constant to the right. 131 struct FolderCommutativeOp2WithConstant 132 : public OpRewritePattern<TestCommutative2Op> { 133 public: 134 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 135 136 LogicalResult matchAndRewrite(TestCommutative2Op op, 137 PatternRewriter &rewriter) const override { 138 auto operand = 139 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 140 if (!operand) 141 return failure(); 142 Attribute constInput; 143 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 144 return failure(); 145 rewriter.replaceOp(op, operand->getOperand(1)); 146 return success(); 147 } 148 }; 149 150 struct TestPatternDriver 151 : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> { 152 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) 153 154 StringRef getArgument() const final { return "test-patterns"; } 155 StringRef getDescription() const final { return "Run test dialect patterns"; } 156 void runOnOperation() override { 157 mlir::RewritePatternSet patterns(&getContext()); 158 populateWithGenerated(patterns); 159 160 // Verify named pattern is generated with expected name. 161 patterns.add<FoldingPattern, TestNamedPatternRule, 162 FolderInsertBeforePreviouslyFoldedConstantPattern, 163 FolderCommutativeOp2WithConstant>(&getContext()); 164 165 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 166 } 167 }; 168 } // namespace 169 170 //===----------------------------------------------------------------------===// 171 // ReturnType Driver. 172 //===----------------------------------------------------------------------===// 173 174 namespace { 175 // Generate ops for each instance where the type can be successfully inferred. 176 template <typename OpTy> 177 static void invokeCreateWithInferredReturnType(Operation *op) { 178 auto *context = op->getContext(); 179 auto fop = op->getParentOfType<func::FuncOp>(); 180 auto location = UnknownLoc::get(context); 181 OpBuilder b(op); 182 b.setInsertionPointAfter(op); 183 184 // Use permutations of 2 args as operands. 185 assert(fop.getNumArguments() >= 2); 186 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 187 for (int j = 0; j < e; ++j) { 188 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 189 SmallVector<Type, 2> inferredReturnTypes; 190 if (succeeded(OpTy::inferReturnTypes( 191 context, llvm::None, values, op->getAttrDictionary(), 192 op->getRegions(), inferredReturnTypes))) { 193 OperationState state(location, OpTy::getOperationName()); 194 // TODO: Expand to regions. 195 OpTy::build(b, state, values, op->getAttrs()); 196 (void)b.create(state); 197 } 198 } 199 } 200 } 201 202 static void reifyReturnShape(Operation *op) { 203 OpBuilder b(op); 204 205 // Use permutations of 2 args as operands. 206 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 207 SmallVector<Value, 2> shapes; 208 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 209 !llvm::hasSingleElement(shapes)) 210 return; 211 for (const auto &it : llvm::enumerate(shapes)) { 212 op->emitRemark() << "value " << it.index() << ": " 213 << it.value().getDefiningOp(); 214 } 215 } 216 217 struct TestReturnTypeDriver 218 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 219 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 220 221 void getDependentDialects(DialectRegistry ®istry) const override { 222 registry.insert<tensor::TensorDialect>(); 223 } 224 StringRef getArgument() const final { return "test-return-type"; } 225 StringRef getDescription() const final { return "Run return type functions"; } 226 227 void runOnOperation() override { 228 if (getOperation().getName() == "testCreateFunctions") { 229 std::vector<Operation *> ops; 230 // Collect ops to avoid triggering on inserted ops. 231 for (auto &op : getOperation().getBody().front()) 232 ops.push_back(&op); 233 // Generate test patterns for each, but skip terminator. 234 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 235 // Test create method of each of the Op classes below. The resultant 236 // output would be in reverse order underneath `op` from which 237 // the attributes and regions are used. 238 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 239 invokeCreateWithInferredReturnType< 240 OpWithShapedTypeInferTypeInterfaceOp>(op); 241 }; 242 return; 243 } 244 if (getOperation().getName() == "testReifyFunctions") { 245 std::vector<Operation *> ops; 246 // Collect ops to avoid triggering on inserted ops. 247 for (auto &op : getOperation().getBody().front()) 248 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 249 ops.push_back(&op); 250 // Generate test patterns for each, but skip terminator. 251 for (auto *op : ops) 252 reifyReturnShape(op); 253 } 254 } 255 }; 256 } // namespace 257 258 namespace { 259 struct TestDerivedAttributeDriver 260 : public PassWrapper<TestDerivedAttributeDriver, 261 OperationPass<func::FuncOp>> { 262 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 263 264 StringRef getArgument() const final { return "test-derived-attr"; } 265 StringRef getDescription() const final { 266 return "Run test derived attributes"; 267 } 268 void runOnOperation() override; 269 }; 270 } // namespace 271 272 void TestDerivedAttributeDriver::runOnOperation() { 273 getOperation().walk([](DerivedAttributeOpInterface dOp) { 274 auto dAttr = dOp.materializeDerivedAttributes(); 275 if (!dAttr) 276 return; 277 for (auto d : dAttr) 278 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 279 }); 280 } 281 282 //===----------------------------------------------------------------------===// 283 // Legalization Driver. 284 //===----------------------------------------------------------------------===// 285 286 namespace { 287 //===----------------------------------------------------------------------===// 288 // Region-Block Rewrite Testing 289 290 /// This pattern is a simple pattern that inlines the first region of a given 291 /// operation into the parent region. 292 struct TestRegionRewriteBlockMovement : public ConversionPattern { 293 TestRegionRewriteBlockMovement(MLIRContext *ctx) 294 : ConversionPattern("test.region", 1, ctx) {} 295 296 LogicalResult 297 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 298 ConversionPatternRewriter &rewriter) const final { 299 // Inline this region into the parent region. 300 auto &parentRegion = *op->getParentRegion(); 301 auto &opRegion = op->getRegion(0); 302 if (op->getAttr("legalizer.should_clone")) 303 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 304 else 305 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 306 307 if (op->getAttr("legalizer.erase_old_blocks")) { 308 while (!opRegion.empty()) 309 rewriter.eraseBlock(&opRegion.front()); 310 } 311 312 // Drop this operation. 313 rewriter.eraseOp(op); 314 return success(); 315 } 316 }; 317 /// This pattern is a simple pattern that generates a region containing an 318 /// illegal operation. 319 struct TestRegionRewriteUndo : public RewritePattern { 320 TestRegionRewriteUndo(MLIRContext *ctx) 321 : RewritePattern("test.region_builder", 1, ctx) {} 322 323 LogicalResult matchAndRewrite(Operation *op, 324 PatternRewriter &rewriter) const final { 325 // Create the region operation with an entry block containing arguments. 326 OperationState newRegion(op->getLoc(), "test.region"); 327 newRegion.addRegion(); 328 auto *regionOp = rewriter.create(newRegion); 329 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 330 entryBlock->addArgument(rewriter.getIntegerType(64), 331 rewriter.getUnknownLoc()); 332 333 // Add an explicitly illegal operation to ensure the conversion fails. 334 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 335 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 336 337 // Drop this operation. 338 rewriter.eraseOp(op); 339 return success(); 340 } 341 }; 342 /// A simple pattern that creates a block at the end of the parent region of the 343 /// matched operation. 344 struct TestCreateBlock : public RewritePattern { 345 TestCreateBlock(MLIRContext *ctx) 346 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 347 348 LogicalResult matchAndRewrite(Operation *op, 349 PatternRewriter &rewriter) const final { 350 Region ®ion = *op->getParentRegion(); 351 Type i32Type = rewriter.getIntegerType(32); 352 Location loc = op->getLoc(); 353 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 354 rewriter.create<TerminatorOp>(loc); 355 rewriter.replaceOp(op, {}); 356 return success(); 357 } 358 }; 359 360 /// A simple pattern that creates a block containing an invalid operation in 361 /// order to trigger the block creation undo mechanism. 362 struct TestCreateIllegalBlock : public RewritePattern { 363 TestCreateIllegalBlock(MLIRContext *ctx) 364 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 365 366 LogicalResult matchAndRewrite(Operation *op, 367 PatternRewriter &rewriter) const final { 368 Region ®ion = *op->getParentRegion(); 369 Type i32Type = rewriter.getIntegerType(32); 370 Location loc = op->getLoc(); 371 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 372 // Create an illegal op to ensure the conversion fails. 373 rewriter.create<ILLegalOpF>(loc, i32Type); 374 rewriter.create<TerminatorOp>(loc); 375 rewriter.replaceOp(op, {}); 376 return success(); 377 } 378 }; 379 380 /// A simple pattern that tests the undo mechanism when replacing the uses of a 381 /// block argument. 382 struct TestUndoBlockArgReplace : public ConversionPattern { 383 TestUndoBlockArgReplace(MLIRContext *ctx) 384 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 385 386 LogicalResult 387 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 388 ConversionPatternRewriter &rewriter) const final { 389 auto illegalOp = 390 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 391 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 392 illegalOp); 393 rewriter.updateRootInPlace(op, [] {}); 394 return success(); 395 } 396 }; 397 398 /// A rewrite pattern that tests the undo mechanism when erasing a block. 399 struct TestUndoBlockErase : public ConversionPattern { 400 TestUndoBlockErase(MLIRContext *ctx) 401 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 402 403 LogicalResult 404 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 405 ConversionPatternRewriter &rewriter) const final { 406 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 407 rewriter.setInsertionPointToStart(secondBlock); 408 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 409 rewriter.eraseBlock(secondBlock); 410 rewriter.updateRootInPlace(op, [] {}); 411 return success(); 412 } 413 }; 414 415 //===----------------------------------------------------------------------===// 416 // Type-Conversion Rewrite Testing 417 418 /// This patterns erases a region operation that has had a type conversion. 419 struct TestDropOpSignatureConversion : public ConversionPattern { 420 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 421 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 422 LogicalResult 423 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 424 ConversionPatternRewriter &rewriter) const override { 425 Region ®ion = op->getRegion(0); 426 Block *entry = ®ion.front(); 427 428 // Convert the original entry arguments. 429 TypeConverter &converter = *getTypeConverter(); 430 TypeConverter::SignatureConversion result(entry->getNumArguments()); 431 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 432 result)) || 433 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 434 return failure(); 435 436 // Convert the region signature and just drop the operation. 437 rewriter.eraseOp(op); 438 return success(); 439 } 440 }; 441 /// This pattern simply updates the operands of the given operation. 442 struct TestPassthroughInvalidOp : public ConversionPattern { 443 TestPassthroughInvalidOp(MLIRContext *ctx) 444 : ConversionPattern("test.invalid", 1, ctx) {} 445 LogicalResult 446 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 447 ConversionPatternRewriter &rewriter) const final { 448 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 449 llvm::None); 450 return success(); 451 } 452 }; 453 /// This pattern handles the case of a split return value. 454 struct TestSplitReturnType : public ConversionPattern { 455 TestSplitReturnType(MLIRContext *ctx) 456 : ConversionPattern("test.return", 1, ctx) {} 457 LogicalResult 458 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 459 ConversionPatternRewriter &rewriter) const final { 460 // Check for a return of F32. 461 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 462 return failure(); 463 464 // Check if the first operation is a cast operation, if it is we use the 465 // results directly. 466 auto *defOp = operands[0].getDefiningOp(); 467 if (auto packerOp = 468 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 469 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 470 return success(); 471 } 472 473 // Otherwise, fail to match. 474 return failure(); 475 } 476 }; 477 478 //===----------------------------------------------------------------------===// 479 // Multi-Level Type-Conversion Rewrite Testing 480 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 481 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 482 : ConversionPattern("test.type_producer", 1, ctx) {} 483 LogicalResult 484 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 485 ConversionPatternRewriter &rewriter) const final { 486 // If the type is I32, change the type to F32. 487 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 488 return failure(); 489 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 490 return success(); 491 } 492 }; 493 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 494 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 495 : ConversionPattern("test.type_producer", 1, ctx) {} 496 LogicalResult 497 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 498 ConversionPatternRewriter &rewriter) const final { 499 // If the type is F32, change the type to F64. 500 if (!Type(*op->result_type_begin()).isF32()) 501 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 502 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 503 return success(); 504 } 505 }; 506 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 507 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 508 : ConversionPattern("test.type_producer", 10, ctx) {} 509 LogicalResult 510 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 511 ConversionPatternRewriter &rewriter) const final { 512 // Always convert to B16, even though it is not a legal type. This tests 513 // that values are unmapped correctly. 514 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 515 return success(); 516 } 517 }; 518 struct TestUpdateConsumerType : public ConversionPattern { 519 TestUpdateConsumerType(MLIRContext *ctx) 520 : ConversionPattern("test.type_consumer", 1, ctx) {} 521 LogicalResult 522 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 523 ConversionPatternRewriter &rewriter) const final { 524 // Verify that the incoming operand has been successfully remapped to F64. 525 if (!operands[0].getType().isF64()) 526 return failure(); 527 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 528 return success(); 529 } 530 }; 531 532 //===----------------------------------------------------------------------===// 533 // Non-Root Replacement Rewrite Testing 534 /// This pattern generates an invalid operation, but replaces it before the 535 /// pattern is finished. This checks that we don't need to legalize the 536 /// temporary op. 537 struct TestNonRootReplacement : public RewritePattern { 538 TestNonRootReplacement(MLIRContext *ctx) 539 : RewritePattern("test.replace_non_root", 1, ctx) {} 540 541 LogicalResult matchAndRewrite(Operation *op, 542 PatternRewriter &rewriter) const final { 543 auto resultType = *op->result_type_begin(); 544 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 545 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 546 547 rewriter.replaceOp(illegalOp, {legalOp}); 548 rewriter.replaceOp(op, {illegalOp}); 549 return success(); 550 } 551 }; 552 553 //===----------------------------------------------------------------------===// 554 // Recursive Rewrite Testing 555 /// This pattern is applied to the same operation multiple times, but has a 556 /// bounded recursion. 557 struct TestBoundedRecursiveRewrite 558 : public OpRewritePattern<TestRecursiveRewriteOp> { 559 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 560 561 void initialize() { 562 // The conversion target handles bounding the recursion of this pattern. 563 setHasBoundedRewriteRecursion(); 564 } 565 566 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 567 PatternRewriter &rewriter) const final { 568 // Decrement the depth of the op in-place. 569 rewriter.updateRootInPlace(op, [&] { 570 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 571 }); 572 return success(); 573 } 574 }; 575 576 struct TestNestedOpCreationUndoRewrite 577 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 578 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 579 580 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 581 PatternRewriter &rewriter) const final { 582 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 583 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 584 return success(); 585 }; 586 }; 587 588 // This pattern matches `test.blackhole` and delete this op and its producer. 589 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 590 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 591 592 LogicalResult matchAndRewrite(BlackHoleOp op, 593 PatternRewriter &rewriter) const final { 594 Operation *producer = op.getOperand().getDefiningOp(); 595 // Always erase the user before the producer, the framework should handle 596 // this correctly. 597 rewriter.eraseOp(op); 598 rewriter.eraseOp(producer); 599 return success(); 600 }; 601 }; 602 603 // This pattern replaces explicitly illegal op with explicitly legal op, 604 // but in addition creates unregistered operation. 605 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 606 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 607 608 LogicalResult matchAndRewrite(ILLegalOpG op, 609 PatternRewriter &rewriter) const final { 610 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 611 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 612 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 613 return success(); 614 }; 615 }; 616 } // namespace 617 618 namespace { 619 struct TestTypeConverter : public TypeConverter { 620 using TypeConverter::TypeConverter; 621 TestTypeConverter() { 622 addConversion(convertType); 623 addArgumentMaterialization(materializeCast); 624 addSourceMaterialization(materializeCast); 625 } 626 627 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 628 // Drop I16 types. 629 if (t.isSignlessInteger(16)) 630 return success(); 631 632 // Convert I64 to F64. 633 if (t.isSignlessInteger(64)) { 634 results.push_back(FloatType::getF64(t.getContext())); 635 return success(); 636 } 637 638 // Convert I42 to I43. 639 if (t.isInteger(42)) { 640 results.push_back(IntegerType::get(t.getContext(), 43)); 641 return success(); 642 } 643 644 // Split F32 into F16,F16. 645 if (t.isF32()) { 646 results.assign(2, FloatType::getF16(t.getContext())); 647 return success(); 648 } 649 650 // Otherwise, convert the type directly. 651 results.push_back(t); 652 return success(); 653 } 654 655 /// Hook for materializing a conversion. This is necessary because we generate 656 /// 1->N type mappings. 657 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 658 ValueRange inputs, Location loc) { 659 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 660 } 661 }; 662 663 struct TestLegalizePatternDriver 664 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 665 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 666 667 StringRef getArgument() const final { return "test-legalize-patterns"; } 668 StringRef getDescription() const final { 669 return "Run test dialect legalization patterns"; 670 } 671 /// The mode of conversion to use with the driver. 672 enum class ConversionMode { Analysis, Full, Partial }; 673 674 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 675 676 void getDependentDialects(DialectRegistry ®istry) const override { 677 registry.insert<func::FuncDialect>(); 678 } 679 680 void runOnOperation() override { 681 TestTypeConverter converter; 682 mlir::RewritePatternSet patterns(&getContext()); 683 populateWithGenerated(patterns); 684 patterns 685 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 686 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 687 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 688 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 689 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 690 TestNonRootReplacement, TestBoundedRecursiveRewrite, 691 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 692 TestCreateUnregisteredOp>(&getContext()); 693 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 694 mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( 695 patterns, converter); 696 mlir::populateCallOpTypeConversionPattern(patterns, converter); 697 698 // Define the conversion target used for the test. 699 ConversionTarget target(getContext()); 700 target.addLegalOp<ModuleOp>(); 701 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 702 TerminatorOp>(); 703 target 704 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 705 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 706 // Don't allow F32 operands. 707 return llvm::none_of(op.getOperandTypes(), 708 [](Type type) { return type.isF32(); }); 709 }); 710 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 711 return converter.isSignatureLegal(op.getFunctionType()) && 712 converter.isLegal(&op.getBody()); 713 }); 714 target.addDynamicallyLegalOp<func::CallOp>( 715 [&](func::CallOp op) { return converter.isLegal(op); }); 716 717 // TestCreateUnregisteredOp creates `arith.constant` operation, 718 // which was not added to target intentionally to test 719 // correct error code from conversion driver. 720 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 721 722 // Expect the type_producer/type_consumer operations to only operate on f64. 723 target.addDynamicallyLegalOp<TestTypeProducerOp>( 724 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 725 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 726 return op.getOperand().getType().isF64(); 727 }); 728 729 // Check support for marking certain operations as recursively legal. 730 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 731 return static_cast<bool>( 732 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 733 }); 734 735 // Mark the bound recursion operation as dynamically legal. 736 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 737 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 738 739 // Handle a partial conversion. 740 if (mode == ConversionMode::Partial) { 741 DenseSet<Operation *> unlegalizedOps; 742 if (failed(applyPartialConversion( 743 getOperation(), target, std::move(patterns), &unlegalizedOps))) { 744 getOperation()->emitRemark() << "applyPartialConversion failed"; 745 } 746 // Emit remarks for each legalizable operation. 747 for (auto *op : unlegalizedOps) 748 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 749 return; 750 } 751 752 // Handle a full conversion. 753 if (mode == ConversionMode::Full) { 754 // Check support for marking unknown operations as dynamically legal. 755 target.markUnknownOpDynamicallyLegal([](Operation *op) { 756 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 757 }); 758 759 if (failed(applyFullConversion(getOperation(), target, 760 std::move(patterns)))) { 761 getOperation()->emitRemark() << "applyFullConversion failed"; 762 } 763 return; 764 } 765 766 // Otherwise, handle an analysis conversion. 767 assert(mode == ConversionMode::Analysis); 768 769 // Analyze the convertible operations. 770 DenseSet<Operation *> legalizedOps; 771 if (failed(applyAnalysisConversion(getOperation(), target, 772 std::move(patterns), legalizedOps))) 773 return signalPassFailure(); 774 775 // Emit remarks for each legalizable operation. 776 for (auto *op : legalizedOps) 777 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 778 } 779 780 /// The mode of conversion to use. 781 ConversionMode mode; 782 }; 783 } // namespace 784 785 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 786 legalizerConversionMode( 787 "test-legalize-mode", 788 llvm::cl::desc("The legalization mode to use with the test driver"), 789 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 790 llvm::cl::values( 791 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 792 "analysis", "Perform an analysis conversion"), 793 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 794 "Perform a full conversion"), 795 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 796 "partial", "Perform a partial conversion"))); 797 798 //===----------------------------------------------------------------------===// 799 // ConversionPatternRewriter::getRemappedValue testing. This method is used 800 // to get the remapped value of an original value that was replaced using 801 // ConversionPatternRewriter. 802 namespace { 803 struct TestRemapValueTypeConverter : public TypeConverter { 804 using TypeConverter::TypeConverter; 805 806 TestRemapValueTypeConverter() { 807 addConversion( 808 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 809 addConversion([](Type type) { return type; }); 810 } 811 }; 812 813 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 814 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 815 /// operand twice. 816 /// 817 /// Example: 818 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 819 /// is replaced with: 820 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 821 struct OneVResOneVOperandOp1Converter 822 : public OpConversionPattern<OneVResOneVOperandOp1> { 823 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 824 825 LogicalResult 826 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 827 ConversionPatternRewriter &rewriter) const override { 828 auto origOps = op.getOperands(); 829 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 830 "One operand expected"); 831 Value origOp = *origOps.begin(); 832 SmallVector<Value, 2> remappedOperands; 833 // Replicate the remapped original operand twice. Note that we don't used 834 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 835 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 836 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 837 838 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 839 remappedOperands); 840 return success(); 841 } 842 }; 843 844 /// A rewriter pattern that tests that blocks can be merged. 845 struct TestRemapValueInRegion 846 : public OpConversionPattern<TestRemappedValueRegionOp> { 847 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 848 849 LogicalResult 850 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 851 ConversionPatternRewriter &rewriter) const final { 852 Block &block = op.getBody().front(); 853 Operation *terminator = block.getTerminator(); 854 855 // Merge the block into the parent region. 856 Block *parentBlock = op->getBlock(); 857 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 858 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 859 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 860 861 // Replace the results of this operation with the remapped terminator 862 // values. 863 SmallVector<Value> terminatorOperands; 864 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 865 terminatorOperands))) 866 return failure(); 867 868 rewriter.eraseOp(terminator); 869 rewriter.replaceOp(op, terminatorOperands); 870 return success(); 871 } 872 }; 873 874 struct TestRemappedValue 875 : public mlir::PassWrapper<TestRemappedValue, OperationPass<func::FuncOp>> { 876 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 877 878 StringRef getArgument() const final { return "test-remapped-value"; } 879 StringRef getDescription() const final { 880 return "Test public remapped value mechanism in ConversionPatternRewriter"; 881 } 882 void runOnOperation() override { 883 TestRemapValueTypeConverter typeConverter; 884 885 mlir::RewritePatternSet patterns(&getContext()); 886 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 887 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 888 &getContext()); 889 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 890 891 mlir::ConversionTarget target(getContext()); 892 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 893 894 // Expect the type_producer/type_consumer operations to only operate on f64. 895 target.addDynamicallyLegalOp<TestTypeProducerOp>( 896 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 897 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 898 return op.getOperand().getType().isF64(); 899 }); 900 901 // We make OneVResOneVOperandOp1 legal only when it has more that one 902 // operand. This will trigger the conversion that will replace one-operand 903 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 904 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 905 [](Operation *op) { return op->getNumOperands() > 1; }); 906 907 if (failed(mlir::applyFullConversion(getOperation(), target, 908 std::move(patterns)))) { 909 signalPassFailure(); 910 } 911 } 912 }; 913 } // namespace 914 915 //===----------------------------------------------------------------------===// 916 // Test patterns without a specific root operation kind 917 //===----------------------------------------------------------------------===// 918 919 namespace { 920 /// This pattern matches and removes any operation in the test dialect. 921 struct RemoveTestDialectOps : public RewritePattern { 922 RemoveTestDialectOps(MLIRContext *context) 923 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 924 925 LogicalResult matchAndRewrite(Operation *op, 926 PatternRewriter &rewriter) const override { 927 if (!isa<TestDialect>(op->getDialect())) 928 return failure(); 929 rewriter.eraseOp(op); 930 return success(); 931 } 932 }; 933 934 struct TestUnknownRootOpDriver 935 : public mlir::PassWrapper<TestUnknownRootOpDriver, 936 OperationPass<func::FuncOp>> { 937 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 938 939 StringRef getArgument() const final { 940 return "test-legalize-unknown-root-patterns"; 941 } 942 StringRef getDescription() const final { 943 return "Test public remapped value mechanism in ConversionPatternRewriter"; 944 } 945 void runOnOperation() override { 946 mlir::RewritePatternSet patterns(&getContext()); 947 patterns.add<RemoveTestDialectOps>(&getContext()); 948 949 mlir::ConversionTarget target(getContext()); 950 target.addIllegalDialect<TestDialect>(); 951 if (failed(applyPartialConversion(getOperation(), target, 952 std::move(patterns)))) 953 signalPassFailure(); 954 } 955 }; 956 } // namespace 957 958 //===----------------------------------------------------------------------===// 959 // Test patterns that uses operations and types defined at runtime 960 //===----------------------------------------------------------------------===// 961 962 namespace { 963 /// This pattern matches dynamic operations 'test.one_operand_two_results' and 964 /// replace them with dynamic operations 'test.generic_dynamic_op'. 965 struct RewriteDynamicOp : public RewritePattern { 966 RewriteDynamicOp(MLIRContext *context) 967 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 968 context) {} 969 970 LogicalResult matchAndRewrite(Operation *op, 971 PatternRewriter &rewriter) const override { 972 assert(op->getName().getStringRef() == 973 "test.dynamic_one_operand_two_results" && 974 "rewrite pattern should only match operations with the right name"); 975 976 OperationState state(op->getLoc(), "test.dynamic_generic", 977 op->getOperands(), op->getResultTypes(), 978 op->getAttrs()); 979 auto *newOp = rewriter.create(state); 980 rewriter.replaceOp(op, newOp->getResults()); 981 return success(); 982 } 983 }; 984 985 struct TestRewriteDynamicOpDriver 986 : public PassWrapper<TestRewriteDynamicOpDriver, 987 OperationPass<func::FuncOp>> { 988 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 989 990 void getDependentDialects(DialectRegistry ®istry) const override { 991 registry.insert<TestDialect>(); 992 } 993 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 994 StringRef getDescription() const final { 995 return "Test rewritting on dynamic operations"; 996 } 997 void runOnOperation() override { 998 RewritePatternSet patterns(&getContext()); 999 patterns.add<RewriteDynamicOp>(&getContext()); 1000 1001 ConversionTarget target(getContext()); 1002 target.addIllegalOp( 1003 OperationName("test.dynamic_one_operand_two_results", &getContext())); 1004 target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 1005 if (failed(applyPartialConversion(getOperation(), target, 1006 std::move(patterns)))) 1007 signalPassFailure(); 1008 } 1009 }; 1010 } // end anonymous namespace 1011 1012 //===----------------------------------------------------------------------===// 1013 // Test type conversions 1014 //===----------------------------------------------------------------------===// 1015 1016 namespace { 1017 struct TestTypeConversionProducer 1018 : public OpConversionPattern<TestTypeProducerOp> { 1019 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 1020 LogicalResult 1021 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 1022 ConversionPatternRewriter &rewriter) const final { 1023 Type resultType = op.getType(); 1024 Type convertedType = getTypeConverter() 1025 ? getTypeConverter()->convertType(resultType) 1026 : resultType; 1027 if (resultType.isa<FloatType>()) 1028 resultType = rewriter.getF64Type(); 1029 else if (resultType.isInteger(16)) 1030 resultType = rewriter.getIntegerType(64); 1031 else if (resultType.isa<test::TestRecursiveType>() && 1032 convertedType != resultType) 1033 resultType = convertedType; 1034 else 1035 return failure(); 1036 1037 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 1038 return success(); 1039 } 1040 }; 1041 1042 /// Call signature conversion and then fail the rewrite to trigger the undo 1043 /// mechanism. 1044 struct TestSignatureConversionUndo 1045 : public OpConversionPattern<TestSignatureConversionUndoOp> { 1046 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 1047 1048 LogicalResult 1049 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 1050 ConversionPatternRewriter &rewriter) const final { 1051 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 1052 return failure(); 1053 } 1054 }; 1055 1056 /// Call signature conversion without providing a type converter to handle 1057 /// materializations. 1058 struct TestTestSignatureConversionNoConverter 1059 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1060 TestTestSignatureConversionNoConverter(TypeConverter &converter, 1061 MLIRContext *context) 1062 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1063 converter(converter) {} 1064 1065 LogicalResult 1066 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1067 ConversionPatternRewriter &rewriter) const final { 1068 Region ®ion = op->getRegion(0); 1069 Block *entry = ®ion.front(); 1070 1071 // Convert the original entry arguments. 1072 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1073 if (failed( 1074 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1075 return failure(); 1076 rewriter.updateRootInPlace( 1077 op, [&] { rewriter.applySignatureConversion(®ion, result); }); 1078 return success(); 1079 } 1080 1081 TypeConverter &converter; 1082 }; 1083 1084 /// Just forward the operands to the root op. This is essentially a no-op 1085 /// pattern that is used to trigger target materialization. 1086 struct TestTypeConsumerForward 1087 : public OpConversionPattern<TestTypeConsumerOp> { 1088 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1089 1090 LogicalResult 1091 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1092 ConversionPatternRewriter &rewriter) const final { 1093 rewriter.updateRootInPlace(op, 1094 [&] { op->setOperands(adaptor.getOperands()); }); 1095 return success(); 1096 } 1097 }; 1098 1099 struct TestTypeConversionAnotherProducer 1100 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1101 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1102 1103 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1104 PatternRewriter &rewriter) const final { 1105 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1106 return success(); 1107 } 1108 }; 1109 1110 struct TestTypeConversionDriver 1111 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 1112 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1113 1114 void getDependentDialects(DialectRegistry ®istry) const override { 1115 registry.insert<TestDialect>(); 1116 } 1117 StringRef getArgument() const final { 1118 return "test-legalize-type-conversion"; 1119 } 1120 StringRef getDescription() const final { 1121 return "Test various type conversion functionalities in DialectConversion"; 1122 } 1123 1124 void runOnOperation() override { 1125 // Initialize the type converter. 1126 TypeConverter converter; 1127 1128 /// Add the legal set of type conversions. 1129 converter.addConversion([](Type type) -> Type { 1130 // Treat F64 as legal. 1131 if (type.isF64()) 1132 return type; 1133 // Allow converting BF16/F16/F32 to F64. 1134 if (type.isBF16() || type.isF16() || type.isF32()) 1135 return FloatType::getF64(type.getContext()); 1136 // Otherwise, the type is illegal. 1137 return nullptr; 1138 }); 1139 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1140 // Drop all integer types. 1141 return success(); 1142 }); 1143 converter.addConversion( 1144 // Convert a recursive self-referring type into a non-self-referring 1145 // type named "outer_converted_type" that contains a SimpleAType. 1146 [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results, 1147 ArrayRef<Type> callStack) -> Optional<LogicalResult> { 1148 // If the type is already converted, return it to indicate that it is 1149 // legal. 1150 if (type.getName() == "outer_converted_type") { 1151 results.push_back(type); 1152 return success(); 1153 } 1154 1155 // If the type is on the call stack more than once (it is there at 1156 // least once because of the _current_ call, which is always the last 1157 // element on the stack), we've hit the recursive case. Just return 1158 // SimpleAType here to create a non-recursive type as a result. 1159 if (llvm::is_contained(callStack.drop_back(), type)) { 1160 results.push_back(test::SimpleAType::get(type.getContext())); 1161 return success(); 1162 } 1163 1164 // Convert the body recursively. 1165 auto result = test::TestRecursiveType::get(type.getContext(), 1166 "outer_converted_type"); 1167 if (failed(result.setBody(converter.convertType(type.getBody())))) 1168 return failure(); 1169 results.push_back(result); 1170 return success(); 1171 }); 1172 1173 /// Add the legal set of type materializations. 1174 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1175 ValueRange inputs, 1176 Location loc) -> Value { 1177 // Allow casting from F64 back to F32. 1178 if (!resultType.isF16() && inputs.size() == 1 && 1179 inputs[0].getType().isF64()) 1180 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1181 // Allow producing an i32 or i64 from nothing. 1182 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1183 inputs.empty()) 1184 return builder.create<TestTypeProducerOp>(loc, resultType); 1185 // Allow producing an i64 from an integer. 1186 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 1187 inputs[0].getType().isa<IntegerType>()) 1188 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1189 // Otherwise, fail. 1190 return nullptr; 1191 }); 1192 1193 // Initialize the conversion target. 1194 mlir::ConversionTarget target(getContext()); 1195 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1196 auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>(); 1197 return op.getType().isF64() || op.getType().isInteger(64) || 1198 (recursiveType && 1199 recursiveType.getName() == "outer_converted_type"); 1200 }); 1201 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1202 return converter.isSignatureLegal(op.getFunctionType()) && 1203 converter.isLegal(&op.getBody()); 1204 }); 1205 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1206 // Allow casts from F64 to F32. 1207 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1208 }); 1209 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1210 [&](TestSignatureConversionNoConverterOp op) { 1211 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1212 }); 1213 1214 // Initialize the set of rewrite patterns. 1215 RewritePatternSet patterns(&getContext()); 1216 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1217 TestSignatureConversionUndo, 1218 TestTestSignatureConversionNoConverter>(converter, 1219 &getContext()); 1220 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1221 mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( 1222 patterns, converter); 1223 1224 if (failed(applyPartialConversion(getOperation(), target, 1225 std::move(patterns)))) 1226 signalPassFailure(); 1227 } 1228 }; 1229 } // namespace 1230 1231 //===----------------------------------------------------------------------===// 1232 // Test Target Materialization With No Uses 1233 //===----------------------------------------------------------------------===// 1234 1235 namespace { 1236 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1237 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1238 1239 LogicalResult 1240 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1241 ConversionPatternRewriter &rewriter) const final { 1242 rewriter.replaceOp(op, adaptor.getOperands()); 1243 return success(); 1244 } 1245 }; 1246 1247 struct TestTargetMaterializationWithNoUses 1248 : public PassWrapper<TestTargetMaterializationWithNoUses, 1249 OperationPass<ModuleOp>> { 1250 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1251 TestTargetMaterializationWithNoUses) 1252 1253 StringRef getArgument() const final { 1254 return "test-target-materialization-with-no-uses"; 1255 } 1256 StringRef getDescription() const final { 1257 return "Test a special case of target materialization in DialectConversion"; 1258 } 1259 1260 void runOnOperation() override { 1261 TypeConverter converter; 1262 converter.addConversion([](Type t) { return t; }); 1263 converter.addConversion([](IntegerType intTy) -> Type { 1264 if (intTy.getWidth() == 16) 1265 return IntegerType::get(intTy.getContext(), 64); 1266 return intTy; 1267 }); 1268 converter.addTargetMaterialization( 1269 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1270 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1271 }); 1272 1273 ConversionTarget target(getContext()); 1274 target.addIllegalOp<TestTypeChangerOp>(); 1275 1276 RewritePatternSet patterns(&getContext()); 1277 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1278 1279 if (failed(applyPartialConversion(getOperation(), target, 1280 std::move(patterns)))) 1281 signalPassFailure(); 1282 } 1283 }; 1284 } // namespace 1285 1286 //===----------------------------------------------------------------------===// 1287 // Test Block Merging 1288 //===----------------------------------------------------------------------===// 1289 1290 namespace { 1291 /// A rewriter pattern that tests that blocks can be merged. 1292 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1293 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1294 1295 LogicalResult 1296 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1297 ConversionPatternRewriter &rewriter) const final { 1298 Block &firstBlock = op.getBody().front(); 1299 Operation *branchOp = firstBlock.getTerminator(); 1300 Block *secondBlock = &*(std::next(op.getBody().begin())); 1301 auto succOperands = branchOp->getOperands(); 1302 SmallVector<Value, 2> replacements(succOperands); 1303 rewriter.eraseOp(branchOp); 1304 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1305 rewriter.updateRootInPlace(op, [] {}); 1306 return success(); 1307 } 1308 }; 1309 1310 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1311 struct TestUndoBlocksMerge : public ConversionPattern { 1312 TestUndoBlocksMerge(MLIRContext *ctx) 1313 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1314 LogicalResult 1315 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1316 ConversionPatternRewriter &rewriter) const final { 1317 Block &firstBlock = op->getRegion(0).front(); 1318 Operation *branchOp = firstBlock.getTerminator(); 1319 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1320 rewriter.setInsertionPointToStart(secondBlock); 1321 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1322 auto succOperands = branchOp->getOperands(); 1323 SmallVector<Value, 2> replacements(succOperands); 1324 rewriter.eraseOp(branchOp); 1325 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1326 rewriter.updateRootInPlace(op, [] {}); 1327 return success(); 1328 } 1329 }; 1330 1331 /// A rewrite mechanism to inline the body of the op into its parent, when both 1332 /// ops can have a single block. 1333 struct TestMergeSingleBlockOps 1334 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1335 using OpConversionPattern< 1336 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1337 1338 LogicalResult 1339 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1340 ConversionPatternRewriter &rewriter) const final { 1341 SingleBlockImplicitTerminatorOp parentOp = 1342 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1343 if (!parentOp) 1344 return failure(); 1345 Block &innerBlock = op.getRegion().front(); 1346 TerminatorOp innerTerminator = 1347 cast<TerminatorOp>(innerBlock.getTerminator()); 1348 rewriter.mergeBlockBefore(&innerBlock, op); 1349 rewriter.eraseOp(innerTerminator); 1350 rewriter.eraseOp(op); 1351 rewriter.updateRootInPlace(op, [] {}); 1352 return success(); 1353 } 1354 }; 1355 1356 struct TestMergeBlocksPatternDriver 1357 : public PassWrapper<TestMergeBlocksPatternDriver, 1358 OperationPass<ModuleOp>> { 1359 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 1360 1361 StringRef getArgument() const final { return "test-merge-blocks"; } 1362 StringRef getDescription() const final { 1363 return "Test Merging operation in ConversionPatternRewriter"; 1364 } 1365 void runOnOperation() override { 1366 MLIRContext *context = &getContext(); 1367 mlir::RewritePatternSet patterns(context); 1368 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1369 context); 1370 ConversionTarget target(*context); 1371 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1372 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1373 target.addIllegalOp<ILLegalOpF>(); 1374 1375 /// Expect the op to have a single block after legalization. 1376 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1377 [&](TestMergeBlocksOp op) -> bool { 1378 return llvm::hasSingleElement(op.getBody()); 1379 }); 1380 1381 /// Only allow `test.br` within test.merge_blocks op. 1382 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1383 return op->getParentOfType<TestMergeBlocksOp>(); 1384 }); 1385 1386 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1387 /// inlined. 1388 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1389 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1390 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1391 }); 1392 1393 DenseSet<Operation *> unlegalizedOps; 1394 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1395 &unlegalizedOps); 1396 for (auto *op : unlegalizedOps) 1397 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1398 } 1399 }; 1400 } // namespace 1401 1402 //===----------------------------------------------------------------------===// 1403 // Test Selective Replacement 1404 //===----------------------------------------------------------------------===// 1405 1406 namespace { 1407 /// A rewrite mechanism to inline the body of the op into its parent, when both 1408 /// ops can have a single block. 1409 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1410 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1411 1412 LogicalResult matchAndRewrite(TestCastOp op, 1413 PatternRewriter &rewriter) const final { 1414 if (op.getNumOperands() != 2) 1415 return failure(); 1416 OperandRange operands = op.getOperands(); 1417 1418 // Replace non-terminator uses with the first operand. 1419 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1420 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1421 }); 1422 // Replace everything else with the second operand if the operation isn't 1423 // dead. 1424 rewriter.replaceOp(op, op.getOperand(1)); 1425 return success(); 1426 } 1427 }; 1428 1429 struct TestSelectiveReplacementPatternDriver 1430 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1431 OperationPass<>> { 1432 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1433 TestSelectiveReplacementPatternDriver) 1434 1435 StringRef getArgument() const final { 1436 return "test-pattern-selective-replacement"; 1437 } 1438 StringRef getDescription() const final { 1439 return "Test selective replacement in the PatternRewriter"; 1440 } 1441 void runOnOperation() override { 1442 MLIRContext *context = &getContext(); 1443 mlir::RewritePatternSet patterns(context); 1444 patterns.add<TestSelectiveOpReplacementPattern>(context); 1445 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1446 std::move(patterns)); 1447 } 1448 }; 1449 } // namespace 1450 1451 //===----------------------------------------------------------------------===// 1452 // PassRegistration 1453 //===----------------------------------------------------------------------===// 1454 1455 namespace mlir { 1456 namespace test { 1457 void registerPatternsTestPass() { 1458 PassRegistration<TestReturnTypeDriver>(); 1459 1460 PassRegistration<TestDerivedAttributeDriver>(); 1461 1462 PassRegistration<TestPatternDriver>(); 1463 1464 PassRegistration<TestLegalizePatternDriver>([] { 1465 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1466 }); 1467 1468 PassRegistration<TestRemappedValue>(); 1469 1470 PassRegistration<TestUnknownRootOpDriver>(); 1471 1472 PassRegistration<TestTypeConversionDriver>(); 1473 PassRegistration<TestTargetMaterializationWithNoUses>(); 1474 1475 PassRegistration<TestRewriteDynamicOpDriver>(); 1476 1477 PassRegistration<TestMergeBlocksPatternDriver>(); 1478 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1479 } 1480 } // namespace test 1481 } // namespace mlir 1482