1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace test; 23 24 // Native function for testing NativeCodeCall 25 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 26 return choice.getValue() ? input1 : input2; 27 } 28 29 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 30 rewriter.create<OpI>(loc, input); 31 } 32 33 static void handleNoResultOp(PatternRewriter &rewriter, 34 OpSymbolBindingNoResult op) { 35 // Turn the no result op to a one-result op. 36 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 37 op.getOperand()); 38 } 39 40 static bool getFirstI32Result(Operation *op, Value &value) { 41 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 42 return false; 43 value = op->getResult(0); 44 return true; 45 } 46 47 static Value bindNativeCodeCallResult(Value value) { return value; } 48 49 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 50 Value input2) { 51 return SmallVector<Value, 2>({input2, input1}); 52 } 53 54 // Test that natives calls are only called once during rewrites. 55 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 56 // This let us check the number of times OpM_Test was called by inspecting 57 // the returned value in the MLIR output. 58 static int64_t opMIncreasingValue = 314159265; 59 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 60 int64_t i = opMIncreasingValue++; 61 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 62 } 63 64 namespace { 65 #include "TestPatterns.inc" 66 } // namespace 67 68 //===----------------------------------------------------------------------===// 69 // Test Reduce Pattern Interface 70 //===----------------------------------------------------------------------===// 71 72 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 73 populateWithGenerated(patterns); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Canonicalizer Driver. 78 //===----------------------------------------------------------------------===// 79 80 namespace { 81 struct FoldingPattern : public RewritePattern { 82 public: 83 FoldingPattern(MLIRContext *context) 84 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 85 /*benefit=*/1, context) {} 86 87 LogicalResult matchAndRewrite(Operation *op, 88 PatternRewriter &rewriter) const override { 89 // Exercise OperationFolder API for a single-result operation that is folded 90 // upon construction. The operation being created through the folder has an 91 // in-place folder, and it should be still present in the output. 92 // Furthermore, the folder should not crash when attempting to recover the 93 // (unchanged) operation result. 94 OperationFolder folder(op->getContext()); 95 Value result = folder.create<TestOpInPlaceFold>( 96 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 97 rewriter.getI32IntegerAttr(0)); 98 assert(result); 99 rewriter.replaceOp(op, result); 100 return success(); 101 } 102 }; 103 104 /// This pattern creates a foldable operation at the entry point of the block. 105 /// This tests the situation where the operation folder will need to replace an 106 /// operation with a previously created constant that does not initially 107 /// dominate the operation to replace. 108 struct FolderInsertBeforePreviouslyFoldedConstantPattern 109 : public OpRewritePattern<TestCastOp> { 110 public: 111 using OpRewritePattern<TestCastOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(TestCastOp op, 114 PatternRewriter &rewriter) const override { 115 if (!op->hasAttr("test_fold_before_previously_folded_op")) 116 return failure(); 117 rewriter.setInsertionPointToStart(op->getBlock()); 118 119 auto constOp = rewriter.create<arith::ConstantOp>( 120 op.getLoc(), rewriter.getBoolAttr(true)); 121 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 122 Value(constOp)); 123 return success(); 124 } 125 }; 126 127 /// This pattern matches test.op_commutative2 with the first operand being 128 /// another test.op_commutative2 with a constant on the right side and fold it 129 /// away by propagating it as its result. This is intend to check that patterns 130 /// are applied after the commutative property moves constant to the right. 131 struct FolderCommutativeOp2WithConstant 132 : public OpRewritePattern<TestCommutative2Op> { 133 public: 134 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 135 136 LogicalResult matchAndRewrite(TestCommutative2Op op, 137 PatternRewriter &rewriter) const override { 138 auto operand = 139 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 140 if (!operand) 141 return failure(); 142 Attribute constInput; 143 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 144 return failure(); 145 rewriter.replaceOp(op, operand->getOperand(1)); 146 return success(); 147 } 148 }; 149 150 struct TestPatternDriver 151 : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> { 152 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) 153 154 StringRef getArgument() const final { return "test-patterns"; } 155 StringRef getDescription() const final { return "Run test dialect patterns"; } 156 void runOnOperation() override { 157 mlir::RewritePatternSet patterns(&getContext()); 158 populateWithGenerated(patterns); 159 160 // Verify named pattern is generated with expected name. 161 patterns.add<FoldingPattern, TestNamedPatternRule, 162 FolderInsertBeforePreviouslyFoldedConstantPattern, 163 FolderCommutativeOp2WithConstant>(&getContext()); 164 165 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 166 } 167 }; 168 } // namespace 169 170 //===----------------------------------------------------------------------===// 171 // ReturnType Driver. 172 //===----------------------------------------------------------------------===// 173 174 namespace { 175 // Generate ops for each instance where the type can be successfully inferred. 176 template <typename OpTy> 177 static void invokeCreateWithInferredReturnType(Operation *op) { 178 auto *context = op->getContext(); 179 auto fop = op->getParentOfType<func::FuncOp>(); 180 auto location = UnknownLoc::get(context); 181 OpBuilder b(op); 182 b.setInsertionPointAfter(op); 183 184 // Use permutations of 2 args as operands. 185 assert(fop.getNumArguments() >= 2); 186 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 187 for (int j = 0; j < e; ++j) { 188 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 189 SmallVector<Type, 2> inferredReturnTypes; 190 if (succeeded(OpTy::inferReturnTypes( 191 context, llvm::None, values, op->getAttrDictionary(), 192 op->getRegions(), inferredReturnTypes))) { 193 OperationState state(location, OpTy::getOperationName()); 194 // TODO: Expand to regions. 195 OpTy::build(b, state, values, op->getAttrs()); 196 (void)b.create(state); 197 } 198 } 199 } 200 } 201 202 static void reifyReturnShape(Operation *op) { 203 OpBuilder b(op); 204 205 // Use permutations of 2 args as operands. 206 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 207 SmallVector<Value, 2> shapes; 208 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 209 !llvm::hasSingleElement(shapes)) 210 return; 211 for (const auto &it : llvm::enumerate(shapes)) { 212 op->emitRemark() << "value " << it.index() << ": " 213 << it.value().getDefiningOp(); 214 } 215 } 216 217 struct TestReturnTypeDriver 218 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 219 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 220 221 void getDependentDialects(DialectRegistry ®istry) const override { 222 registry.insert<tensor::TensorDialect>(); 223 } 224 StringRef getArgument() const final { return "test-return-type"; } 225 StringRef getDescription() const final { return "Run return type functions"; } 226 227 void runOnOperation() override { 228 if (getOperation().getName() == "testCreateFunctions") { 229 std::vector<Operation *> ops; 230 // Collect ops to avoid triggering on inserted ops. 231 for (auto &op : getOperation().getBody().front()) 232 ops.push_back(&op); 233 // Generate test patterns for each, but skip terminator. 234 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 235 // Test create method of each of the Op classes below. The resultant 236 // output would be in reverse order underneath `op` from which 237 // the attributes and regions are used. 238 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 239 invokeCreateWithInferredReturnType< 240 OpWithShapedTypeInferTypeInterfaceOp>(op); 241 }; 242 return; 243 } 244 if (getOperation().getName() == "testReifyFunctions") { 245 std::vector<Operation *> ops; 246 // Collect ops to avoid triggering on inserted ops. 247 for (auto &op : getOperation().getBody().front()) 248 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 249 ops.push_back(&op); 250 // Generate test patterns for each, but skip terminator. 251 for (auto *op : ops) 252 reifyReturnShape(op); 253 } 254 } 255 }; 256 } // namespace 257 258 namespace { 259 struct TestDerivedAttributeDriver 260 : public PassWrapper<TestDerivedAttributeDriver, 261 OperationPass<func::FuncOp>> { 262 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 263 264 StringRef getArgument() const final { return "test-derived-attr"; } 265 StringRef getDescription() const final { 266 return "Run test derived attributes"; 267 } 268 void runOnOperation() override; 269 }; 270 } // namespace 271 272 void TestDerivedAttributeDriver::runOnOperation() { 273 getOperation().walk([](DerivedAttributeOpInterface dOp) { 274 auto dAttr = dOp.materializeDerivedAttributes(); 275 if (!dAttr) 276 return; 277 for (auto d : dAttr) 278 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 279 }); 280 } 281 282 //===----------------------------------------------------------------------===// 283 // Legalization Driver. 284 //===----------------------------------------------------------------------===// 285 286 namespace { 287 //===----------------------------------------------------------------------===// 288 // Region-Block Rewrite Testing 289 290 /// This pattern is a simple pattern that inlines the first region of a given 291 /// operation into the parent region. 292 struct TestRegionRewriteBlockMovement : public ConversionPattern { 293 TestRegionRewriteBlockMovement(MLIRContext *ctx) 294 : ConversionPattern("test.region", 1, ctx) {} 295 296 LogicalResult 297 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 298 ConversionPatternRewriter &rewriter) const final { 299 // Inline this region into the parent region. 300 auto &parentRegion = *op->getParentRegion(); 301 auto &opRegion = op->getRegion(0); 302 if (op->getAttr("legalizer.should_clone")) 303 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 304 else 305 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 306 307 if (op->getAttr("legalizer.erase_old_blocks")) { 308 while (!opRegion.empty()) 309 rewriter.eraseBlock(&opRegion.front()); 310 } 311 312 // Drop this operation. 313 rewriter.eraseOp(op); 314 return success(); 315 } 316 }; 317 /// This pattern is a simple pattern that generates a region containing an 318 /// illegal operation. 319 struct TestRegionRewriteUndo : public RewritePattern { 320 TestRegionRewriteUndo(MLIRContext *ctx) 321 : RewritePattern("test.region_builder", 1, ctx) {} 322 323 LogicalResult matchAndRewrite(Operation *op, 324 PatternRewriter &rewriter) const final { 325 // Create the region operation with an entry block containing arguments. 326 OperationState newRegion(op->getLoc(), "test.region"); 327 newRegion.addRegion(); 328 auto *regionOp = rewriter.create(newRegion); 329 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 330 entryBlock->addArgument(rewriter.getIntegerType(64), 331 rewriter.getUnknownLoc()); 332 333 // Add an explicitly illegal operation to ensure the conversion fails. 334 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 335 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 336 337 // Drop this operation. 338 rewriter.eraseOp(op); 339 return success(); 340 } 341 }; 342 /// A simple pattern that creates a block at the end of the parent region of the 343 /// matched operation. 344 struct TestCreateBlock : public RewritePattern { 345 TestCreateBlock(MLIRContext *ctx) 346 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 347 348 LogicalResult matchAndRewrite(Operation *op, 349 PatternRewriter &rewriter) const final { 350 Region ®ion = *op->getParentRegion(); 351 Type i32Type = rewriter.getIntegerType(32); 352 Location loc = op->getLoc(); 353 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 354 rewriter.create<TerminatorOp>(loc); 355 rewriter.replaceOp(op, {}); 356 return success(); 357 } 358 }; 359 360 /// A simple pattern that creates a block containing an invalid operation in 361 /// order to trigger the block creation undo mechanism. 362 struct TestCreateIllegalBlock : public RewritePattern { 363 TestCreateIllegalBlock(MLIRContext *ctx) 364 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 365 366 LogicalResult matchAndRewrite(Operation *op, 367 PatternRewriter &rewriter) const final { 368 Region ®ion = *op->getParentRegion(); 369 Type i32Type = rewriter.getIntegerType(32); 370 Location loc = op->getLoc(); 371 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 372 // Create an illegal op to ensure the conversion fails. 373 rewriter.create<ILLegalOpF>(loc, i32Type); 374 rewriter.create<TerminatorOp>(loc); 375 rewriter.replaceOp(op, {}); 376 return success(); 377 } 378 }; 379 380 /// A simple pattern that tests the undo mechanism when replacing the uses of a 381 /// block argument. 382 struct TestUndoBlockArgReplace : public ConversionPattern { 383 TestUndoBlockArgReplace(MLIRContext *ctx) 384 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 385 386 LogicalResult 387 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 388 ConversionPatternRewriter &rewriter) const final { 389 auto illegalOp = 390 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 391 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 392 illegalOp); 393 rewriter.updateRootInPlace(op, [] {}); 394 return success(); 395 } 396 }; 397 398 /// A rewrite pattern that tests the undo mechanism when erasing a block. 399 struct TestUndoBlockErase : public ConversionPattern { 400 TestUndoBlockErase(MLIRContext *ctx) 401 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 402 403 LogicalResult 404 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 405 ConversionPatternRewriter &rewriter) const final { 406 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 407 rewriter.setInsertionPointToStart(secondBlock); 408 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 409 rewriter.eraseBlock(secondBlock); 410 rewriter.updateRootInPlace(op, [] {}); 411 return success(); 412 } 413 }; 414 415 //===----------------------------------------------------------------------===// 416 // Type-Conversion Rewrite Testing 417 418 /// This patterns erases a region operation that has had a type conversion. 419 struct TestDropOpSignatureConversion : public ConversionPattern { 420 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 421 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 422 LogicalResult 423 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 424 ConversionPatternRewriter &rewriter) const override { 425 Region ®ion = op->getRegion(0); 426 Block *entry = ®ion.front(); 427 428 // Convert the original entry arguments. 429 TypeConverter &converter = *getTypeConverter(); 430 TypeConverter::SignatureConversion result(entry->getNumArguments()); 431 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 432 result)) || 433 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 434 return failure(); 435 436 // Convert the region signature and just drop the operation. 437 rewriter.eraseOp(op); 438 return success(); 439 } 440 }; 441 /// This pattern simply updates the operands of the given operation. 442 struct TestPassthroughInvalidOp : public ConversionPattern { 443 TestPassthroughInvalidOp(MLIRContext *ctx) 444 : ConversionPattern("test.invalid", 1, ctx) {} 445 LogicalResult 446 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 447 ConversionPatternRewriter &rewriter) const final { 448 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 449 llvm::None); 450 return success(); 451 } 452 }; 453 /// This pattern handles the case of a split return value. 454 struct TestSplitReturnType : public ConversionPattern { 455 TestSplitReturnType(MLIRContext *ctx) 456 : ConversionPattern("test.return", 1, ctx) {} 457 LogicalResult 458 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 459 ConversionPatternRewriter &rewriter) const final { 460 // Check for a return of F32. 461 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 462 return failure(); 463 464 // Check if the first operation is a cast operation, if it is we use the 465 // results directly. 466 auto *defOp = operands[0].getDefiningOp(); 467 if (auto packerOp = 468 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 469 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 470 return success(); 471 } 472 473 // Otherwise, fail to match. 474 return failure(); 475 } 476 }; 477 478 //===----------------------------------------------------------------------===// 479 // Multi-Level Type-Conversion Rewrite Testing 480 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 481 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 482 : ConversionPattern("test.type_producer", 1, ctx) {} 483 LogicalResult 484 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 485 ConversionPatternRewriter &rewriter) const final { 486 // If the type is I32, change the type to F32. 487 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 488 return failure(); 489 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 490 return success(); 491 } 492 }; 493 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 494 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 495 : ConversionPattern("test.type_producer", 1, ctx) {} 496 LogicalResult 497 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 498 ConversionPatternRewriter &rewriter) const final { 499 // If the type is F32, change the type to F64. 500 if (!Type(*op->result_type_begin()).isF32()) 501 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 502 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 503 return success(); 504 } 505 }; 506 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 507 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 508 : ConversionPattern("test.type_producer", 10, ctx) {} 509 LogicalResult 510 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 511 ConversionPatternRewriter &rewriter) const final { 512 // Always convert to B16, even though it is not a legal type. This tests 513 // that values are unmapped correctly. 514 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 515 return success(); 516 } 517 }; 518 struct TestUpdateConsumerType : public ConversionPattern { 519 TestUpdateConsumerType(MLIRContext *ctx) 520 : ConversionPattern("test.type_consumer", 1, ctx) {} 521 LogicalResult 522 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 523 ConversionPatternRewriter &rewriter) const final { 524 // Verify that the incoming operand has been successfully remapped to F64. 525 if (!operands[0].getType().isF64()) 526 return failure(); 527 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 528 return success(); 529 } 530 }; 531 532 //===----------------------------------------------------------------------===// 533 // Non-Root Replacement Rewrite Testing 534 /// This pattern generates an invalid operation, but replaces it before the 535 /// pattern is finished. This checks that we don't need to legalize the 536 /// temporary op. 537 struct TestNonRootReplacement : public RewritePattern { 538 TestNonRootReplacement(MLIRContext *ctx) 539 : RewritePattern("test.replace_non_root", 1, ctx) {} 540 541 LogicalResult matchAndRewrite(Operation *op, 542 PatternRewriter &rewriter) const final { 543 auto resultType = *op->result_type_begin(); 544 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 545 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 546 547 rewriter.replaceOp(illegalOp, {legalOp}); 548 rewriter.replaceOp(op, {illegalOp}); 549 return success(); 550 } 551 }; 552 553 //===----------------------------------------------------------------------===// 554 // Recursive Rewrite Testing 555 /// This pattern is applied to the same operation multiple times, but has a 556 /// bounded recursion. 557 struct TestBoundedRecursiveRewrite 558 : public OpRewritePattern<TestRecursiveRewriteOp> { 559 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 560 561 void initialize() { 562 // The conversion target handles bounding the recursion of this pattern. 563 setHasBoundedRewriteRecursion(); 564 } 565 566 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 567 PatternRewriter &rewriter) const final { 568 // Decrement the depth of the op in-place. 569 rewriter.updateRootInPlace(op, [&] { 570 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 571 }); 572 return success(); 573 } 574 }; 575 576 struct TestNestedOpCreationUndoRewrite 577 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 578 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 579 580 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 581 PatternRewriter &rewriter) const final { 582 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 583 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 584 return success(); 585 }; 586 }; 587 588 // This pattern matches `test.blackhole` and delete this op and its producer. 589 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 590 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 591 592 LogicalResult matchAndRewrite(BlackHoleOp op, 593 PatternRewriter &rewriter) const final { 594 Operation *producer = op.getOperand().getDefiningOp(); 595 // Always erase the user before the producer, the framework should handle 596 // this correctly. 597 rewriter.eraseOp(op); 598 rewriter.eraseOp(producer); 599 return success(); 600 }; 601 }; 602 603 // This pattern replaces explicitly illegal op with explicitly legal op, 604 // but in addition creates unregistered operation. 605 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 606 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 607 608 LogicalResult matchAndRewrite(ILLegalOpG op, 609 PatternRewriter &rewriter) const final { 610 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 611 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 612 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 613 return success(); 614 }; 615 }; 616 } // namespace 617 618 namespace { 619 struct TestTypeConverter : public TypeConverter { 620 using TypeConverter::TypeConverter; 621 TestTypeConverter() { 622 addConversion(convertType); 623 addArgumentMaterialization(materializeCast); 624 addSourceMaterialization(materializeCast); 625 } 626 627 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 628 // Drop I16 types. 629 if (t.isSignlessInteger(16)) 630 return success(); 631 632 // Convert I64 to F64. 633 if (t.isSignlessInteger(64)) { 634 results.push_back(FloatType::getF64(t.getContext())); 635 return success(); 636 } 637 638 // Convert I42 to I43. 639 if (t.isInteger(42)) { 640 results.push_back(IntegerType::get(t.getContext(), 43)); 641 return success(); 642 } 643 644 // Split F32 into F16,F16. 645 if (t.isF32()) { 646 results.assign(2, FloatType::getF16(t.getContext())); 647 return success(); 648 } 649 650 // Otherwise, convert the type directly. 651 results.push_back(t); 652 return success(); 653 } 654 655 /// Hook for materializing a conversion. This is necessary because we generate 656 /// 1->N type mappings. 657 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 658 ValueRange inputs, Location loc) { 659 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 660 } 661 }; 662 663 struct TestLegalizePatternDriver 664 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 665 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 666 667 StringRef getArgument() const final { return "test-legalize-patterns"; } 668 StringRef getDescription() const final { 669 return "Run test dialect legalization patterns"; 670 } 671 /// The mode of conversion to use with the driver. 672 enum class ConversionMode { Analysis, Full, Partial }; 673 674 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 675 676 void getDependentDialects(DialectRegistry ®istry) const override { 677 registry.insert<func::FuncDialect>(); 678 } 679 680 void runOnOperation() override { 681 TestTypeConverter converter; 682 mlir::RewritePatternSet patterns(&getContext()); 683 populateWithGenerated(patterns); 684 patterns 685 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 686 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 687 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 688 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 689 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 690 TestNonRootReplacement, TestBoundedRecursiveRewrite, 691 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 692 TestCreateUnregisteredOp>(&getContext()); 693 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 694 mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( 695 patterns, converter); 696 mlir::populateCallOpTypeConversionPattern(patterns, converter); 697 698 // Define the conversion target used for the test. 699 ConversionTarget target(getContext()); 700 target.addLegalOp<ModuleOp>(); 701 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 702 TerminatorOp>(); 703 target 704 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 705 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 706 // Don't allow F32 operands. 707 return llvm::none_of(op.getOperandTypes(), 708 [](Type type) { return type.isF32(); }); 709 }); 710 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 711 return converter.isSignatureLegal(op.getFunctionType()) && 712 converter.isLegal(&op.getBody()); 713 }); 714 target.addDynamicallyLegalOp<func::CallOp>( 715 [&](func::CallOp op) { return converter.isLegal(op); }); 716 717 // TestCreateUnregisteredOp creates `arith.constant` operation, 718 // which was not added to target intentionally to test 719 // correct error code from conversion driver. 720 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 721 722 // Expect the type_producer/type_consumer operations to only operate on f64. 723 target.addDynamicallyLegalOp<TestTypeProducerOp>( 724 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 725 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 726 return op.getOperand().getType().isF64(); 727 }); 728 729 // Check support for marking certain operations as recursively legal. 730 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 731 return static_cast<bool>( 732 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 733 }); 734 735 // Mark the bound recursion operation as dynamically legal. 736 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 737 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 738 739 // Handle a partial conversion. 740 if (mode == ConversionMode::Partial) { 741 DenseSet<Operation *> unlegalizedOps; 742 if (failed(applyPartialConversion( 743 getOperation(), target, std::move(patterns), &unlegalizedOps))) { 744 getOperation()->emitRemark() << "applyPartialConversion failed"; 745 } 746 // Emit remarks for each legalizable operation. 747 for (auto *op : unlegalizedOps) 748 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 749 return; 750 } 751 752 // Handle a full conversion. 753 if (mode == ConversionMode::Full) { 754 // Check support for marking unknown operations as dynamically legal. 755 target.markUnknownOpDynamicallyLegal([](Operation *op) { 756 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 757 }); 758 759 if (failed(applyFullConversion(getOperation(), target, 760 std::move(patterns)))) { 761 getOperation()->emitRemark() << "applyFullConversion failed"; 762 } 763 return; 764 } 765 766 // Otherwise, handle an analysis conversion. 767 assert(mode == ConversionMode::Analysis); 768 769 // Analyze the convertible operations. 770 DenseSet<Operation *> legalizedOps; 771 if (failed(applyAnalysisConversion(getOperation(), target, 772 std::move(patterns), legalizedOps))) 773 return signalPassFailure(); 774 775 // Emit remarks for each legalizable operation. 776 for (auto *op : legalizedOps) 777 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 778 } 779 780 /// The mode of conversion to use. 781 ConversionMode mode; 782 }; 783 } // namespace 784 785 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 786 legalizerConversionMode( 787 "test-legalize-mode", 788 llvm::cl::desc("The legalization mode to use with the test driver"), 789 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 790 llvm::cl::values( 791 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 792 "analysis", "Perform an analysis conversion"), 793 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 794 "Perform a full conversion"), 795 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 796 "partial", "Perform a partial conversion"))); 797 798 //===----------------------------------------------------------------------===// 799 // ConversionPatternRewriter::getRemappedValue testing. This method is used 800 // to get the remapped value of an original value that was replaced using 801 // ConversionPatternRewriter. 802 namespace { 803 struct TestRemapValueTypeConverter : public TypeConverter { 804 using TypeConverter::TypeConverter; 805 806 TestRemapValueTypeConverter() { 807 addConversion( 808 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 809 addConversion([](Type type) { return type; }); 810 } 811 }; 812 813 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 814 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 815 /// operand twice. 816 /// 817 /// Example: 818 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 819 /// is replaced with: 820 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 821 struct OneVResOneVOperandOp1Converter 822 : public OpConversionPattern<OneVResOneVOperandOp1> { 823 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 824 825 LogicalResult 826 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 827 ConversionPatternRewriter &rewriter) const override { 828 auto origOps = op.getOperands(); 829 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 830 "One operand expected"); 831 Value origOp = *origOps.begin(); 832 SmallVector<Value, 2> remappedOperands; 833 // Replicate the remapped original operand twice. Note that we don't used 834 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 835 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 836 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 837 838 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 839 remappedOperands); 840 return success(); 841 } 842 }; 843 844 /// A rewriter pattern that tests that blocks can be merged. 845 struct TestRemapValueInRegion 846 : public OpConversionPattern<TestRemappedValueRegionOp> { 847 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 848 849 LogicalResult 850 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 851 ConversionPatternRewriter &rewriter) const final { 852 Block &block = op.getBody().front(); 853 Operation *terminator = block.getTerminator(); 854 855 // Merge the block into the parent region. 856 Block *parentBlock = op->getBlock(); 857 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 858 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 859 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 860 861 // Replace the results of this operation with the remapped terminator 862 // values. 863 SmallVector<Value> terminatorOperands; 864 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 865 terminatorOperands))) 866 return failure(); 867 868 rewriter.eraseOp(terminator); 869 rewriter.replaceOp(op, terminatorOperands); 870 return success(); 871 } 872 }; 873 874 struct TestRemappedValue 875 : public mlir::PassWrapper<TestRemappedValue, OperationPass<func::FuncOp>> { 876 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 877 878 StringRef getArgument() const final { return "test-remapped-value"; } 879 StringRef getDescription() const final { 880 return "Test public remapped value mechanism in ConversionPatternRewriter"; 881 } 882 void runOnOperation() override { 883 TestRemapValueTypeConverter typeConverter; 884 885 mlir::RewritePatternSet patterns(&getContext()); 886 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 887 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 888 &getContext()); 889 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 890 891 mlir::ConversionTarget target(getContext()); 892 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 893 894 // Expect the type_producer/type_consumer operations to only operate on f64. 895 target.addDynamicallyLegalOp<TestTypeProducerOp>( 896 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 897 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 898 return op.getOperand().getType().isF64(); 899 }); 900 901 // We make OneVResOneVOperandOp1 legal only when it has more that one 902 // operand. This will trigger the conversion that will replace one-operand 903 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 904 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 905 [](Operation *op) { return op->getNumOperands() > 1; }); 906 907 if (failed(mlir::applyFullConversion(getOperation(), target, 908 std::move(patterns)))) { 909 signalPassFailure(); 910 } 911 } 912 }; 913 } // namespace 914 915 //===----------------------------------------------------------------------===// 916 // Test patterns without a specific root operation kind 917 //===----------------------------------------------------------------------===// 918 919 namespace { 920 /// This pattern matches and removes any operation in the test dialect. 921 struct RemoveTestDialectOps : public RewritePattern { 922 RemoveTestDialectOps(MLIRContext *context) 923 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 924 925 LogicalResult matchAndRewrite(Operation *op, 926 PatternRewriter &rewriter) const override { 927 if (!isa<TestDialect>(op->getDialect())) 928 return failure(); 929 rewriter.eraseOp(op); 930 return success(); 931 } 932 }; 933 934 struct TestUnknownRootOpDriver 935 : public mlir::PassWrapper<TestUnknownRootOpDriver, 936 OperationPass<func::FuncOp>> { 937 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 938 939 StringRef getArgument() const final { 940 return "test-legalize-unknown-root-patterns"; 941 } 942 StringRef getDescription() const final { 943 return "Test public remapped value mechanism in ConversionPatternRewriter"; 944 } 945 void runOnOperation() override { 946 mlir::RewritePatternSet patterns(&getContext()); 947 patterns.add<RemoveTestDialectOps>(&getContext()); 948 949 mlir::ConversionTarget target(getContext()); 950 target.addIllegalDialect<TestDialect>(); 951 if (failed(applyPartialConversion(getOperation(), target, 952 std::move(patterns)))) 953 signalPassFailure(); 954 } 955 }; 956 } // namespace 957 958 //===----------------------------------------------------------------------===// 959 // Test type conversions 960 //===----------------------------------------------------------------------===// 961 962 namespace { 963 struct TestTypeConversionProducer 964 : public OpConversionPattern<TestTypeProducerOp> { 965 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 966 LogicalResult 967 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 968 ConversionPatternRewriter &rewriter) const final { 969 Type resultType = op.getType(); 970 Type convertedType = getTypeConverter() 971 ? getTypeConverter()->convertType(resultType) 972 : resultType; 973 if (resultType.isa<FloatType>()) 974 resultType = rewriter.getF64Type(); 975 else if (resultType.isInteger(16)) 976 resultType = rewriter.getIntegerType(64); 977 else if (resultType.isa<test::TestRecursiveType>() && 978 convertedType != resultType) 979 resultType = convertedType; 980 else 981 return failure(); 982 983 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 984 return success(); 985 } 986 }; 987 988 /// Call signature conversion and then fail the rewrite to trigger the undo 989 /// mechanism. 990 struct TestSignatureConversionUndo 991 : public OpConversionPattern<TestSignatureConversionUndoOp> { 992 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 993 994 LogicalResult 995 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 996 ConversionPatternRewriter &rewriter) const final { 997 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 998 return failure(); 999 } 1000 }; 1001 1002 /// Call signature conversion without providing a type converter to handle 1003 /// materializations. 1004 struct TestTestSignatureConversionNoConverter 1005 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1006 TestTestSignatureConversionNoConverter(TypeConverter &converter, 1007 MLIRContext *context) 1008 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 1009 converter(converter) {} 1010 1011 LogicalResult 1012 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 1013 ConversionPatternRewriter &rewriter) const final { 1014 Region ®ion = op->getRegion(0); 1015 Block *entry = ®ion.front(); 1016 1017 // Convert the original entry arguments. 1018 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1019 if (failed( 1020 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1021 return failure(); 1022 rewriter.updateRootInPlace( 1023 op, [&] { rewriter.applySignatureConversion(®ion, result); }); 1024 return success(); 1025 } 1026 1027 TypeConverter &converter; 1028 }; 1029 1030 /// Just forward the operands to the root op. This is essentially a no-op 1031 /// pattern that is used to trigger target materialization. 1032 struct TestTypeConsumerForward 1033 : public OpConversionPattern<TestTypeConsumerOp> { 1034 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1035 1036 LogicalResult 1037 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1038 ConversionPatternRewriter &rewriter) const final { 1039 rewriter.updateRootInPlace(op, 1040 [&] { op->setOperands(adaptor.getOperands()); }); 1041 return success(); 1042 } 1043 }; 1044 1045 struct TestTypeConversionAnotherProducer 1046 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1047 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1048 1049 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1050 PatternRewriter &rewriter) const final { 1051 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1052 return success(); 1053 } 1054 }; 1055 1056 struct TestTypeConversionDriver 1057 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 1058 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 1059 1060 void getDependentDialects(DialectRegistry ®istry) const override { 1061 registry.insert<TestDialect>(); 1062 } 1063 StringRef getArgument() const final { 1064 return "test-legalize-type-conversion"; 1065 } 1066 StringRef getDescription() const final { 1067 return "Test various type conversion functionalities in DialectConversion"; 1068 } 1069 1070 void runOnOperation() override { 1071 // Initialize the type converter. 1072 TypeConverter converter; 1073 1074 /// Add the legal set of type conversions. 1075 converter.addConversion([](Type type) -> Type { 1076 // Treat F64 as legal. 1077 if (type.isF64()) 1078 return type; 1079 // Allow converting BF16/F16/F32 to F64. 1080 if (type.isBF16() || type.isF16() || type.isF32()) 1081 return FloatType::getF64(type.getContext()); 1082 // Otherwise, the type is illegal. 1083 return nullptr; 1084 }); 1085 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1086 // Drop all integer types. 1087 return success(); 1088 }); 1089 converter.addConversion( 1090 // Convert a recursive self-referring type into a non-self-referring 1091 // type named "outer_converted_type" that contains a SimpleAType. 1092 [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results, 1093 ArrayRef<Type> callStack) -> Optional<LogicalResult> { 1094 // If the type is already converted, return it to indicate that it is 1095 // legal. 1096 if (type.getName() == "outer_converted_type") { 1097 results.push_back(type); 1098 return success(); 1099 } 1100 1101 // If the type is on the call stack more than once (it is there at 1102 // least once because of the _current_ call, which is always the last 1103 // element on the stack), we've hit the recursive case. Just return 1104 // SimpleAType here to create a non-recursive type as a result. 1105 if (llvm::is_contained(callStack.drop_back(), type)) { 1106 results.push_back(test::SimpleAType::get(type.getContext())); 1107 return success(); 1108 } 1109 1110 // Convert the body recursively. 1111 auto result = test::TestRecursiveType::get(type.getContext(), 1112 "outer_converted_type"); 1113 if (failed(result.setBody(converter.convertType(type.getBody())))) 1114 return failure(); 1115 results.push_back(result); 1116 return success(); 1117 }); 1118 1119 /// Add the legal set of type materializations. 1120 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1121 ValueRange inputs, 1122 Location loc) -> Value { 1123 // Allow casting from F64 back to F32. 1124 if (!resultType.isF16() && inputs.size() == 1 && 1125 inputs[0].getType().isF64()) 1126 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1127 // Allow producing an i32 or i64 from nothing. 1128 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1129 inputs.empty()) 1130 return builder.create<TestTypeProducerOp>(loc, resultType); 1131 // Allow producing an i64 from an integer. 1132 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 1133 inputs[0].getType().isa<IntegerType>()) 1134 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1135 // Otherwise, fail. 1136 return nullptr; 1137 }); 1138 1139 // Initialize the conversion target. 1140 mlir::ConversionTarget target(getContext()); 1141 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1142 auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>(); 1143 return op.getType().isF64() || op.getType().isInteger(64) || 1144 (recursiveType && 1145 recursiveType.getName() == "outer_converted_type"); 1146 }); 1147 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1148 return converter.isSignatureLegal(op.getFunctionType()) && 1149 converter.isLegal(&op.getBody()); 1150 }); 1151 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1152 // Allow casts from F64 to F32. 1153 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1154 }); 1155 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1156 [&](TestSignatureConversionNoConverterOp op) { 1157 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1158 }); 1159 1160 // Initialize the set of rewrite patterns. 1161 RewritePatternSet patterns(&getContext()); 1162 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1163 TestSignatureConversionUndo, 1164 TestTestSignatureConversionNoConverter>(converter, 1165 &getContext()); 1166 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1167 mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( 1168 patterns, converter); 1169 1170 if (failed(applyPartialConversion(getOperation(), target, 1171 std::move(patterns)))) 1172 signalPassFailure(); 1173 } 1174 }; 1175 } // namespace 1176 1177 //===----------------------------------------------------------------------===// 1178 // Test Target Materialization With No Uses 1179 //===----------------------------------------------------------------------===// 1180 1181 namespace { 1182 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1183 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1184 1185 LogicalResult 1186 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1187 ConversionPatternRewriter &rewriter) const final { 1188 rewriter.replaceOp(op, adaptor.getOperands()); 1189 return success(); 1190 } 1191 }; 1192 1193 struct TestTargetMaterializationWithNoUses 1194 : public PassWrapper<TestTargetMaterializationWithNoUses, 1195 OperationPass<ModuleOp>> { 1196 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1197 TestTargetMaterializationWithNoUses) 1198 1199 StringRef getArgument() const final { 1200 return "test-target-materialization-with-no-uses"; 1201 } 1202 StringRef getDescription() const final { 1203 return "Test a special case of target materialization in DialectConversion"; 1204 } 1205 1206 void runOnOperation() override { 1207 TypeConverter converter; 1208 converter.addConversion([](Type t) { return t; }); 1209 converter.addConversion([](IntegerType intTy) -> Type { 1210 if (intTy.getWidth() == 16) 1211 return IntegerType::get(intTy.getContext(), 64); 1212 return intTy; 1213 }); 1214 converter.addTargetMaterialization( 1215 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1216 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1217 }); 1218 1219 ConversionTarget target(getContext()); 1220 target.addIllegalOp<TestTypeChangerOp>(); 1221 1222 RewritePatternSet patterns(&getContext()); 1223 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1224 1225 if (failed(applyPartialConversion(getOperation(), target, 1226 std::move(patterns)))) 1227 signalPassFailure(); 1228 } 1229 }; 1230 } // namespace 1231 1232 //===----------------------------------------------------------------------===// 1233 // Test Block Merging 1234 //===----------------------------------------------------------------------===// 1235 1236 namespace { 1237 /// A rewriter pattern that tests that blocks can be merged. 1238 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1239 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1240 1241 LogicalResult 1242 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1243 ConversionPatternRewriter &rewriter) const final { 1244 Block &firstBlock = op.getBody().front(); 1245 Operation *branchOp = firstBlock.getTerminator(); 1246 Block *secondBlock = &*(std::next(op.getBody().begin())); 1247 auto succOperands = branchOp->getOperands(); 1248 SmallVector<Value, 2> replacements(succOperands); 1249 rewriter.eraseOp(branchOp); 1250 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1251 rewriter.updateRootInPlace(op, [] {}); 1252 return success(); 1253 } 1254 }; 1255 1256 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1257 struct TestUndoBlocksMerge : public ConversionPattern { 1258 TestUndoBlocksMerge(MLIRContext *ctx) 1259 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1260 LogicalResult 1261 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1262 ConversionPatternRewriter &rewriter) const final { 1263 Block &firstBlock = op->getRegion(0).front(); 1264 Operation *branchOp = firstBlock.getTerminator(); 1265 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1266 rewriter.setInsertionPointToStart(secondBlock); 1267 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1268 auto succOperands = branchOp->getOperands(); 1269 SmallVector<Value, 2> replacements(succOperands); 1270 rewriter.eraseOp(branchOp); 1271 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1272 rewriter.updateRootInPlace(op, [] {}); 1273 return success(); 1274 } 1275 }; 1276 1277 /// A rewrite mechanism to inline the body of the op into its parent, when both 1278 /// ops can have a single block. 1279 struct TestMergeSingleBlockOps 1280 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1281 using OpConversionPattern< 1282 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1283 1284 LogicalResult 1285 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1286 ConversionPatternRewriter &rewriter) const final { 1287 SingleBlockImplicitTerminatorOp parentOp = 1288 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1289 if (!parentOp) 1290 return failure(); 1291 Block &innerBlock = op.getRegion().front(); 1292 TerminatorOp innerTerminator = 1293 cast<TerminatorOp>(innerBlock.getTerminator()); 1294 rewriter.mergeBlockBefore(&innerBlock, op); 1295 rewriter.eraseOp(innerTerminator); 1296 rewriter.eraseOp(op); 1297 rewriter.updateRootInPlace(op, [] {}); 1298 return success(); 1299 } 1300 }; 1301 1302 struct TestMergeBlocksPatternDriver 1303 : public PassWrapper<TestMergeBlocksPatternDriver, 1304 OperationPass<ModuleOp>> { 1305 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 1306 1307 StringRef getArgument() const final { return "test-merge-blocks"; } 1308 StringRef getDescription() const final { 1309 return "Test Merging operation in ConversionPatternRewriter"; 1310 } 1311 void runOnOperation() override { 1312 MLIRContext *context = &getContext(); 1313 mlir::RewritePatternSet patterns(context); 1314 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1315 context); 1316 ConversionTarget target(*context); 1317 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1318 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1319 target.addIllegalOp<ILLegalOpF>(); 1320 1321 /// Expect the op to have a single block after legalization. 1322 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1323 [&](TestMergeBlocksOp op) -> bool { 1324 return llvm::hasSingleElement(op.getBody()); 1325 }); 1326 1327 /// Only allow `test.br` within test.merge_blocks op. 1328 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1329 return op->getParentOfType<TestMergeBlocksOp>(); 1330 }); 1331 1332 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1333 /// inlined. 1334 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1335 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1336 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1337 }); 1338 1339 DenseSet<Operation *> unlegalizedOps; 1340 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1341 &unlegalizedOps); 1342 for (auto *op : unlegalizedOps) 1343 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1344 } 1345 }; 1346 } // namespace 1347 1348 //===----------------------------------------------------------------------===// 1349 // Test Selective Replacement 1350 //===----------------------------------------------------------------------===// 1351 1352 namespace { 1353 /// A rewrite mechanism to inline the body of the op into its parent, when both 1354 /// ops can have a single block. 1355 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1356 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1357 1358 LogicalResult matchAndRewrite(TestCastOp op, 1359 PatternRewriter &rewriter) const final { 1360 if (op.getNumOperands() != 2) 1361 return failure(); 1362 OperandRange operands = op.getOperands(); 1363 1364 // Replace non-terminator uses with the first operand. 1365 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1366 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1367 }); 1368 // Replace everything else with the second operand if the operation isn't 1369 // dead. 1370 rewriter.replaceOp(op, op.getOperand(1)); 1371 return success(); 1372 } 1373 }; 1374 1375 struct TestSelectiveReplacementPatternDriver 1376 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1377 OperationPass<>> { 1378 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 1379 TestSelectiveReplacementPatternDriver) 1380 1381 StringRef getArgument() const final { 1382 return "test-pattern-selective-replacement"; 1383 } 1384 StringRef getDescription() const final { 1385 return "Test selective replacement in the PatternRewriter"; 1386 } 1387 void runOnOperation() override { 1388 MLIRContext *context = &getContext(); 1389 mlir::RewritePatternSet patterns(context); 1390 patterns.add<TestSelectiveOpReplacementPattern>(context); 1391 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1392 std::move(patterns)); 1393 } 1394 }; 1395 } // namespace 1396 1397 //===----------------------------------------------------------------------===// 1398 // PassRegistration 1399 //===----------------------------------------------------------------------===// 1400 1401 namespace mlir { 1402 namespace test { 1403 void registerPatternsTestPass() { 1404 PassRegistration<TestReturnTypeDriver>(); 1405 1406 PassRegistration<TestDerivedAttributeDriver>(); 1407 1408 PassRegistration<TestPatternDriver>(); 1409 1410 PassRegistration<TestLegalizePatternDriver>([] { 1411 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1412 }); 1413 1414 PassRegistration<TestRemappedValue>(); 1415 1416 PassRegistration<TestUnknownRootOpDriver>(); 1417 1418 PassRegistration<TestTypeConversionDriver>(); 1419 PassRegistration<TestTargetMaterializationWithNoUses>(); 1420 1421 PassRegistration<TestMergeBlocksPatternDriver>(); 1422 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1423 } 1424 } // namespace test 1425 } // namespace mlir 1426