1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms.h"
17 #include "mlir/Dialect/SCF/Utils/Utils.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 #include "llvm/ADT/SetVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 struct TestSCFForUtilsPass
29     : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> {
30   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass)
31 
32   StringRef getArgument() const final { return "test-scf-for-utils"; }
33   StringRef getDescription() const final { return "test scf.for utils"; }
34   explicit TestSCFForUtilsPass() = default;
35   TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
36 
37   Option<bool> testReplaceWithNewYields{
38       *this, "test-replace-with-new-yields",
39       llvm::cl::desc("Test replacing a loop with a new loop that returns new "
40                      "additional yeild values"),
41       llvm::cl::init(false)};
42 
43   void runOnOperation() override {
44     func::FuncOp func = getOperation();
45     SmallVector<scf::ForOp, 4> toErase;
46 
47     if (testReplaceWithNewYields) {
48       func.walk([&](scf::ForOp forOp) {
49         if (forOp.getNumResults() == 0)
50           return;
51         auto newInitValues = forOp.getInitArgs();
52         if (newInitValues.empty())
53           return;
54         NewYieldValueFn fn = [&](OpBuilder &b, Location loc,
55                                  ArrayRef<BlockArgument> newBBArgs) {
56           Block *block = newBBArgs.front().getOwner();
57           SmallVector<Value> newYieldValues;
58           for (auto yieldVal :
59                cast<scf::YieldOp>(block->getTerminator()).getResults()) {
60             newYieldValues.push_back(
61                 b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
62           }
63           return newYieldValues;
64         };
65         OpBuilder b(forOp);
66         replaceLoopWithNewYields(b, forOp, newInitValues, fn);
67       });
68     }
69   }
70 };
71 
72 struct TestSCFIfUtilsPass
73     : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
74   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass)
75 
76   StringRef getArgument() const final { return "test-scf-if-utils"; }
77   StringRef getDescription() const final { return "test scf.if utils"; }
78   explicit TestSCFIfUtilsPass() = default;
79 
80   void runOnOperation() override {
81     int count = 0;
82     getOperation().walk([&](scf::IfOp ifOp) {
83       auto strCount = std::to_string(count++);
84       func::FuncOp thenFn, elseFn;
85       OpBuilder b(ifOp);
86       IRRewriter rewriter(b);
87       if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
88                              std::string("outlined_then") + strCount, &elseFn,
89                              std::string("outlined_else") + strCount))) {
90         this->signalPassFailure();
91         return WalkResult::interrupt();
92       }
93       return WalkResult::advance();
94     });
95   }
96 };
97 
98 static const StringLiteral kTestPipeliningLoopMarker =
99     "__test_pipelining_loop__";
100 static const StringLiteral kTestPipeliningStageMarker =
101     "__test_pipelining_stage__";
102 /// Marker to express the order in which operations should be after
103 /// pipelining.
104 static const StringLiteral kTestPipeliningOpOrderMarker =
105     "__test_pipelining_op_order__";
106 
107 static const StringLiteral kTestPipeliningAnnotationPart =
108     "__test_pipelining_part";
109 static const StringLiteral kTestPipeliningAnnotationIteration =
110     "__test_pipelining_iteration";
111 
112 struct TestSCFPipeliningPass
113     : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> {
114   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass)
115 
116   TestSCFPipeliningPass() = default;
117   TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
118   StringRef getArgument() const final { return "test-scf-pipelining"; }
119   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
120 
121   Option<bool> annotatePipeline{
122       *this, "annotate",
123       llvm::cl::desc("Annote operations during loop pipelining transformation"),
124       llvm::cl::init(false)};
125 
126   static void
127   getSchedule(scf::ForOp forOp,
128               std::vector<std::pair<Operation *, unsigned>> &schedule) {
129     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
130       return;
131     schedule.resize(forOp.getBody()->getOperations().size() - 1);
132     forOp.walk([&schedule](Operation *op) {
133       auto attrStage =
134           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
135       auto attrCycle =
136           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
137       if (attrCycle && attrStage) {
138         schedule[attrCycle.getInt()] =
139             std::make_pair(op, unsigned(attrStage.getInt()));
140       }
141     });
142   }
143 
144   static void annotate(Operation *op,
145                        mlir::scf::PipeliningOption::PipelinerPart part,
146                        unsigned iteration) {
147     OpBuilder b(op);
148     switch (part) {
149     case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
150       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
151       break;
152     case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
153       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
154       break;
155     case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
156       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
157       break;
158     }
159     op->setAttr(kTestPipeliningAnnotationIteration,
160                 b.getI32IntegerAttr(iteration));
161   }
162 
163   void getDependentDialects(DialectRegistry &registry) const override {
164     registry.insert<arith::ArithmeticDialect>();
165   }
166 
167   void runOnOperation() override {
168     RewritePatternSet patterns(&getContext());
169     mlir::scf::PipeliningOption options;
170     options.getScheduleFn = getSchedule;
171     if (annotatePipeline)
172       options.annotateFn = annotate;
173     scf::populateSCFLoopPipeliningPatterns(patterns, options);
174     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
175     getOperation().walk([](Operation *op) {
176       // Clean up the markers.
177       op->removeAttr(kTestPipeliningStageMarker);
178       op->removeAttr(kTestPipeliningOpOrderMarker);
179     });
180   }
181 };
182 } // namespace
183 
184 namespace mlir {
185 namespace test {
186 void registerTestSCFUtilsPass() {
187   PassRegistration<TestSCFForUtilsPass>();
188   PassRegistration<TestSCFIfUtilsPass>();
189   PassRegistration<TestSCFPipeliningPass>();
190 }
191 } // namespace test
192 } // namespace mlir
193