1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/SCF/Transforms.h"
16 #include "mlir/Dialect/SCF/Utils/Utils.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 #include "llvm/ADT/SetVector.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 class TestSCFForUtilsPass
28     : public PassWrapper<TestSCFForUtilsPass, OperationPass<FuncOp>> {
29 public:
30   StringRef getArgument() const final { return "test-scf-for-utils"; }
31   StringRef getDescription() const final { return "test scf.for utils"; }
32   explicit TestSCFForUtilsPass() = default;
33 
34   void runOnOperation() override {
35     FuncOp func = getOperation();
36     SmallVector<scf::ForOp, 4> toErase;
37 
38     func.walk([&](Operation *fakeRead) {
39       if (fakeRead->getName().getStringRef() != "fake_read")
40         return;
41       auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
42       auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
43       auto loop = fakeRead->getParentOfType<scf::ForOp>();
44 
45       OpBuilder b(loop);
46       (void)loop.moveOutOfLoop({fakeRead});
47       fakeWrite->moveAfter(loop);
48       auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
49                                         fakeCompute->getResult(0));
50       fakeCompute->getResult(0).replaceAllUsesWith(
51           newLoop.getResults().take_back()[0]);
52       toErase.push_back(loop);
53     });
54     for (auto loop : llvm::reverse(toErase))
55       loop.erase();
56   }
57 };
58 
59 class TestSCFIfUtilsPass
60     : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
61 public:
62   StringRef getArgument() const final { return "test-scf-if-utils"; }
63   StringRef getDescription() const final { return "test scf.if utils"; }
64   explicit TestSCFIfUtilsPass() = default;
65 
66   void runOnOperation() override {
67     int count = 0;
68     getOperation().walk([&](scf::IfOp ifOp) {
69       auto strCount = std::to_string(count++);
70       FuncOp thenFn, elseFn;
71       OpBuilder b(ifOp);
72       IRRewriter rewriter(b);
73       if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
74                              std::string("outlined_then") + strCount, &elseFn,
75                              std::string("outlined_else") + strCount))) {
76         this->signalPassFailure();
77         return WalkResult::interrupt();
78       }
79       return WalkResult::advance();
80     });
81   }
82 };
83 
84 static const StringLiteral kTestPipeliningLoopMarker =
85     "__test_pipelining_loop__";
86 static const StringLiteral kTestPipeliningStageMarker =
87     "__test_pipelining_stage__";
88 /// Marker to express the order in which operations should be after pipelining.
89 static const StringLiteral kTestPipeliningOpOrderMarker =
90     "__test_pipelining_op_order__";
91 
92 static const StringLiteral kTestPipeliningAnnotationPart =
93     "__test_pipelining_part";
94 static const StringLiteral kTestPipeliningAnnotationIteration =
95     "__test_pipelining_iteration";
96 
97 class TestSCFPipeliningPass
98     : public PassWrapper<TestSCFPipeliningPass, OperationPass<FuncOp>> {
99 public:
100   TestSCFPipeliningPass() = default;
101   TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
102   StringRef getArgument() const final { return "test-scf-pipelining"; }
103   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
104 
105   Option<bool> annotatePipeline{
106       *this, "annotate",
107       llvm::cl::desc("Annote operations during loop pipelining transformation"),
108       llvm::cl::init(false)};
109 
110   static void
111   getSchedule(scf::ForOp forOp,
112               std::vector<std::pair<Operation *, unsigned>> &schedule) {
113     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
114       return;
115     schedule.resize(forOp.getBody()->getOperations().size() - 1);
116     forOp.walk([&schedule](Operation *op) {
117       auto attrStage =
118           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
119       auto attrCycle =
120           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
121       if (attrCycle && attrStage) {
122         schedule[attrCycle.getInt()] =
123             std::make_pair(op, unsigned(attrStage.getInt()));
124       }
125     });
126   }
127 
128   static void annotate(Operation *op,
129                        mlir::scf::PipeliningOption::PipelinerPart part,
130                        unsigned iteration) {
131     OpBuilder b(op);
132     switch (part) {
133     case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
134       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
135       break;
136     case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
137       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
138       break;
139     case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
140       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
141       break;
142     }
143     op->setAttr(kTestPipeliningAnnotationIteration,
144                 b.getI32IntegerAttr(iteration));
145   }
146 
147   void getDependentDialects(DialectRegistry &registry) const override {
148     registry.insert<arith::ArithmeticDialect>();
149   }
150 
151   void runOnOperation() override {
152     RewritePatternSet patterns(&getContext());
153     mlir::scf::PipeliningOption options;
154     options.getScheduleFn = getSchedule;
155     if (annotatePipeline)
156       options.annotateFn = annotate;
157     scf::populateSCFLoopPipeliningPatterns(patterns, options);
158     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
159     getOperation().walk([](Operation *op) {
160       // Clean up the markers.
161       op->removeAttr(kTestPipeliningStageMarker);
162       op->removeAttr(kTestPipeliningOpOrderMarker);
163     });
164   }
165 };
166 } // namespace
167 
168 namespace mlir {
169 namespace test {
170 void registerTestSCFUtilsPass() {
171   PassRegistration<TestSCFForUtilsPass>();
172   PassRegistration<TestSCFIfUtilsPass>();
173   PassRegistration<TestSCFPipeliningPass>();
174 }
175 } // namespace test
176 } // namespace mlir
177