1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to test SCF dialect utils. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/Dialect/SCF/Transforms.h" 16 #include "mlir/Dialect/SCF/Utils/Utils.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 #include "llvm/ADT/SetVector.h" 23 24 using namespace mlir; 25 26 namespace { 27 class TestSCFForUtilsPass 28 : public PassWrapper<TestSCFForUtilsPass, OperationPass<FuncOp>> { 29 public: 30 StringRef getArgument() const final { return "test-scf-for-utils"; } 31 StringRef getDescription() const final { return "test scf.for utils"; } 32 explicit TestSCFForUtilsPass() = default; 33 34 void runOnOperation() override { 35 FuncOp func = getOperation(); 36 SmallVector<scf::ForOp, 4> toErase; 37 38 func.walk([&](Operation *fakeRead) { 39 if (fakeRead->getName().getStringRef() != "fake_read") 40 return; 41 auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner(); 42 auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner(); 43 auto loop = fakeRead->getParentOfType<scf::ForOp>(); 44 45 OpBuilder b(loop); 46 (void)loop.moveOutOfLoop({fakeRead}); 47 fakeWrite->moveAfter(loop); 48 auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), 49 fakeCompute->getResult(0)); 50 fakeCompute->getResult(0).replaceAllUsesWith( 51 newLoop.getResults().take_back()[0]); 52 toErase.push_back(loop); 53 }); 54 for (auto loop : llvm::reverse(toErase)) 55 loop.erase(); 56 } 57 }; 58 59 class TestSCFIfUtilsPass 60 : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> { 61 public: 62 StringRef getArgument() const final { return "test-scf-if-utils"; } 63 StringRef getDescription() const final { return "test scf.if utils"; } 64 explicit TestSCFIfUtilsPass() = default; 65 66 void runOnOperation() override { 67 int count = 0; 68 getOperation().walk([&](scf::IfOp ifOp) { 69 auto strCount = std::to_string(count++); 70 FuncOp thenFn, elseFn; 71 OpBuilder b(ifOp); 72 IRRewriter rewriter(b); 73 if (failed(outlineIfOp(rewriter, ifOp, &thenFn, 74 std::string("outlined_then") + strCount, &elseFn, 75 std::string("outlined_else") + strCount))) { 76 this->signalPassFailure(); 77 return WalkResult::interrupt(); 78 } 79 return WalkResult::advance(); 80 }); 81 } 82 }; 83 84 static const StringLiteral kTestPipeliningLoopMarker = 85 "__test_pipelining_loop__"; 86 static const StringLiteral kTestPipeliningStageMarker = 87 "__test_pipelining_stage__"; 88 /// Marker to express the order in which operations should be after pipelining. 89 static const StringLiteral kTestPipeliningOpOrderMarker = 90 "__test_pipelining_op_order__"; 91 92 static const StringLiteral kTestPipeliningAnnotationPart = 93 "__test_pipelining_part"; 94 static const StringLiteral kTestPipeliningAnnotationIteration = 95 "__test_pipelining_iteration"; 96 97 class TestSCFPipeliningPass 98 : public PassWrapper<TestSCFPipeliningPass, OperationPass<FuncOp>> { 99 public: 100 TestSCFPipeliningPass() = default; 101 TestSCFPipeliningPass(const TestSCFPipeliningPass &) {} 102 StringRef getArgument() const final { return "test-scf-pipelining"; } 103 StringRef getDescription() const final { return "test scf.forOp pipelining"; } 104 105 Option<bool> annotatePipeline{ 106 *this, "annotate", 107 llvm::cl::desc("Annote operations during loop pipelining transformation"), 108 llvm::cl::init(false)}; 109 110 static void 111 getSchedule(scf::ForOp forOp, 112 std::vector<std::pair<Operation *, unsigned>> &schedule) { 113 if (!forOp->hasAttr(kTestPipeliningLoopMarker)) 114 return; 115 schedule.resize(forOp.getBody()->getOperations().size() - 1); 116 forOp.walk([&schedule](Operation *op) { 117 auto attrStage = 118 op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker); 119 auto attrCycle = 120 op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker); 121 if (attrCycle && attrStage) { 122 schedule[attrCycle.getInt()] = 123 std::make_pair(op, unsigned(attrStage.getInt())); 124 } 125 }); 126 } 127 128 static void annotate(Operation *op, 129 mlir::scf::PipeliningOption::PipelinerPart part, 130 unsigned iteration) { 131 OpBuilder b(op); 132 switch (part) { 133 case mlir::scf::PipeliningOption::PipelinerPart::Prologue: 134 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue")); 135 break; 136 case mlir::scf::PipeliningOption::PipelinerPart::Kernel: 137 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel")); 138 break; 139 case mlir::scf::PipeliningOption::PipelinerPart::Epilogue: 140 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue")); 141 break; 142 } 143 op->setAttr(kTestPipeliningAnnotationIteration, 144 b.getI32IntegerAttr(iteration)); 145 } 146 147 void getDependentDialects(DialectRegistry ®istry) const override { 148 registry.insert<arith::ArithmeticDialect>(); 149 } 150 151 void runOnOperation() override { 152 RewritePatternSet patterns(&getContext()); 153 mlir::scf::PipeliningOption options; 154 options.getScheduleFn = getSchedule; 155 if (annotatePipeline) 156 options.annotateFn = annotate; 157 scf::populateSCFLoopPipeliningPatterns(patterns, options); 158 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 159 getOperation().walk([](Operation *op) { 160 // Clean up the markers. 161 op->removeAttr(kTestPipeliningStageMarker); 162 op->removeAttr(kTestPipeliningOpOrderMarker); 163 }); 164 } 165 }; 166 } // namespace 167 168 namespace mlir { 169 namespace test { 170 void registerTestSCFUtilsPass() { 171 PassRegistration<TestSCFForUtilsPass>(); 172 PassRegistration<TestSCFIfUtilsPass>(); 173 PassRegistration<TestSCFPipeliningPass>(); 174 } 175 } // namespace test 176 } // namespace mlir 177