1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
17 #include "mlir/Dialect/SCF/Utils/Utils.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23 #include "llvm/ADT/SetVector.h"
24
25 using namespace mlir;
26
27 namespace {
28 struct TestSCFForUtilsPass
29 : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon3ae6e7e50111::TestSCFForUtilsPass30 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass)
31
32 StringRef getArgument() const final { return "test-scf-for-utils"; }
getDescription__anon3ae6e7e50111::TestSCFForUtilsPass33 StringRef getDescription() const final { return "test scf.for utils"; }
34 explicit TestSCFForUtilsPass() = default;
TestSCFForUtilsPass__anon3ae6e7e50111::TestSCFForUtilsPass35 TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
36
37 Option<bool> testReplaceWithNewYields{
38 *this, "test-replace-with-new-yields",
39 llvm::cl::desc("Test replacing a loop with a new loop that returns new "
40 "additional yeild values"),
41 llvm::cl::init(false)};
42
runOnOperation__anon3ae6e7e50111::TestSCFForUtilsPass43 void runOnOperation() override {
44 func::FuncOp func = getOperation();
45 SmallVector<scf::ForOp, 4> toErase;
46
47 if (testReplaceWithNewYields) {
48 func.walk([&](scf::ForOp forOp) {
49 if (forOp.getNumResults() == 0)
50 return;
51 auto newInitValues = forOp.getInitArgs();
52 if (newInitValues.empty())
53 return;
54 NewYieldValueFn fn = [&](OpBuilder &b, Location loc,
55 ArrayRef<BlockArgument> newBBArgs) {
56 Block *block = newBBArgs.front().getOwner();
57 SmallVector<Value> newYieldValues;
58 for (auto yieldVal :
59 cast<scf::YieldOp>(block->getTerminator()).getResults()) {
60 newYieldValues.push_back(
61 b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
62 }
63 return newYieldValues;
64 };
65 OpBuilder b(forOp);
66 replaceLoopWithNewYields(b, forOp, newInitValues, fn);
67 });
68 }
69 }
70 };
71
72 struct TestSCFIfUtilsPass
73 : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon3ae6e7e50111::TestSCFIfUtilsPass74 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass)
75
76 StringRef getArgument() const final { return "test-scf-if-utils"; }
getDescription__anon3ae6e7e50111::TestSCFIfUtilsPass77 StringRef getDescription() const final { return "test scf.if utils"; }
78 explicit TestSCFIfUtilsPass() = default;
79
runOnOperation__anon3ae6e7e50111::TestSCFIfUtilsPass80 void runOnOperation() override {
81 int count = 0;
82 getOperation().walk([&](scf::IfOp ifOp) {
83 auto strCount = std::to_string(count++);
84 func::FuncOp thenFn, elseFn;
85 OpBuilder b(ifOp);
86 IRRewriter rewriter(b);
87 if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
88 std::string("outlined_then") + strCount, &elseFn,
89 std::string("outlined_else") + strCount))) {
90 this->signalPassFailure();
91 return WalkResult::interrupt();
92 }
93 return WalkResult::advance();
94 });
95 }
96 };
97
98 static const StringLiteral kTestPipeliningLoopMarker =
99 "__test_pipelining_loop__";
100 static const StringLiteral kTestPipeliningStageMarker =
101 "__test_pipelining_stage__";
102 /// Marker to express the order in which operations should be after
103 /// pipelining.
104 static const StringLiteral kTestPipeliningOpOrderMarker =
105 "__test_pipelining_op_order__";
106
107 static const StringLiteral kTestPipeliningAnnotationPart =
108 "__test_pipelining_part";
109 static const StringLiteral kTestPipeliningAnnotationIteration =
110 "__test_pipelining_iteration";
111
112 struct TestSCFPipeliningPass
113 : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> {
114 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass)
115
116 TestSCFPipeliningPass() = default;
TestSCFPipeliningPass__anon3ae6e7e50111::TestSCFPipeliningPass117 TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
getArgument__anon3ae6e7e50111::TestSCFPipeliningPass118 StringRef getArgument() const final { return "test-scf-pipelining"; }
getDescription__anon3ae6e7e50111::TestSCFPipeliningPass119 StringRef getDescription() const final { return "test scf.forOp pipelining"; }
120
121 Option<bool> annotatePipeline{
122 *this, "annotate",
123 llvm::cl::desc("Annote operations during loop pipelining transformation"),
124 llvm::cl::init(false)};
125
126 Option<bool> noEpiloguePeeling{
127 *this, "no-epilogue-peeling",
128 llvm::cl::desc("Use predicates instead of peeling the epilogue."),
129 llvm::cl::init(false)};
130
131 static void
getSchedule__anon3ae6e7e50111::TestSCFPipeliningPass132 getSchedule(scf::ForOp forOp,
133 std::vector<std::pair<Operation *, unsigned>> &schedule) {
134 if (!forOp->hasAttr(kTestPipeliningLoopMarker))
135 return;
136 schedule.resize(forOp.getBody()->getOperations().size() - 1);
137 forOp.walk([&schedule](Operation *op) {
138 auto attrStage =
139 op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
140 auto attrCycle =
141 op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
142 if (attrCycle && attrStage) {
143 schedule[attrCycle.getInt()] =
144 std::make_pair(op, unsigned(attrStage.getInt()));
145 }
146 });
147 }
148
149 /// Helper to generate "predicated" version of `op`. For simplicity we just
150 /// wrap the operation in a scf.ifOp operation.
predicateOp__anon3ae6e7e50111::TestSCFPipeliningPass151 static Operation *predicateOp(Operation *op, Value pred,
152 PatternRewriter &rewriter) {
153 Location loc = op->getLoc();
154 auto ifOp =
155 rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
156 // True branch.
157 op->moveBefore(&ifOp.getThenRegion().front(),
158 ifOp.getThenRegion().front().end());
159 rewriter.setInsertionPointAfter(op);
160 rewriter.create<scf::YieldOp>(loc, op->getResults());
161 // False branch.
162 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
163 SmallVector<Value> zeros;
164 for (Type type : op->getResultTypes()) {
165 zeros.push_back(
166 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(type)));
167 }
168 rewriter.create<scf::YieldOp>(loc, zeros);
169 return ifOp.getOperation();
170 }
171
annotate__anon3ae6e7e50111::TestSCFPipeliningPass172 static void annotate(Operation *op,
173 mlir::scf::PipeliningOption::PipelinerPart part,
174 unsigned iteration) {
175 OpBuilder b(op);
176 switch (part) {
177 case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
178 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
179 break;
180 case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
181 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
182 break;
183 case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
184 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
185 break;
186 }
187 op->setAttr(kTestPipeliningAnnotationIteration,
188 b.getI32IntegerAttr(iteration));
189 }
190
getDependentDialects__anon3ae6e7e50111::TestSCFPipeliningPass191 void getDependentDialects(DialectRegistry ®istry) const override {
192 registry.insert<arith::ArithmeticDialect>();
193 }
194
runOnOperation__anon3ae6e7e50111::TestSCFPipeliningPass195 void runOnOperation() override {
196 RewritePatternSet patterns(&getContext());
197 mlir::scf::PipeliningOption options;
198 options.getScheduleFn = getSchedule;
199 if (annotatePipeline)
200 options.annotateFn = annotate;
201 if (noEpiloguePeeling) {
202 options.peelEpilogue = false;
203 options.predicateFn = predicateOp;
204 }
205 scf::populateSCFLoopPipeliningPatterns(patterns, options);
206 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
207 getOperation().walk([](Operation *op) {
208 // Clean up the markers.
209 op->removeAttr(kTestPipeliningStageMarker);
210 op->removeAttr(kTestPipeliningOpOrderMarker);
211 });
212 }
213 };
214 } // namespace
215
216 namespace mlir {
217 namespace test {
registerTestSCFUtilsPass()218 void registerTestSCFUtilsPass() {
219 PassRegistration<TestSCFForUtilsPass>();
220 PassRegistration<TestSCFIfUtilsPass>();
221 PassRegistration<TestSCFPipeliningPass>();
222 }
223 } // namespace test
224 } // namespace mlir
225