1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to test SCF dialect utils. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/SCF/Transforms.h" 17 #include "mlir/Dialect/SCF/Utils/Utils.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 23 #include "llvm/ADT/SetVector.h" 24 25 using namespace mlir; 26 27 namespace { 28 struct TestSCFForUtilsPass 29 : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> { 30 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass) 31 32 StringRef getArgument() const final { return "test-scf-for-utils"; } 33 StringRef getDescription() const final { return "test scf.for utils"; } 34 explicit TestSCFForUtilsPass() = default; 35 TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {} 36 37 Option<bool> testReplaceWithNewYields{ 38 *this, "test-replace-with-new-yields", 39 llvm::cl::desc("Test replacing a loop with a new loop that returns new " 40 "additional yeild values"), 41 llvm::cl::init(false)}; 42 43 void runOnOperation() override { 44 func::FuncOp func = getOperation(); 45 SmallVector<scf::ForOp, 4> toErase; 46 47 if (testReplaceWithNewYields) { 48 func.walk([&](scf::ForOp forOp) { 49 if (forOp.getNumResults() == 0) 50 return; 51 auto newInitValues = forOp.getInitArgs(); 52 if (newInitValues.empty()) 53 return; 54 NewYieldValueFn fn = [&](OpBuilder &b, Location loc, 55 ArrayRef<BlockArgument> newBBArgs) { 56 Block *block = newBBArgs.front().getOwner(); 57 SmallVector<Value> newYieldValues; 58 for (auto yieldVal : 59 cast<scf::YieldOp>(block->getTerminator()).getResults()) { 60 newYieldValues.push_back( 61 b.create<arith::AddFOp>(loc, yieldVal, yieldVal)); 62 } 63 return newYieldValues; 64 }; 65 OpBuilder b(forOp); 66 replaceLoopWithNewYields(b, forOp, newInitValues, fn); 67 }); 68 } 69 } 70 }; 71 72 struct TestSCFIfUtilsPass 73 : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> { 74 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass) 75 76 StringRef getArgument() const final { return "test-scf-if-utils"; } 77 StringRef getDescription() const final { return "test scf.if utils"; } 78 explicit TestSCFIfUtilsPass() = default; 79 80 void runOnOperation() override { 81 int count = 0; 82 getOperation().walk([&](scf::IfOp ifOp) { 83 auto strCount = std::to_string(count++); 84 func::FuncOp thenFn, elseFn; 85 OpBuilder b(ifOp); 86 IRRewriter rewriter(b); 87 if (failed(outlineIfOp(rewriter, ifOp, &thenFn, 88 std::string("outlined_then") + strCount, &elseFn, 89 std::string("outlined_else") + strCount))) { 90 this->signalPassFailure(); 91 return WalkResult::interrupt(); 92 } 93 return WalkResult::advance(); 94 }); 95 } 96 }; 97 98 static const StringLiteral kTestPipeliningLoopMarker = 99 "__test_pipelining_loop__"; 100 static const StringLiteral kTestPipeliningStageMarker = 101 "__test_pipelining_stage__"; 102 /// Marker to express the order in which operations should be after 103 /// pipelining. 104 static const StringLiteral kTestPipeliningOpOrderMarker = 105 "__test_pipelining_op_order__"; 106 107 static const StringLiteral kTestPipeliningAnnotationPart = 108 "__test_pipelining_part"; 109 static const StringLiteral kTestPipeliningAnnotationIteration = 110 "__test_pipelining_iteration"; 111 112 struct TestSCFPipeliningPass 113 : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> { 114 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass) 115 116 TestSCFPipeliningPass() = default; 117 TestSCFPipeliningPass(const TestSCFPipeliningPass &) {} 118 StringRef getArgument() const final { return "test-scf-pipelining"; } 119 StringRef getDescription() const final { return "test scf.forOp pipelining"; } 120 121 Option<bool> annotatePipeline{ 122 *this, "annotate", 123 llvm::cl::desc("Annote operations during loop pipelining transformation"), 124 llvm::cl::init(false)}; 125 126 Option<bool> noEpiloguePeeling{ 127 *this, "no-epilogue-peeling", 128 llvm::cl::desc("Use predicates instead of peeling the epilogue."), 129 llvm::cl::init(false)}; 130 131 static void 132 getSchedule(scf::ForOp forOp, 133 std::vector<std::pair<Operation *, unsigned>> &schedule) { 134 if (!forOp->hasAttr(kTestPipeliningLoopMarker)) 135 return; 136 schedule.resize(forOp.getBody()->getOperations().size() - 1); 137 forOp.walk([&schedule](Operation *op) { 138 auto attrStage = 139 op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker); 140 auto attrCycle = 141 op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker); 142 if (attrCycle && attrStage) { 143 schedule[attrCycle.getInt()] = 144 std::make_pair(op, unsigned(attrStage.getInt())); 145 } 146 }); 147 } 148 149 /// Helper to generate "predicated" version of `op`. For simplicity we just 150 /// wrap the operation in a scf.ifOp operation. 151 static Operation *predicateOp(Operation *op, Value pred, 152 PatternRewriter &rewriter) { 153 Location loc = op->getLoc(); 154 auto ifOp = 155 rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true); 156 // True branch. 157 op->moveBefore(&ifOp.getThenRegion().front(), 158 ifOp.getThenRegion().front().end()); 159 rewriter.setInsertionPointAfter(op); 160 rewriter.create<scf::YieldOp>(loc, op->getResults()); 161 // False branch. 162 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 163 SmallVector<Value> zeros; 164 for (Type type : op->getResultTypes()) { 165 zeros.push_back( 166 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(type))); 167 } 168 rewriter.create<scf::YieldOp>(loc, zeros); 169 return ifOp.getOperation(); 170 } 171 172 static void annotate(Operation *op, 173 mlir::scf::PipeliningOption::PipelinerPart part, 174 unsigned iteration) { 175 OpBuilder b(op); 176 switch (part) { 177 case mlir::scf::PipeliningOption::PipelinerPart::Prologue: 178 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue")); 179 break; 180 case mlir::scf::PipeliningOption::PipelinerPart::Kernel: 181 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel")); 182 break; 183 case mlir::scf::PipeliningOption::PipelinerPart::Epilogue: 184 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue")); 185 break; 186 } 187 op->setAttr(kTestPipeliningAnnotationIteration, 188 b.getI32IntegerAttr(iteration)); 189 } 190 191 void getDependentDialects(DialectRegistry ®istry) const override { 192 registry.insert<arith::ArithmeticDialect>(); 193 } 194 195 void runOnOperation() override { 196 RewritePatternSet patterns(&getContext()); 197 mlir::scf::PipeliningOption options; 198 options.getScheduleFn = getSchedule; 199 if (annotatePipeline) 200 options.annotateFn = annotate; 201 if (noEpiloguePeeling) { 202 options.peelEpilogue = false; 203 options.predicateFn = predicateOp; 204 } 205 scf::populateSCFLoopPipeliningPatterns(patterns, options); 206 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 207 getOperation().walk([](Operation *op) { 208 // Clean up the markers. 209 op->removeAttr(kTestPipeliningStageMarker); 210 op->removeAttr(kTestPipeliningOpOrderMarker); 211 }); 212 } 213 }; 214 } // namespace 215 216 namespace mlir { 217 namespace test { 218 void registerTestSCFUtilsPass() { 219 PassRegistration<TestSCFForUtilsPass>(); 220 PassRegistration<TestSCFIfUtilsPass>(); 221 PassRegistration<TestSCFPipeliningPass>(); 222 } 223 } // namespace test 224 } // namespace mlir 225