1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms.h"
17 #include "mlir/Dialect/SCF/Utils/Utils.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 #include "llvm/ADT/SetVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 struct TestSCFForUtilsPass
29     : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> {
30   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass)
31 
32   StringRef getArgument() const final { return "test-scf-for-utils"; }
33   StringRef getDescription() const final { return "test scf.for utils"; }
34   explicit TestSCFForUtilsPass() = default;
35 
36   void runOnOperation() override {
37     func::FuncOp func = getOperation();
38     SmallVector<scf::ForOp, 4> toErase;
39 
40     func.walk([&](Operation *fakeRead) {
41       if (fakeRead->getName().getStringRef() != "fake_read")
42         return;
43       auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
44       auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
45       auto loop = fakeRead->getParentOfType<scf::ForOp>();
46 
47       OpBuilder b(loop);
48       loop.moveOutOfLoop(fakeRead);
49       fakeWrite->moveAfter(loop);
50       auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
51                                         fakeCompute->getResult(0));
52       fakeCompute->getResult(0).replaceAllUsesWith(
53           newLoop.getResults().take_back()[0]);
54       toErase.push_back(loop);
55     });
56     for (auto loop : llvm::reverse(toErase))
57       loop.erase();
58   }
59 };
60 
61 struct TestSCFIfUtilsPass
62     : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
63   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass)
64 
65   StringRef getArgument() const final { return "test-scf-if-utils"; }
66   StringRef getDescription() const final { return "test scf.if utils"; }
67   explicit TestSCFIfUtilsPass() = default;
68 
69   void runOnOperation() override {
70     int count = 0;
71     getOperation().walk([&](scf::IfOp ifOp) {
72       auto strCount = std::to_string(count++);
73       func::FuncOp thenFn, elseFn;
74       OpBuilder b(ifOp);
75       IRRewriter rewriter(b);
76       if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
77                              std::string("outlined_then") + strCount, &elseFn,
78                              std::string("outlined_else") + strCount))) {
79         this->signalPassFailure();
80         return WalkResult::interrupt();
81       }
82       return WalkResult::advance();
83     });
84   }
85 };
86 
87 static const StringLiteral kTestPipeliningLoopMarker =
88     "__test_pipelining_loop__";
89 static const StringLiteral kTestPipeliningStageMarker =
90     "__test_pipelining_stage__";
91 /// Marker to express the order in which operations should be after pipelining.
92 static const StringLiteral kTestPipeliningOpOrderMarker =
93     "__test_pipelining_op_order__";
94 
95 static const StringLiteral kTestPipeliningAnnotationPart =
96     "__test_pipelining_part";
97 static const StringLiteral kTestPipeliningAnnotationIteration =
98     "__test_pipelining_iteration";
99 
100 struct TestSCFPipeliningPass
101     : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> {
102   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass)
103 
104   TestSCFPipeliningPass() = default;
105   TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
106   StringRef getArgument() const final { return "test-scf-pipelining"; }
107   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
108 
109   Option<bool> annotatePipeline{
110       *this, "annotate",
111       llvm::cl::desc("Annote operations during loop pipelining transformation"),
112       llvm::cl::init(false)};
113 
114   static void
115   getSchedule(scf::ForOp forOp,
116               std::vector<std::pair<Operation *, unsigned>> &schedule) {
117     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
118       return;
119     schedule.resize(forOp.getBody()->getOperations().size() - 1);
120     forOp.walk([&schedule](Operation *op) {
121       auto attrStage =
122           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
123       auto attrCycle =
124           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
125       if (attrCycle && attrStage) {
126         schedule[attrCycle.getInt()] =
127             std::make_pair(op, unsigned(attrStage.getInt()));
128       }
129     });
130   }
131 
132   static void annotate(Operation *op,
133                        mlir::scf::PipeliningOption::PipelinerPart part,
134                        unsigned iteration) {
135     OpBuilder b(op);
136     switch (part) {
137     case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
138       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
139       break;
140     case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
141       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
142       break;
143     case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
144       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
145       break;
146     }
147     op->setAttr(kTestPipeliningAnnotationIteration,
148                 b.getI32IntegerAttr(iteration));
149   }
150 
151   void getDependentDialects(DialectRegistry &registry) const override {
152     registry.insert<arith::ArithmeticDialect>();
153   }
154 
155   void runOnOperation() override {
156     RewritePatternSet patterns(&getContext());
157     mlir::scf::PipeliningOption options;
158     options.getScheduleFn = getSchedule;
159     if (annotatePipeline)
160       options.annotateFn = annotate;
161     scf::populateSCFLoopPipeliningPatterns(patterns, options);
162     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
163     getOperation().walk([](Operation *op) {
164       // Clean up the markers.
165       op->removeAttr(kTestPipeliningStageMarker);
166       op->removeAttr(kTestPipeliningOpOrderMarker);
167     });
168   }
169 };
170 } // namespace
171 
172 namespace mlir {
173 namespace test {
174 void registerTestSCFUtilsPass() {
175   PassRegistration<TestSCFForUtilsPass>();
176   PassRegistration<TestSCFIfUtilsPass>();
177   PassRegistration<TestSCFPipeliningPass>();
178 }
179 } // namespace test
180 } // namespace mlir
181