1 //===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/Pass.h"
10 
11 using namespace mlir;
12 
13 static void printRegion(Region *region) {
14   llvm::outs() << "region " << region->getRegionNumber() << " from operation '"
15                << region->getParentOp()->getName() << "'";
16 }
17 
18 static void printBlock(Block *block) {
19   llvm::outs() << "block ";
20   block->printAsOperand(llvm::outs(), /*printType=*/false);
21   llvm::outs() << " from ";
22   printRegion(block->getParent());
23 }
24 
25 static void printOperation(Operation *op) {
26   llvm::outs() << "op '" << op->getName() << "'";
27 }
28 
29 /// Tests pure callbacks.
30 static void testPureCallbacks(Operation *op) {
31   auto opPure = [](Operation *op) {
32     llvm::outs() << "Visiting ";
33     printOperation(op);
34     llvm::outs() << "\n";
35   };
36   auto blockPure = [](Block *block) {
37     llvm::outs() << "Visiting ";
38     printBlock(block);
39     llvm::outs() << "\n";
40   };
41   auto regionPure = [](Region *region) {
42     llvm::outs() << "Visiting ";
43     printRegion(region);
44     llvm::outs() << "\n";
45   };
46 
47   llvm::outs() << "Op pre-order visits"
48                << "\n";
49   op->walk<WalkOrder::PreOrder>(opPure);
50   llvm::outs() << "Block pre-order visits"
51                << "\n";
52   op->walk<WalkOrder::PreOrder>(blockPure);
53   llvm::outs() << "Region pre-order visits"
54                << "\n";
55   op->walk<WalkOrder::PreOrder>(regionPure);
56 
57   llvm::outs() << "Op post-order visits"
58                << "\n";
59   op->walk<WalkOrder::PostOrder>(opPure);
60   llvm::outs() << "Block post-order visits"
61                << "\n";
62   op->walk<WalkOrder::PostOrder>(blockPure);
63   llvm::outs() << "Region post-order visits"
64                << "\n";
65   op->walk<WalkOrder::PostOrder>(regionPure);
66 }
67 
68 /// Tests erasure callbacks that skip the walk.
69 static void testSkipErasureCallbacks(Operation *op) {
70   auto skipOpErasure = [](Operation *op) {
71     // Do not erase module and function op. Otherwise there wouldn't be too
72     // much to test in pre-order.
73     if (isa<ModuleOp>(op) || isa<FuncOp>(op))
74       return WalkResult::advance();
75 
76     llvm::outs() << "Erasing ";
77     printOperation(op);
78     llvm::outs() << "\n";
79     op->dropAllUses();
80     op->erase();
81     return WalkResult::skip();
82   };
83   auto skipBlockErasure = [](Block *block) {
84     // Do not erase module and function blocks. Otherwise there wouldn't be
85     // too much to test in pre-order.
86     Operation *parentOp = block->getParentOp();
87     if (isa<ModuleOp>(parentOp) || isa<FuncOp>(parentOp))
88       return WalkResult::advance();
89 
90     llvm::outs() << "Erasing ";
91     printBlock(block);
92     llvm::outs() << "\n";
93     block->erase();
94     return WalkResult::skip();
95   };
96 
97   llvm::outs() << "Op pre-order erasures (skip)"
98                << "\n";
99   Operation *cloned = op->clone();
100   cloned->walk<WalkOrder::PreOrder>(skipOpErasure);
101   cloned->erase();
102 
103   llvm::outs() << "Block pre-order erasures (skip)"
104                << "\n";
105   cloned = op->clone();
106   cloned->walk<WalkOrder::PreOrder>(skipBlockErasure);
107   cloned->erase();
108 
109   llvm::outs() << "Op post-order erasures (skip)"
110                << "\n";
111   cloned = op->clone();
112   cloned->walk<WalkOrder::PostOrder>(skipOpErasure);
113   cloned->erase();
114 
115   llvm::outs() << "Block post-order erasures (skip)"
116                << "\n";
117   cloned = op->clone();
118   cloned->walk<WalkOrder::PostOrder>(skipBlockErasure);
119   cloned->erase();
120 }
121 
122 /// Tests callbacks that erase the op or block but don't return 'Skip'. This
123 /// callbacks are only valid in post-order.
124 static void testNoSkipErasureCallbacks(Operation *op) {
125   auto noSkipOpErasure = [](Operation *op) {
126     llvm::outs() << "Erasing ";
127     printOperation(op);
128     llvm::outs() << "\n";
129     op->dropAllUses();
130     op->erase();
131   };
132   auto noSkipBlockErasure = [](Block *block) {
133     llvm::outs() << "Erasing ";
134     printBlock(block);
135     llvm::outs() << "\n";
136     block->erase();
137   };
138 
139   llvm::outs() << "Op post-order erasures (no skip)"
140                << "\n";
141   Operation *cloned = op->clone();
142   cloned->walk<WalkOrder::PostOrder>(noSkipOpErasure);
143 
144   llvm::outs() << "Block post-order erasures (no skip)"
145                << "\n";
146   cloned = op->clone();
147   cloned->walk<WalkOrder::PostOrder>(noSkipBlockErasure);
148   cloned->erase();
149 }
150 
151 namespace {
152 /// This pass exercises the different configurations of the IR visitors.
153 struct TestIRVisitorsPass
154     : public PassWrapper<TestIRVisitorsPass, OperationPass<>> {
155   StringRef getArgument() const final { return "test-ir-visitors"; }
156   StringRef getDescription() const final { return "Test various visitors."; }
157   void runOnOperation() override {
158     Operation *op = getOperation();
159     testPureCallbacks(op);
160     testSkipErasureCallbacks(op);
161     testNoSkipErasureCallbacks(op);
162   }
163 };
164 } // end anonymous namespace
165 
166 namespace mlir {
167 namespace test {
168 void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); }
169 } // namespace test
170 } // namespace mlir
171