1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/SCF/Transforms.h"
15 #include "mlir/Dialect/SCF/Utils.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "mlir/Transforms/Passes.h"
21 
22 #include "llvm/ADT/SetVector.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 class TestSCFForUtilsPass
28     : public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
29 public:
30   StringRef getArgument() const final { return "test-scf-for-utils"; }
31   StringRef getDescription() const final { return "test scf.for utils"; }
32   explicit TestSCFForUtilsPass() = default;
33 
34   void runOnFunction() override {
35     FuncOp func = getFunction();
36     SmallVector<scf::ForOp, 4> toErase;
37 
38     func.walk([&](Operation *fakeRead) {
39       if (fakeRead->getName().getStringRef() != "fake_read")
40         return;
41       auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
42       auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
43       auto loop = fakeRead->getParentOfType<scf::ForOp>();
44 
45       OpBuilder b(loop);
46       (void)loop.moveOutOfLoop({fakeRead});
47       fakeWrite->moveAfter(loop);
48       auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
49                                         fakeCompute->getResult(0));
50       fakeCompute->getResult(0).replaceAllUsesWith(
51           newLoop.getResults().take_back()[0]);
52       toErase.push_back(loop);
53     });
54     for (auto loop : llvm::reverse(toErase))
55       loop.erase();
56   }
57 };
58 
59 class TestSCFIfUtilsPass
60     : public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
61 public:
62   StringRef getArgument() const final { return "test-scf-if-utils"; }
63   StringRef getDescription() const final { return "test scf.if utils"; }
64   explicit TestSCFIfUtilsPass() = default;
65 
66   void runOnFunction() override {
67     int count = 0;
68     FuncOp func = getFunction();
69     func.walk([&](scf::IfOp ifOp) {
70       auto strCount = std::to_string(count++);
71       FuncOp thenFn, elseFn;
72       OpBuilder b(ifOp);
73       outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
74                   &elseFn, std::string("outlined_else") + strCount);
75     });
76   }
77 };
78 
79 static const StringLiteral kTestPipeliningLoopMarker =
80     "__test_pipelining_loop__";
81 static const StringLiteral kTestPipeliningStageMarker =
82     "__test_pipelining_stage__";
83 /// Marker to express the order in which operations should be after pipelining.
84 static const StringLiteral kTestPipeliningOpOrderMarker =
85     "__test_pipelining_op_order__";
86 
87 class TestSCFPipeliningPass
88     : public PassWrapper<TestSCFPipeliningPass, FunctionPass> {
89 public:
90   StringRef getArgument() const final { return "test-scf-pipelining"; }
91   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
92   explicit TestSCFPipeliningPass() = default;
93 
94   static void
95   getSchedule(scf::ForOp forOp,
96               std::vector<std::pair<Operation *, unsigned>> &schedule) {
97     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
98       return;
99     schedule.resize(forOp.getBody()->getOperations().size() - 1);
100     forOp.walk([&schedule](Operation *op) {
101       auto attrStage =
102           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
103       auto attrCycle =
104           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
105       if (attrCycle && attrStage) {
106         schedule[attrCycle.getInt()] =
107             std::make_pair(op, unsigned(attrStage.getInt()));
108       }
109     });
110   }
111 
112   void runOnFunction() override {
113     RewritePatternSet patterns(&getContext());
114     mlir::scf::PipeliningOption options;
115     options.getScheduleFn = getSchedule;
116 
117     scf::populateSCFLoopPipeliningPatterns(patterns, options);
118     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
119     getFunction().walk([](Operation *op) {
120       // Clean up the markers.
121       op->removeAttr(kTestPipeliningStageMarker);
122       op->removeAttr(kTestPipeliningOpOrderMarker);
123     });
124   }
125 };
126 } // namespace
127 
128 namespace mlir {
129 namespace test {
130 void registerTestSCFUtilsPass() {
131   PassRegistration<TestSCFForUtilsPass>();
132   PassRegistration<TestSCFIfUtilsPass>();
133   PassRegistration<TestSCFPipeliningPass>();
134 }
135 } // namespace test
136 } // namespace mlir
137