1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/SCF/Utils.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/Passes.h"
18 
19 #include "llvm/ADT/SetVector.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 class TestSCFForUtilsPass
25     : public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
26 public:
27   StringRef getArgument() const final { return "test-scf-for-utils"; }
28   StringRef getDescription() const final { return "test scf.for utils"; }
29   explicit TestSCFForUtilsPass() = default;
30 
31   void runOnFunction() override {
32     FuncOp func = getFunction();
33     SmallVector<scf::ForOp, 4> toErase;
34 
35     func.walk([&](Operation *fakeRead) {
36       if (fakeRead->getName().getStringRef() != "fake_read")
37         return;
38       auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
39       auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
40       auto loop = fakeRead->getParentOfType<scf::ForOp>();
41 
42       OpBuilder b(loop);
43       (void)loop.moveOutOfLoop({fakeRead});
44       fakeWrite->moveAfter(loop);
45       auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
46                                         fakeCompute->getResult(0));
47       fakeCompute->getResult(0).replaceAllUsesWith(
48           newLoop.getResults().take_back()[0]);
49       toErase.push_back(loop);
50     });
51     for (auto loop : llvm::reverse(toErase))
52       loop.erase();
53   }
54 };
55 
56 class TestSCFIfUtilsPass
57     : public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
58 public:
59   StringRef getArgument() const final { return "test-scf-if-utils"; }
60   StringRef getDescription() const final { return "test scf.if utils"; }
61   explicit TestSCFIfUtilsPass() = default;
62 
63   void runOnFunction() override {
64     int count = 0;
65     FuncOp func = getFunction();
66     func.walk([&](scf::IfOp ifOp) {
67       auto strCount = std::to_string(count++);
68       FuncOp thenFn, elseFn;
69       OpBuilder b(ifOp);
70       outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
71                   &elseFn, std::string("outlined_else") + strCount);
72     });
73   }
74 };
75 } // namespace
76 
77 namespace mlir {
78 namespace test {
79 void registerTestSCFUtilsPass() {
80   PassRegistration<TestSCFForUtilsPass>();
81   PassRegistration<TestSCFIfUtilsPass>();
82 }
83 } // namespace test
84 } // namespace mlir
85