1b73f1d2cSMogball //===- TestControlFlowSink.cpp - Test control-flow sink pass --------------===// 2b73f1d2cSMogball // 3b73f1d2cSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b73f1d2cSMogball // See https://llvm.org/LICENSE.txt for license information. 5b73f1d2cSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b73f1d2cSMogball // 7b73f1d2cSMogball //===----------------------------------------------------------------------===// 8b73f1d2cSMogball // 9b73f1d2cSMogball // This pass tests the control-flow sink utilities by implementing an example 10b73f1d2cSMogball // control-flow sink pass. 11b73f1d2cSMogball // 12b73f1d2cSMogball //===----------------------------------------------------------------------===// 13b73f1d2cSMogball 14b73f1d2cSMogball #include "mlir/Dialect/Func/IR/FuncOps.h" 15b73f1d2cSMogball #include "mlir/IR/Dominance.h" 16b73f1d2cSMogball #include "mlir/Pass/Pass.h" 17b73f1d2cSMogball #include "mlir/Transforms/ControlFlowSinkUtils.h" 18b73f1d2cSMogball 19b73f1d2cSMogball using namespace mlir; 20b73f1d2cSMogball 21b73f1d2cSMogball namespace { 22b73f1d2cSMogball /// An example control-flow sink pass to test the control-flow sink utilites. 23b73f1d2cSMogball /// This pass will sink ops named `test.sink_me` and tag them with an attribute 24b73f1d2cSMogball /// `was_sunk` into the first region of `test.sink_target` ops. 25b73f1d2cSMogball struct TestControlFlowSinkPass 26*58ceae95SRiver Riddle : public PassWrapper<TestControlFlowSinkPass, OperationPass<func::FuncOp>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon15860e140111::TestControlFlowSinkPass275e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestControlFlowSinkPass) 285e50dd04SRiver Riddle 29b73f1d2cSMogball /// Get the command-line argument of the test pass. 30b73f1d2cSMogball StringRef getArgument() const final { return "test-control-flow-sink"; } 31b73f1d2cSMogball /// Get the description of the test pass. getDescription__anon15860e140111::TestControlFlowSinkPass32b73f1d2cSMogball StringRef getDescription() const final { 33b73f1d2cSMogball return "Test control-flow sink pass"; 34b73f1d2cSMogball } 35b73f1d2cSMogball 36b73f1d2cSMogball /// Runs the pass on the function. runOnOperation__anon15860e140111::TestControlFlowSinkPass37b73f1d2cSMogball void runOnOperation() override { 38b73f1d2cSMogball auto &domInfo = getAnalysis<DominanceInfo>(); 39b73f1d2cSMogball auto shouldMoveIntoRegion = [](Operation *op, Region *region) { 40b73f1d2cSMogball return region->getRegionNumber() == 0 && 41b73f1d2cSMogball op->getName().getStringRef() == "test.sink_me"; 42b73f1d2cSMogball }; 43b73f1d2cSMogball auto moveIntoRegion = [](Operation *op, Region *region) { 44b73f1d2cSMogball Block &entry = region->front(); 45b73f1d2cSMogball op->moveBefore(&entry, entry.begin()); 46b73f1d2cSMogball op->setAttr("was_sunk", 47b73f1d2cSMogball Builder(op).getI32IntegerAttr(region->getRegionNumber())); 48b73f1d2cSMogball }; 49b73f1d2cSMogball 50b73f1d2cSMogball getOperation()->walk([&](Operation *op) { 51b73f1d2cSMogball if (op->getName().getStringRef() != "test.sink_target") 52b73f1d2cSMogball return; 53b73f1d2cSMogball SmallVector<Region *> regions = 54b73f1d2cSMogball llvm::to_vector(RegionRange(op->getRegions())); 55b73f1d2cSMogball controlFlowSink(regions, domInfo, shouldMoveIntoRegion, moveIntoRegion); 56b73f1d2cSMogball }); 57b73f1d2cSMogball } 58b73f1d2cSMogball }; 59b73f1d2cSMogball } // end anonymous namespace 60b73f1d2cSMogball 61b73f1d2cSMogball namespace mlir { 62b73f1d2cSMogball namespace test { registerTestControlFlowSink()63b73f1d2cSMogballvoid registerTestControlFlowSink() { 64b73f1d2cSMogball PassRegistration<TestControlFlowSinkPass>(); 65b73f1d2cSMogball } 66b73f1d2cSMogball } // end namespace test 67b73f1d2cSMogball } // end namespace mlir 68