1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to test SCF dialect utils. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/SCF/Utils.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/Passes.h" 18 19 #include "llvm/ADT/SetVector.h" 20 21 using namespace mlir; 22 23 namespace { 24 class TestSCFForUtilsPass 25 : public PassWrapper<TestSCFForUtilsPass, FunctionPass> { 26 public: 27 explicit TestSCFForUtilsPass() {} 28 29 void runOnFunction() override { 30 FuncOp func = getFunction(); 31 SmallVector<scf::ForOp, 4> toErase; 32 33 func.walk([&](Operation *fakeRead) { 34 if (fakeRead->getName().getStringRef() != "fake_read") 35 return; 36 auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner(); 37 auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner(); 38 auto loop = fakeRead->getParentOfType<scf::ForOp>(); 39 40 OpBuilder b(loop); 41 (void)loop.moveOutOfLoop({fakeRead}); 42 fakeWrite->moveAfter(loop); 43 auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), 44 fakeCompute->getResult(0)); 45 fakeCompute->getResult(0).replaceAllUsesWith( 46 newLoop.getResults().take_back()[0]); 47 toErase.push_back(loop); 48 }); 49 for (auto loop : llvm::reverse(toErase)) 50 loop.erase(); 51 } 52 }; 53 54 class TestSCFIfUtilsPass 55 : public PassWrapper<TestSCFIfUtilsPass, FunctionPass> { 56 public: 57 explicit TestSCFIfUtilsPass() {} 58 59 void runOnFunction() override { 60 int count = 0; 61 FuncOp func = getFunction(); 62 func.walk([&](scf::IfOp ifOp) { 63 auto strCount = std::to_string(count++); 64 FuncOp thenFn, elseFn; 65 OpBuilder b(ifOp); 66 outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount, 67 &elseFn, std::string("outlined_else") + strCount); 68 }); 69 } 70 }; 71 } // namespace 72 73 namespace mlir { 74 namespace test { 75 void registerTestSCFUtilsPass() { 76 PassRegistration<TestSCFForUtilsPass>("test-scf-for-utils", 77 "test scf.for utils"); 78 PassRegistration<TestSCFIfUtilsPass>("test-scf-if-utils", 79 "test scf.if utils"); 80 } 81 } // namespace test 82 } // namespace mlir 83