1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestTypes.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/FoldUtils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace test; 23 24 // Native function for testing NativeCodeCall 25 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 26 return choice.getValue() ? input1 : input2; 27 } 28 29 static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 30 rewriter.create<OpI>(loc, input); 31 } 32 33 static void handleNoResultOp(PatternRewriter &rewriter, 34 OpSymbolBindingNoResult op) { 35 // Turn the no result op to a one-result op. 36 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 37 op.getOperand()); 38 } 39 40 static bool getFirstI32Result(Operation *op, Value &value) { 41 if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 42 return false; 43 value = op->getResult(0); 44 return true; 45 } 46 47 static Value bindNativeCodeCallResult(Value value) { return value; } 48 49 static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 50 Value input2) { 51 return SmallVector<Value, 2>({input2, input1}); 52 } 53 54 // Test that natives calls are only called once during rewrites. 55 // OpM_Test will return Pi, increased by 1 for each subsequent calls. 56 // This let us check the number of times OpM_Test was called by inspecting 57 // the returned value in the MLIR output. 58 static int64_t opMIncreasingValue = 314159265; 59 static Attribute opMTest(PatternRewriter &rewriter, Value val) { 60 int64_t i = opMIncreasingValue++; 61 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 62 } 63 64 namespace { 65 #include "TestPatterns.inc" 66 } // namespace 67 68 //===----------------------------------------------------------------------===// 69 // Test Reduce Pattern Interface 70 //===----------------------------------------------------------------------===// 71 72 void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 73 populateWithGenerated(patterns); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Canonicalizer Driver. 78 //===----------------------------------------------------------------------===// 79 80 namespace { 81 struct FoldingPattern : public RewritePattern { 82 public: 83 FoldingPattern(MLIRContext *context) 84 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 85 /*benefit=*/1, context) {} 86 87 LogicalResult matchAndRewrite(Operation *op, 88 PatternRewriter &rewriter) const override { 89 // Exercise OperationFolder API for a single-result operation that is folded 90 // upon construction. The operation being created through the folder has an 91 // in-place folder, and it should be still present in the output. 92 // Furthermore, the folder should not crash when attempting to recover the 93 // (unchanged) operation result. 94 OperationFolder folder(op->getContext()); 95 Value result = folder.create<TestOpInPlaceFold>( 96 rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 97 rewriter.getI32IntegerAttr(0)); 98 assert(result); 99 rewriter.replaceOp(op, result); 100 return success(); 101 } 102 }; 103 104 /// This pattern creates a foldable operation at the entry point of the block. 105 /// This tests the situation where the operation folder will need to replace an 106 /// operation with a previously created constant that does not initially 107 /// dominate the operation to replace. 108 struct FolderInsertBeforePreviouslyFoldedConstantPattern 109 : public OpRewritePattern<TestCastOp> { 110 public: 111 using OpRewritePattern<TestCastOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(TestCastOp op, 114 PatternRewriter &rewriter) const override { 115 if (!op->hasAttr("test_fold_before_previously_folded_op")) 116 return failure(); 117 rewriter.setInsertionPointToStart(op->getBlock()); 118 119 auto constOp = rewriter.create<arith::ConstantOp>( 120 op.getLoc(), rewriter.getBoolAttr(true)); 121 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 122 Value(constOp)); 123 return success(); 124 } 125 }; 126 127 /// This pattern matches test.op_commutative2 with the first operand being 128 /// another test.op_commutative2 with a constant on the right side and fold it 129 /// away by propagating it as its result. This is intend to check that patterns 130 /// are applied after the commutative property moves constant to the right. 131 struct FolderCommutativeOp2WithConstant 132 : public OpRewritePattern<TestCommutative2Op> { 133 public: 134 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 135 136 LogicalResult matchAndRewrite(TestCommutative2Op op, 137 PatternRewriter &rewriter) const override { 138 auto operand = 139 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 140 if (!operand) 141 return failure(); 142 Attribute constInput; 143 if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 144 return failure(); 145 rewriter.replaceOp(op, operand->getOperand(1)); 146 return success(); 147 } 148 }; 149 150 struct TestPatternDriver 151 : public PassWrapper<TestPatternDriver, OperationPass<FuncOp>> { 152 StringRef getArgument() const final { return "test-patterns"; } 153 StringRef getDescription() const final { return "Run test dialect patterns"; } 154 void runOnOperation() override { 155 mlir::RewritePatternSet patterns(&getContext()); 156 157 // Verify named pattern is generated with expected name. 158 patterns.add<FoldingPattern, TestNamedPatternRule, 159 FolderInsertBeforePreviouslyFoldedConstantPattern, 160 FolderCommutativeOp2WithConstant>(&getContext()); 161 162 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 163 } 164 }; 165 } // namespace 166 167 //===----------------------------------------------------------------------===// 168 // ReturnType Driver. 169 //===----------------------------------------------------------------------===// 170 171 namespace { 172 // Generate ops for each instance where the type can be successfully inferred. 173 template <typename OpTy> 174 static void invokeCreateWithInferredReturnType(Operation *op) { 175 auto *context = op->getContext(); 176 auto fop = op->getParentOfType<FuncOp>(); 177 auto location = UnknownLoc::get(context); 178 OpBuilder b(op); 179 b.setInsertionPointAfter(op); 180 181 // Use permutations of 2 args as operands. 182 assert(fop.getNumArguments() >= 2); 183 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 184 for (int j = 0; j < e; ++j) { 185 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 186 SmallVector<Type, 2> inferredReturnTypes; 187 if (succeeded(OpTy::inferReturnTypes( 188 context, llvm::None, values, op->getAttrDictionary(), 189 op->getRegions(), inferredReturnTypes))) { 190 OperationState state(location, OpTy::getOperationName()); 191 // TODO: Expand to regions. 192 OpTy::build(b, state, values, op->getAttrs()); 193 (void)b.create(state); 194 } 195 } 196 } 197 } 198 199 static void reifyReturnShape(Operation *op) { 200 OpBuilder b(op); 201 202 // Use permutations of 2 args as operands. 203 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 204 SmallVector<Value, 2> shapes; 205 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 206 !llvm::hasSingleElement(shapes)) 207 return; 208 for (const auto &it : llvm::enumerate(shapes)) { 209 op->emitRemark() << "value " << it.index() << ": " 210 << it.value().getDefiningOp(); 211 } 212 } 213 214 struct TestReturnTypeDriver 215 : public PassWrapper<TestReturnTypeDriver, OperationPass<FuncOp>> { 216 void getDependentDialects(DialectRegistry ®istry) const override { 217 registry.insert<tensor::TensorDialect>(); 218 } 219 StringRef getArgument() const final { return "test-return-type"; } 220 StringRef getDescription() const final { return "Run return type functions"; } 221 222 void runOnOperation() override { 223 if (getOperation().getName() == "testCreateFunctions") { 224 std::vector<Operation *> ops; 225 // Collect ops to avoid triggering on inserted ops. 226 for (auto &op : getOperation().getBody().front()) 227 ops.push_back(&op); 228 // Generate test patterns for each, but skip terminator. 229 for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 230 // Test create method of each of the Op classes below. The resultant 231 // output would be in reverse order underneath `op` from which 232 // the attributes and regions are used. 233 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 234 invokeCreateWithInferredReturnType< 235 OpWithShapedTypeInferTypeInterfaceOp>(op); 236 }; 237 return; 238 } 239 if (getOperation().getName() == "testReifyFunctions") { 240 std::vector<Operation *> ops; 241 // Collect ops to avoid triggering on inserted ops. 242 for (auto &op : getOperation().getBody().front()) 243 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 244 ops.push_back(&op); 245 // Generate test patterns for each, but skip terminator. 246 for (auto *op : ops) 247 reifyReturnShape(op); 248 } 249 } 250 }; 251 } // namespace 252 253 namespace { 254 struct TestDerivedAttributeDriver 255 : public PassWrapper<TestDerivedAttributeDriver, OperationPass<FuncOp>> { 256 StringRef getArgument() const final { return "test-derived-attr"; } 257 StringRef getDescription() const final { 258 return "Run test derived attributes"; 259 } 260 void runOnOperation() override; 261 }; 262 } // namespace 263 264 void TestDerivedAttributeDriver::runOnOperation() { 265 getOperation().walk([](DerivedAttributeOpInterface dOp) { 266 auto dAttr = dOp.materializeDerivedAttributes(); 267 if (!dAttr) 268 return; 269 for (auto d : dAttr) 270 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 271 }); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // Legalization Driver. 276 //===----------------------------------------------------------------------===// 277 278 namespace { 279 //===----------------------------------------------------------------------===// 280 // Region-Block Rewrite Testing 281 282 /// This pattern is a simple pattern that inlines the first region of a given 283 /// operation into the parent region. 284 struct TestRegionRewriteBlockMovement : public ConversionPattern { 285 TestRegionRewriteBlockMovement(MLIRContext *ctx) 286 : ConversionPattern("test.region", 1, ctx) {} 287 288 LogicalResult 289 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 290 ConversionPatternRewriter &rewriter) const final { 291 // Inline this region into the parent region. 292 auto &parentRegion = *op->getParentRegion(); 293 auto &opRegion = op->getRegion(0); 294 if (op->getAttr("legalizer.should_clone")) 295 rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 296 else 297 rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 298 299 if (op->getAttr("legalizer.erase_old_blocks")) { 300 while (!opRegion.empty()) 301 rewriter.eraseBlock(&opRegion.front()); 302 } 303 304 // Drop this operation. 305 rewriter.eraseOp(op); 306 return success(); 307 } 308 }; 309 /// This pattern is a simple pattern that generates a region containing an 310 /// illegal operation. 311 struct TestRegionRewriteUndo : public RewritePattern { 312 TestRegionRewriteUndo(MLIRContext *ctx) 313 : RewritePattern("test.region_builder", 1, ctx) {} 314 315 LogicalResult matchAndRewrite(Operation *op, 316 PatternRewriter &rewriter) const final { 317 // Create the region operation with an entry block containing arguments. 318 OperationState newRegion(op->getLoc(), "test.region"); 319 newRegion.addRegion(); 320 auto *regionOp = rewriter.create(newRegion); 321 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 322 entryBlock->addArgument(rewriter.getIntegerType(64), 323 rewriter.getUnknownLoc()); 324 325 // Add an explicitly illegal operation to ensure the conversion fails. 326 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 327 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 328 329 // Drop this operation. 330 rewriter.eraseOp(op); 331 return success(); 332 } 333 }; 334 /// A simple pattern that creates a block at the end of the parent region of the 335 /// matched operation. 336 struct TestCreateBlock : public RewritePattern { 337 TestCreateBlock(MLIRContext *ctx) 338 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 339 340 LogicalResult matchAndRewrite(Operation *op, 341 PatternRewriter &rewriter) const final { 342 Region ®ion = *op->getParentRegion(); 343 Type i32Type = rewriter.getIntegerType(32); 344 Location loc = op->getLoc(); 345 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 346 rewriter.create<TerminatorOp>(loc); 347 rewriter.replaceOp(op, {}); 348 return success(); 349 } 350 }; 351 352 /// A simple pattern that creates a block containing an invalid operation in 353 /// order to trigger the block creation undo mechanism. 354 struct TestCreateIllegalBlock : public RewritePattern { 355 TestCreateIllegalBlock(MLIRContext *ctx) 356 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 357 358 LogicalResult matchAndRewrite(Operation *op, 359 PatternRewriter &rewriter) const final { 360 Region ®ion = *op->getParentRegion(); 361 Type i32Type = rewriter.getIntegerType(32); 362 Location loc = op->getLoc(); 363 rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 364 // Create an illegal op to ensure the conversion fails. 365 rewriter.create<ILLegalOpF>(loc, i32Type); 366 rewriter.create<TerminatorOp>(loc); 367 rewriter.replaceOp(op, {}); 368 return success(); 369 } 370 }; 371 372 /// A simple pattern that tests the undo mechanism when replacing the uses of a 373 /// block argument. 374 struct TestUndoBlockArgReplace : public ConversionPattern { 375 TestUndoBlockArgReplace(MLIRContext *ctx) 376 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 377 378 LogicalResult 379 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 380 ConversionPatternRewriter &rewriter) const final { 381 auto illegalOp = 382 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 383 rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 384 illegalOp); 385 rewriter.updateRootInPlace(op, [] {}); 386 return success(); 387 } 388 }; 389 390 /// A rewrite pattern that tests the undo mechanism when erasing a block. 391 struct TestUndoBlockErase : public ConversionPattern { 392 TestUndoBlockErase(MLIRContext *ctx) 393 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 394 395 LogicalResult 396 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 397 ConversionPatternRewriter &rewriter) const final { 398 Block *secondBlock = &*std::next(op->getRegion(0).begin()); 399 rewriter.setInsertionPointToStart(secondBlock); 400 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 401 rewriter.eraseBlock(secondBlock); 402 rewriter.updateRootInPlace(op, [] {}); 403 return success(); 404 } 405 }; 406 407 //===----------------------------------------------------------------------===// 408 // Type-Conversion Rewrite Testing 409 410 /// This patterns erases a region operation that has had a type conversion. 411 struct TestDropOpSignatureConversion : public ConversionPattern { 412 TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 413 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 414 LogicalResult 415 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 416 ConversionPatternRewriter &rewriter) const override { 417 Region ®ion = op->getRegion(0); 418 Block *entry = ®ion.front(); 419 420 // Convert the original entry arguments. 421 TypeConverter &converter = *getTypeConverter(); 422 TypeConverter::SignatureConversion result(entry->getNumArguments()); 423 if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 424 result)) || 425 failed(rewriter.convertRegionTypes(®ion, converter, &result))) 426 return failure(); 427 428 // Convert the region signature and just drop the operation. 429 rewriter.eraseOp(op); 430 return success(); 431 } 432 }; 433 /// This pattern simply updates the operands of the given operation. 434 struct TestPassthroughInvalidOp : public ConversionPattern { 435 TestPassthroughInvalidOp(MLIRContext *ctx) 436 : ConversionPattern("test.invalid", 1, ctx) {} 437 LogicalResult 438 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 439 ConversionPatternRewriter &rewriter) const final { 440 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 441 llvm::None); 442 return success(); 443 } 444 }; 445 /// This pattern handles the case of a split return value. 446 struct TestSplitReturnType : public ConversionPattern { 447 TestSplitReturnType(MLIRContext *ctx) 448 : ConversionPattern("test.return", 1, ctx) {} 449 LogicalResult 450 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 451 ConversionPatternRewriter &rewriter) const final { 452 // Check for a return of F32. 453 if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 454 return failure(); 455 456 // Check if the first operation is a cast operation, if it is we use the 457 // results directly. 458 auto *defOp = operands[0].getDefiningOp(); 459 if (auto packerOp = 460 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { 461 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 462 return success(); 463 } 464 465 // Otherwise, fail to match. 466 return failure(); 467 } 468 }; 469 470 //===----------------------------------------------------------------------===// 471 // Multi-Level Type-Conversion Rewrite Testing 472 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 473 TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 474 : ConversionPattern("test.type_producer", 1, ctx) {} 475 LogicalResult 476 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 477 ConversionPatternRewriter &rewriter) const final { 478 // If the type is I32, change the type to F32. 479 if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 480 return failure(); 481 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 482 return success(); 483 } 484 }; 485 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 486 TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 487 : ConversionPattern("test.type_producer", 1, ctx) {} 488 LogicalResult 489 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 490 ConversionPatternRewriter &rewriter) const final { 491 // If the type is F32, change the type to F64. 492 if (!Type(*op->result_type_begin()).isF32()) 493 return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 494 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 495 return success(); 496 } 497 }; 498 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 499 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 500 : ConversionPattern("test.type_producer", 10, ctx) {} 501 LogicalResult 502 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 503 ConversionPatternRewriter &rewriter) const final { 504 // Always convert to B16, even though it is not a legal type. This tests 505 // that values are unmapped correctly. 506 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 507 return success(); 508 } 509 }; 510 struct TestUpdateConsumerType : public ConversionPattern { 511 TestUpdateConsumerType(MLIRContext *ctx) 512 : ConversionPattern("test.type_consumer", 1, ctx) {} 513 LogicalResult 514 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 515 ConversionPatternRewriter &rewriter) const final { 516 // Verify that the incoming operand has been successfully remapped to F64. 517 if (!operands[0].getType().isF64()) 518 return failure(); 519 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 520 return success(); 521 } 522 }; 523 524 //===----------------------------------------------------------------------===// 525 // Non-Root Replacement Rewrite Testing 526 /// This pattern generates an invalid operation, but replaces it before the 527 /// pattern is finished. This checks that we don't need to legalize the 528 /// temporary op. 529 struct TestNonRootReplacement : public RewritePattern { 530 TestNonRootReplacement(MLIRContext *ctx) 531 : RewritePattern("test.replace_non_root", 1, ctx) {} 532 533 LogicalResult matchAndRewrite(Operation *op, 534 PatternRewriter &rewriter) const final { 535 auto resultType = *op->result_type_begin(); 536 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 537 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 538 539 rewriter.replaceOp(illegalOp, {legalOp}); 540 rewriter.replaceOp(op, {illegalOp}); 541 return success(); 542 } 543 }; 544 545 //===----------------------------------------------------------------------===// 546 // Recursive Rewrite Testing 547 /// This pattern is applied to the same operation multiple times, but has a 548 /// bounded recursion. 549 struct TestBoundedRecursiveRewrite 550 : public OpRewritePattern<TestRecursiveRewriteOp> { 551 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 552 553 void initialize() { 554 // The conversion target handles bounding the recursion of this pattern. 555 setHasBoundedRewriteRecursion(); 556 } 557 558 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 559 PatternRewriter &rewriter) const final { 560 // Decrement the depth of the op in-place. 561 rewriter.updateRootInPlace(op, [&] { 562 op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 563 }); 564 return success(); 565 } 566 }; 567 568 struct TestNestedOpCreationUndoRewrite 569 : public OpRewritePattern<IllegalOpWithRegionAnchor> { 570 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 571 572 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 573 PatternRewriter &rewriter) const final { 574 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 575 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 576 return success(); 577 }; 578 }; 579 580 // This pattern matches `test.blackhole` and delete this op and its producer. 581 struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 582 using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 583 584 LogicalResult matchAndRewrite(BlackHoleOp op, 585 PatternRewriter &rewriter) const final { 586 Operation *producer = op.getOperand().getDefiningOp(); 587 // Always erase the user before the producer, the framework should handle 588 // this correctly. 589 rewriter.eraseOp(op); 590 rewriter.eraseOp(producer); 591 return success(); 592 }; 593 }; 594 595 // This pattern replaces explicitly illegal op with explicitly legal op, 596 // but in addition creates unregistered operation. 597 struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 598 using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 599 600 LogicalResult matchAndRewrite(ILLegalOpG op, 601 PatternRewriter &rewriter) const final { 602 IntegerAttr attr = rewriter.getI32IntegerAttr(0); 603 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 604 rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 605 return success(); 606 }; 607 }; 608 } // namespace 609 610 namespace { 611 struct TestTypeConverter : public TypeConverter { 612 using TypeConverter::TypeConverter; 613 TestTypeConverter() { 614 addConversion(convertType); 615 addArgumentMaterialization(materializeCast); 616 addSourceMaterialization(materializeCast); 617 } 618 619 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 620 // Drop I16 types. 621 if (t.isSignlessInteger(16)) 622 return success(); 623 624 // Convert I64 to F64. 625 if (t.isSignlessInteger(64)) { 626 results.push_back(FloatType::getF64(t.getContext())); 627 return success(); 628 } 629 630 // Convert I42 to I43. 631 if (t.isInteger(42)) { 632 results.push_back(IntegerType::get(t.getContext(), 43)); 633 return success(); 634 } 635 636 // Split F32 into F16,F16. 637 if (t.isF32()) { 638 results.assign(2, FloatType::getF16(t.getContext())); 639 return success(); 640 } 641 642 // Otherwise, convert the type directly. 643 results.push_back(t); 644 return success(); 645 } 646 647 /// Hook for materializing a conversion. This is necessary because we generate 648 /// 1->N type mappings. 649 static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 650 ValueRange inputs, Location loc) { 651 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 652 } 653 }; 654 655 struct TestLegalizePatternDriver 656 : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 657 StringRef getArgument() const final { return "test-legalize-patterns"; } 658 StringRef getDescription() const final { 659 return "Run test dialect legalization patterns"; 660 } 661 /// The mode of conversion to use with the driver. 662 enum class ConversionMode { Analysis, Full, Partial }; 663 664 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 665 666 void getDependentDialects(DialectRegistry ®istry) const override { 667 registry.insert<func::FuncDialect>(); 668 } 669 670 void runOnOperation() override { 671 TestTypeConverter converter; 672 mlir::RewritePatternSet patterns(&getContext()); 673 populateWithGenerated(patterns); 674 patterns 675 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 676 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 677 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 678 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 679 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 680 TestNonRootReplacement, TestBoundedRecursiveRewrite, 681 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 682 TestCreateUnregisteredOp>(&getContext()); 683 patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 684 mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, 685 converter); 686 mlir::populateCallOpTypeConversionPattern(patterns, converter); 687 688 // Define the conversion target used for the test. 689 ConversionTarget target(getContext()); 690 target.addLegalOp<ModuleOp>(); 691 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 692 TerminatorOp>(); 693 target 694 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 695 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 696 // Don't allow F32 operands. 697 return llvm::none_of(op.getOperandTypes(), 698 [](Type type) { return type.isF32(); }); 699 }); 700 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 701 return converter.isSignatureLegal(op.getFunctionType()) && 702 converter.isLegal(&op.getBody()); 703 }); 704 target.addDynamicallyLegalOp<func::CallOp>( 705 [&](func::CallOp op) { return converter.isLegal(op); }); 706 707 // TestCreateUnregisteredOp creates `arith.constant` operation, 708 // which was not added to target intentionally to test 709 // correct error code from conversion driver. 710 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 711 712 // Expect the type_producer/type_consumer operations to only operate on f64. 713 target.addDynamicallyLegalOp<TestTypeProducerOp>( 714 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 715 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 716 return op.getOperand().getType().isF64(); 717 }); 718 719 // Check support for marking certain operations as recursively legal. 720 target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 721 return static_cast<bool>( 722 op->getAttrOfType<UnitAttr>("test.recursively_legal")); 723 }); 724 725 // Mark the bound recursion operation as dynamically legal. 726 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 727 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 728 729 // Handle a partial conversion. 730 if (mode == ConversionMode::Partial) { 731 DenseSet<Operation *> unlegalizedOps; 732 if (failed(applyPartialConversion( 733 getOperation(), target, std::move(patterns), &unlegalizedOps))) { 734 getOperation()->emitRemark() << "applyPartialConversion failed"; 735 } 736 // Emit remarks for each legalizable operation. 737 for (auto *op : unlegalizedOps) 738 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 739 return; 740 } 741 742 // Handle a full conversion. 743 if (mode == ConversionMode::Full) { 744 // Check support for marking unknown operations as dynamically legal. 745 target.markUnknownOpDynamicallyLegal([](Operation *op) { 746 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 747 }); 748 749 if (failed(applyFullConversion(getOperation(), target, 750 std::move(patterns)))) { 751 getOperation()->emitRemark() << "applyFullConversion failed"; 752 } 753 return; 754 } 755 756 // Otherwise, handle an analysis conversion. 757 assert(mode == ConversionMode::Analysis); 758 759 // Analyze the convertible operations. 760 DenseSet<Operation *> legalizedOps; 761 if (failed(applyAnalysisConversion(getOperation(), target, 762 std::move(patterns), legalizedOps))) 763 return signalPassFailure(); 764 765 // Emit remarks for each legalizable operation. 766 for (auto *op : legalizedOps) 767 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 768 } 769 770 /// The mode of conversion to use. 771 ConversionMode mode; 772 }; 773 } // namespace 774 775 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 776 legalizerConversionMode( 777 "test-legalize-mode", 778 llvm::cl::desc("The legalization mode to use with the test driver"), 779 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 780 llvm::cl::values( 781 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 782 "analysis", "Perform an analysis conversion"), 783 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 784 "Perform a full conversion"), 785 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 786 "partial", "Perform a partial conversion"))); 787 788 //===----------------------------------------------------------------------===// 789 // ConversionPatternRewriter::getRemappedValue testing. This method is used 790 // to get the remapped value of an original value that was replaced using 791 // ConversionPatternRewriter. 792 namespace { 793 struct TestRemapValueTypeConverter : public TypeConverter { 794 using TypeConverter::TypeConverter; 795 796 TestRemapValueTypeConverter() { 797 addConversion( 798 [](Float32Type type) { return Float64Type::get(type.getContext()); }); 799 addConversion([](Type type) { return type; }); 800 } 801 }; 802 803 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 804 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 805 /// operand twice. 806 /// 807 /// Example: 808 /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 809 /// is replaced with: 810 /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 811 struct OneVResOneVOperandOp1Converter 812 : public OpConversionPattern<OneVResOneVOperandOp1> { 813 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 814 815 LogicalResult 816 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 817 ConversionPatternRewriter &rewriter) const override { 818 auto origOps = op.getOperands(); 819 assert(std::distance(origOps.begin(), origOps.end()) == 1 && 820 "One operand expected"); 821 Value origOp = *origOps.begin(); 822 SmallVector<Value, 2> remappedOperands; 823 // Replicate the remapped original operand twice. Note that we don't used 824 // the remapped 'operand' since the goal is testing 'getRemappedValue'. 825 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 826 remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 827 828 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 829 remappedOperands); 830 return success(); 831 } 832 }; 833 834 /// A rewriter pattern that tests that blocks can be merged. 835 struct TestRemapValueInRegion 836 : public OpConversionPattern<TestRemappedValueRegionOp> { 837 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 838 839 LogicalResult 840 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 841 ConversionPatternRewriter &rewriter) const final { 842 Block &block = op.getBody().front(); 843 Operation *terminator = block.getTerminator(); 844 845 // Merge the block into the parent region. 846 Block *parentBlock = op->getBlock(); 847 Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 848 rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 849 rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 850 851 // Replace the results of this operation with the remapped terminator 852 // values. 853 SmallVector<Value> terminatorOperands; 854 if (failed(rewriter.getRemappedValues(terminator->getOperands(), 855 terminatorOperands))) 856 return failure(); 857 858 rewriter.eraseOp(terminator); 859 rewriter.replaceOp(op, terminatorOperands); 860 return success(); 861 } 862 }; 863 864 struct TestRemappedValue 865 : public mlir::PassWrapper<TestRemappedValue, OperationPass<FuncOp>> { 866 StringRef getArgument() const final { return "test-remapped-value"; } 867 StringRef getDescription() const final { 868 return "Test public remapped value mechanism in ConversionPatternRewriter"; 869 } 870 void runOnOperation() override { 871 TestRemapValueTypeConverter typeConverter; 872 873 mlir::RewritePatternSet patterns(&getContext()); 874 patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 875 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 876 &getContext()); 877 patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 878 879 mlir::ConversionTarget target(getContext()); 880 target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 881 882 // Expect the type_producer/type_consumer operations to only operate on f64. 883 target.addDynamicallyLegalOp<TestTypeProducerOp>( 884 [](TestTypeProducerOp op) { return op.getType().isF64(); }); 885 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 886 return op.getOperand().getType().isF64(); 887 }); 888 889 // We make OneVResOneVOperandOp1 legal only when it has more that one 890 // operand. This will trigger the conversion that will replace one-operand 891 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 892 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 893 [](Operation *op) { return op->getNumOperands() > 1; }); 894 895 if (failed(mlir::applyFullConversion(getOperation(), target, 896 std::move(patterns)))) { 897 signalPassFailure(); 898 } 899 } 900 }; 901 } // namespace 902 903 //===----------------------------------------------------------------------===// 904 // Test patterns without a specific root operation kind 905 //===----------------------------------------------------------------------===// 906 907 namespace { 908 /// This pattern matches and removes any operation in the test dialect. 909 struct RemoveTestDialectOps : public RewritePattern { 910 RemoveTestDialectOps(MLIRContext *context) 911 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 912 913 LogicalResult matchAndRewrite(Operation *op, 914 PatternRewriter &rewriter) const override { 915 if (!isa<TestDialect>(op->getDialect())) 916 return failure(); 917 rewriter.eraseOp(op); 918 return success(); 919 } 920 }; 921 922 struct TestUnknownRootOpDriver 923 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<FuncOp>> { 924 StringRef getArgument() const final { 925 return "test-legalize-unknown-root-patterns"; 926 } 927 StringRef getDescription() const final { 928 return "Test public remapped value mechanism in ConversionPatternRewriter"; 929 } 930 void runOnOperation() override { 931 mlir::RewritePatternSet patterns(&getContext()); 932 patterns.add<RemoveTestDialectOps>(&getContext()); 933 934 mlir::ConversionTarget target(getContext()); 935 target.addIllegalDialect<TestDialect>(); 936 if (failed(applyPartialConversion(getOperation(), target, 937 std::move(patterns)))) 938 signalPassFailure(); 939 } 940 }; 941 } // namespace 942 943 //===----------------------------------------------------------------------===// 944 // Test type conversions 945 //===----------------------------------------------------------------------===// 946 947 namespace { 948 struct TestTypeConversionProducer 949 : public OpConversionPattern<TestTypeProducerOp> { 950 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 951 LogicalResult 952 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 953 ConversionPatternRewriter &rewriter) const final { 954 Type resultType = op.getType(); 955 Type convertedType = getTypeConverter() 956 ? getTypeConverter()->convertType(resultType) 957 : resultType; 958 if (resultType.isa<FloatType>()) 959 resultType = rewriter.getF64Type(); 960 else if (resultType.isInteger(16)) 961 resultType = rewriter.getIntegerType(64); 962 else if (resultType.isa<test::TestRecursiveType>() && 963 convertedType != resultType) 964 resultType = convertedType; 965 else 966 return failure(); 967 968 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 969 return success(); 970 } 971 }; 972 973 /// Call signature conversion and then fail the rewrite to trigger the undo 974 /// mechanism. 975 struct TestSignatureConversionUndo 976 : public OpConversionPattern<TestSignatureConversionUndoOp> { 977 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 978 979 LogicalResult 980 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 981 ConversionPatternRewriter &rewriter) const final { 982 (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 983 return failure(); 984 } 985 }; 986 987 /// Call signature conversion without providing a type converter to handle 988 /// materializations. 989 struct TestTestSignatureConversionNoConverter 990 : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 991 TestTestSignatureConversionNoConverter(TypeConverter &converter, 992 MLIRContext *context) 993 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 994 converter(converter) {} 995 996 LogicalResult 997 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 998 ConversionPatternRewriter &rewriter) const final { 999 Region ®ion = op->getRegion(0); 1000 Block *entry = ®ion.front(); 1001 1002 // Convert the original entry arguments. 1003 TypeConverter::SignatureConversion result(entry->getNumArguments()); 1004 if (failed( 1005 converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 1006 return failure(); 1007 rewriter.updateRootInPlace( 1008 op, [&] { rewriter.applySignatureConversion(®ion, result); }); 1009 return success(); 1010 } 1011 1012 TypeConverter &converter; 1013 }; 1014 1015 /// Just forward the operands to the root op. This is essentially a no-op 1016 /// pattern that is used to trigger target materialization. 1017 struct TestTypeConsumerForward 1018 : public OpConversionPattern<TestTypeConsumerOp> { 1019 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 1020 1021 LogicalResult 1022 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 1023 ConversionPatternRewriter &rewriter) const final { 1024 rewriter.updateRootInPlace(op, 1025 [&] { op->setOperands(adaptor.getOperands()); }); 1026 return success(); 1027 } 1028 }; 1029 1030 struct TestTypeConversionAnotherProducer 1031 : public OpRewritePattern<TestAnotherTypeProducerOp> { 1032 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 1033 1034 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 1035 PatternRewriter &rewriter) const final { 1036 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 1037 return success(); 1038 } 1039 }; 1040 1041 struct TestTypeConversionDriver 1042 : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 1043 void getDependentDialects(DialectRegistry ®istry) const override { 1044 registry.insert<TestDialect>(); 1045 } 1046 StringRef getArgument() const final { 1047 return "test-legalize-type-conversion"; 1048 } 1049 StringRef getDescription() const final { 1050 return "Test various type conversion functionalities in DialectConversion"; 1051 } 1052 1053 void runOnOperation() override { 1054 // Initialize the type converter. 1055 TypeConverter converter; 1056 1057 /// Add the legal set of type conversions. 1058 converter.addConversion([](Type type) -> Type { 1059 // Treat F64 as legal. 1060 if (type.isF64()) 1061 return type; 1062 // Allow converting BF16/F16/F32 to F64. 1063 if (type.isBF16() || type.isF16() || type.isF32()) 1064 return FloatType::getF64(type.getContext()); 1065 // Otherwise, the type is illegal. 1066 return nullptr; 1067 }); 1068 converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 1069 // Drop all integer types. 1070 return success(); 1071 }); 1072 converter.addConversion( 1073 // Convert a recursive self-referring type into a non-self-referring 1074 // type named "outer_converted_type" that contains a SimpleAType. 1075 [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results, 1076 ArrayRef<Type> callStack) -> Optional<LogicalResult> { 1077 // If the type is already converted, return it to indicate that it is 1078 // legal. 1079 if (type.getName() == "outer_converted_type") { 1080 results.push_back(type); 1081 return success(); 1082 } 1083 1084 // If the type is on the call stack more than once (it is there at 1085 // least once because of the _current_ call, which is always the last 1086 // element on the stack), we've hit the recursive case. Just return 1087 // SimpleAType here to create a non-recursive type as a result. 1088 if (llvm::is_contained(callStack.drop_back(), type)) { 1089 results.push_back(test::SimpleAType::get(type.getContext())); 1090 return success(); 1091 } 1092 1093 // Convert the body recursively. 1094 auto result = test::TestRecursiveType::get(type.getContext(), 1095 "outer_converted_type"); 1096 if (failed(result.setBody(converter.convertType(type.getBody())))) 1097 return failure(); 1098 results.push_back(result); 1099 return success(); 1100 }); 1101 1102 /// Add the legal set of type materializations. 1103 converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 1104 ValueRange inputs, 1105 Location loc) -> Value { 1106 // Allow casting from F64 back to F32. 1107 if (!resultType.isF16() && inputs.size() == 1 && 1108 inputs[0].getType().isF64()) 1109 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1110 // Allow producing an i32 or i64 from nothing. 1111 if ((resultType.isInteger(32) || resultType.isInteger(64)) && 1112 inputs.empty()) 1113 return builder.create<TestTypeProducerOp>(loc, resultType); 1114 // Allow producing an i64 from an integer. 1115 if (resultType.isa<IntegerType>() && inputs.size() == 1 && 1116 inputs[0].getType().isa<IntegerType>()) 1117 return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 1118 // Otherwise, fail. 1119 return nullptr; 1120 }); 1121 1122 // Initialize the conversion target. 1123 mlir::ConversionTarget target(getContext()); 1124 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 1125 auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>(); 1126 return op.getType().isF64() || op.getType().isInteger(64) || 1127 (recursiveType && 1128 recursiveType.getName() == "outer_converted_type"); 1129 }); 1130 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 1131 return converter.isSignatureLegal(op.getFunctionType()) && 1132 converter.isLegal(&op.getBody()); 1133 }); 1134 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 1135 // Allow casts from F64 to F32. 1136 return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 1137 }); 1138 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 1139 [&](TestSignatureConversionNoConverterOp op) { 1140 return converter.isLegal(op.getRegion().front().getArgumentTypes()); 1141 }); 1142 1143 // Initialize the set of rewrite patterns. 1144 RewritePatternSet patterns(&getContext()); 1145 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 1146 TestSignatureConversionUndo, 1147 TestTestSignatureConversionNoConverter>(converter, 1148 &getContext()); 1149 patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1150 mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, 1151 converter); 1152 1153 if (failed(applyPartialConversion(getOperation(), target, 1154 std::move(patterns)))) 1155 signalPassFailure(); 1156 } 1157 }; 1158 } // namespace 1159 1160 //===----------------------------------------------------------------------===// 1161 // Test Target Materialization With No Uses 1162 //===----------------------------------------------------------------------===// 1163 1164 namespace { 1165 struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1166 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1167 1168 LogicalResult 1169 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1170 ConversionPatternRewriter &rewriter) const final { 1171 rewriter.replaceOp(op, adaptor.getOperands()); 1172 return success(); 1173 } 1174 }; 1175 1176 struct TestTargetMaterializationWithNoUses 1177 : public PassWrapper<TestTargetMaterializationWithNoUses, 1178 OperationPass<ModuleOp>> { 1179 StringRef getArgument() const final { 1180 return "test-target-materialization-with-no-uses"; 1181 } 1182 StringRef getDescription() const final { 1183 return "Test a special case of target materialization in DialectConversion"; 1184 } 1185 1186 void runOnOperation() override { 1187 TypeConverter converter; 1188 converter.addConversion([](Type t) { return t; }); 1189 converter.addConversion([](IntegerType intTy) -> Type { 1190 if (intTy.getWidth() == 16) 1191 return IntegerType::get(intTy.getContext(), 64); 1192 return intTy; 1193 }); 1194 converter.addTargetMaterialization( 1195 [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1196 return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1197 }); 1198 1199 ConversionTarget target(getContext()); 1200 target.addIllegalOp<TestTypeChangerOp>(); 1201 1202 RewritePatternSet patterns(&getContext()); 1203 patterns.add<ForwardOperandPattern>(converter, &getContext()); 1204 1205 if (failed(applyPartialConversion(getOperation(), target, 1206 std::move(patterns)))) 1207 signalPassFailure(); 1208 } 1209 }; 1210 } // namespace 1211 1212 //===----------------------------------------------------------------------===// 1213 // Test Block Merging 1214 //===----------------------------------------------------------------------===// 1215 1216 namespace { 1217 /// A rewriter pattern that tests that blocks can be merged. 1218 struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1219 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1220 1221 LogicalResult 1222 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1223 ConversionPatternRewriter &rewriter) const final { 1224 Block &firstBlock = op.getBody().front(); 1225 Operation *branchOp = firstBlock.getTerminator(); 1226 Block *secondBlock = &*(std::next(op.getBody().begin())); 1227 auto succOperands = branchOp->getOperands(); 1228 SmallVector<Value, 2> replacements(succOperands); 1229 rewriter.eraseOp(branchOp); 1230 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1231 rewriter.updateRootInPlace(op, [] {}); 1232 return success(); 1233 } 1234 }; 1235 1236 /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1237 struct TestUndoBlocksMerge : public ConversionPattern { 1238 TestUndoBlocksMerge(MLIRContext *ctx) 1239 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1240 LogicalResult 1241 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1242 ConversionPatternRewriter &rewriter) const final { 1243 Block &firstBlock = op->getRegion(0).front(); 1244 Operation *branchOp = firstBlock.getTerminator(); 1245 Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1246 rewriter.setInsertionPointToStart(secondBlock); 1247 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1248 auto succOperands = branchOp->getOperands(); 1249 SmallVector<Value, 2> replacements(succOperands); 1250 rewriter.eraseOp(branchOp); 1251 rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1252 rewriter.updateRootInPlace(op, [] {}); 1253 return success(); 1254 } 1255 }; 1256 1257 /// A rewrite mechanism to inline the body of the op into its parent, when both 1258 /// ops can have a single block. 1259 struct TestMergeSingleBlockOps 1260 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1261 using OpConversionPattern< 1262 SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1263 1264 LogicalResult 1265 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 1266 ConversionPatternRewriter &rewriter) const final { 1267 SingleBlockImplicitTerminatorOp parentOp = 1268 op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1269 if (!parentOp) 1270 return failure(); 1271 Block &innerBlock = op.getRegion().front(); 1272 TerminatorOp innerTerminator = 1273 cast<TerminatorOp>(innerBlock.getTerminator()); 1274 rewriter.mergeBlockBefore(&innerBlock, op); 1275 rewriter.eraseOp(innerTerminator); 1276 rewriter.eraseOp(op); 1277 rewriter.updateRootInPlace(op, [] {}); 1278 return success(); 1279 } 1280 }; 1281 1282 struct TestMergeBlocksPatternDriver 1283 : public PassWrapper<TestMergeBlocksPatternDriver, 1284 OperationPass<ModuleOp>> { 1285 StringRef getArgument() const final { return "test-merge-blocks"; } 1286 StringRef getDescription() const final { 1287 return "Test Merging operation in ConversionPatternRewriter"; 1288 } 1289 void runOnOperation() override { 1290 MLIRContext *context = &getContext(); 1291 mlir::RewritePatternSet patterns(context); 1292 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1293 context); 1294 ConversionTarget target(*context); 1295 target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1296 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1297 target.addIllegalOp<ILLegalOpF>(); 1298 1299 /// Expect the op to have a single block after legalization. 1300 target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1301 [&](TestMergeBlocksOp op) -> bool { 1302 return llvm::hasSingleElement(op.getBody()); 1303 }); 1304 1305 /// Only allow `test.br` within test.merge_blocks op. 1306 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 1307 return op->getParentOfType<TestMergeBlocksOp>(); 1308 }); 1309 1310 /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1311 /// inlined. 1312 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1313 [&](SingleBlockImplicitTerminatorOp op) -> bool { 1314 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1315 }); 1316 1317 DenseSet<Operation *> unlegalizedOps; 1318 (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1319 &unlegalizedOps); 1320 for (auto *op : unlegalizedOps) 1321 op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1322 } 1323 }; 1324 } // namespace 1325 1326 //===----------------------------------------------------------------------===// 1327 // Test Selective Replacement 1328 //===----------------------------------------------------------------------===// 1329 1330 namespace { 1331 /// A rewrite mechanism to inline the body of the op into its parent, when both 1332 /// ops can have a single block. 1333 struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1334 using OpRewritePattern<TestCastOp>::OpRewritePattern; 1335 1336 LogicalResult matchAndRewrite(TestCastOp op, 1337 PatternRewriter &rewriter) const final { 1338 if (op.getNumOperands() != 2) 1339 return failure(); 1340 OperandRange operands = op.getOperands(); 1341 1342 // Replace non-terminator uses with the first operand. 1343 rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1344 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1345 }); 1346 // Replace everything else with the second operand if the operation isn't 1347 // dead. 1348 rewriter.replaceOp(op, op.getOperand(1)); 1349 return success(); 1350 } 1351 }; 1352 1353 struct TestSelectiveReplacementPatternDriver 1354 : public PassWrapper<TestSelectiveReplacementPatternDriver, 1355 OperationPass<>> { 1356 StringRef getArgument() const final { 1357 return "test-pattern-selective-replacement"; 1358 } 1359 StringRef getDescription() const final { 1360 return "Test selective replacement in the PatternRewriter"; 1361 } 1362 void runOnOperation() override { 1363 MLIRContext *context = &getContext(); 1364 mlir::RewritePatternSet patterns(context); 1365 patterns.add<TestSelectiveOpReplacementPattern>(context); 1366 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1367 std::move(patterns)); 1368 } 1369 }; 1370 } // namespace 1371 1372 //===----------------------------------------------------------------------===// 1373 // PassRegistration 1374 //===----------------------------------------------------------------------===// 1375 1376 namespace mlir { 1377 namespace test { 1378 void registerPatternsTestPass() { 1379 PassRegistration<TestReturnTypeDriver>(); 1380 1381 PassRegistration<TestDerivedAttributeDriver>(); 1382 1383 PassRegistration<TestPatternDriver>(); 1384 1385 PassRegistration<TestLegalizePatternDriver>([] { 1386 return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1387 }); 1388 1389 PassRegistration<TestRemappedValue>(); 1390 1391 PassRegistration<TestUnknownRootOpDriver>(); 1392 1393 PassRegistration<TestTypeConversionDriver>(); 1394 PassRegistration<TestTargetMaterializationWithNoUses>(); 1395 1396 PassRegistration<TestMergeBlocksPatternDriver>(); 1397 PassRegistration<TestSelectiveReplacementPatternDriver>(); 1398 } 1399 } // namespace test 1400 } // namespace mlir 1401