1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to test SCF dialect utils. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/SCF/Transforms.h" 17 #include "mlir/Dialect/SCF/Utils/Utils.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 23 #include "llvm/ADT/SetVector.h" 24 25 using namespace mlir; 26 27 namespace { 28 struct TestSCFForUtilsPass 29 : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> { 30 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass) 31 32 StringRef getArgument() const final { return "test-scf-for-utils"; } 33 StringRef getDescription() const final { return "test scf.for utils"; } 34 explicit TestSCFForUtilsPass() = default; 35 TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {} 36 37 Option<bool> testReplaceWithNewYields{ 38 *this, "test-replace-with-new-yields", 39 llvm::cl::desc("Test replacing a loop with a new loop that returns new " 40 "additional yeild values"), 41 llvm::cl::init(false)}; 42 43 void runOnOperation() override { 44 func::FuncOp func = getOperation(); 45 SmallVector<scf::ForOp, 4> toErase; 46 47 if (testReplaceWithNewYields) { 48 func.walk([&](scf::ForOp forOp) { 49 if (forOp.getNumResults() == 0) 50 return; 51 auto newInitValues = forOp.getInitArgs(); 52 if (newInitValues.empty()) 53 return; 54 NewYieldValueFn fn = [&](OpBuilder &b, Location loc, 55 ArrayRef<BlockArgument> newBBArgs) { 56 Block *block = newBBArgs.front().getOwner(); 57 SmallVector<Value> newYieldValues; 58 for (auto yieldVal : 59 cast<scf::YieldOp>(block->getTerminator()).getResults()) { 60 newYieldValues.push_back( 61 b.create<arith::AddFOp>(loc, yieldVal, yieldVal)); 62 } 63 return newYieldValues; 64 }; 65 OpBuilder b(forOp); 66 replaceLoopWithNewYields(b, forOp, newInitValues, fn); 67 }); 68 } 69 } 70 }; 71 72 struct TestSCFIfUtilsPass 73 : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> { 74 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass) 75 76 StringRef getArgument() const final { return "test-scf-if-utils"; } 77 StringRef getDescription() const final { return "test scf.if utils"; } 78 explicit TestSCFIfUtilsPass() = default; 79 80 void runOnOperation() override { 81 int count = 0; 82 getOperation().walk([&](scf::IfOp ifOp) { 83 auto strCount = std::to_string(count++); 84 func::FuncOp thenFn, elseFn; 85 OpBuilder b(ifOp); 86 IRRewriter rewriter(b); 87 if (failed(outlineIfOp(rewriter, ifOp, &thenFn, 88 std::string("outlined_then") + strCount, &elseFn, 89 std::string("outlined_else") + strCount))) { 90 this->signalPassFailure(); 91 return WalkResult::interrupt(); 92 } 93 return WalkResult::advance(); 94 }); 95 } 96 }; 97 98 static const StringLiteral kTestPipeliningLoopMarker = 99 "__test_pipelining_loop__"; 100 static const StringLiteral kTestPipeliningStageMarker = 101 "__test_pipelining_stage__"; 102 /// Marker to express the order in which operations should be after 103 /// pipelining. 104 static const StringLiteral kTestPipeliningOpOrderMarker = 105 "__test_pipelining_op_order__"; 106 107 static const StringLiteral kTestPipeliningAnnotationPart = 108 "__test_pipelining_part"; 109 static const StringLiteral kTestPipeliningAnnotationIteration = 110 "__test_pipelining_iteration"; 111 112 struct TestSCFPipeliningPass 113 : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> { 114 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass) 115 116 TestSCFPipeliningPass() = default; 117 TestSCFPipeliningPass(const TestSCFPipeliningPass &) {} 118 StringRef getArgument() const final { return "test-scf-pipelining"; } 119 StringRef getDescription() const final { return "test scf.forOp pipelining"; } 120 121 Option<bool> annotatePipeline{ 122 *this, "annotate", 123 llvm::cl::desc("Annote operations during loop pipelining transformation"), 124 llvm::cl::init(false)}; 125 126 static void 127 getSchedule(scf::ForOp forOp, 128 std::vector<std::pair<Operation *, unsigned>> &schedule) { 129 if (!forOp->hasAttr(kTestPipeliningLoopMarker)) 130 return; 131 schedule.resize(forOp.getBody()->getOperations().size() - 1); 132 forOp.walk([&schedule](Operation *op) { 133 auto attrStage = 134 op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker); 135 auto attrCycle = 136 op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker); 137 if (attrCycle && attrStage) { 138 schedule[attrCycle.getInt()] = 139 std::make_pair(op, unsigned(attrStage.getInt())); 140 } 141 }); 142 } 143 144 static void annotate(Operation *op, 145 mlir::scf::PipeliningOption::PipelinerPart part, 146 unsigned iteration) { 147 OpBuilder b(op); 148 switch (part) { 149 case mlir::scf::PipeliningOption::PipelinerPart::Prologue: 150 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue")); 151 break; 152 case mlir::scf::PipeliningOption::PipelinerPart::Kernel: 153 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel")); 154 break; 155 case mlir::scf::PipeliningOption::PipelinerPart::Epilogue: 156 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue")); 157 break; 158 } 159 op->setAttr(kTestPipeliningAnnotationIteration, 160 b.getI32IntegerAttr(iteration)); 161 } 162 163 void getDependentDialects(DialectRegistry ®istry) const override { 164 registry.insert<arith::ArithmeticDialect>(); 165 } 166 167 void runOnOperation() override { 168 RewritePatternSet patterns(&getContext()); 169 mlir::scf::PipeliningOption options; 170 options.getScheduleFn = getSchedule; 171 if (annotatePipeline) 172 options.annotateFn = annotate; 173 scf::populateSCFLoopPipeliningPatterns(patterns, options); 174 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 175 getOperation().walk([](Operation *op) { 176 // Clean up the markers. 177 op->removeAttr(kTestPipeliningStageMarker); 178 op->removeAttr(kTestPipeliningOpOrderMarker); 179 }); 180 } 181 }; 182 } // namespace 183 184 namespace mlir { 185 namespace test { 186 void registerTestSCFUtilsPass() { 187 PassRegistration<TestSCFForUtilsPass>(); 188 PassRegistration<TestSCFIfUtilsPass>(); 189 PassRegistration<TestSCFPipeliningPass>(); 190 } 191 } // namespace test 192 } // namespace mlir 193