1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/SCF/Transforms.h"
16 #include "mlir/Dialect/SCF/Utils.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "mlir/Transforms/Passes.h"
23 
24 #include "llvm/ADT/SetVector.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 class TestSCFForUtilsPass
30     : public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
31 public:
32   StringRef getArgument() const final { return "test-scf-for-utils"; }
33   StringRef getDescription() const final { return "test scf.for utils"; }
34   explicit TestSCFForUtilsPass() = default;
35 
36   void runOnFunction() override {
37     FuncOp func = getFunction();
38     SmallVector<scf::ForOp, 4> toErase;
39 
40     func.walk([&](Operation *fakeRead) {
41       if (fakeRead->getName().getStringRef() != "fake_read")
42         return;
43       auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
44       auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
45       auto loop = fakeRead->getParentOfType<scf::ForOp>();
46 
47       OpBuilder b(loop);
48       (void)loop.moveOutOfLoop({fakeRead});
49       fakeWrite->moveAfter(loop);
50       auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
51                                         fakeCompute->getResult(0));
52       fakeCompute->getResult(0).replaceAllUsesWith(
53           newLoop.getResults().take_back()[0]);
54       toErase.push_back(loop);
55     });
56     for (auto loop : llvm::reverse(toErase))
57       loop.erase();
58   }
59 };
60 
61 class TestSCFIfUtilsPass
62     : public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
63 public:
64   StringRef getArgument() const final { return "test-scf-if-utils"; }
65   StringRef getDescription() const final { return "test scf.if utils"; }
66   explicit TestSCFIfUtilsPass() = default;
67 
68   void runOnFunction() override {
69     int count = 0;
70     FuncOp func = getFunction();
71     func.walk([&](scf::IfOp ifOp) {
72       auto strCount = std::to_string(count++);
73       FuncOp thenFn, elseFn;
74       OpBuilder b(ifOp);
75       outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
76                   &elseFn, std::string("outlined_else") + strCount);
77     });
78   }
79 };
80 
81 static const StringLiteral kTestPipeliningLoopMarker =
82     "__test_pipelining_loop__";
83 static const StringLiteral kTestPipeliningStageMarker =
84     "__test_pipelining_stage__";
85 /// Marker to express the order in which operations should be after pipelining.
86 static const StringLiteral kTestPipeliningOpOrderMarker =
87     "__test_pipelining_op_order__";
88 
89 class TestSCFPipeliningPass
90     : public PassWrapper<TestSCFPipeliningPass, FunctionPass> {
91 public:
92   StringRef getArgument() const final { return "test-scf-pipelining"; }
93   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
94   explicit TestSCFPipeliningPass() = default;
95 
96   static void
97   getSchedule(scf::ForOp forOp,
98               std::vector<std::pair<Operation *, unsigned>> &schedule) {
99     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
100       return;
101     schedule.resize(forOp.getBody()->getOperations().size() - 1);
102     forOp.walk([&schedule](Operation *op) {
103       auto attrStage =
104           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
105       auto attrCycle =
106           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
107       if (attrCycle && attrStage) {
108         schedule[attrCycle.getInt()] =
109             std::make_pair(op, unsigned(attrStage.getInt()));
110       }
111     });
112   }
113 
114   void getDependentDialects(DialectRegistry &registry) const override {
115     registry.insert<arith::ArithmeticDialect, StandardOpsDialect>();
116   }
117 
118   void runOnFunction() override {
119     RewritePatternSet patterns(&getContext());
120     mlir::scf::PipeliningOption options;
121     options.getScheduleFn = getSchedule;
122 
123     scf::populateSCFLoopPipeliningPatterns(patterns, options);
124     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
125     getFunction().walk([](Operation *op) {
126       // Clean up the markers.
127       op->removeAttr(kTestPipeliningStageMarker);
128       op->removeAttr(kTestPipeliningOpOrderMarker);
129     });
130   }
131 };
132 } // namespace
133 
134 namespace mlir {
135 namespace test {
136 void registerTestSCFUtilsPass() {
137   PassRegistration<TestSCFForUtilsPass>();
138   PassRegistration<TestSCFIfUtilsPass>();
139   PassRegistration<TestSCFPipeliningPass>();
140 }
141 } // namespace test
142 } // namespace mlir
143