1 //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Pass/Pass.h"
11 
12 using namespace mlir;
13 
14 static std::string getStageDescription(const WalkStage &stage) {
15   if (stage.isBeforeAllRegions())
16     return "before all regions";
17   if (stage.isAfterAllRegions())
18     return "after all regions";
19   return "before region #" + std::to_string(stage.getNextRegion());
20 }
21 
22 namespace {
23 /// This pass exercises generic visitor with void callbacks and prints the order
24 /// and stage in which operations are visited.
25 class TestGenericIRVisitorPass
26     : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
27 public:
28   StringRef getArgument() const final { return "test-generic-ir-visitors"; }
29   StringRef getDescription() const final { return "Test generic IR visitors."; }
30   void runOnOperation() override {
31     Operation *outerOp = getOperation();
32     int stepNo = 0;
33     outerOp->walk([&](Operation *op, const WalkStage &stage) {
34       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
35                    << getStageDescription(stage) << "\n";
36     });
37 
38     // Exercise static inference of operation type.
39     outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
40       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
41                    << getStageDescription(stage) << "\n";
42     });
43   }
44 };
45 
46 /// This pass exercises the generic visitor with non-void callbacks and prints
47 /// the order and stage in which operations are visited. It will interrupt the
48 /// walk based on attributes peesent in the IR.
49 class TestGenericIRVisitorInterruptPass
50     : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
51 public:
52   StringRef getArgument() const final {
53     return "test-generic-ir-visitors-interrupt";
54   }
55   StringRef getDescription() const final {
56     return "Test generic IR visitors with interrupts.";
57   }
58   void runOnOperation() override {
59     Operation *outerOp = getOperation();
60     int stepNo = 0;
61 
62     auto walker = [&](Operation *op, const WalkStage &stage) {
63       if (auto interruptBeforeAall =
64               op->getAttrOfType<BoolAttr>("interrupt_before_all"))
65         if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
66           return WalkResult::interrupt();
67 
68       if (auto interruptAfterAll =
69               op->getAttrOfType<BoolAttr>("interrupt_after_all"))
70         if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
71           return WalkResult::interrupt();
72 
73       if (auto interruptAfterRegion =
74               op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
75         if (stage.isAfterRegion(
76                 static_cast<int>(interruptAfterRegion.getInt())))
77           return WalkResult::interrupt();
78 
79       if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
80         if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
81           return WalkResult::skip();
82 
83       if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
84         if (skipAfterAll.getValue() && stage.isAfterAllRegions())
85           return WalkResult::skip();
86 
87       if (auto skipAfterRegion =
88               op->getAttrOfType<IntegerAttr>("skip_after_region"))
89         if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
90           return WalkResult::skip();
91 
92       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
93                    << getStageDescription(stage) << "\n";
94       return WalkResult::advance();
95     };
96 
97     // Interrupt the walk based on attributes on the operation.
98     auto result = outerOp->walk(walker);
99 
100     if (result.wasInterrupted())
101       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
102 
103     // Exercise static inference of operation type.
104     result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
105       return walker(op, stage);
106     });
107 
108     if (result.wasInterrupted())
109       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
110   }
111 };
112 
113 } // namespace
114 
115 namespace mlir {
116 namespace test {
117 void registerTestGenericIRVisitorsPass() {
118   PassRegistration<TestGenericIRVisitorPass>();
119   PassRegistration<TestGenericIRVisitorInterruptPass>();
120 }
121 
122 } // namespace test
123 } // namespace mlir
124