1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to test SCF dialect utils. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/SCF/Utils.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/Passes.h" 18 19 #include "llvm/ADT/SetVector.h" 20 21 using namespace mlir; 22 23 namespace { 24 class TestSCFForUtilsPass 25 : public PassWrapper<TestSCFForUtilsPass, FunctionPass> { 26 public: 27 StringRef getArgument() const final { return "test-scf-for-utils"; } 28 StringRef getDescription() const final { return "test scf.for utils"; } 29 explicit TestSCFForUtilsPass() = default; 30 31 void runOnFunction() override { 32 FuncOp func = getFunction(); 33 SmallVector<scf::ForOp, 4> toErase; 34 35 func.walk([&](Operation *fakeRead) { 36 if (fakeRead->getName().getStringRef() != "fake_read") 37 return; 38 auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner(); 39 auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner(); 40 auto loop = fakeRead->getParentOfType<scf::ForOp>(); 41 42 OpBuilder b(loop); 43 (void)loop.moveOutOfLoop({fakeRead}); 44 fakeWrite->moveAfter(loop); 45 auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), 46 fakeCompute->getResult(0)); 47 fakeCompute->getResult(0).replaceAllUsesWith( 48 newLoop.getResults().take_back()[0]); 49 toErase.push_back(loop); 50 }); 51 for (auto loop : llvm::reverse(toErase)) 52 loop.erase(); 53 } 54 }; 55 56 class TestSCFIfUtilsPass 57 : public PassWrapper<TestSCFIfUtilsPass, FunctionPass> { 58 public: 59 StringRef getArgument() const final { return "test-scf-if-utils"; } 60 StringRef getDescription() const final { return "test scf.if utils"; } 61 explicit TestSCFIfUtilsPass() = default; 62 63 void runOnFunction() override { 64 int count = 0; 65 FuncOp func = getFunction(); 66 func.walk([&](scf::IfOp ifOp) { 67 auto strCount = std::to_string(count++); 68 FuncOp thenFn, elseFn; 69 OpBuilder b(ifOp); 70 outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount, 71 &elseFn, std::string("outlined_else") + strCount); 72 }); 73 } 74 }; 75 } // namespace 76 77 namespace mlir { 78 namespace test { 79 void registerTestSCFUtilsPass() { 80 PassRegistration<TestSCFForUtilsPass>(); 81 PassRegistration<TestSCFIfUtilsPass>(); 82 } 83 } // namespace test 84 } // namespace mlir 85