1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/SCF/Utils.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/Passes.h"
18 
19 #include "llvm/ADT/SetVector.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 class TestSCFForUtilsPass
25     : public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
26 public:
27   explicit TestSCFForUtilsPass() {}
28 
29   void runOnFunction() override {
30     FuncOp func = getFunction();
31     SmallVector<scf::ForOp, 4> toErase;
32 
33     func.walk([&](Operation *fakeRead) {
34       if (fakeRead->getName().getStringRef() != "fake_read")
35         return;
36       auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
37       auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
38       auto loop = fakeRead->getParentOfType<scf::ForOp>();
39 
40       OpBuilder b(loop);
41       (void)loop.moveOutOfLoop({fakeRead});
42       fakeWrite->moveAfter(loop);
43       auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
44                                         fakeCompute->getResult(0));
45       fakeCompute->getResult(0).replaceAllUsesWith(
46           newLoop.getResults().take_back()[0]);
47       toErase.push_back(loop);
48     });
49     for (auto loop : llvm::reverse(toErase))
50       loop.erase();
51   }
52 };
53 
54 class TestSCFIfUtilsPass
55     : public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
56 public:
57   explicit TestSCFIfUtilsPass() {}
58 
59   void runOnFunction() override {
60     int count = 0;
61     FuncOp func = getFunction();
62     func.walk([&](scf::IfOp ifOp) {
63       auto strCount = std::to_string(count++);
64       FuncOp thenFn, elseFn;
65       OpBuilder b(ifOp);
66       outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
67                   &elseFn, std::string("outlined_else") + strCount);
68     });
69   }
70 };
71 } // namespace
72 
73 namespace mlir {
74 namespace test {
75 void registerTestSCFUtilsPass() {
76   PassRegistration<TestSCFForUtilsPass>("test-scf-for-utils",
77                                         "test scf.for utils");
78   PassRegistration<TestSCFIfUtilsPass>("test-scf-if-utils",
79                                        "test scf.if utils");
80 }
81 } // namespace test
82 } // namespace mlir
83