1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/SCF/Transforms.h"
16 #include "mlir/Dialect/SCF/Utils/Utils.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 #include "llvm/ADT/SetVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 class TestSCFForUtilsPass
29     : public PassWrapper<TestSCFForUtilsPass, OperationPass<FuncOp>> {
30 public:
31   StringRef getArgument() const final { return "test-scf-for-utils"; }
32   StringRef getDescription() const final { return "test scf.for utils"; }
33   explicit TestSCFForUtilsPass() = default;
34 
35   void runOnOperation() override {
36     FuncOp func = getOperation();
37     SmallVector<scf::ForOp, 4> toErase;
38 
39     func.walk([&](Operation *fakeRead) {
40       if (fakeRead->getName().getStringRef() != "fake_read")
41         return;
42       auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
43       auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
44       auto loop = fakeRead->getParentOfType<scf::ForOp>();
45 
46       OpBuilder b(loop);
47       (void)loop.moveOutOfLoop({fakeRead});
48       fakeWrite->moveAfter(loop);
49       auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
50                                         fakeCompute->getResult(0));
51       fakeCompute->getResult(0).replaceAllUsesWith(
52           newLoop.getResults().take_back()[0]);
53       toErase.push_back(loop);
54     });
55     for (auto loop : llvm::reverse(toErase))
56       loop.erase();
57   }
58 };
59 
60 class TestSCFIfUtilsPass
61     : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
62 public:
63   StringRef getArgument() const final { return "test-scf-if-utils"; }
64   StringRef getDescription() const final { return "test scf.if utils"; }
65   explicit TestSCFIfUtilsPass() = default;
66 
67   void runOnOperation() override {
68     int count = 0;
69     getOperation().walk([&](scf::IfOp ifOp) {
70       auto strCount = std::to_string(count++);
71       FuncOp thenFn, elseFn;
72       OpBuilder b(ifOp);
73       IRRewriter rewriter(b);
74       if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
75                              std::string("outlined_then") + strCount, &elseFn,
76                              std::string("outlined_else") + strCount))) {
77         this->signalPassFailure();
78         return WalkResult::interrupt();
79       }
80       return WalkResult::advance();
81     });
82   }
83 };
84 
85 static const StringLiteral kTestPipeliningLoopMarker =
86     "__test_pipelining_loop__";
87 static const StringLiteral kTestPipeliningStageMarker =
88     "__test_pipelining_stage__";
89 /// Marker to express the order in which operations should be after pipelining.
90 static const StringLiteral kTestPipeliningOpOrderMarker =
91     "__test_pipelining_op_order__";
92 
93 class TestSCFPipeliningPass
94     : public PassWrapper<TestSCFPipeliningPass, OperationPass<FuncOp>> {
95 public:
96   StringRef getArgument() const final { return "test-scf-pipelining"; }
97   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
98   explicit TestSCFPipeliningPass() = default;
99 
100   static void
101   getSchedule(scf::ForOp forOp,
102               std::vector<std::pair<Operation *, unsigned>> &schedule) {
103     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
104       return;
105     schedule.resize(forOp.getBody()->getOperations().size() - 1);
106     forOp.walk([&schedule](Operation *op) {
107       auto attrStage =
108           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
109       auto attrCycle =
110           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
111       if (attrCycle && attrStage) {
112         schedule[attrCycle.getInt()] =
113             std::make_pair(op, unsigned(attrStage.getInt()));
114       }
115     });
116   }
117 
118   void getDependentDialects(DialectRegistry &registry) const override {
119     registry.insert<arith::ArithmeticDialect, StandardOpsDialect>();
120   }
121 
122   void runOnOperation() override {
123     RewritePatternSet patterns(&getContext());
124     mlir::scf::PipeliningOption options;
125     options.getScheduleFn = getSchedule;
126 
127     scf::populateSCFLoopPipeliningPatterns(patterns, options);
128     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
129     getOperation().walk([](Operation *op) {
130       // Clean up the markers.
131       op->removeAttr(kTestPipeliningStageMarker);
132       op->removeAttr(kTestPipeliningOpOrderMarker);
133     });
134   }
135 };
136 } // namespace
137 
138 namespace mlir {
139 namespace test {
140 void registerTestSCFUtilsPass() {
141   PassRegistration<TestSCFForUtilsPass>();
142   PassRegistration<TestSCFIfUtilsPass>();
143   PassRegistration<TestSCFPipeliningPass>();
144 }
145 } // namespace test
146 } // namespace mlir
147