1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/SCF/Transforms.h"
16 #include "mlir/Dialect/SCF/Utils.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "mlir/Transforms/Passes.h"
23 
24 #include "llvm/ADT/SetVector.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 class TestSCFForUtilsPass
30     : public PassWrapper<TestSCFForUtilsPass, OperationPass<FuncOp>> {
31 public:
32   StringRef getArgument() const final { return "test-scf-for-utils"; }
33   StringRef getDescription() const final { return "test scf.for utils"; }
34   explicit TestSCFForUtilsPass() = default;
35 
36   void runOnOperation() override {
37     FuncOp func = getOperation();
38     SmallVector<scf::ForOp, 4> toErase;
39 
40     func.walk([&](Operation *fakeRead) {
41       if (fakeRead->getName().getStringRef() != "fake_read")
42         return;
43       auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
44       auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
45       auto loop = fakeRead->getParentOfType<scf::ForOp>();
46 
47       OpBuilder b(loop);
48       (void)loop.moveOutOfLoop({fakeRead});
49       fakeWrite->moveAfter(loop);
50       auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
51                                         fakeCompute->getResult(0));
52       fakeCompute->getResult(0).replaceAllUsesWith(
53           newLoop.getResults().take_back()[0]);
54       toErase.push_back(loop);
55     });
56     for (auto loop : llvm::reverse(toErase))
57       loop.erase();
58   }
59 };
60 
61 class TestSCFIfUtilsPass
62     : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
63 public:
64   StringRef getArgument() const final { return "test-scf-if-utils"; }
65   StringRef getDescription() const final { return "test scf.if utils"; }
66   explicit TestSCFIfUtilsPass() = default;
67 
68   void runOnOperation() override {
69     int count = 0;
70     getOperation().walk([&](scf::IfOp ifOp) {
71       auto strCount = std::to_string(count++);
72       FuncOp thenFn, elseFn;
73       OpBuilder b(ifOp);
74       IRRewriter rewriter(b);
75       if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
76                              std::string("outlined_then") + strCount, &elseFn,
77                              std::string("outlined_else") + strCount))) {
78         this->signalPassFailure();
79         return WalkResult::interrupt();
80       }
81       return WalkResult::advance();
82     });
83   }
84 };
85 
86 static const StringLiteral kTestPipeliningLoopMarker =
87     "__test_pipelining_loop__";
88 static const StringLiteral kTestPipeliningStageMarker =
89     "__test_pipelining_stage__";
90 /// Marker to express the order in which operations should be after pipelining.
91 static const StringLiteral kTestPipeliningOpOrderMarker =
92     "__test_pipelining_op_order__";
93 
94 class TestSCFPipeliningPass
95     : public PassWrapper<TestSCFPipeliningPass, OperationPass<FuncOp>> {
96 public:
97   StringRef getArgument() const final { return "test-scf-pipelining"; }
98   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
99   explicit TestSCFPipeliningPass() = default;
100 
101   static void
102   getSchedule(scf::ForOp forOp,
103               std::vector<std::pair<Operation *, unsigned>> &schedule) {
104     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
105       return;
106     schedule.resize(forOp.getBody()->getOperations().size() - 1);
107     forOp.walk([&schedule](Operation *op) {
108       auto attrStage =
109           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
110       auto attrCycle =
111           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
112       if (attrCycle && attrStage) {
113         schedule[attrCycle.getInt()] =
114             std::make_pair(op, unsigned(attrStage.getInt()));
115       }
116     });
117   }
118 
119   void getDependentDialects(DialectRegistry &registry) const override {
120     registry.insert<arith::ArithmeticDialect, StandardOpsDialect>();
121   }
122 
123   void runOnOperation() override {
124     RewritePatternSet patterns(&getContext());
125     mlir::scf::PipeliningOption options;
126     options.getScheduleFn = getSchedule;
127 
128     scf::populateSCFLoopPipeliningPatterns(patterns, options);
129     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
130     getOperation().walk([](Operation *op) {
131       // Clean up the markers.
132       op->removeAttr(kTestPipeliningStageMarker);
133       op->removeAttr(kTestPipeliningOpOrderMarker);
134     });
135   }
136 };
137 } // namespace
138 
139 namespace mlir {
140 namespace test {
141 void registerTestSCFUtilsPass() {
142   PassRegistration<TestSCFForUtilsPass>();
143   PassRegistration<TestSCFIfUtilsPass>();
144   PassRegistration<TestSCFPipeliningPass>();
145 }
146 } // namespace test
147 } // namespace mlir
148