1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace test; 23 24 // Native function for testing NativeCodeCall 25 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 26 return choice.getValue() ? input1 : input2; 27 } 28 29 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 30 rewriter.create<OpI>(loc, input); 31 } 32 33 static void handleNoResultOp(PatternRewriter &rewriter, 34 OpSymbolBindingNoResult op) { 35 // Turn the no result op to a one-result op. 36 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 37 op.getOperand()); 38 } 39 40 static bool getFirstI32Result(Operation *op, Value &value) { 41 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 42 return false; 43 value = op->getResult(0); 44 return true; 45 } 46 47 static Value bindNativeCodeCallResult(Value value) { return value; } 48 49 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 50 Value input2) { 51 return SmallVector<Value, 2>({input2, input1}); 52 } 53 54 // Test that natives calls are only called once during rewrites. 55 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 56 // This let us check the number of times OpM_Test was called by inspecting 57 // the returned value in the MLIR output. 58 static int64_t opMIncreasingValue = 314159265; 59 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 60 int64_t i = opMIncreasingValue++; 61 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 62 } 63 64 namespace { 65 #include "TestPatterns.inc" 66 } // namespace 67 68 //===----------------------------------------------------------------------===// 69 // Test Reduce Pattern Interface 70 //===----------------------------------------------------------------------===// 71 72 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 73 populateWithGenerated(patterns); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Canonicalizer Driver. 78 //===----------------------------------------------------------------------===// 79 80 namespace { 81 struct FoldingPattern : public RewritePattern { 82 public: 83 FoldingPattern(MLIRContext *context) 84 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 85 /*benefit=*/1, context) {} 86 87 LogicalResult matchAndRewrite(Operation *op, 88 PatternRewriter &rewriter) const override { 89 // Exercise OperationFolder API for a single-result operation that is folded 90 // upon construction. The operation being created through the folder has an 91 // in-place folder, and it should be still present in the output. 92 // Furthermore, the folder should not crash when attempting to recover the 93 // (unchanged) operation result. 94 OperationFolder folder(op->getContext()); 95 Value result = folder.create<TestOpInPlaceFold>( 96 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 97 rewriter.getI32IntegerAttr(0)); 98 assert(result); 99 rewriter.replaceOp(op, result); 100 return success(); 101 } 102 }; 103 104 /// This pattern creates a foldable operation at the entry point of the block. 105 /// This tests the situation where the operation folder will need to replace an 106 /// operation with a previously created constant that does not initially 107 /// dominate the operation to replace. 108 struct FolderInsertBeforePreviouslyFoldedConstantPattern 109 : public OpRewritePattern<TestCastOp> { 110 public: 111 using OpRewritePattern<TestCastOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(TestCastOp op, 114 PatternRewriter &rewriter) const override { 115 if (!op->hasAttr("test_fold_before_previously_folded_op")) 116 return failure(); 117 rewriter.setInsertionPointToStart(op->getBlock()); 118 119 auto constOp = rewriter.create<arith::ConstantOp>( 120 op.getLoc(), rewriter.getBoolAttr(true)); 121 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 122 Value(constOp)); 123 return success(); 124 } 125 }; 126 127 /// This pattern matches test.op_commutative2 with the first operand being 128 /// another test.op_commutative2 with a constant on the right side and fold it 129 /// away by propagating it as its result. This is intend to check that patterns 130 /// are applied after the commutative property moves constant to the right. 131 struct FolderCommutativeOp2WithConstant 132 : public OpRewritePattern<TestCommutative2Op> { 133 public: 134 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 135 136 LogicalResult matchAndRewrite(TestCommutative2Op op, 137 PatternRewriter &rewriter) const override { 138 auto operand = 139 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 140 if (!operand) 141 return failure(); 142 Attribute constInput; 143 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 144 return failure(); 145 rewriter.replaceOp(op, operand->getOperand(1)); 146 return success(); 147 } 148 }; 149 150 struct TestPatternDriver 151 : public PassWrapper<TestPatternDriver, OperationPass<FuncOp>> { 152 StringRef getArgument() const final { return "test-patterns"; } 153 StringRef getDescription() const final { return "Run test dialect patterns"; } 154 void runOnOperation() override { 155 mlir::RewritePatternSet patterns(&getContext()); 156 populateWithGenerated(patterns); 157 158 // Verify named pattern is generated with expected name. 159 patterns.add<FoldingPattern, TestNamedPatternRule, 160 FolderInsertBeforePreviouslyFoldedConstantPattern, 161 FolderCommutativeOp2WithConstant>(&getContext()); 162 163 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 164 } 165 }; 166 } // namespace 167 168 //===----------------------------------------------------------------------===// 169 // ReturnType Driver. 170 //===----------------------------------------------------------------------===// 171 172 namespace { 173 // Generate ops for each instance where the type can be successfully inferred. 174 template <typename OpTy> 175 static void invokeCreateWithInferredReturnType(Operation *op) { 176 auto *context = op->getContext(); 177 auto fop = op->getParentOfType<FuncOp>(); 178 auto location = UnknownLoc::get(context); 179 OpBuilder b(op); 180 b.setInsertionPointAfter(op); 181 182 // Use permutations of 2 args as operands. 183 assert(fop.getNumArguments() >= 2); 184 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 185 for (int j = 0; j < e; ++j) { 186 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 187 SmallVector<Type, 2> inferredReturnTypes; 188 if (succeeded(OpTy::inferReturnTypes( 189 context, llvm::None, values, op->getAttrDictionary(), 190 op->getRegions(), inferredReturnTypes))) { 191 OperationState state(location, OpTy::getOperationName()); 192 // TODO: Expand to regions. 193 OpTy::build(b, state, values, op->getAttrs()); 194 (void)b.create(state); 195 } 196 } 197 } 198 } 199 200 static void reifyReturnShape(Operation *op) { 201 OpBuilder b(op); 202 203 // Use permutations of 2 args as operands. 204 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 205 SmallVector<Value, 2> shapes; 206 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 207 !llvm::hasSingleElement(shapes)) 208 return; 209 for (const auto &it : llvm::enumerate(shapes)) { 210 op->emitRemark() << "value " << it.index() << ": " 211 << it.value().getDefiningOp(); 212 } 213 } 214 215 struct TestReturnTypeDriver 216 : public PassWrapper<TestReturnTypeDriver, OperationPass<FuncOp>> { 217 void getDependentDialects(DialectRegistry ®istry) const override { 218 registry.insert<tensor::TensorDialect>(); 219 } 220 StringRef getArgument() const final { return "test-return-type"; } 221 StringRef getDescription() const final { return "Run return type functions"; } 222 223 void runOnOperation() override { 224 if (getOperation().getName() == "testCreateFunctions") { 225 std::vector<Operation *> ops; 226 // Collect ops to avoid triggering on inserted ops. 227 for (auto &op : getOperation().getBody().front()) 228 ops.push_back(&op); 229 // Generate test patterns for each, but skip terminator. 230 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 231 // Test create method of each of the Op classes below. The resultant 232 // output would be in reverse order underneath `op` from which 233 // the attributes and regions are used. 234 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 235 invokeCreateWithInferredReturnType< 236 OpWithShapedTypeInferTypeInterfaceOp>(op); 237 }; 238 return; 239 } 240 if (getOperation().getName() == "testReifyFunctions") { 241 std::vector<Operation *> ops; 242 // Collect ops to avoid triggering on inserted ops. 243 for (auto &op : getOperation().getBody().front()) 244 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 245 ops.push_back(&op); 246 // Generate test patterns for each, but skip terminator. 247 for (auto *op : ops) 248 reifyReturnShape(op); 249 } 250 } 251 }; 252 } // namespace 253 254 namespace { 255 struct TestDerivedAttributeDriver 256 : public PassWrapper<TestDerivedAttributeDriver, OperationPass<FuncOp>> { 257 StringRef getArgument() const final { return "test-derived-attr"; } 258 StringRef getDescription() const final { 259 return "Run test derived attributes"; 260 } 261 void runOnOperation() override; 262 }; 263 } // namespace 264 265 void TestDerivedAttributeDriver::runOnOperation() { 266 getOperation().walk([](DerivedAttributeOpInterface dOp) { 267 auto dAttr = dOp.materializeDerivedAttributes(); 268 if (!dAttr) 269 return; 270 for (auto d : dAttr) 271 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 272 }); 273 } 274 275 //===----------------------------------------------------------------------===// 276 // Legalization Driver. 277 //===----------------------------------------------------------------------===// 278 279 namespace { 280 //===----------------------------------------------------------------------===// 281 // Region-Block Rewrite Testing 282 283 /// This pattern is a simple pattern that inlines the first region of a given 284 /// operation into the parent region. 285 struct TestRegionRewriteBlockMovement : public ConversionPattern { 286 TestRegionRewriteBlockMovement(MLIRContext *ctx) 287 : ConversionPattern("test.region", 1, ctx) {} 288 289 LogicalResult 290 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 291 ConversionPatternRewriter &rewriter) const final { 292 // Inline this region into the parent region. 293 auto &parentRegion = *op->getParentRegion(); 294 auto &opRegion = op->getRegion(0); 295 if (op->getAttr("legalizer.should_clone")) 296 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 297 else 298 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 299 300 if (op->getAttr("legalizer.erase_old_blocks")) { 301 while (!opRegion.empty()) 302 rewriter.eraseBlock(&opRegion.front()); 303 } 304 305 // Drop this operation. 306 rewriter.eraseOp(op); 307 return success(); 308 } 309 }; 310 /// This pattern is a simple pattern that generates a region containing an 311 /// illegal operation. 312 struct TestRegionRewriteUndo : public RewritePattern { 313 TestRegionRewriteUndo(MLIRContext *ctx) 314 : RewritePattern("test.region_builder", 1, ctx) {} 315 316 LogicalResult matchAndRewrite(Operation *op, 317 PatternRewriter &rewriter) const final { 318 // Create the region operation with an entry block containing arguments. 319 OperationState newRegion(op->getLoc(), "test.region"); 320 newRegion.addRegion(); 321 auto *regionOp = rewriter.create(newRegion); 322 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 323 entryBlock->addArgument(rewriter.getIntegerType(64), 324 rewriter.getUnknownLoc()); 325 326 // Add an explicitly illegal operation to ensure the conversion fails. 327 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 328 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 329 330 // Drop this operation. 331 rewriter.eraseOp(op); 332 return success(); 333 } 334 }; 335 /// A simple pattern that creates a block at the end of the parent region of the 336 /// matched operation. 337 struct TestCreateBlock : public RewritePattern { 338 TestCreateBlock(MLIRContext *ctx) 339 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 340 341 LogicalResult matchAndRewrite(Operation *op, 342 PatternRewriter &rewriter) const final { 343 Region ®ion = *op->getParentRegion(); 344 Type i32Type = rewriter.getIntegerType(32); 345 Location loc = op->getLoc(); 346 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 347 rewriter.create<TerminatorOp>(loc); 348 rewriter.replaceOp(op, {}); 349 return success(); 350 } 351 }; 352 353 /// A simple pattern that creates a block containing an invalid operation in 354 /// order to trigger the block creation undo mechanism. 355 struct TestCreateIllegalBlock : public RewritePattern { 356 TestCreateIllegalBlock(MLIRContext *ctx) 357 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 358 359 LogicalResult matchAndRewrite(Operation *op, 360 PatternRewriter &rewriter) const final { 361 Region ®ion = *op->getParentRegion(); 362 Type i32Type = rewriter.getIntegerType(32); 363 Location loc = op->getLoc(); 364 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 365 // Create an illegal op to ensure the conversion fails. 366 rewriter.create<ILLegalOpF>(loc, i32Type); 367 rewriter.create<TerminatorOp>(loc); 368 rewriter.replaceOp(op, {}); 369 return success(); 370 } 371 }; 372 373 /// A simple pattern that tests the undo mechanism when replacing the uses of a 374 /// block argument. 375 struct TestUndoBlockArgReplace : public ConversionPattern { 376 TestUndoBlockArgReplace(MLIRContext *ctx) 377 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 378 379 LogicalResult 380 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 381 ConversionPatternRewriter &rewriter) const final { 382 auto illegalOp = 383 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 384 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 385 illegalOp); 386 rewriter.updateRootInPlace(op, [] {}); 387 return success(); 388 } 389 }; 390 391 /// A rewrite pattern that tests the undo mechanism when erasing a block. 392 struct TestUndoBlockErase : public ConversionPattern { 393 TestUndoBlockErase(MLIRContext *ctx) 394 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 395 396 LogicalResult 397 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 398 ConversionPatternRewriter &rewriter) const final { 399 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 400 rewriter.setInsertionPointToStart(secondBlock); 401 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 402 rewriter.eraseBlock(secondBlock); 403 rewriter.updateRootInPlace(op, [] {}); 404 return success(); 405 } 406 }; 407 408 //===----------------------------------------------------------------------===// 409 // Type-Conversion Rewrite Testing 410 411 /// This patterns erases a region operation that has had a type conversion. 412 struct TestDropOpSignatureConversion : public ConversionPattern { 413 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 414 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 415 LogicalResult 416 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 417 ConversionPatternRewriter &rewriter) const override { 418 Region ®ion = op->getRegion(0); 419 Block *entry = ®ion.front(); 420 421 // Convert the original entry arguments. 422 TypeConverter &converter = *getTypeConverter(); 423 TypeConverter::SignatureConversion result(entry->getNumArguments()); 424 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 425 result)) || 426 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 427 return failure(); 428 429 // Convert the region signature and just drop the operation. 430 rewriter.eraseOp(op); 431 return success(); 432 } 433 }; 434 /// This pattern simply updates the operands of the given operation. 435 struct TestPassthroughInvalidOp : public ConversionPattern { 436 TestPassthroughInvalidOp(MLIRContext *ctx) 437 : ConversionPattern("test.invalid", 1, ctx) {} 438 LogicalResult 439 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 440 ConversionPatternRewriter &rewriter) const final { 441 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 442 llvm::None); 443 return success(); 444 } 445 }; 446 /// This pattern handles the case of a split return value. 447 struct TestSplitReturnType : public ConversionPattern { 448 TestSplitReturnType(MLIRContext *ctx) 449 : ConversionPattern("test.return", 1, ctx) {} 450 LogicalResult 451 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 452 ConversionPatternRewriter &rewriter) const final { 453 // Check for a return of F32. 454 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 455 return failure(); 456 457 // Check if the first operation is a cast operation, if it is we use the 458 // results directly. 459 auto *defOp = operands[0].getDefiningOp(); 460 if (auto packerOp = 461 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 462 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 463 return success(); 464 } 465 466 // Otherwise, fail to match. 467 return failure(); 468 } 469 }; 470 471 //===----------------------------------------------------------------------===// 472 // Multi-Level Type-Conversion Rewrite Testing 473 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 474 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 475 : ConversionPattern("test.type_producer", 1, ctx) {} 476 LogicalResult 477 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 478 ConversionPatternRewriter &rewriter) const final { 479 // If the type is I32, change the type to F32. 480 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 481 return failure(); 482 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 483 return success(); 484 } 485 }; 486 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 487 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 488 : ConversionPattern("test.type_producer", 1, ctx) {} 489 LogicalResult 490 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 491 ConversionPatternRewriter &rewriter) const final { 492 // If the type is F32, change the type to F64. 493 if (!Type(*op->result_type_begin()).isF32()) 494 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 495 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 496 return success(); 497 } 498 }; 499 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 500 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 501 : ConversionPattern("test.type_producer", 10, ctx) {} 502 LogicalResult 503 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 504 ConversionPatternRewriter &rewriter) const final { 505 // Always convert to B16, even though it is not a legal type. This tests 506 // that values are unmapped correctly. 507 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 508 return success(); 509 } 510 }; 511 struct TestUpdateConsumerType : public ConversionPattern { 512 TestUpdateConsumerType(MLIRContext *ctx) 513 : ConversionPattern("test.type_consumer", 1, ctx) {} 514 LogicalResult 515 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 516 ConversionPatternRewriter &rewriter) const final { 517 // Verify that the incoming operand has been successfully remapped to F64. 518 if (!operands[0].getType().isF64()) 519 return failure(); 520 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 521 return success(); 522 } 523 }; 524 525 //===----------------------------------------------------------------------===// 526 // Non-Root Replacement Rewrite Testing 527 /// This pattern generates an invalid operation, but replaces it before the 528 /// pattern is finished. This checks that we don't need to legalize the 529 /// temporary op. 530 struct TestNonRootReplacement : public RewritePattern { 531 TestNonRootReplacement(MLIRContext *ctx) 532 : RewritePattern("test.replace_non_root", 1, ctx) {} 533 534 LogicalResult matchAndRewrite(Operation *op, 535 PatternRewriter &rewriter) const final { 536 auto resultType = *op->result_type_begin(); 537 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 538 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 539 540 rewriter.replaceOp(illegalOp, {legalOp}); 541 rewriter.replaceOp(op, {illegalOp}); 542 return success(); 543 } 544 }; 545 546 //===----------------------------------------------------------------------===// 547 // Recursive Rewrite Testing 548 /// This pattern is applied to the same operation multiple times, but has a 549 /// bounded recursion. 550 struct TestBoundedRecursiveRewrite 551 : public OpRewritePattern<TestRecursiveRewriteOp> { 552 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 553 554 void initialize() { 555 // The conversion target handles bounding the recursion of this pattern. 556 setHasBoundedRewriteRecursion(); 557 } 558 559 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 560 PatternRewriter &rewriter) const final { 561 // Decrement the depth of the op in-place. 562 rewriter.updateRootInPlace(op, [&] { 563 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 564 }); 565 return success(); 566 } 567 }; 568 569 struct TestNestedOpCreationUndoRewrite 570 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 571 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 572 573 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 574 PatternRewriter &rewriter) const final { 575 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 576 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 577 return success(); 578 }; 579 }; 580 581 // This pattern matches `test.blackhole` and delete this op and its producer. 582 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 583 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 584 585 LogicalResult matchAndRewrite(BlackHoleOp op, 586 PatternRewriter &rewriter) const final { 587 Operation *producer = op.getOperand().getDefiningOp(); 588 // Always erase the user before the producer, the framework should handle 589 // this correctly. 590 rewriter.eraseOp(op); 591 rewriter.eraseOp(producer); 592 return success(); 593 }; 594 }; 595 596 // This pattern replaces explicitly illegal op with explicitly legal op, 597 // but in addition creates unregistered operation. 598 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 599 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 600 601 LogicalResult matchAndRewrite(ILLegalOpG op, 602 PatternRewriter &rewriter) const final { 603 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 604 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 605 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 606 return success(); 607 }; 608 }; 609 } // namespace 610 611 namespace { 612 struct TestTypeConverter : public TypeConverter { 613 using TypeConverter::TypeConverter; 614 TestTypeConverter() { 615 addConversion(convertType); 616 addArgumentMaterialization(materializeCast); 617 addSourceMaterialization(materializeCast); 618 } 619 620 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 621 // Drop I16 types. 622 if (t.isSignlessInteger(16)) 623 return success(); 624 625 // Convert I64 to F64. 626 if (t.isSignlessInteger(64)) { 627 results.push_back(FloatType::getF64(t.getContext())); 628 return success(); 629 } 630 631 // Convert I42 to I43. 632 if (t.isInteger(42)) { 633 results.push_back(IntegerType::get(t.getContext(), 43)); 634 return success(); 635 } 636 637 // Split F32 into F16,F16. 638 if (t.isF32()) { 639 results.assign(2, FloatType::getF16(t.getContext())); 640 return success(); 641 } 642 643 // Otherwise, convert the type directly. 644 results.push_back(t); 645 return success(); 646 } 647 648 /// Hook for materializing a conversion. This is necessary because we generate 649 /// 1->N type mappings. 650 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 651 ValueRange inputs, Location loc) { 652 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 653 } 654 }; 655 656 struct TestLegalizePatternDriver 657 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 658 StringRef getArgument() const final { return "test-legalize-patterns"; } 659 StringRef getDescription() const final { 660 return "Run test dialect legalization patterns"; 661 } 662 /// The mode of conversion to use with the driver. 663 enum class ConversionMode { Analysis, Full, Partial }; 664 665 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 666 667 void getDependentDialects(DialectRegistry ®istry) const override { 668 registry.insert<func::FuncDialect>(); 669 } 670 671 void runOnOperation() override { 672 TestTypeConverter converter; 673 mlir::RewritePatternSet patterns(&getContext()); 674 populateWithGenerated(patterns); 675 patterns 676 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 677 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 678 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 679 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 680 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 681 TestNonRootReplacement, TestBoundedRecursiveRewrite, 682 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 683 TestCreateUnregisteredOp>(&getContext()); 684 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 685 mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, 686 converter); 687 mlir::populateCallOpTypeConversionPattern(patterns, converter); 688 689 // Define the conversion target used for the test. 690 ConversionTarget target(getContext()); 691 target.addLegalOp<ModuleOp>(); 692 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 693 TerminatorOp>(); 694 target 695 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 696 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 697 // Don't allow F32 operands. 698 return llvm::none_of(op.getOperandTypes(), 699 [](Type type) { return type.isF32(); }); 700 }); 701 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 702 return converter.isSignatureLegal(op.getFunctionType()) && 703 converter.isLegal(&op.getBody()); 704 }); 705 target.addDynamicallyLegalOp<func::CallOp>( 706 [&](func::CallOp op) { return converter.isLegal(op); }); 707 708 // TestCreateUnregisteredOp creates `arith.constant` operation, 709 // which was not added to target intentionally to test 710 // correct error code from conversion driver. 711 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 712 713 // Expect the type_producer/type_consumer operations to only operate on f64. 714 target.addDynamicallyLegalOp<TestTypeProducerOp>( 715 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 716 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 717 return op.getOperand().getType().isF64(); 718 }); 719 720 // Check support for marking certain operations as recursively legal. 721 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 722 return static_cast<bool>( 723 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 724 }); 725 726 // Mark the bound recursion operation as dynamically legal. 727 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 728 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 729 730 // Handle a partial conversion. 731 if (mode == ConversionMode::Partial) { 732 DenseSet<Operation *> unlegalizedOps; 733 if (failed(applyPartialConversion( 734 getOperation(), target, std::move(patterns), &unlegalizedOps))) { 735 getOperation()->emitRemark() << "applyPartialConversion failed"; 736 } 737 // Emit remarks for each legalizable operation. 738 for (auto *op : unlegalizedOps) 739 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 740 return; 741 } 742 743 // Handle a full conversion. 744 if (mode == ConversionMode::Full) { 745 // Check support for marking unknown operations as dynamically legal. 746 target.markUnknownOpDynamicallyLegal([](Operation *op) { 747 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 748 }); 749 750 if (failed(applyFullConversion(getOperation(), target, 751 std::move(patterns)))) { 752 getOperation()->emitRemark() << "applyFullConversion failed"; 753 } 754 return; 755 } 756 757 // Otherwise, handle an analysis conversion. 758 assert(mode == ConversionMode::Analysis); 759 760 // Analyze the convertible operations. 761 DenseSet<Operation *> legalizedOps; 762 if (failed(applyAnalysisConversion(getOperation(), target, 763 std::move(patterns), legalizedOps))) 764 return signalPassFailure(); 765 766 // Emit remarks for each legalizable operation. 767 for (auto *op : legalizedOps) 768 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 769 } 770 771 /// The mode of conversion to use. 772 ConversionMode mode; 773 }; 774 } // namespace 775 776 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 777 legalizerConversionMode( 778 "test-legalize-mode", 779 llvm::cl::desc("The legalization mode to use with the test driver"), 780 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 781 llvm::cl::values( 782 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 783 "analysis", "Perform an analysis conversion"), 784 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 785 "Perform a full conversion"), 786 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 787 "partial", "Perform a partial conversion"))); 788 789 //===----------------------------------------------------------------------===// 790 // ConversionPatternRewriter::getRemappedValue testing. This method is used 791 // to get the remapped value of an original value that was replaced using 792 // ConversionPatternRewriter. 793 namespace { 794 struct TestRemapValueTypeConverter : public TypeConverter { 795 using TypeConverter::TypeConverter; 796 797 TestRemapValueTypeConverter() { 798 addConversion( 799 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 800 addConversion([](Type type) { return type; }); 801 } 802 }; 803 804 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 805 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 806 /// operand twice. 807 /// 808 /// Example: 809 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 810 /// is replaced with: 811 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 812 struct OneVResOneVOperandOp1Converter 813 : public OpConversionPattern<OneVResOneVOperandOp1> { 814 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 815 816 LogicalResult 817 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 818 ConversionPatternRewriter &rewriter) const override { 819 auto origOps = op.getOperands(); 820 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 821 "One operand expected"); 822 Value origOp = *origOps.begin(); 823 SmallVector<Value, 2> remappedOperands; 824 // Replicate the remapped original operand twice. Note that we don't used 825 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 826 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 827 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 828 829 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 830 remappedOperands); 831 return success(); 832 } 833 }; 834 835 /// A rewriter pattern that tests that blocks can be merged. 836 struct TestRemapValueInRegion 837 : public OpConversionPattern<TestRemappedValueRegionOp> { 838 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 839 840 LogicalResult 841 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 842 ConversionPatternRewriter &rewriter) const final { 843 Block &block = op.getBody().front(); 844 Operation *terminator = block.getTerminator(); 845 846 // Merge the block into the parent region. 847 Block *parentBlock = op->getBlock(); 848 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 849 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 850 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 851 852 // Replace the results of this operation with the remapped terminator 853 // values. 854 SmallVector<Value> terminatorOperands; 855 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 856 terminatorOperands))) 857 return failure(); 858 859 rewriter.eraseOp(terminator); 860 rewriter.replaceOp(op, terminatorOperands); 861 return success(); 862 } 863 }; 864 865 struct TestRemappedValue 866 : public mlir::PassWrapper<TestRemappedValue, OperationPass<FuncOp>> { 867 StringRef getArgument() const final { return "test-remapped-value"; } 868 StringRef getDescription() const final { 869 return "Test public remapped value mechanism in ConversionPatternRewriter"; 870 } 871 void runOnOperation() override { 872 TestRemapValueTypeConverter typeConverter; 873 874 mlir::RewritePatternSet patterns(&getContext()); 875 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 876 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 877 &getContext()); 878 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 879 880 mlir::ConversionTarget target(getContext()); 881 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 882 883 // Expect the type_producer/type_consumer operations to only operate on f64. 884 target.addDynamicallyLegalOp<TestTypeProducerOp>( 885 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 886 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 887 return op.getOperand().getType().isF64(); 888 }); 889 890 // We make OneVResOneVOperandOp1 legal only when it has more that one 891 // operand. This will trigger the conversion that will replace one-operand 892 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 893 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 894 [](Operation *op) { return op->getNumOperands() > 1; }); 895 896 if (failed(mlir::applyFullConversion(getOperation(), target, 897 std::move(patterns)))) { 898 signalPassFailure(); 899 } 900 } 901 }; 902 } // namespace 903 904 //===----------------------------------------------------------------------===// 905 // Test patterns without a specific root operation kind 906 //===----------------------------------------------------------------------===// 907 908 namespace { 909 /// This pattern matches and removes any operation in the test dialect. 910 struct RemoveTestDialectOps : public RewritePattern { 911 RemoveTestDialectOps(MLIRContext *context) 912 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 913 914 LogicalResult matchAndRewrite(Operation *op, 915 PatternRewriter &rewriter) const override { 916 if (!isa<TestDialect>(op->getDialect())) 917 return failure(); 918 rewriter.eraseOp(op); 919 return success(); 920 } 921 }; 922 923 struct TestUnknownRootOpDriver 924 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<FuncOp>> { 925 StringRef getArgument() const final { 926 return "test-legalize-unknown-root-patterns"; 927 } 928 StringRef getDescription() const final { 929 return "Test public remapped value mechanism in ConversionPatternRewriter"; 930 } 931 void runOnOperation() override { 932 mlir::RewritePatternSet patterns(&getContext()); 933 patterns.add<RemoveTestDialectOps>(&getContext()); 934 935 mlir::ConversionTarget target(getContext()); 936 target.addIllegalDialect<TestDialect>(); 937 if (failed(applyPartialConversion(getOperation(), target, 938 std::move(patterns)))) 939 signalPassFailure(); 940 } 941 }; 942 } // namespace 943 944 //===----------------------------------------------------------------------===// 945 // Test type conversions 946 //===----------------------------------------------------------------------===// 947 948 namespace { 949 struct TestTypeConversionProducer 950 : public OpConversionPattern<TestTypeProducerOp> { 951 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 952 LogicalResult 953 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 954 ConversionPatternRewriter &rewriter) const final { 955 Type resultType = op.getType(); 956 Type convertedType = getTypeConverter() 957 ? getTypeConverter()->convertType(resultType) 958 : resultType; 959 if (resultType.isa<FloatType>()) 960 resultType = rewriter.getF64Type(); 961 else if (resultType.isInteger(16)) 962 resultType = rewriter.getIntegerType(64); 963 else if (resultType.isa<test::TestRecursiveType>() && 964 convertedType != resultType) 965 resultType = convertedType; 966 else 967 return failure(); 968 969 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 970 return success(); 971 } 972 }; 973 974 /// Call signature conversion and then fail the rewrite to trigger the undo 975 /// mechanism. 976 struct TestSignatureConversionUndo 977 : public OpConversionPattern<TestSignatureConversionUndoOp> { 978 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 979 980 LogicalResult 981 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 982 ConversionPatternRewriter &rewriter) const final { 983 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 984 return failure(); 985 } 986 }; 987 988 /// Call signature conversion without providing a type converter to handle 989 /// materializations. 990 struct TestTestSignatureConversionNoConverter 991 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 992 TestTestSignatureConversionNoConverter(TypeConverter &converter, 993 MLIRContext *context) 994 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 995 converter(converter) {} 996 997 LogicalResult 998 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 999 ConversionPatternRewriter &rewriter) const final { 1000 Region ®ion = op->getRegion(0); 1001 Block *entry = ®ion.front(); 1002 1003 // Convert the original entry arguments. 1004 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1005 if (failed( 1006 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1007 return failure(); 1008 rewriter.updateRootInPlace( 1009 op, [&] { rewriter.applySignatureConversion(®ion, result); }); 1010 return success(); 1011 } 1012 1013 TypeConverter &converter; 1014 }; 1015 1016 /// Just forward the operands to the root op. This is essentially a no-op 1017 /// pattern that is used to trigger target materialization. 1018 struct TestTypeConsumerForward 1019 : public OpConversionPattern<TestTypeConsumerOp> { 1020 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1021 1022 LogicalResult 1023 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1024 ConversionPatternRewriter &rewriter) const final { 1025 rewriter.updateRootInPlace(op, 1026 [&] { op->setOperands(adaptor.getOperands()); }); 1027 return success(); 1028 } 1029 }; 1030 1031 struct TestTypeConversionAnotherProducer 1032 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1033 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1034 1035 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1036 PatternRewriter &rewriter) const final { 1037 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1038 return success(); 1039 } 1040 }; 1041 1042 struct TestTypeConversionDriver 1043 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 1044 void getDependentDialects(DialectRegistry ®istry) const override { 1045 registry.insert<TestDialect>(); 1046 } 1047 StringRef getArgument() const final { 1048 return "test-legalize-type-conversion"; 1049 } 1050 StringRef getDescription() const final { 1051 return "Test various type conversion functionalities in DialectConversion"; 1052 } 1053 1054 void runOnOperation() override { 1055 // Initialize the type converter. 1056 TypeConverter converter; 1057 1058 /// Add the legal set of type conversions. 1059 converter.addConversion([](Type type) -> Type { 1060 // Treat F64 as legal. 1061 if (type.isF64()) 1062 return type; 1063 // Allow converting BF16/F16/F32 to F64. 1064 if (type.isBF16() || type.isF16() || type.isF32()) 1065 return FloatType::getF64(type.getContext()); 1066 // Otherwise, the type is illegal. 1067 return nullptr; 1068 }); 1069 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1070 // Drop all integer types. 1071 return success(); 1072 }); 1073 converter.addConversion( 1074 // Convert a recursive self-referring type into a non-self-referring 1075 // type named "outer_converted_type" that contains a SimpleAType. 1076 [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results, 1077 ArrayRef<Type> callStack) -> Optional<LogicalResult> { 1078 // If the type is already converted, return it to indicate that it is 1079 // legal. 1080 if (type.getName() == "outer_converted_type") { 1081 results.push_back(type); 1082 return success(); 1083 } 1084 1085 // If the type is on the call stack more than once (it is there at 1086 // least once because of the _current_ call, which is always the last 1087 // element on the stack), we've hit the recursive case. Just return 1088 // SimpleAType here to create a non-recursive type as a result. 1089 if (llvm::is_contained(callStack.drop_back(), type)) { 1090 results.push_back(test::SimpleAType::get(type.getContext())); 1091 return success(); 1092 } 1093 1094 // Convert the body recursively. 1095 auto result = test::TestRecursiveType::get(type.getContext(), 1096 "outer_converted_type"); 1097 if (failed(result.setBody(converter.convertType(type.getBody())))) 1098 return failure(); 1099 results.push_back(result); 1100 return success(); 1101 }); 1102 1103 /// Add the legal set of type materializations. 1104 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1105 ValueRange inputs, 1106 Location loc) -> Value { 1107 // Allow casting from F64 back to F32. 1108 if (!resultType.isF16() && inputs.size() == 1 && 1109 inputs[0].getType().isF64()) 1110 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1111 // Allow producing an i32 or i64 from nothing. 1112 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1113 inputs.empty()) 1114 return builder.create<TestTypeProducerOp>(loc, resultType); 1115 // Allow producing an i64 from an integer. 1116 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 1117 inputs[0].getType().isa<IntegerType>()) 1118 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1119 // Otherwise, fail. 1120 return nullptr; 1121 }); 1122 1123 // Initialize the conversion target. 1124 mlir::ConversionTarget target(getContext()); 1125 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1126 auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>(); 1127 return op.getType().isF64() || op.getType().isInteger(64) || 1128 (recursiveType && 1129 recursiveType.getName() == "outer_converted_type"); 1130 }); 1131 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 1132 return converter.isSignatureLegal(op.getFunctionType()) && 1133 converter.isLegal(&op.getBody()); 1134 }); 1135 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1136 // Allow casts from F64 to F32. 1137 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1138 }); 1139 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1140 [&](TestSignatureConversionNoConverterOp op) { 1141 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1142 }); 1143 1144 // Initialize the set of rewrite patterns. 1145 RewritePatternSet patterns(&getContext()); 1146 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1147 TestSignatureConversionUndo, 1148 TestTestSignatureConversionNoConverter>(converter, 1149 &getContext()); 1150 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1151 mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, 1152 converter); 1153 1154 if (failed(applyPartialConversion(getOperation(), target, 1155 std::move(patterns)))) 1156 signalPassFailure(); 1157 } 1158 }; 1159 } // namespace 1160 1161 //===----------------------------------------------------------------------===// 1162 // Test Target Materialization With No Uses 1163 //===----------------------------------------------------------------------===// 1164 1165 namespace { 1166 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1167 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1168 1169 LogicalResult 1170 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1171 ConversionPatternRewriter &rewriter) const final { 1172 rewriter.replaceOp(op, adaptor.getOperands()); 1173 return success(); 1174 } 1175 }; 1176 1177 struct TestTargetMaterializationWithNoUses 1178 : public PassWrapper<TestTargetMaterializationWithNoUses, 1179 OperationPass<ModuleOp>> { 1180 StringRef getArgument() const final { 1181 return "test-target-materialization-with-no-uses"; 1182 } 1183 StringRef getDescription() const final { 1184 return "Test a special case of target materialization in DialectConversion"; 1185 } 1186 1187 void runOnOperation() override { 1188 TypeConverter converter; 1189 converter.addConversion([](Type t) { return t; }); 1190 converter.addConversion([](IntegerType intTy) -> Type { 1191 if (intTy.getWidth() == 16) 1192 return IntegerType::get(intTy.getContext(), 64); 1193 return intTy; 1194 }); 1195 converter.addTargetMaterialization( 1196 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1197 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1198 }); 1199 1200 ConversionTarget target(getContext()); 1201 target.addIllegalOp<TestTypeChangerOp>(); 1202 1203 RewritePatternSet patterns(&getContext()); 1204 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1205 1206 if (failed(applyPartialConversion(getOperation(), target, 1207 std::move(patterns)))) 1208 signalPassFailure(); 1209 } 1210 }; 1211 } // namespace 1212 1213 //===----------------------------------------------------------------------===// 1214 // Test Block Merging 1215 //===----------------------------------------------------------------------===// 1216 1217 namespace { 1218 /// A rewriter pattern that tests that blocks can be merged. 1219 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1220 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1221 1222 LogicalResult 1223 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1224 ConversionPatternRewriter &rewriter) const final { 1225 Block &firstBlock = op.getBody().front(); 1226 Operation *branchOp = firstBlock.getTerminator(); 1227 Block *secondBlock = &*(std::next(op.getBody().begin())); 1228 auto succOperands = branchOp->getOperands(); 1229 SmallVector<Value, 2> replacements(succOperands); 1230 rewriter.eraseOp(branchOp); 1231 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1232 rewriter.updateRootInPlace(op, [] {}); 1233 return success(); 1234 } 1235 }; 1236 1237 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1238 struct TestUndoBlocksMerge : public ConversionPattern { 1239 TestUndoBlocksMerge(MLIRContext *ctx) 1240 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1241 LogicalResult 1242 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1243 ConversionPatternRewriter &rewriter) const final { 1244 Block &firstBlock = op->getRegion(0).front(); 1245 Operation *branchOp = firstBlock.getTerminator(); 1246 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1247 rewriter.setInsertionPointToStart(secondBlock); 1248 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1249 auto succOperands = branchOp->getOperands(); 1250 SmallVector<Value, 2> replacements(succOperands); 1251 rewriter.eraseOp(branchOp); 1252 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1253 rewriter.updateRootInPlace(op, [] {}); 1254 return success(); 1255 } 1256 }; 1257 1258 /// A rewrite mechanism to inline the body of the op into its parent, when both 1259 /// ops can have a single block. 1260 struct TestMergeSingleBlockOps 1261 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1262 using OpConversionPattern< 1263 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1264 1265 LogicalResult 1266 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1267 ConversionPatternRewriter &rewriter) const final { 1268 SingleBlockImplicitTerminatorOp parentOp = 1269 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1270 if (!parentOp) 1271 return failure(); 1272 Block &innerBlock = op.getRegion().front(); 1273 TerminatorOp innerTerminator = 1274 cast<TerminatorOp>(innerBlock.getTerminator()); 1275 rewriter.mergeBlockBefore(&innerBlock, op); 1276 rewriter.eraseOp(innerTerminator); 1277 rewriter.eraseOp(op); 1278 rewriter.updateRootInPlace(op, [] {}); 1279 return success(); 1280 } 1281 }; 1282 1283 struct TestMergeBlocksPatternDriver 1284 : public PassWrapper<TestMergeBlocksPatternDriver, 1285 OperationPass<ModuleOp>> { 1286 StringRef getArgument() const final { return "test-merge-blocks"; } 1287 StringRef getDescription() const final { 1288 return "Test Merging operation in ConversionPatternRewriter"; 1289 } 1290 void runOnOperation() override { 1291 MLIRContext *context = &getContext(); 1292 mlir::RewritePatternSet patterns(context); 1293 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1294 context); 1295 ConversionTarget target(*context); 1296 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1297 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1298 target.addIllegalOp<ILLegalOpF>(); 1299 1300 /// Expect the op to have a single block after legalization. 1301 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1302 [&](TestMergeBlocksOp op) -> bool { 1303 return llvm::hasSingleElement(op.getBody()); 1304 }); 1305 1306 /// Only allow `test.br` within test.merge_blocks op. 1307 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1308 return op->getParentOfType<TestMergeBlocksOp>(); 1309 }); 1310 1311 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1312 /// inlined. 1313 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1314 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1315 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1316 }); 1317 1318 DenseSet<Operation *> unlegalizedOps; 1319 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1320 &unlegalizedOps); 1321 for (auto *op : unlegalizedOps) 1322 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1323 } 1324 }; 1325 } // namespace 1326 1327 //===----------------------------------------------------------------------===// 1328 // Test Selective Replacement 1329 //===----------------------------------------------------------------------===// 1330 1331 namespace { 1332 /// A rewrite mechanism to inline the body of the op into its parent, when both 1333 /// ops can have a single block. 1334 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1335 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1336 1337 LogicalResult matchAndRewrite(TestCastOp op, 1338 PatternRewriter &rewriter) const final { 1339 if (op.getNumOperands() != 2) 1340 return failure(); 1341 OperandRange operands = op.getOperands(); 1342 1343 // Replace non-terminator uses with the first operand. 1344 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1345 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1346 }); 1347 // Replace everything else with the second operand if the operation isn't 1348 // dead. 1349 rewriter.replaceOp(op, op.getOperand(1)); 1350 return success(); 1351 } 1352 }; 1353 1354 struct TestSelectiveReplacementPatternDriver 1355 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1356 OperationPass<>> { 1357 StringRef getArgument() const final { 1358 return "test-pattern-selective-replacement"; 1359 } 1360 StringRef getDescription() const final { 1361 return "Test selective replacement in the PatternRewriter"; 1362 } 1363 void runOnOperation() override { 1364 MLIRContext *context = &getContext(); 1365 mlir::RewritePatternSet patterns(context); 1366 patterns.add<TestSelectiveOpReplacementPattern>(context); 1367 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1368 std::move(patterns)); 1369 } 1370 }; 1371 } // namespace 1372 1373 //===----------------------------------------------------------------------===// 1374 // PassRegistration 1375 //===----------------------------------------------------------------------===// 1376 1377 namespace mlir { 1378 namespace test { 1379 void registerPatternsTestPass() { 1380 PassRegistration<TestReturnTypeDriver>(); 1381 1382 PassRegistration<TestDerivedAttributeDriver>(); 1383 1384 PassRegistration<TestPatternDriver>(); 1385 1386 PassRegistration<TestLegalizePatternDriver>([] { 1387 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1388 }); 1389 1390 PassRegistration<TestRemappedValue>(); 1391 1392 PassRegistration<TestUnknownRootOpDriver>(); 1393 1394 PassRegistration<TestTypeConversionDriver>(); 1395 PassRegistration<TestTargetMaterializationWithNoUses>(); 1396 1397 PassRegistration<TestMergeBlocksPatternDriver>(); 1398 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1399 } 1400 } // namespace test 1401 } // namespace mlir 1402