1 //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Pass/Pass.h"
11 
12 using namespace mlir;
13 
14 static std::string getStageDescription(const WalkStage &stage) {
15   if (stage.isBeforeAllRegions())
16     return "before all regions";
17   if (stage.isAfterAllRegions())
18     return "after all regions";
19   return "before region #" + std::to_string(stage.getNextRegion());
20 }
21 
22 namespace {
23 /// This pass exercises generic visitor with void callbacks and prints the order
24 /// and stage in which operations are visited.
25 struct TestGenericIRVisitorPass
26     : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
27   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass)
28 
29   StringRef getArgument() const final { return "test-generic-ir-visitors"; }
30   StringRef getDescription() const final { return "Test generic IR visitors."; }
31   void runOnOperation() override {
32     Operation *outerOp = getOperation();
33     int stepNo = 0;
34     outerOp->walk([&](Operation *op, const WalkStage &stage) {
35       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
36                    << getStageDescription(stage) << "\n";
37     });
38 
39     // Exercise static inference of operation type.
40     outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
41       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
42                    << getStageDescription(stage) << "\n";
43     });
44   }
45 };
46 
47 /// This pass exercises the generic visitor with non-void callbacks and prints
48 /// the order and stage in which operations are visited. It will interrupt the
49 /// walk based on attributes peesent in the IR.
50 struct TestGenericIRVisitorInterruptPass
51     : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
52   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
53       TestGenericIRVisitorInterruptPass)
54 
55   StringRef getArgument() const final {
56     return "test-generic-ir-visitors-interrupt";
57   }
58   StringRef getDescription() const final {
59     return "Test generic IR visitors with interrupts.";
60   }
61   void runOnOperation() override {
62     Operation *outerOp = getOperation();
63     int stepNo = 0;
64 
65     auto walker = [&](Operation *op, const WalkStage &stage) {
66       if (auto interruptBeforeAall =
67               op->getAttrOfType<BoolAttr>("interrupt_before_all"))
68         if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
69           return WalkResult::interrupt();
70 
71       if (auto interruptAfterAll =
72               op->getAttrOfType<BoolAttr>("interrupt_after_all"))
73         if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
74           return WalkResult::interrupt();
75 
76       if (auto interruptAfterRegion =
77               op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
78         if (stage.isAfterRegion(
79                 static_cast<int>(interruptAfterRegion.getInt())))
80           return WalkResult::interrupt();
81 
82       if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
83         if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
84           return WalkResult::skip();
85 
86       if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
87         if (skipAfterAll.getValue() && stage.isAfterAllRegions())
88           return WalkResult::skip();
89 
90       if (auto skipAfterRegion =
91               op->getAttrOfType<IntegerAttr>("skip_after_region"))
92         if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
93           return WalkResult::skip();
94 
95       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
96                    << getStageDescription(stage) << "\n";
97       return WalkResult::advance();
98     };
99 
100     // Interrupt the walk based on attributes on the operation.
101     auto result = outerOp->walk(walker);
102 
103     if (result.wasInterrupted())
104       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
105 
106     // Exercise static inference of operation type.
107     result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
108       return walker(op, stage);
109     });
110 
111     if (result.wasInterrupted())
112       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
113   }
114 };
115 
116 } // namespace
117 
118 namespace mlir {
119 namespace test {
120 void registerTestGenericIRVisitorsPass() {
121   PassRegistration<TestGenericIRVisitorPass>();
122   PassRegistration<TestGenericIRVisitorInterruptPass>();
123 }
124 
125 } // namespace test
126 } // namespace mlir
127