1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
17 #include "mlir/Dialect/SCF/Utils/Utils.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 #include "llvm/ADT/SetVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 struct TestSCFForUtilsPass
29     : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon3ae6e7e50111::TestSCFForUtilsPass30   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass)
31 
32   StringRef getArgument() const final { return "test-scf-for-utils"; }
getDescription__anon3ae6e7e50111::TestSCFForUtilsPass33   StringRef getDescription() const final { return "test scf.for utils"; }
34   explicit TestSCFForUtilsPass() = default;
TestSCFForUtilsPass__anon3ae6e7e50111::TestSCFForUtilsPass35   TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
36 
37   Option<bool> testReplaceWithNewYields{
38       *this, "test-replace-with-new-yields",
39       llvm::cl::desc("Test replacing a loop with a new loop that returns new "
40                      "additional yeild values"),
41       llvm::cl::init(false)};
42 
runOnOperation__anon3ae6e7e50111::TestSCFForUtilsPass43   void runOnOperation() override {
44     func::FuncOp func = getOperation();
45     SmallVector<scf::ForOp, 4> toErase;
46 
47     if (testReplaceWithNewYields) {
48       func.walk([&](scf::ForOp forOp) {
49         if (forOp.getNumResults() == 0)
50           return;
51         auto newInitValues = forOp.getInitArgs();
52         if (newInitValues.empty())
53           return;
54         NewYieldValueFn fn = [&](OpBuilder &b, Location loc,
55                                  ArrayRef<BlockArgument> newBBArgs) {
56           Block *block = newBBArgs.front().getOwner();
57           SmallVector<Value> newYieldValues;
58           for (auto yieldVal :
59                cast<scf::YieldOp>(block->getTerminator()).getResults()) {
60             newYieldValues.push_back(
61                 b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
62           }
63           return newYieldValues;
64         };
65         OpBuilder b(forOp);
66         replaceLoopWithNewYields(b, forOp, newInitValues, fn);
67       });
68     }
69   }
70 };
71 
72 struct TestSCFIfUtilsPass
73     : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon3ae6e7e50111::TestSCFIfUtilsPass74   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass)
75 
76   StringRef getArgument() const final { return "test-scf-if-utils"; }
getDescription__anon3ae6e7e50111::TestSCFIfUtilsPass77   StringRef getDescription() const final { return "test scf.if utils"; }
78   explicit TestSCFIfUtilsPass() = default;
79 
runOnOperation__anon3ae6e7e50111::TestSCFIfUtilsPass80   void runOnOperation() override {
81     int count = 0;
82     getOperation().walk([&](scf::IfOp ifOp) {
83       auto strCount = std::to_string(count++);
84       func::FuncOp thenFn, elseFn;
85       OpBuilder b(ifOp);
86       IRRewriter rewriter(b);
87       if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
88                              std::string("outlined_then") + strCount, &elseFn,
89                              std::string("outlined_else") + strCount))) {
90         this->signalPassFailure();
91         return WalkResult::interrupt();
92       }
93       return WalkResult::advance();
94     });
95   }
96 };
97 
98 static const StringLiteral kTestPipeliningLoopMarker =
99     "__test_pipelining_loop__";
100 static const StringLiteral kTestPipeliningStageMarker =
101     "__test_pipelining_stage__";
102 /// Marker to express the order in which operations should be after
103 /// pipelining.
104 static const StringLiteral kTestPipeliningOpOrderMarker =
105     "__test_pipelining_op_order__";
106 
107 static const StringLiteral kTestPipeliningAnnotationPart =
108     "__test_pipelining_part";
109 static const StringLiteral kTestPipeliningAnnotationIteration =
110     "__test_pipelining_iteration";
111 
112 struct TestSCFPipeliningPass
113     : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> {
114   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass)
115 
116   TestSCFPipeliningPass() = default;
TestSCFPipeliningPass__anon3ae6e7e50111::TestSCFPipeliningPass117   TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
getArgument__anon3ae6e7e50111::TestSCFPipeliningPass118   StringRef getArgument() const final { return "test-scf-pipelining"; }
getDescription__anon3ae6e7e50111::TestSCFPipeliningPass119   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
120 
121   Option<bool> annotatePipeline{
122       *this, "annotate",
123       llvm::cl::desc("Annote operations during loop pipelining transformation"),
124       llvm::cl::init(false)};
125 
126   Option<bool> noEpiloguePeeling{
127       *this, "no-epilogue-peeling",
128       llvm::cl::desc("Use predicates instead of peeling the epilogue."),
129       llvm::cl::init(false)};
130 
131   static void
getSchedule__anon3ae6e7e50111::TestSCFPipeliningPass132   getSchedule(scf::ForOp forOp,
133               std::vector<std::pair<Operation *, unsigned>> &schedule) {
134     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
135       return;
136     schedule.resize(forOp.getBody()->getOperations().size() - 1);
137     forOp.walk([&schedule](Operation *op) {
138       auto attrStage =
139           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
140       auto attrCycle =
141           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
142       if (attrCycle && attrStage) {
143         schedule[attrCycle.getInt()] =
144             std::make_pair(op, unsigned(attrStage.getInt()));
145       }
146     });
147   }
148 
149   /// Helper to generate "predicated" version of `op`. For simplicity we just
150   /// wrap the operation in a scf.ifOp operation.
predicateOp__anon3ae6e7e50111::TestSCFPipeliningPass151   static Operation *predicateOp(Operation *op, Value pred,
152                                 PatternRewriter &rewriter) {
153     Location loc = op->getLoc();
154     auto ifOp =
155         rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
156     // True branch.
157     op->moveBefore(&ifOp.getThenRegion().front(),
158                    ifOp.getThenRegion().front().end());
159     rewriter.setInsertionPointAfter(op);
160     rewriter.create<scf::YieldOp>(loc, op->getResults());
161     // False branch.
162     rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
163     SmallVector<Value> zeros;
164     for (Type type : op->getResultTypes()) {
165       zeros.push_back(
166           rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(type)));
167     }
168     rewriter.create<scf::YieldOp>(loc, zeros);
169     return ifOp.getOperation();
170   }
171 
annotate__anon3ae6e7e50111::TestSCFPipeliningPass172   static void annotate(Operation *op,
173                        mlir::scf::PipeliningOption::PipelinerPart part,
174                        unsigned iteration) {
175     OpBuilder b(op);
176     switch (part) {
177     case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
178       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
179       break;
180     case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
181       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
182       break;
183     case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
184       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
185       break;
186     }
187     op->setAttr(kTestPipeliningAnnotationIteration,
188                 b.getI32IntegerAttr(iteration));
189   }
190 
getDependentDialects__anon3ae6e7e50111::TestSCFPipeliningPass191   void getDependentDialects(DialectRegistry &registry) const override {
192     registry.insert<arith::ArithmeticDialect>();
193   }
194 
runOnOperation__anon3ae6e7e50111::TestSCFPipeliningPass195   void runOnOperation() override {
196     RewritePatternSet patterns(&getContext());
197     mlir::scf::PipeliningOption options;
198     options.getScheduleFn = getSchedule;
199     if (annotatePipeline)
200       options.annotateFn = annotate;
201     if (noEpiloguePeeling) {
202       options.peelEpilogue = false;
203       options.predicateFn = predicateOp;
204     }
205     scf::populateSCFLoopPipeliningPatterns(patterns, options);
206     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
207     getOperation().walk([](Operation *op) {
208       // Clean up the markers.
209       op->removeAttr(kTestPipeliningStageMarker);
210       op->removeAttr(kTestPipeliningOpOrderMarker);
211     });
212   }
213 };
214 } // namespace
215 
216 namespace mlir {
217 namespace test {
registerTestSCFUtilsPass()218 void registerTestSCFUtilsPass() {
219   PassRegistration<TestSCFForUtilsPass>();
220   PassRegistration<TestSCFIfUtilsPass>();
221   PassRegistration<TestSCFPipeliningPass>();
222 }
223 } // namespace test
224 } // namespace mlir
225