1b73f1d2cSMogball //===- TestControlFlowSink.cpp - Test control-flow sink pass --------------===//
2b73f1d2cSMogball //
3b73f1d2cSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b73f1d2cSMogball // See https://llvm.org/LICENSE.txt for license information.
5b73f1d2cSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b73f1d2cSMogball //
7b73f1d2cSMogball //===----------------------------------------------------------------------===//
8b73f1d2cSMogball //
9b73f1d2cSMogball // This pass tests the control-flow sink utilities by implementing an example
10b73f1d2cSMogball // control-flow sink pass.
11b73f1d2cSMogball //
12b73f1d2cSMogball //===----------------------------------------------------------------------===//
13b73f1d2cSMogball 
14b73f1d2cSMogball #include "mlir/Dialect/Func/IR/FuncOps.h"
15b73f1d2cSMogball #include "mlir/IR/Dominance.h"
16b73f1d2cSMogball #include "mlir/Pass/Pass.h"
17b73f1d2cSMogball #include "mlir/Transforms/ControlFlowSinkUtils.h"
18b73f1d2cSMogball 
19b73f1d2cSMogball using namespace mlir;
20b73f1d2cSMogball 
21b73f1d2cSMogball namespace {
22b73f1d2cSMogball /// An example control-flow sink pass to test the control-flow sink utilites.
23b73f1d2cSMogball /// This pass will sink ops named `test.sink_me` and tag them with an attribute
24b73f1d2cSMogball /// `was_sunk` into the first region of `test.sink_target` ops.
25b73f1d2cSMogball struct TestControlFlowSinkPass
26*58ceae95SRiver Riddle     : public PassWrapper<TestControlFlowSinkPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon15860e140111::TestControlFlowSinkPass275e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestControlFlowSinkPass)
285e50dd04SRiver Riddle 
29b73f1d2cSMogball   /// Get the command-line argument of the test pass.
30b73f1d2cSMogball   StringRef getArgument() const final { return "test-control-flow-sink"; }
31b73f1d2cSMogball   /// Get the description of the test pass.
getDescription__anon15860e140111::TestControlFlowSinkPass32b73f1d2cSMogball   StringRef getDescription() const final {
33b73f1d2cSMogball     return "Test control-flow sink pass";
34b73f1d2cSMogball   }
35b73f1d2cSMogball 
36b73f1d2cSMogball   /// Runs the pass on the function.
runOnOperation__anon15860e140111::TestControlFlowSinkPass37b73f1d2cSMogball   void runOnOperation() override {
38b73f1d2cSMogball     auto &domInfo = getAnalysis<DominanceInfo>();
39b73f1d2cSMogball     auto shouldMoveIntoRegion = [](Operation *op, Region *region) {
40b73f1d2cSMogball       return region->getRegionNumber() == 0 &&
41b73f1d2cSMogball              op->getName().getStringRef() == "test.sink_me";
42b73f1d2cSMogball     };
43b73f1d2cSMogball     auto moveIntoRegion = [](Operation *op, Region *region) {
44b73f1d2cSMogball       Block &entry = region->front();
45b73f1d2cSMogball       op->moveBefore(&entry, entry.begin());
46b73f1d2cSMogball       op->setAttr("was_sunk",
47b73f1d2cSMogball                   Builder(op).getI32IntegerAttr(region->getRegionNumber()));
48b73f1d2cSMogball     };
49b73f1d2cSMogball 
50b73f1d2cSMogball     getOperation()->walk([&](Operation *op) {
51b73f1d2cSMogball       if (op->getName().getStringRef() != "test.sink_target")
52b73f1d2cSMogball         return;
53b73f1d2cSMogball       SmallVector<Region *> regions =
54b73f1d2cSMogball           llvm::to_vector(RegionRange(op->getRegions()));
55b73f1d2cSMogball       controlFlowSink(regions, domInfo, shouldMoveIntoRegion, moveIntoRegion);
56b73f1d2cSMogball     });
57b73f1d2cSMogball   }
58b73f1d2cSMogball };
59b73f1d2cSMogball } // end anonymous namespace
60b73f1d2cSMogball 
61b73f1d2cSMogball namespace mlir {
62b73f1d2cSMogball namespace test {
registerTestControlFlowSink()63b73f1d2cSMogball void registerTestControlFlowSink() {
64b73f1d2cSMogball   PassRegistration<TestControlFlowSinkPass>();
65b73f1d2cSMogball }
66b73f1d2cSMogball } // end namespace test
67b73f1d2cSMogball } // end namespace mlir
68