1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/SCF/Transforms.h"
16 #include "mlir/Dialect/SCF/Utils/Utils.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 #include "llvm/ADT/SetVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 class TestSCFForUtilsPass
29     : public PassWrapper<TestSCFForUtilsPass, OperationPass<FuncOp>> {
30 public:
31   StringRef getArgument() const final { return "test-scf-for-utils"; }
32   StringRef getDescription() const final { return "test scf.for utils"; }
33   explicit TestSCFForUtilsPass() = default;
34 
35   void runOnOperation() override {
36     FuncOp func = getOperation();
37     SmallVector<scf::ForOp, 4> toErase;
38 
39     func.walk([&](Operation *fakeRead) {
40       if (fakeRead->getName().getStringRef() != "fake_read")
41         return;
42       auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
43       auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
44       auto loop = fakeRead->getParentOfType<scf::ForOp>();
45 
46       OpBuilder b(loop);
47       (void)loop.moveOutOfLoop({fakeRead});
48       fakeWrite->moveAfter(loop);
49       auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
50                                         fakeCompute->getResult(0));
51       fakeCompute->getResult(0).replaceAllUsesWith(
52           newLoop.getResults().take_back()[0]);
53       toErase.push_back(loop);
54     });
55     for (auto loop : llvm::reverse(toErase))
56       loop.erase();
57   }
58 };
59 
60 class TestSCFIfUtilsPass
61     : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
62 public:
63   StringRef getArgument() const final { return "test-scf-if-utils"; }
64   StringRef getDescription() const final { return "test scf.if utils"; }
65   explicit TestSCFIfUtilsPass() = default;
66 
67   void runOnOperation() override {
68     int count = 0;
69     getOperation().walk([&](scf::IfOp ifOp) {
70       auto strCount = std::to_string(count++);
71       FuncOp thenFn, elseFn;
72       OpBuilder b(ifOp);
73       IRRewriter rewriter(b);
74       if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
75                              std::string("outlined_then") + strCount, &elseFn,
76                              std::string("outlined_else") + strCount))) {
77         this->signalPassFailure();
78         return WalkResult::interrupt();
79       }
80       return WalkResult::advance();
81     });
82   }
83 };
84 
85 static const StringLiteral kTestPipeliningLoopMarker =
86     "__test_pipelining_loop__";
87 static const StringLiteral kTestPipeliningStageMarker =
88     "__test_pipelining_stage__";
89 /// Marker to express the order in which operations should be after pipelining.
90 static const StringLiteral kTestPipeliningOpOrderMarker =
91     "__test_pipelining_op_order__";
92 
93 static const StringLiteral kTestPipeliningAnnotationPart =
94     "__test_pipelining_part";
95 static const StringLiteral kTestPipeliningAnnotationIteration =
96     "__test_pipelining_iteration";
97 
98 class TestSCFPipeliningPass
99     : public PassWrapper<TestSCFPipeliningPass, OperationPass<FuncOp>> {
100 public:
101   TestSCFPipeliningPass() = default;
102   TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
103   StringRef getArgument() const final { return "test-scf-pipelining"; }
104   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
105 
106   Option<bool> annotatePipeline{
107       *this, "annotate",
108       llvm::cl::desc("Annote operations during loop pipelining transformation"),
109       llvm::cl::init(false)};
110 
111   static void
112   getSchedule(scf::ForOp forOp,
113               std::vector<std::pair<Operation *, unsigned>> &schedule) {
114     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
115       return;
116     schedule.resize(forOp.getBody()->getOperations().size() - 1);
117     forOp.walk([&schedule](Operation *op) {
118       auto attrStage =
119           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
120       auto attrCycle =
121           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
122       if (attrCycle && attrStage) {
123         schedule[attrCycle.getInt()] =
124             std::make_pair(op, unsigned(attrStage.getInt()));
125       }
126     });
127   }
128 
129   static void annotate(Operation *op,
130                        mlir::scf::PipeliningOption::PipelinerPart part,
131                        unsigned iteration) {
132     OpBuilder b(op);
133     switch (part) {
134     case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
135       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
136       break;
137     case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
138       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
139       break;
140     case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
141       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
142       break;
143     }
144     op->setAttr(kTestPipeliningAnnotationIteration,
145                 b.getI32IntegerAttr(iteration));
146   }
147 
148   void getDependentDialects(DialectRegistry &registry) const override {
149     registry.insert<arith::ArithmeticDialect, StandardOpsDialect>();
150   }
151 
152   void runOnOperation() override {
153     RewritePatternSet patterns(&getContext());
154     mlir::scf::PipeliningOption options;
155     options.getScheduleFn = getSchedule;
156     if (annotatePipeline)
157       options.annotateFn = annotate;
158     scf::populateSCFLoopPipeliningPatterns(patterns, options);
159     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
160     getOperation().walk([](Operation *op) {
161       // Clean up the markers.
162       op->removeAttr(kTestPipeliningStageMarker);
163       op->removeAttr(kTestPipeliningOpOrderMarker);
164     });
165   }
166 };
167 } // namespace
168 
169 namespace mlir {
170 namespace test {
171 void registerTestSCFUtilsPass() {
172   PassRegistration<TestSCFForUtilsPass>();
173   PassRegistration<TestSCFIfUtilsPass>();
174   PassRegistration<TestSCFPipeliningPass>();
175 }
176 } // namespace test
177 } // namespace mlir
178