1fec6c5acSUday Bondhugula //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2fec6c5acSUday Bondhugula // 3fec6c5acSUday Bondhugula // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fec6c5acSUday Bondhugula // See https://llvm.org/LICENSE.txt for license information. 5fec6c5acSUday Bondhugula // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fec6c5acSUday Bondhugula // 7fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 8fec6c5acSUday Bondhugula 9fec6c5acSUday Bondhugula #include "TestDialect.h" 1026f93d9fSAlex Zinenko #include "mlir/Dialect/StandardOps/IR/Ops.h" 11473bdaf2SAlex Zinenko #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 12c0a6318dSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 132bf423b0SRob Suderman #include "mlir/IR/Matchers.h" 14fec6c5acSUday Bondhugula #include "mlir/Pass/Pass.h" 15fec6c5acSUday Bondhugula #include "mlir/Transforms/DialectConversion.h" 1626f93d9fSAlex Zinenko #include "mlir/Transforms/FoldUtils.h" 17b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 189ba37b3bSJacques Pienaar 19fec6c5acSUday Bondhugula using namespace mlir; 207776b19eSStephen Neuendorffer using namespace test; 21fec6c5acSUday Bondhugula 22fec6c5acSUday Bondhugula // Native function for testing NativeCodeCall 23fec6c5acSUday Bondhugula static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 24fec6c5acSUday Bondhugula return choice.getValue() ? input1 : input2; 25fec6c5acSUday Bondhugula } 26fec6c5acSUday Bondhugula 2729429d1aSJacques Pienaar static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 2829429d1aSJacques Pienaar rewriter.create<OpI>(loc, input); 29fec6c5acSUday Bondhugula } 30fec6c5acSUday Bondhugula 31fec6c5acSUday Bondhugula static void handleNoResultOp(PatternRewriter &rewriter, 32fec6c5acSUday Bondhugula OpSymbolBindingNoResult op) { 33fec6c5acSUday Bondhugula // Turn the no result op to a one-result op. 34fec6c5acSUday Bondhugula rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(), 35fec6c5acSUday Bondhugula op.operand()); 36fec6c5acSUday Bondhugula } 37fec6c5acSUday Bondhugula 3834b5482bSChia-hung Duan static bool getFirstI32Result(Operation *op, Value &value) { 3934b5482bSChia-hung Duan if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 4034b5482bSChia-hung Duan return false; 4134b5482bSChia-hung Duan value = op->getResult(0); 4234b5482bSChia-hung Duan return true; 4334b5482bSChia-hung Duan } 4434b5482bSChia-hung Duan 4534b5482bSChia-hung Duan static Value bindNativeCodeCallResult(Value value) { return value; } 4634b5482bSChia-hung Duan 47d7314b3cSChia-hung Duan static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 48d7314b3cSChia-hung Duan Value input2) { 49d7314b3cSChia-hung Duan return SmallVector<Value, 2>({input2, input1}); 50d7314b3cSChia-hung Duan } 51d7314b3cSChia-hung Duan 5201641197SAlexEichenberger // Test that natives calls are only called once during rewrites. 5301641197SAlexEichenberger // OpM_Test will return Pi, increased by 1 for each subsequent calls. 5401641197SAlexEichenberger // This let us check the number of times OpM_Test was called by inspecting 5501641197SAlexEichenberger // the returned value in the MLIR output. 5601641197SAlexEichenberger static int64_t opMIncreasingValue = 314159265; 5701641197SAlexEichenberger static Attribute OpMTest(PatternRewriter &rewriter, Value val) { 5801641197SAlexEichenberger int64_t i = opMIncreasingValue++; 5901641197SAlexEichenberger return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 6001641197SAlexEichenberger } 6101641197SAlexEichenberger 62fec6c5acSUday Bondhugula namespace { 63fec6c5acSUday Bondhugula #include "TestPatterns.inc" 64fec6c5acSUday Bondhugula } // end anonymous namespace 65fec6c5acSUday Bondhugula 66fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 67c484c7ddSChia-hung Duan // Test Reduce Pattern Interface 68c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 69c484c7ddSChia-hung Duan 707776b19eSStephen Neuendorffer void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 71c484c7ddSChia-hung Duan populateWithGenerated(patterns); 72c484c7ddSChia-hung Duan } 73c484c7ddSChia-hung Duan 74c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 75fec6c5acSUday Bondhugula // Canonicalizer Driver. 76fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 77fec6c5acSUday Bondhugula 78fec6c5acSUday Bondhugula namespace { 7926f93d9fSAlex Zinenko struct FoldingPattern : public RewritePattern { 8026f93d9fSAlex Zinenko public: 8126f93d9fSAlex Zinenko FoldingPattern(MLIRContext *context) 8226f93d9fSAlex Zinenko : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 8326f93d9fSAlex Zinenko /*benefit=*/1, context) {} 8426f93d9fSAlex Zinenko 8526f93d9fSAlex Zinenko LogicalResult matchAndRewrite(Operation *op, 8626f93d9fSAlex Zinenko PatternRewriter &rewriter) const override { 872b638ed5SKazuaki Ishizaki // Exercise OperationFolder API for a single-result operation that is folded 8826f93d9fSAlex Zinenko // upon construction. The operation being created through the folder has an 8926f93d9fSAlex Zinenko // in-place folder, and it should be still present in the output. 9026f93d9fSAlex Zinenko // Furthermore, the folder should not crash when attempting to recover the 91a23d0559SKazuaki Ishizaki // (unchanged) operation result. 9226f93d9fSAlex Zinenko OperationFolder folder(op->getContext()); 9326f93d9fSAlex Zinenko Value result = folder.create<TestOpInPlaceFold>( 9426f93d9fSAlex Zinenko rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), 9526f93d9fSAlex Zinenko rewriter.getI32IntegerAttr(0)); 9626f93d9fSAlex Zinenko assert(result); 9726f93d9fSAlex Zinenko rewriter.replaceOp(op, result); 9826f93d9fSAlex Zinenko return success(); 9926f93d9fSAlex Zinenko } 10026f93d9fSAlex Zinenko }; 10126f93d9fSAlex Zinenko 102*e4635e63SRiver Riddle /// This pattern creates a foldable operation at the entry point of the block. 103*e4635e63SRiver Riddle /// This tests the situation where the operation folder will need to replace an 104*e4635e63SRiver Riddle /// operation with a previously created constant that does not initially 105*e4635e63SRiver Riddle /// dominate the operation to replace. 106*e4635e63SRiver Riddle struct FolderInsertBeforePreviouslyFoldedConstantPattern 107*e4635e63SRiver Riddle : public OpRewritePattern<TestCastOp> { 108*e4635e63SRiver Riddle public: 109*e4635e63SRiver Riddle using OpRewritePattern<TestCastOp>::OpRewritePattern; 110*e4635e63SRiver Riddle 111*e4635e63SRiver Riddle LogicalResult matchAndRewrite(TestCastOp op, 112*e4635e63SRiver Riddle PatternRewriter &rewriter) const override { 113*e4635e63SRiver Riddle if (!op->hasAttr("test_fold_before_previously_folded_op")) 114*e4635e63SRiver Riddle return failure(); 115*e4635e63SRiver Riddle rewriter.setInsertionPointToStart(op->getBlock()); 116*e4635e63SRiver Riddle 117*e4635e63SRiver Riddle auto constOp = 118*e4635e63SRiver Riddle rewriter.create<ConstantOp>(op.getLoc(), rewriter.getBoolAttr(true)); 119*e4635e63SRiver Riddle rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 120*e4635e63SRiver Riddle Value(constOp)); 121*e4635e63SRiver Riddle return success(); 122*e4635e63SRiver Riddle } 123*e4635e63SRiver Riddle }; 124*e4635e63SRiver Riddle 12580aca1eaSRiver Riddle struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> { 126b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-patterns"; } 127b5e22e6dSMehdi Amini StringRef getDescription() const final { return "Run test dialect patterns"; } 128fec6c5acSUday Bondhugula void runOnFunction() override { 129dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext()); 1301d909c9aSChris Lattner populateWithGenerated(patterns); 131fec6c5acSUday Bondhugula 132fec6c5acSUday Bondhugula // Verify named pattern is generated with expected name. 133*e4635e63SRiver Riddle patterns.add<FoldingPattern, TestNamedPatternRule, 134*e4635e63SRiver Riddle FolderInsertBeforePreviouslyFoldedConstantPattern>( 135*e4635e63SRiver Riddle &getContext()); 136fec6c5acSUday Bondhugula 137e21adfa3SRiver Riddle (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 138fec6c5acSUday Bondhugula } 139fec6c5acSUday Bondhugula }; 140fec6c5acSUday Bondhugula } // end anonymous namespace 141fec6c5acSUday Bondhugula 142fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 143fec6c5acSUday Bondhugula // ReturnType Driver. 144fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 145fec6c5acSUday Bondhugula 146fec6c5acSUday Bondhugula namespace { 147fec6c5acSUday Bondhugula // Generate ops for each instance where the type can be successfully inferred. 148fec6c5acSUday Bondhugula template <typename OpTy> 149fec6c5acSUday Bondhugula static void invokeCreateWithInferredReturnType(Operation *op) { 150fec6c5acSUday Bondhugula auto *context = op->getContext(); 151fec6c5acSUday Bondhugula auto fop = op->getParentOfType<FuncOp>(); 152fec6c5acSUday Bondhugula auto location = UnknownLoc::get(context); 153fec6c5acSUday Bondhugula OpBuilder b(op); 154fec6c5acSUday Bondhugula b.setInsertionPointAfter(op); 155fec6c5acSUday Bondhugula 156fec6c5acSUday Bondhugula // Use permutations of 2 args as operands. 157fec6c5acSUday Bondhugula assert(fop.getNumArguments() >= 2); 158fec6c5acSUday Bondhugula for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 159fec6c5acSUday Bondhugula for (int j = 0; j < e; ++j) { 160fec6c5acSUday Bondhugula std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 161fec6c5acSUday Bondhugula SmallVector<Type, 2> inferredReturnTypes; 1625eae715aSJacques Pienaar if (succeeded(OpTy::inferReturnTypes( 1635eae715aSJacques Pienaar context, llvm::None, values, op->getAttrDictionary(), 1645eae715aSJacques Pienaar op->getRegions(), inferredReturnTypes))) { 165fec6c5acSUday Bondhugula OperationState state(location, OpTy::getOperationName()); 1669db53a18SRiver Riddle // TODO: Expand to regions. 167bb1d976fSAlex Zinenko OpTy::build(b, state, values, op->getAttrs()); 168fec6c5acSUday Bondhugula (void)b.createOperation(state); 169fec6c5acSUday Bondhugula } 170fec6c5acSUday Bondhugula } 171fec6c5acSUday Bondhugula } 172fec6c5acSUday Bondhugula } 173fec6c5acSUday Bondhugula 174fec6c5acSUday Bondhugula static void reifyReturnShape(Operation *op) { 175fec6c5acSUday Bondhugula OpBuilder b(op); 176fec6c5acSUday Bondhugula 177fec6c5acSUday Bondhugula // Use permutations of 2 args as operands. 178fec6c5acSUday Bondhugula auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 179fec6c5acSUday Bondhugula SmallVector<Value, 2> shapes; 180851d02f6SWenyi Zhao if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 1819b051703SMaheshRavishankar !llvm::hasSingleElement(shapes)) 182fec6c5acSUday Bondhugula return; 1839b051703SMaheshRavishankar for (auto it : llvm::enumerate(shapes)) { 184fec6c5acSUday Bondhugula op->emitRemark() << "value " << it.index() << ": " 185fec6c5acSUday Bondhugula << it.value().getDefiningOp(); 186fec6c5acSUday Bondhugula } 1879b051703SMaheshRavishankar } 188fec6c5acSUday Bondhugula 18980aca1eaSRiver Riddle struct TestReturnTypeDriver 19080aca1eaSRiver Riddle : public PassWrapper<TestReturnTypeDriver, FunctionPass> { 191e2310704SJulian Gross void getDependentDialects(DialectRegistry ®istry) const override { 192c0a6318dSMatthias Springer registry.insert<tensor::TensorDialect>(); 193e2310704SJulian Gross } 194b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-return-type"; } 195b5e22e6dSMehdi Amini StringRef getDescription() const final { return "Run return type functions"; } 196e2310704SJulian Gross 197fec6c5acSUday Bondhugula void runOnFunction() override { 198fec6c5acSUday Bondhugula if (getFunction().getName() == "testCreateFunctions") { 199fec6c5acSUday Bondhugula std::vector<Operation *> ops; 200fec6c5acSUday Bondhugula // Collect ops to avoid triggering on inserted ops. 201fec6c5acSUday Bondhugula for (auto &op : getFunction().getBody().front()) 202fec6c5acSUday Bondhugula ops.push_back(&op); 203fec6c5acSUday Bondhugula // Generate test patterns for each, but skip terminator. 204fec6c5acSUday Bondhugula for (auto *op : llvm::makeArrayRef(ops).drop_back()) { 205fec6c5acSUday Bondhugula // Test create method of each of the Op classes below. The resultant 206fec6c5acSUday Bondhugula // output would be in reverse order underneath `op` from which 207fec6c5acSUday Bondhugula // the attributes and regions are used. 208fec6c5acSUday Bondhugula invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 209fec6c5acSUday Bondhugula invokeCreateWithInferredReturnType< 210fec6c5acSUday Bondhugula OpWithShapedTypeInferTypeInterfaceOp>(op); 211fec6c5acSUday Bondhugula }; 212fec6c5acSUday Bondhugula return; 213fec6c5acSUday Bondhugula } 214fec6c5acSUday Bondhugula if (getFunction().getName() == "testReifyFunctions") { 215fec6c5acSUday Bondhugula std::vector<Operation *> ops; 216fec6c5acSUday Bondhugula // Collect ops to avoid triggering on inserted ops. 217fec6c5acSUday Bondhugula for (auto &op : getFunction().getBody().front()) 218fec6c5acSUday Bondhugula if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 219fec6c5acSUday Bondhugula ops.push_back(&op); 220fec6c5acSUday Bondhugula // Generate test patterns for each, but skip terminator. 221fec6c5acSUday Bondhugula for (auto *op : ops) 222fec6c5acSUday Bondhugula reifyReturnShape(op); 223fec6c5acSUday Bondhugula } 224fec6c5acSUday Bondhugula } 225fec6c5acSUday Bondhugula }; 226fec6c5acSUday Bondhugula } // end anonymous namespace 227fec6c5acSUday Bondhugula 2289ba37b3bSJacques Pienaar namespace { 2299ba37b3bSJacques Pienaar struct TestDerivedAttributeDriver 2309ba37b3bSJacques Pienaar : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> { 231b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-derived-attr"; } 232b5e22e6dSMehdi Amini StringRef getDescription() const final { 233b5e22e6dSMehdi Amini return "Run test derived attributes"; 234b5e22e6dSMehdi Amini } 2359ba37b3bSJacques Pienaar void runOnFunction() override; 2369ba37b3bSJacques Pienaar }; 2379ba37b3bSJacques Pienaar } // end anonymous namespace 2389ba37b3bSJacques Pienaar 2399ba37b3bSJacques Pienaar void TestDerivedAttributeDriver::runOnFunction() { 2409ba37b3bSJacques Pienaar getFunction().walk([](DerivedAttributeOpInterface dOp) { 2419ba37b3bSJacques Pienaar auto dAttr = dOp.materializeDerivedAttributes(); 2429ba37b3bSJacques Pienaar if (!dAttr) 2439ba37b3bSJacques Pienaar return; 2449ba37b3bSJacques Pienaar for (auto d : dAttr) 2459ba37b3bSJacques Pienaar dOp.emitRemark() << d.first << " = " << d.second; 2469ba37b3bSJacques Pienaar }); 2479ba37b3bSJacques Pienaar } 2489ba37b3bSJacques Pienaar 249fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 250fec6c5acSUday Bondhugula // Legalization Driver. 251fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 252fec6c5acSUday Bondhugula 253fec6c5acSUday Bondhugula namespace { 254fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 255fec6c5acSUday Bondhugula // Region-Block Rewrite Testing 256fec6c5acSUday Bondhugula 257fec6c5acSUday Bondhugula /// This pattern is a simple pattern that inlines the first region of a given 258fec6c5acSUday Bondhugula /// operation into the parent region. 259fec6c5acSUday Bondhugula struct TestRegionRewriteBlockMovement : public ConversionPattern { 260fec6c5acSUday Bondhugula TestRegionRewriteBlockMovement(MLIRContext *ctx) 261fec6c5acSUday Bondhugula : ConversionPattern("test.region", 1, ctx) {} 262fec6c5acSUday Bondhugula 263fec6c5acSUday Bondhugula LogicalResult 264fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 265fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 266fec6c5acSUday Bondhugula // Inline this region into the parent region. 267fec6c5acSUday Bondhugula auto &parentRegion = *op->getParentRegion(); 268b0750e2dSTres Popp auto &opRegion = op->getRegion(0); 269fec6c5acSUday Bondhugula if (op->getAttr("legalizer.should_clone")) 270b0750e2dSTres Popp rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 271fec6c5acSUday Bondhugula else 272b0750e2dSTres Popp rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 273b0750e2dSTres Popp 274b0750e2dSTres Popp if (op->getAttr("legalizer.erase_old_blocks")) { 275b0750e2dSTres Popp while (!opRegion.empty()) 276b0750e2dSTres Popp rewriter.eraseBlock(&opRegion.front()); 277b0750e2dSTres Popp } 278fec6c5acSUday Bondhugula 279fec6c5acSUday Bondhugula // Drop this operation. 280fec6c5acSUday Bondhugula rewriter.eraseOp(op); 281fec6c5acSUday Bondhugula return success(); 282fec6c5acSUday Bondhugula } 283fec6c5acSUday Bondhugula }; 284fec6c5acSUday Bondhugula /// This pattern is a simple pattern that generates a region containing an 285fec6c5acSUday Bondhugula /// illegal operation. 286fec6c5acSUday Bondhugula struct TestRegionRewriteUndo : public RewritePattern { 287fec6c5acSUday Bondhugula TestRegionRewriteUndo(MLIRContext *ctx) 288fec6c5acSUday Bondhugula : RewritePattern("test.region_builder", 1, ctx) {} 289fec6c5acSUday Bondhugula 290fec6c5acSUday Bondhugula LogicalResult matchAndRewrite(Operation *op, 291fec6c5acSUday Bondhugula PatternRewriter &rewriter) const final { 292fec6c5acSUday Bondhugula // Create the region operation with an entry block containing arguments. 293fec6c5acSUday Bondhugula OperationState newRegion(op->getLoc(), "test.region"); 294fec6c5acSUday Bondhugula newRegion.addRegion(); 295fec6c5acSUday Bondhugula auto *regionOp = rewriter.createOperation(newRegion); 296fec6c5acSUday Bondhugula auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 297fec6c5acSUday Bondhugula entryBlock->addArgument(rewriter.getIntegerType(64)); 298fec6c5acSUday Bondhugula 299fec6c5acSUday Bondhugula // Add an explicitly illegal operation to ensure the conversion fails. 300fec6c5acSUday Bondhugula rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 301fec6c5acSUday Bondhugula rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 302fec6c5acSUday Bondhugula 303fec6c5acSUday Bondhugula // Drop this operation. 304fec6c5acSUday Bondhugula rewriter.eraseOp(op); 305fec6c5acSUday Bondhugula return success(); 306fec6c5acSUday Bondhugula } 307fec6c5acSUday Bondhugula }; 308f27f1e8cSAlex Zinenko /// A simple pattern that creates a block at the end of the parent region of the 309f27f1e8cSAlex Zinenko /// matched operation. 310f27f1e8cSAlex Zinenko struct TestCreateBlock : public RewritePattern { 311f27f1e8cSAlex Zinenko TestCreateBlock(MLIRContext *ctx) 312f27f1e8cSAlex Zinenko : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 313f27f1e8cSAlex Zinenko 314f27f1e8cSAlex Zinenko LogicalResult matchAndRewrite(Operation *op, 315f27f1e8cSAlex Zinenko PatternRewriter &rewriter) const final { 316f27f1e8cSAlex Zinenko Region ®ion = *op->getParentRegion(); 317f27f1e8cSAlex Zinenko Type i32Type = rewriter.getIntegerType(32); 318f27f1e8cSAlex Zinenko rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 319f27f1e8cSAlex Zinenko rewriter.create<TerminatorOp>(op->getLoc()); 320f27f1e8cSAlex Zinenko rewriter.replaceOp(op, {}); 321f27f1e8cSAlex Zinenko return success(); 322f27f1e8cSAlex Zinenko } 323f27f1e8cSAlex Zinenko }; 324f27f1e8cSAlex Zinenko 325a23d0559SKazuaki Ishizaki /// A simple pattern that creates a block containing an invalid operation in 326f27f1e8cSAlex Zinenko /// order to trigger the block creation undo mechanism. 327f27f1e8cSAlex Zinenko struct TestCreateIllegalBlock : public RewritePattern { 328f27f1e8cSAlex Zinenko TestCreateIllegalBlock(MLIRContext *ctx) 329f27f1e8cSAlex Zinenko : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 330f27f1e8cSAlex Zinenko 331f27f1e8cSAlex Zinenko LogicalResult matchAndRewrite(Operation *op, 332f27f1e8cSAlex Zinenko PatternRewriter &rewriter) const final { 333f27f1e8cSAlex Zinenko Region ®ion = *op->getParentRegion(); 334f27f1e8cSAlex Zinenko Type i32Type = rewriter.getIntegerType(32); 335f27f1e8cSAlex Zinenko rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); 336f27f1e8cSAlex Zinenko // Create an illegal op to ensure the conversion fails. 337f27f1e8cSAlex Zinenko rewriter.create<ILLegalOpF>(op->getLoc(), i32Type); 338f27f1e8cSAlex Zinenko rewriter.create<TerminatorOp>(op->getLoc()); 339f27f1e8cSAlex Zinenko rewriter.replaceOp(op, {}); 340f27f1e8cSAlex Zinenko return success(); 341f27f1e8cSAlex Zinenko } 342f27f1e8cSAlex Zinenko }; 343fec6c5acSUday Bondhugula 3440816de16SRiver Riddle /// A simple pattern that tests the undo mechanism when replacing the uses of a 3450816de16SRiver Riddle /// block argument. 3460816de16SRiver Riddle struct TestUndoBlockArgReplace : public ConversionPattern { 3470816de16SRiver Riddle TestUndoBlockArgReplace(MLIRContext *ctx) 3480816de16SRiver Riddle : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 3490816de16SRiver Riddle 3500816de16SRiver Riddle LogicalResult 3510816de16SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 3520816de16SRiver Riddle ConversionPatternRewriter &rewriter) const final { 3530816de16SRiver Riddle auto illegalOp = 3540816de16SRiver Riddle rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 355e2b71610SRahul Joshi rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 3560816de16SRiver Riddle illegalOp); 3570816de16SRiver Riddle rewriter.updateRootInPlace(op, [] {}); 3580816de16SRiver Riddle return success(); 3590816de16SRiver Riddle } 3600816de16SRiver Riddle }; 3610816de16SRiver Riddle 362df48026bSAlex Zinenko /// A rewrite pattern that tests the undo mechanism when erasing a block. 363df48026bSAlex Zinenko struct TestUndoBlockErase : public ConversionPattern { 364df48026bSAlex Zinenko TestUndoBlockErase(MLIRContext *ctx) 365df48026bSAlex Zinenko : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 366df48026bSAlex Zinenko 367df48026bSAlex Zinenko LogicalResult 368df48026bSAlex Zinenko matchAndRewrite(Operation *op, ArrayRef<Value> operands, 369df48026bSAlex Zinenko ConversionPatternRewriter &rewriter) const final { 370df48026bSAlex Zinenko Block *secondBlock = &*std::next(op->getRegion(0).begin()); 371df48026bSAlex Zinenko rewriter.setInsertionPointToStart(secondBlock); 372df48026bSAlex Zinenko rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 373df48026bSAlex Zinenko rewriter.eraseBlock(secondBlock); 374df48026bSAlex Zinenko rewriter.updateRootInPlace(op, [] {}); 375df48026bSAlex Zinenko return success(); 376df48026bSAlex Zinenko } 377df48026bSAlex Zinenko }; 378df48026bSAlex Zinenko 379fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 380fec6c5acSUday Bondhugula // Type-Conversion Rewrite Testing 381fec6c5acSUday Bondhugula 382fec6c5acSUday Bondhugula /// This patterns erases a region operation that has had a type conversion. 383fec6c5acSUday Bondhugula struct TestDropOpSignatureConversion : public ConversionPattern { 384fec6c5acSUday Bondhugula TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 38576f3c2f3SRiver Riddle : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 386fec6c5acSUday Bondhugula LogicalResult 387fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 388fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const override { 389fec6c5acSUday Bondhugula Region ®ion = op->getRegion(0); 390fec6c5acSUday Bondhugula Block *entry = ®ion.front(); 391fec6c5acSUday Bondhugula 392fec6c5acSUday Bondhugula // Convert the original entry arguments. 3938d67d187SRiver Riddle TypeConverter &converter = *getTypeConverter(); 394fec6c5acSUday Bondhugula TypeConverter::SignatureConversion result(entry->getNumArguments()); 3958d67d187SRiver Riddle if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 3968d67d187SRiver Riddle result)) || 3978d67d187SRiver Riddle failed(rewriter.convertRegionTypes(®ion, converter, &result))) 398fec6c5acSUday Bondhugula return failure(); 399fec6c5acSUday Bondhugula 400fec6c5acSUday Bondhugula // Convert the region signature and just drop the operation. 401fec6c5acSUday Bondhugula rewriter.eraseOp(op); 402fec6c5acSUday Bondhugula return success(); 403fec6c5acSUday Bondhugula } 404fec6c5acSUday Bondhugula }; 405fec6c5acSUday Bondhugula /// This pattern simply updates the operands of the given operation. 406fec6c5acSUday Bondhugula struct TestPassthroughInvalidOp : public ConversionPattern { 407fec6c5acSUday Bondhugula TestPassthroughInvalidOp(MLIRContext *ctx) 408fec6c5acSUday Bondhugula : ConversionPattern("test.invalid", 1, ctx) {} 409fec6c5acSUday Bondhugula LogicalResult 410fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 411fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 412fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 413fec6c5acSUday Bondhugula llvm::None); 414fec6c5acSUday Bondhugula return success(); 415fec6c5acSUday Bondhugula } 416fec6c5acSUday Bondhugula }; 417fec6c5acSUday Bondhugula /// This pattern handles the case of a split return value. 418fec6c5acSUday Bondhugula struct TestSplitReturnType : public ConversionPattern { 419fec6c5acSUday Bondhugula TestSplitReturnType(MLIRContext *ctx) 420fec6c5acSUday Bondhugula : ConversionPattern("test.return", 1, ctx) {} 421fec6c5acSUday Bondhugula LogicalResult 422fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 423fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 424fec6c5acSUday Bondhugula // Check for a return of F32. 425fec6c5acSUday Bondhugula if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 426fec6c5acSUday Bondhugula return failure(); 427fec6c5acSUday Bondhugula 428fec6c5acSUday Bondhugula // Check if the first operation is a cast operation, if it is we use the 429fec6c5acSUday Bondhugula // results directly. 430fec6c5acSUday Bondhugula auto *defOp = operands[0].getDefiningOp(); 431fec6c5acSUday Bondhugula if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 432fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); 433fec6c5acSUday Bondhugula return success(); 434fec6c5acSUday Bondhugula } 435fec6c5acSUday Bondhugula 436fec6c5acSUday Bondhugula // Otherwise, fail to match. 437fec6c5acSUday Bondhugula return failure(); 438fec6c5acSUday Bondhugula } 439fec6c5acSUday Bondhugula }; 440fec6c5acSUday Bondhugula 441fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 442fec6c5acSUday Bondhugula // Multi-Level Type-Conversion Rewrite Testing 443fec6c5acSUday Bondhugula struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 444fec6c5acSUday Bondhugula TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 445fec6c5acSUday Bondhugula : ConversionPattern("test.type_producer", 1, ctx) {} 446fec6c5acSUday Bondhugula LogicalResult 447fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 448fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 449fec6c5acSUday Bondhugula // If the type is I32, change the type to F32. 450fec6c5acSUday Bondhugula if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 451fec6c5acSUday Bondhugula return failure(); 452fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 453fec6c5acSUday Bondhugula return success(); 454fec6c5acSUday Bondhugula } 455fec6c5acSUday Bondhugula }; 456fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 457fec6c5acSUday Bondhugula TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 458fec6c5acSUday Bondhugula : ConversionPattern("test.type_producer", 1, ctx) {} 459fec6c5acSUday Bondhugula LogicalResult 460fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 461fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 462fec6c5acSUday Bondhugula // If the type is F32, change the type to F64. 463fec6c5acSUday Bondhugula if (!Type(*op->result_type_begin()).isF32()) 464fec6c5acSUday Bondhugula return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 465fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 466fec6c5acSUday Bondhugula return success(); 467fec6c5acSUday Bondhugula } 468fec6c5acSUday Bondhugula }; 469fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 470fec6c5acSUday Bondhugula TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 471fec6c5acSUday Bondhugula : ConversionPattern("test.type_producer", 10, ctx) {} 472fec6c5acSUday Bondhugula LogicalResult 473fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 474fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 475fec6c5acSUday Bondhugula // Always convert to B16, even though it is not a legal type. This tests 476fec6c5acSUday Bondhugula // that values are unmapped correctly. 477fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 478fec6c5acSUday Bondhugula return success(); 479fec6c5acSUday Bondhugula } 480fec6c5acSUday Bondhugula }; 481fec6c5acSUday Bondhugula struct TestUpdateConsumerType : public ConversionPattern { 482fec6c5acSUday Bondhugula TestUpdateConsumerType(MLIRContext *ctx) 483fec6c5acSUday Bondhugula : ConversionPattern("test.type_consumer", 1, ctx) {} 484fec6c5acSUday Bondhugula LogicalResult 485fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 486fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 487fec6c5acSUday Bondhugula // Verify that the incoming operand has been successfully remapped to F64. 488fec6c5acSUday Bondhugula if (!operands[0].getType().isF64()) 489fec6c5acSUday Bondhugula return failure(); 490fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 491fec6c5acSUday Bondhugula return success(); 492fec6c5acSUday Bondhugula } 493fec6c5acSUday Bondhugula }; 494fec6c5acSUday Bondhugula 495fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 496fec6c5acSUday Bondhugula // Non-Root Replacement Rewrite Testing 497fec6c5acSUday Bondhugula /// This pattern generates an invalid operation, but replaces it before the 498fec6c5acSUday Bondhugula /// pattern is finished. This checks that we don't need to legalize the 499fec6c5acSUday Bondhugula /// temporary op. 500fec6c5acSUday Bondhugula struct TestNonRootReplacement : public RewritePattern { 501fec6c5acSUday Bondhugula TestNonRootReplacement(MLIRContext *ctx) 502fec6c5acSUday Bondhugula : RewritePattern("test.replace_non_root", 1, ctx) {} 503fec6c5acSUday Bondhugula 504fec6c5acSUday Bondhugula LogicalResult matchAndRewrite(Operation *op, 505fec6c5acSUday Bondhugula PatternRewriter &rewriter) const final { 506fec6c5acSUday Bondhugula auto resultType = *op->result_type_begin(); 507fec6c5acSUday Bondhugula auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 508fec6c5acSUday Bondhugula auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 509fec6c5acSUday Bondhugula 510fec6c5acSUday Bondhugula rewriter.replaceOp(illegalOp, {legalOp}); 511fec6c5acSUday Bondhugula rewriter.replaceOp(op, {illegalOp}); 512fec6c5acSUday Bondhugula return success(); 513fec6c5acSUday Bondhugula } 514fec6c5acSUday Bondhugula }; 515bd1ccfe6SRiver Riddle 516bd1ccfe6SRiver Riddle //===----------------------------------------------------------------------===// 517bd1ccfe6SRiver Riddle // Recursive Rewrite Testing 518bd1ccfe6SRiver Riddle /// This pattern is applied to the same operation multiple times, but has a 519bd1ccfe6SRiver Riddle /// bounded recursion. 520bd1ccfe6SRiver Riddle struct TestBoundedRecursiveRewrite 521bd1ccfe6SRiver Riddle : public OpRewritePattern<TestRecursiveRewriteOp> { 5222257e4a7SRiver Riddle using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 5232257e4a7SRiver Riddle 5242257e4a7SRiver Riddle void initialize() { 525b99bd771SRiver Riddle // The conversion target handles bounding the recursion of this pattern. 526b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 527b99bd771SRiver Riddle } 528bd1ccfe6SRiver Riddle 529bd1ccfe6SRiver Riddle LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 530bd1ccfe6SRiver Riddle PatternRewriter &rewriter) const final { 531bd1ccfe6SRiver Riddle // Decrement the depth of the op in-place. 532bd1ccfe6SRiver Riddle rewriter.updateRootInPlace(op, [&] { 5331ffc1aaaSChristian Sigg op->setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); 534bd1ccfe6SRiver Riddle }); 535bd1ccfe6SRiver Riddle return success(); 536bd1ccfe6SRiver Riddle } 537bd1ccfe6SRiver Riddle }; 5385d5df06aSAlex Zinenko 5395d5df06aSAlex Zinenko struct TestNestedOpCreationUndoRewrite 5405d5df06aSAlex Zinenko : public OpRewritePattern<IllegalOpWithRegionAnchor> { 5415d5df06aSAlex Zinenko using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 5425d5df06aSAlex Zinenko 5435d5df06aSAlex Zinenko LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 5445d5df06aSAlex Zinenko PatternRewriter &rewriter) const final { 5455d5df06aSAlex Zinenko // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 5465d5df06aSAlex Zinenko rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 5475d5df06aSAlex Zinenko return success(); 5485d5df06aSAlex Zinenko }; 5495d5df06aSAlex Zinenko }; 550a360a978SMehdi Amini 551a360a978SMehdi Amini // This pattern matches `test.blackhole` and delete this op and its producer. 552a360a978SMehdi Amini struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 553a360a978SMehdi Amini using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 554a360a978SMehdi Amini 555a360a978SMehdi Amini LogicalResult matchAndRewrite(BlackHoleOp op, 556a360a978SMehdi Amini PatternRewriter &rewriter) const final { 557a360a978SMehdi Amini Operation *producer = op.getOperand().getDefiningOp(); 558a360a978SMehdi Amini // Always erase the user before the producer, the framework should handle 559a360a978SMehdi Amini // this correctly. 560a360a978SMehdi Amini rewriter.eraseOp(op); 561a360a978SMehdi Amini rewriter.eraseOp(producer); 562a360a978SMehdi Amini return success(); 563a360a978SMehdi Amini }; 564a360a978SMehdi Amini }; 565fec6c5acSUday Bondhugula } // namespace 566fec6c5acSUday Bondhugula 567fec6c5acSUday Bondhugula namespace { 568fec6c5acSUday Bondhugula struct TestTypeConverter : public TypeConverter { 569fec6c5acSUday Bondhugula using TypeConverter::TypeConverter; 5705c5dafc5SAlex Zinenko TestTypeConverter() { 5715c5dafc5SAlex Zinenko addConversion(convertType); 5724589dd92SRiver Riddle addArgumentMaterialization(materializeCast); 5734589dd92SRiver Riddle addSourceMaterialization(materializeCast); 5747cc79984SVladislav Vinogradov 5757cc79984SVladislav Vinogradov /// Materialize the cast for one-to-one conversion from i64 to f64. 5767cc79984SVladislav Vinogradov const auto materializeOneToOneCast = 5777cc79984SVladislav Vinogradov [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, 5787cc79984SVladislav Vinogradov Location loc) -> Optional<Value> { 5797cc79984SVladislav Vinogradov if (resultType.getWidth() == 42 && inputs.size() == 1) 5807cc79984SVladislav Vinogradov return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 5817cc79984SVladislav Vinogradov return llvm::None; 5827cc79984SVladislav Vinogradov }; 5837cc79984SVladislav Vinogradov addArgumentMaterialization(materializeOneToOneCast); 5845c5dafc5SAlex Zinenko } 585fec6c5acSUday Bondhugula 586fec6c5acSUday Bondhugula static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 587fec6c5acSUday Bondhugula // Drop I16 types. 588fec6c5acSUday Bondhugula if (t.isSignlessInteger(16)) 589fec6c5acSUday Bondhugula return success(); 590fec6c5acSUday Bondhugula 591fec6c5acSUday Bondhugula // Convert I64 to F64. 592fec6c5acSUday Bondhugula if (t.isSignlessInteger(64)) { 593fec6c5acSUday Bondhugula results.push_back(FloatType::getF64(t.getContext())); 594fec6c5acSUday Bondhugula return success(); 595fec6c5acSUday Bondhugula } 596fec6c5acSUday Bondhugula 5975c5dafc5SAlex Zinenko // Convert I42 to I43. 5985c5dafc5SAlex Zinenko if (t.isInteger(42)) { 5991b97cdf8SRiver Riddle results.push_back(IntegerType::get(t.getContext(), 43)); 6005c5dafc5SAlex Zinenko return success(); 6015c5dafc5SAlex Zinenko } 6025c5dafc5SAlex Zinenko 603fec6c5acSUday Bondhugula // Split F32 into F16,F16. 604fec6c5acSUday Bondhugula if (t.isF32()) { 605fec6c5acSUday Bondhugula results.assign(2, FloatType::getF16(t.getContext())); 606fec6c5acSUday Bondhugula return success(); 607fec6c5acSUday Bondhugula } 608fec6c5acSUday Bondhugula 609fec6c5acSUday Bondhugula // Otherwise, convert the type directly. 610fec6c5acSUday Bondhugula results.push_back(t); 611fec6c5acSUday Bondhugula return success(); 612fec6c5acSUday Bondhugula } 613fec6c5acSUday Bondhugula 6145c5dafc5SAlex Zinenko /// Hook for materializing a conversion. This is necessary because we generate 6155c5dafc5SAlex Zinenko /// 1->N type mappings. 6164589dd92SRiver Riddle static Optional<Value> materializeCast(OpBuilder &builder, Type resultType, 6174589dd92SRiver Riddle ValueRange inputs, Location loc) { 6185c5dafc5SAlex Zinenko if (inputs.size() == 1) 6195c5dafc5SAlex Zinenko return inputs[0]; 6204589dd92SRiver Riddle return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 6215c5dafc5SAlex Zinenko } 622fec6c5acSUday Bondhugula }; 623fec6c5acSUday Bondhugula 624fec6c5acSUday Bondhugula struct TestLegalizePatternDriver 62580aca1eaSRiver Riddle : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> { 626b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-legalize-patterns"; } 627b5e22e6dSMehdi Amini StringRef getDescription() const final { 628b5e22e6dSMehdi Amini return "Run test dialect legalization patterns"; 629b5e22e6dSMehdi Amini } 630fec6c5acSUday Bondhugula /// The mode of conversion to use with the driver. 631fec6c5acSUday Bondhugula enum class ConversionMode { Analysis, Full, Partial }; 632fec6c5acSUday Bondhugula 633fec6c5acSUday Bondhugula TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 634fec6c5acSUday Bondhugula 635722f909fSRiver Riddle void runOnOperation() override { 636fec6c5acSUday Bondhugula TestTypeConverter converter; 637dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext()); 6381d909c9aSChris Lattner populateWithGenerated(patterns); 639dc4e913bSChris Lattner patterns 640dc4e913bSChris Lattner .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 641dc4e913bSChris Lattner TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, 642dc4e913bSChris Lattner TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, 643df48026bSAlex Zinenko TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 644fec6c5acSUday Bondhugula TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 6455d5df06aSAlex Zinenko TestNonRootReplacement, TestBoundedRecursiveRewrite, 646a360a978SMehdi Amini TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( 647a360a978SMehdi Amini &getContext()); 648dc4e913bSChris Lattner patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); 6493a506b31SChris Lattner mlir::populateFuncOpTypeConversionPattern(patterns, converter); 6503a506b31SChris Lattner mlir::populateCallOpTypeConversionPattern(patterns, converter); 651fec6c5acSUday Bondhugula 652fec6c5acSUday Bondhugula // Define the conversion target used for the test. 653fec6c5acSUday Bondhugula ConversionTarget target(getContext()); 654973ddb7dSMehdi Amini target.addLegalOp<ModuleOp>(); 655f27f1e8cSAlex Zinenko target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp, 656f27f1e8cSAlex Zinenko TerminatorOp>(); 657fec6c5acSUday Bondhugula target 658fec6c5acSUday Bondhugula .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 659fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 660fec6c5acSUday Bondhugula // Don't allow F32 operands. 661fec6c5acSUday Bondhugula return llvm::none_of(op.getOperandTypes(), 662fec6c5acSUday Bondhugula [](Type type) { return type.isF32(); }); 663fec6c5acSUday Bondhugula }); 6648d67d187SRiver Riddle target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 6658d67d187SRiver Riddle return converter.isSignatureLegal(op.getType()) && 6668d67d187SRiver Riddle converter.isLegal(&op.getBody()); 6678d67d187SRiver Riddle }); 668fec6c5acSUday Bondhugula 669fec6c5acSUday Bondhugula // Expect the type_producer/type_consumer operations to only operate on f64. 670fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<TestTypeProducerOp>( 671fec6c5acSUday Bondhugula [](TestTypeProducerOp op) { return op.getType().isF64(); }); 672fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 673fec6c5acSUday Bondhugula return op.getOperand().getType().isF64(); 674fec6c5acSUday Bondhugula }); 675fec6c5acSUday Bondhugula 676fec6c5acSUday Bondhugula // Check support for marking certain operations as recursively legal. 677fec6c5acSUday Bondhugula target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) { 678fec6c5acSUday Bondhugula return static_cast<bool>( 679fec6c5acSUday Bondhugula op->getAttrOfType<UnitAttr>("test.recursively_legal")); 680fec6c5acSUday Bondhugula }); 681fec6c5acSUday Bondhugula 682bd1ccfe6SRiver Riddle // Mark the bound recursion operation as dynamically legal. 683bd1ccfe6SRiver Riddle target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 684bd1ccfe6SRiver Riddle [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); 685bd1ccfe6SRiver Riddle 686fec6c5acSUday Bondhugula // Handle a partial conversion. 687fec6c5acSUday Bondhugula if (mode == ConversionMode::Partial) { 6888de482eaSLucy Fox DenseSet<Operation *> unlegalizedOps; 6893fffffa8SRiver Riddle (void)applyPartialConversion(getOperation(), target, std::move(patterns), 6908de482eaSLucy Fox &unlegalizedOps); 6918de482eaSLucy Fox // Emit remarks for each legalizable operation. 6928de482eaSLucy Fox for (auto *op : unlegalizedOps) 6938de482eaSLucy Fox op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 694fec6c5acSUday Bondhugula return; 695fec6c5acSUday Bondhugula } 696fec6c5acSUday Bondhugula 697fec6c5acSUday Bondhugula // Handle a full conversion. 698fec6c5acSUday Bondhugula if (mode == ConversionMode::Full) { 699fec6c5acSUday Bondhugula // Check support for marking unknown operations as dynamically legal. 700fec6c5acSUday Bondhugula target.markUnknownOpDynamicallyLegal([](Operation *op) { 701fec6c5acSUday Bondhugula return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 702fec6c5acSUday Bondhugula }); 703fec6c5acSUday Bondhugula 7043fffffa8SRiver Riddle (void)applyFullConversion(getOperation(), target, std::move(patterns)); 705fec6c5acSUday Bondhugula return; 706fec6c5acSUday Bondhugula } 707fec6c5acSUday Bondhugula 708fec6c5acSUday Bondhugula // Otherwise, handle an analysis conversion. 709fec6c5acSUday Bondhugula assert(mode == ConversionMode::Analysis); 710fec6c5acSUday Bondhugula 711fec6c5acSUday Bondhugula // Analyze the convertible operations. 712fec6c5acSUday Bondhugula DenseSet<Operation *> legalizedOps; 7133fffffa8SRiver Riddle if (failed(applyAnalysisConversion(getOperation(), target, 7143fffffa8SRiver Riddle std::move(patterns), legalizedOps))) 715fec6c5acSUday Bondhugula return signalPassFailure(); 716fec6c5acSUday Bondhugula 717fec6c5acSUday Bondhugula // Emit remarks for each legalizable operation. 718fec6c5acSUday Bondhugula for (auto *op : legalizedOps) 719fec6c5acSUday Bondhugula op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 720fec6c5acSUday Bondhugula } 721fec6c5acSUday Bondhugula 722fec6c5acSUday Bondhugula /// The mode of conversion to use. 723fec6c5acSUday Bondhugula ConversionMode mode; 724fec6c5acSUday Bondhugula }; 725fec6c5acSUday Bondhugula } // end anonymous namespace 726fec6c5acSUday Bondhugula 727fec6c5acSUday Bondhugula static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 728fec6c5acSUday Bondhugula legalizerConversionMode( 729fec6c5acSUday Bondhugula "test-legalize-mode", 730fec6c5acSUday Bondhugula llvm::cl::desc("The legalization mode to use with the test driver"), 731fec6c5acSUday Bondhugula llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 732fec6c5acSUday Bondhugula llvm::cl::values( 733fec6c5acSUday Bondhugula clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 734fec6c5acSUday Bondhugula "analysis", "Perform an analysis conversion"), 735fec6c5acSUday Bondhugula clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 736fec6c5acSUday Bondhugula "Perform a full conversion"), 737fec6c5acSUday Bondhugula clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 738fec6c5acSUday Bondhugula "partial", "Perform a partial conversion"))); 739fec6c5acSUday Bondhugula 740fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 741fec6c5acSUday Bondhugula // ConversionPatternRewriter::getRemappedValue testing. This method is used 7425aacce3dSKazuaki Ishizaki // to get the remapped value of an original value that was replaced using 743fec6c5acSUday Bondhugula // ConversionPatternRewriter. 744fec6c5acSUday Bondhugula namespace { 745fec6c5acSUday Bondhugula /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 746fec6c5acSUday Bondhugula /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 747fec6c5acSUday Bondhugula /// operand twice. 748fec6c5acSUday Bondhugula /// 749fec6c5acSUday Bondhugula /// Example: 750fec6c5acSUday Bondhugula /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 751fec6c5acSUday Bondhugula /// is replaced with: 752fec6c5acSUday Bondhugula /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 753fec6c5acSUday Bondhugula struct OneVResOneVOperandOp1Converter 754fec6c5acSUday Bondhugula : public OpConversionPattern<OneVResOneVOperandOp1> { 755fec6c5acSUday Bondhugula using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 756fec6c5acSUday Bondhugula 757fec6c5acSUday Bondhugula LogicalResult 758fec6c5acSUday Bondhugula matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands, 759fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const override { 760fec6c5acSUday Bondhugula auto origOps = op.getOperands(); 761fec6c5acSUday Bondhugula assert(std::distance(origOps.begin(), origOps.end()) == 1 && 762fec6c5acSUday Bondhugula "One operand expected"); 763fec6c5acSUday Bondhugula Value origOp = *origOps.begin(); 764fec6c5acSUday Bondhugula SmallVector<Value, 2> remappedOperands; 765fec6c5acSUday Bondhugula // Replicate the remapped original operand twice. Note that we don't used 766fec6c5acSUday Bondhugula // the remapped 'operand' since the goal is testing 'getRemappedValue'. 767fec6c5acSUday Bondhugula remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 768fec6c5acSUday Bondhugula remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 769fec6c5acSUday Bondhugula 770fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 771fec6c5acSUday Bondhugula remappedOperands); 772fec6c5acSUday Bondhugula return success(); 773fec6c5acSUday Bondhugula } 774fec6c5acSUday Bondhugula }; 775fec6c5acSUday Bondhugula 77680aca1eaSRiver Riddle struct TestRemappedValue 77780aca1eaSRiver Riddle : public mlir::PassWrapper<TestRemappedValue, FunctionPass> { 778b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-remapped-value"; } 779b5e22e6dSMehdi Amini StringRef getDescription() const final { 780b5e22e6dSMehdi Amini return "Test public remapped value mechanism in ConversionPatternRewriter"; 781b5e22e6dSMehdi Amini } 782fec6c5acSUday Bondhugula void runOnFunction() override { 783dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext()); 784dc4e913bSChris Lattner patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 785fec6c5acSUday Bondhugula 786fec6c5acSUday Bondhugula mlir::ConversionTarget target(getContext()); 787973ddb7dSMehdi Amini target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>(); 788fec6c5acSUday Bondhugula // We make OneVResOneVOperandOp1 legal only when it has more that one 789fec6c5acSUday Bondhugula // operand. This will trigger the conversion that will replace one-operand 790fec6c5acSUday Bondhugula // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 791fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 792fec6c5acSUday Bondhugula [](Operation *op) -> bool { 793fec6c5acSUday Bondhugula return std::distance(op->operand_begin(), op->operand_end()) > 1; 794fec6c5acSUday Bondhugula }); 795fec6c5acSUday Bondhugula 7963fffffa8SRiver Riddle if (failed(mlir::applyFullConversion(getFunction(), target, 7973fffffa8SRiver Riddle std::move(patterns)))) { 798fec6c5acSUday Bondhugula signalPassFailure(); 799fec6c5acSUday Bondhugula } 800fec6c5acSUday Bondhugula } 801fec6c5acSUday Bondhugula }; 802fec6c5acSUday Bondhugula } // end anonymous namespace 803fec6c5acSUday Bondhugula 80480d7ac3bSRiver Riddle //===----------------------------------------------------------------------===// 80580d7ac3bSRiver Riddle // Test patterns without a specific root operation kind 80680d7ac3bSRiver Riddle //===----------------------------------------------------------------------===// 80780d7ac3bSRiver Riddle 80880d7ac3bSRiver Riddle namespace { 80980d7ac3bSRiver Riddle /// This pattern matches and removes any operation in the test dialect. 81080d7ac3bSRiver Riddle struct RemoveTestDialectOps : public RewritePattern { 81176f3c2f3SRiver Riddle RemoveTestDialectOps(MLIRContext *context) 81276f3c2f3SRiver Riddle : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 81380d7ac3bSRiver Riddle 81480d7ac3bSRiver Riddle LogicalResult matchAndRewrite(Operation *op, 81580d7ac3bSRiver Riddle PatternRewriter &rewriter) const override { 81680d7ac3bSRiver Riddle if (!isa<TestDialect>(op->getDialect())) 81780d7ac3bSRiver Riddle return failure(); 81880d7ac3bSRiver Riddle rewriter.eraseOp(op); 81980d7ac3bSRiver Riddle return success(); 82080d7ac3bSRiver Riddle } 82180d7ac3bSRiver Riddle }; 82280d7ac3bSRiver Riddle 82380d7ac3bSRiver Riddle struct TestUnknownRootOpDriver 82480d7ac3bSRiver Riddle : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> { 825b5e22e6dSMehdi Amini StringRef getArgument() const final { 826b5e22e6dSMehdi Amini return "test-legalize-unknown-root-patterns"; 827b5e22e6dSMehdi Amini } 828b5e22e6dSMehdi Amini StringRef getDescription() const final { 829b5e22e6dSMehdi Amini return "Test public remapped value mechanism in ConversionPatternRewriter"; 830b5e22e6dSMehdi Amini } 83180d7ac3bSRiver Riddle void runOnFunction() override { 832dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext()); 83376f3c2f3SRiver Riddle patterns.add<RemoveTestDialectOps>(&getContext()); 83480d7ac3bSRiver Riddle 83580d7ac3bSRiver Riddle mlir::ConversionTarget target(getContext()); 83680d7ac3bSRiver Riddle target.addIllegalDialect<TestDialect>(); 8373fffffa8SRiver Riddle if (failed( 8383fffffa8SRiver Riddle applyPartialConversion(getFunction(), target, std::move(patterns)))) 83980d7ac3bSRiver Riddle signalPassFailure(); 84080d7ac3bSRiver Riddle } 84180d7ac3bSRiver Riddle }; 84280d7ac3bSRiver Riddle } // end anonymous namespace 84380d7ac3bSRiver Riddle 8444589dd92SRiver Riddle //===----------------------------------------------------------------------===// 8454589dd92SRiver Riddle // Test type conversions 8464589dd92SRiver Riddle //===----------------------------------------------------------------------===// 8474589dd92SRiver Riddle 8484589dd92SRiver Riddle namespace { 8494589dd92SRiver Riddle struct TestTypeConversionProducer 8504589dd92SRiver Riddle : public OpConversionPattern<TestTypeProducerOp> { 8514589dd92SRiver Riddle using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 8524589dd92SRiver Riddle LogicalResult 8534589dd92SRiver Riddle matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands, 8544589dd92SRiver Riddle ConversionPatternRewriter &rewriter) const final { 8554589dd92SRiver Riddle Type resultType = op.getType(); 8564589dd92SRiver Riddle if (resultType.isa<FloatType>()) 8574589dd92SRiver Riddle resultType = rewriter.getF64Type(); 8584589dd92SRiver Riddle else if (resultType.isInteger(16)) 8594589dd92SRiver Riddle resultType = rewriter.getIntegerType(64); 8604589dd92SRiver Riddle else 8614589dd92SRiver Riddle return failure(); 8624589dd92SRiver Riddle 8634589dd92SRiver Riddle rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 8644589dd92SRiver Riddle return success(); 8654589dd92SRiver Riddle } 8664589dd92SRiver Riddle }; 8674589dd92SRiver Riddle 8680409eb28SAlex Zinenko /// Call signature conversion and then fail the rewrite to trigger the undo 8690409eb28SAlex Zinenko /// mechanism. 8700409eb28SAlex Zinenko struct TestSignatureConversionUndo 8710409eb28SAlex Zinenko : public OpConversionPattern<TestSignatureConversionUndoOp> { 8720409eb28SAlex Zinenko using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 8730409eb28SAlex Zinenko 8740409eb28SAlex Zinenko LogicalResult 8750409eb28SAlex Zinenko matchAndRewrite(TestSignatureConversionUndoOp op, ArrayRef<Value> operands, 8760409eb28SAlex Zinenko ConversionPatternRewriter &rewriter) const final { 8770409eb28SAlex Zinenko (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 8780409eb28SAlex Zinenko return failure(); 8790409eb28SAlex Zinenko } 8800409eb28SAlex Zinenko }; 8810409eb28SAlex Zinenko 8820409eb28SAlex Zinenko /// Just forward the operands to the root op. This is essentially a no-op 8830409eb28SAlex Zinenko /// pattern that is used to trigger target materialization. 8840409eb28SAlex Zinenko struct TestTypeConsumerForward 8850409eb28SAlex Zinenko : public OpConversionPattern<TestTypeConsumerOp> { 8860409eb28SAlex Zinenko using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 8870409eb28SAlex Zinenko 8880409eb28SAlex Zinenko LogicalResult 8890409eb28SAlex Zinenko matchAndRewrite(TestTypeConsumerOp op, ArrayRef<Value> operands, 8900409eb28SAlex Zinenko ConversionPatternRewriter &rewriter) const final { 8910409eb28SAlex Zinenko rewriter.updateRootInPlace(op, [&] { op->setOperands(operands); }); 8920409eb28SAlex Zinenko return success(); 8930409eb28SAlex Zinenko } 8940409eb28SAlex Zinenko }; 8950409eb28SAlex Zinenko 8965b91060dSAlex Zinenko struct TestTypeConversionAnotherProducer 8975b91060dSAlex Zinenko : public OpRewritePattern<TestAnotherTypeProducerOp> { 8985b91060dSAlex Zinenko using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 8995b91060dSAlex Zinenko 9005b91060dSAlex Zinenko LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 9015b91060dSAlex Zinenko PatternRewriter &rewriter) const final { 9025b91060dSAlex Zinenko rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 9035b91060dSAlex Zinenko return success(); 9045b91060dSAlex Zinenko } 9055b91060dSAlex Zinenko }; 9065b91060dSAlex Zinenko 9074589dd92SRiver Riddle struct TestTypeConversionDriver 9084589dd92SRiver Riddle : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> { 909f9dc2b70SMehdi Amini void getDependentDialects(DialectRegistry ®istry) const override { 910f9dc2b70SMehdi Amini registry.insert<TestDialect>(); 911f9dc2b70SMehdi Amini } 912b5e22e6dSMehdi Amini StringRef getArgument() const final { 913b5e22e6dSMehdi Amini return "test-legalize-type-conversion"; 914b5e22e6dSMehdi Amini } 915b5e22e6dSMehdi Amini StringRef getDescription() const final { 916b5e22e6dSMehdi Amini return "Test various type conversion functionalities in DialectConversion"; 917b5e22e6dSMehdi Amini } 918f9dc2b70SMehdi Amini 9194589dd92SRiver Riddle void runOnOperation() override { 9204589dd92SRiver Riddle // Initialize the type converter. 9214589dd92SRiver Riddle TypeConverter converter; 9224589dd92SRiver Riddle 9234589dd92SRiver Riddle /// Add the legal set of type conversions. 9244589dd92SRiver Riddle converter.addConversion([](Type type) -> Type { 9254589dd92SRiver Riddle // Treat F64 as legal. 9264589dd92SRiver Riddle if (type.isF64()) 9274589dd92SRiver Riddle return type; 9284589dd92SRiver Riddle // Allow converting BF16/F16/F32 to F64. 9294589dd92SRiver Riddle if (type.isBF16() || type.isF16() || type.isF32()) 9304589dd92SRiver Riddle return FloatType::getF64(type.getContext()); 9314589dd92SRiver Riddle // Otherwise, the type is illegal. 9324589dd92SRiver Riddle return nullptr; 9334589dd92SRiver Riddle }); 9344589dd92SRiver Riddle converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 9354589dd92SRiver Riddle // Drop all integer types. 9364589dd92SRiver Riddle return success(); 9374589dd92SRiver Riddle }); 9384589dd92SRiver Riddle 9394589dd92SRiver Riddle /// Add the legal set of type materializations. 9404589dd92SRiver Riddle converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 9414589dd92SRiver Riddle ValueRange inputs, 9424589dd92SRiver Riddle Location loc) -> Value { 9434589dd92SRiver Riddle // Allow casting from F64 back to F32. 9444589dd92SRiver Riddle if (!resultType.isF16() && inputs.size() == 1 && 9454589dd92SRiver Riddle inputs[0].getType().isF64()) 9464589dd92SRiver Riddle return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 9474589dd92SRiver Riddle // Allow producing an i32 or i64 from nothing. 9484589dd92SRiver Riddle if ((resultType.isInteger(32) || resultType.isInteger(64)) && 9494589dd92SRiver Riddle inputs.empty()) 9504589dd92SRiver Riddle return builder.create<TestTypeProducerOp>(loc, resultType); 9514589dd92SRiver Riddle // Allow producing an i64 from an integer. 9524589dd92SRiver Riddle if (resultType.isa<IntegerType>() && inputs.size() == 1 && 9534589dd92SRiver Riddle inputs[0].getType().isa<IntegerType>()) 9544589dd92SRiver Riddle return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 9554589dd92SRiver Riddle // Otherwise, fail. 9564589dd92SRiver Riddle return nullptr; 9574589dd92SRiver Riddle }); 9584589dd92SRiver Riddle 9594589dd92SRiver Riddle // Initialize the conversion target. 9604589dd92SRiver Riddle mlir::ConversionTarget target(getContext()); 9614589dd92SRiver Riddle target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 9624589dd92SRiver Riddle return op.getType().isF64() || op.getType().isInteger(64); 9634589dd92SRiver Riddle }); 9644589dd92SRiver Riddle target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 9654589dd92SRiver Riddle return converter.isSignatureLegal(op.getType()) && 9664589dd92SRiver Riddle converter.isLegal(&op.getBody()); 9674589dd92SRiver Riddle }); 9684589dd92SRiver Riddle target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 9694589dd92SRiver Riddle // Allow casts from F64 to F32. 9704589dd92SRiver Riddle return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 9714589dd92SRiver Riddle }); 9724589dd92SRiver Riddle 9734589dd92SRiver Riddle // Initialize the set of rewrite patterns. 974dc4e913bSChris Lattner RewritePatternSet patterns(&getContext()); 975dc4e913bSChris Lattner patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, 9760409eb28SAlex Zinenko TestSignatureConversionUndo>(converter, &getContext()); 977dc4e913bSChris Lattner patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 9783a506b31SChris Lattner mlir::populateFuncOpTypeConversionPattern(patterns, converter); 9794589dd92SRiver Riddle 9803fffffa8SRiver Riddle if (failed(applyPartialConversion(getOperation(), target, 9813fffffa8SRiver Riddle std::move(patterns)))) 9824589dd92SRiver Riddle signalPassFailure(); 9834589dd92SRiver Riddle } 9844589dd92SRiver Riddle }; 9854589dd92SRiver Riddle } // end anonymous namespace 9864589dd92SRiver Riddle 987c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===// 988c8fb6ee3SRiver Riddle // Test Block Merging 989c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===// 990c8fb6ee3SRiver Riddle 991e888886cSMaheshRavishankar namespace { 992e888886cSMaheshRavishankar /// A rewriter pattern that tests that blocks can be merged. 993e888886cSMaheshRavishankar struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 994e888886cSMaheshRavishankar using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 995e888886cSMaheshRavishankar 996e888886cSMaheshRavishankar LogicalResult 997e888886cSMaheshRavishankar matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands, 998e888886cSMaheshRavishankar ConversionPatternRewriter &rewriter) const final { 999e888886cSMaheshRavishankar Block &firstBlock = op.body().front(); 1000e888886cSMaheshRavishankar Operation *branchOp = firstBlock.getTerminator(); 1001e888886cSMaheshRavishankar Block *secondBlock = &*(std::next(op.body().begin())); 1002e888886cSMaheshRavishankar auto succOperands = branchOp->getOperands(); 1003e888886cSMaheshRavishankar SmallVector<Value, 2> replacements(succOperands); 1004e888886cSMaheshRavishankar rewriter.eraseOp(branchOp); 1005e888886cSMaheshRavishankar rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1006e888886cSMaheshRavishankar rewriter.updateRootInPlace(op, [] {}); 1007e888886cSMaheshRavishankar return success(); 1008e888886cSMaheshRavishankar } 1009e888886cSMaheshRavishankar }; 1010e888886cSMaheshRavishankar 1011e888886cSMaheshRavishankar /// A rewrite pattern to tests the undo mechanism of blocks being merged. 1012e888886cSMaheshRavishankar struct TestUndoBlocksMerge : public ConversionPattern { 1013e888886cSMaheshRavishankar TestUndoBlocksMerge(MLIRContext *ctx) 1014e888886cSMaheshRavishankar : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 1015e888886cSMaheshRavishankar LogicalResult 1016e888886cSMaheshRavishankar matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1017e888886cSMaheshRavishankar ConversionPatternRewriter &rewriter) const final { 1018e888886cSMaheshRavishankar Block &firstBlock = op->getRegion(0).front(); 1019e888886cSMaheshRavishankar Operation *branchOp = firstBlock.getTerminator(); 1020e888886cSMaheshRavishankar Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 1021e888886cSMaheshRavishankar rewriter.setInsertionPointToStart(secondBlock); 1022e888886cSMaheshRavishankar rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 1023e888886cSMaheshRavishankar auto succOperands = branchOp->getOperands(); 1024e888886cSMaheshRavishankar SmallVector<Value, 2> replacements(succOperands); 1025e888886cSMaheshRavishankar rewriter.eraseOp(branchOp); 1026e888886cSMaheshRavishankar rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 1027e888886cSMaheshRavishankar rewriter.updateRootInPlace(op, [] {}); 1028e888886cSMaheshRavishankar return success(); 1029e888886cSMaheshRavishankar } 1030e888886cSMaheshRavishankar }; 1031e888886cSMaheshRavishankar 1032e888886cSMaheshRavishankar /// A rewrite mechanism to inline the body of the op into its parent, when both 1033e888886cSMaheshRavishankar /// ops can have a single block. 1034e888886cSMaheshRavishankar struct TestMergeSingleBlockOps 1035e888886cSMaheshRavishankar : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 1036e888886cSMaheshRavishankar using OpConversionPattern< 1037e888886cSMaheshRavishankar SingleBlockImplicitTerminatorOp>::OpConversionPattern; 1038e888886cSMaheshRavishankar 1039e888886cSMaheshRavishankar LogicalResult 1040e888886cSMaheshRavishankar matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands, 1041e888886cSMaheshRavishankar ConversionPatternRewriter &rewriter) const final { 1042e888886cSMaheshRavishankar SingleBlockImplicitTerminatorOp parentOp = 10430bf4a82aSChristian Sigg op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1044e888886cSMaheshRavishankar if (!parentOp) 1045e888886cSMaheshRavishankar return failure(); 1046e888886cSMaheshRavishankar Block &innerBlock = op.region().front(); 1047e888886cSMaheshRavishankar TerminatorOp innerTerminator = 1048e888886cSMaheshRavishankar cast<TerminatorOp>(innerBlock.getTerminator()); 10499c7b0c4aSRahul Joshi rewriter.mergeBlockBefore(&innerBlock, op); 1050e888886cSMaheshRavishankar rewriter.eraseOp(innerTerminator); 1051e888886cSMaheshRavishankar rewriter.eraseOp(op); 1052e888886cSMaheshRavishankar rewriter.updateRootInPlace(op, [] {}); 1053e888886cSMaheshRavishankar return success(); 1054e888886cSMaheshRavishankar } 1055e888886cSMaheshRavishankar }; 1056e888886cSMaheshRavishankar 1057e888886cSMaheshRavishankar struct TestMergeBlocksPatternDriver 1058e888886cSMaheshRavishankar : public PassWrapper<TestMergeBlocksPatternDriver, 1059e888886cSMaheshRavishankar OperationPass<ModuleOp>> { 1060b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-merge-blocks"; } 1061b5e22e6dSMehdi Amini StringRef getDescription() const final { 1062b5e22e6dSMehdi Amini return "Test Merging operation in ConversionPatternRewriter"; 1063b5e22e6dSMehdi Amini } 1064e888886cSMaheshRavishankar void runOnOperation() override { 1065e888886cSMaheshRavishankar MLIRContext *context = &getContext(); 1066dc4e913bSChris Lattner mlir::RewritePatternSet patterns(context); 1067dc4e913bSChris Lattner patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 1068e888886cSMaheshRavishankar context); 1069e888886cSMaheshRavishankar ConversionTarget target(*context); 1070973ddb7dSMehdi Amini target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 1071973ddb7dSMehdi Amini TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 1072e888886cSMaheshRavishankar target.addIllegalOp<ILLegalOpF>(); 1073e888886cSMaheshRavishankar 1074e888886cSMaheshRavishankar /// Expect the op to have a single block after legalization. 1075e888886cSMaheshRavishankar target.addDynamicallyLegalOp<TestMergeBlocksOp>( 1076e888886cSMaheshRavishankar [&](TestMergeBlocksOp op) -> bool { 1077e888886cSMaheshRavishankar return llvm::hasSingleElement(op.body()); 1078e888886cSMaheshRavishankar }); 1079e888886cSMaheshRavishankar 1080e888886cSMaheshRavishankar /// Only allow `test.br` within test.merge_blocks op. 1081e888886cSMaheshRavishankar target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 10820bf4a82aSChristian Sigg return op->getParentOfType<TestMergeBlocksOp>(); 1083e888886cSMaheshRavishankar }); 1084e888886cSMaheshRavishankar 1085e888886cSMaheshRavishankar /// Expect that all nested test.SingleBlockImplicitTerminator ops are 1086e888886cSMaheshRavishankar /// inlined. 1087e888886cSMaheshRavishankar target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 1088e888886cSMaheshRavishankar [&](SingleBlockImplicitTerminatorOp op) -> bool { 10890bf4a82aSChristian Sigg return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 1090e888886cSMaheshRavishankar }); 1091e888886cSMaheshRavishankar 1092e888886cSMaheshRavishankar DenseSet<Operation *> unlegalizedOps; 10933fffffa8SRiver Riddle (void)applyPartialConversion(getOperation(), target, std::move(patterns), 1094e888886cSMaheshRavishankar &unlegalizedOps); 1095e888886cSMaheshRavishankar for (auto *op : unlegalizedOps) 1096e888886cSMaheshRavishankar op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1097e888886cSMaheshRavishankar } 1098e888886cSMaheshRavishankar }; 1099e888886cSMaheshRavishankar } // namespace 1100e888886cSMaheshRavishankar 11014589dd92SRiver Riddle //===----------------------------------------------------------------------===// 1102c8fb6ee3SRiver Riddle // Test Selective Replacement 1103c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===// 1104c8fb6ee3SRiver Riddle 1105c8fb6ee3SRiver Riddle namespace { 1106c8fb6ee3SRiver Riddle /// A rewrite mechanism to inline the body of the op into its parent, when both 1107c8fb6ee3SRiver Riddle /// ops can have a single block. 1108c8fb6ee3SRiver Riddle struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 1109c8fb6ee3SRiver Riddle using OpRewritePattern<TestCastOp>::OpRewritePattern; 1110c8fb6ee3SRiver Riddle 1111c8fb6ee3SRiver Riddle LogicalResult matchAndRewrite(TestCastOp op, 1112c8fb6ee3SRiver Riddle PatternRewriter &rewriter) const final { 1113c8fb6ee3SRiver Riddle if (op.getNumOperands() != 2) 1114c8fb6ee3SRiver Riddle return failure(); 1115c8fb6ee3SRiver Riddle OperandRange operands = op.getOperands(); 1116c8fb6ee3SRiver Riddle 1117c8fb6ee3SRiver Riddle // Replace non-terminator uses with the first operand. 1118c8fb6ee3SRiver Riddle rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) { 1119fe7c0d90SRiver Riddle return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 1120c8fb6ee3SRiver Riddle }); 1121c8fb6ee3SRiver Riddle // Replace everything else with the second operand if the operation isn't 1122c8fb6ee3SRiver Riddle // dead. 1123c8fb6ee3SRiver Riddle rewriter.replaceOp(op, op.getOperand(1)); 1124c8fb6ee3SRiver Riddle return success(); 1125c8fb6ee3SRiver Riddle } 1126c8fb6ee3SRiver Riddle }; 1127c8fb6ee3SRiver Riddle 1128c8fb6ee3SRiver Riddle struct TestSelectiveReplacementPatternDriver 1129c8fb6ee3SRiver Riddle : public PassWrapper<TestSelectiveReplacementPatternDriver, 1130c8fb6ee3SRiver Riddle OperationPass<>> { 1131b5e22e6dSMehdi Amini StringRef getArgument() const final { 1132b5e22e6dSMehdi Amini return "test-pattern-selective-replacement"; 1133b5e22e6dSMehdi Amini } 1134b5e22e6dSMehdi Amini StringRef getDescription() const final { 1135b5e22e6dSMehdi Amini return "Test selective replacement in the PatternRewriter"; 1136b5e22e6dSMehdi Amini } 1137c8fb6ee3SRiver Riddle void runOnOperation() override { 1138c8fb6ee3SRiver Riddle MLIRContext *context = &getContext(); 1139dc4e913bSChris Lattner mlir::RewritePatternSet patterns(context); 1140dc4e913bSChris Lattner patterns.add<TestSelectiveOpReplacementPattern>(context); 1141e21adfa3SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 1142c8fb6ee3SRiver Riddle std::move(patterns)); 1143c8fb6ee3SRiver Riddle } 1144c8fb6ee3SRiver Riddle }; 1145c8fb6ee3SRiver Riddle } // namespace 1146c8fb6ee3SRiver Riddle 1147c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===// 11484589dd92SRiver Riddle // PassRegistration 11494589dd92SRiver Riddle //===----------------------------------------------------------------------===// 11504589dd92SRiver Riddle 1151fec6c5acSUday Bondhugula namespace mlir { 115272c65b69SAlexander Belyaev namespace test { 1153fec6c5acSUday Bondhugula void registerPatternsTestPass() { 1154b5e22e6dSMehdi Amini PassRegistration<TestReturnTypeDriver>(); 1155fec6c5acSUday Bondhugula 1156b5e22e6dSMehdi Amini PassRegistration<TestDerivedAttributeDriver>(); 11579ba37b3bSJacques Pienaar 1158b5e22e6dSMehdi Amini PassRegistration<TestPatternDriver>(); 1159fec6c5acSUday Bondhugula 1160b5e22e6dSMehdi Amini PassRegistration<TestLegalizePatternDriver>([] { 1161b5e22e6dSMehdi Amini return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 1162fec6c5acSUday Bondhugula }); 1163fec6c5acSUday Bondhugula 1164b5e22e6dSMehdi Amini PassRegistration<TestRemappedValue>(); 116580d7ac3bSRiver Riddle 1166b5e22e6dSMehdi Amini PassRegistration<TestUnknownRootOpDriver>(); 11674589dd92SRiver Riddle 1168b5e22e6dSMehdi Amini PassRegistration<TestTypeConversionDriver>(); 1169e888886cSMaheshRavishankar 1170b5e22e6dSMehdi Amini PassRegistration<TestMergeBlocksPatternDriver>(); 1171b5e22e6dSMehdi Amini PassRegistration<TestSelectiveReplacementPatternDriver>(); 1172fec6c5acSUday Bondhugula } 117372c65b69SAlexander Belyaev } // namespace test 1174fec6c5acSUday Bondhugula } // namespace mlir 1175