1 //===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/Pass/Pass.h"
11 
12 using namespace mlir;
13 
14 static void printRegion(Region *region) {
15   llvm::outs() << "region " << region->getRegionNumber() << " from operation '"
16                << region->getParentOp()->getName() << "'";
17 }
18 
19 static void printBlock(Block *block) {
20   llvm::outs() << "block ";
21   block->printAsOperand(llvm::outs(), /*printType=*/false);
22   llvm::outs() << " from ";
23   printRegion(block->getParent());
24 }
25 
26 static void printOperation(Operation *op) {
27   llvm::outs() << "op '" << op->getName() << "'";
28 }
29 
30 /// Tests pure callbacks.
31 static void testPureCallbacks(Operation *op) {
32   auto opPure = [](Operation *op) {
33     llvm::outs() << "Visiting ";
34     printOperation(op);
35     llvm::outs() << "\n";
36   };
37   auto blockPure = [](Block *block) {
38     llvm::outs() << "Visiting ";
39     printBlock(block);
40     llvm::outs() << "\n";
41   };
42   auto regionPure = [](Region *region) {
43     llvm::outs() << "Visiting ";
44     printRegion(region);
45     llvm::outs() << "\n";
46   };
47 
48   llvm::outs() << "Op pre-order visits"
49                << "\n";
50   op->walk<WalkOrder::PreOrder>(opPure);
51   llvm::outs() << "Block pre-order visits"
52                << "\n";
53   op->walk<WalkOrder::PreOrder>(blockPure);
54   llvm::outs() << "Region pre-order visits"
55                << "\n";
56   op->walk<WalkOrder::PreOrder>(regionPure);
57 
58   llvm::outs() << "Op post-order visits"
59                << "\n";
60   op->walk<WalkOrder::PostOrder>(opPure);
61   llvm::outs() << "Block post-order visits"
62                << "\n";
63   op->walk<WalkOrder::PostOrder>(blockPure);
64   llvm::outs() << "Region post-order visits"
65                << "\n";
66   op->walk<WalkOrder::PostOrder>(regionPure);
67 }
68 
69 /// Tests erasure callbacks that skip the walk.
70 static void testSkipErasureCallbacks(Operation *op) {
71   auto skipOpErasure = [](Operation *op) {
72     // Do not erase module and module children operations. Otherwise, there
73     // wouldn't be too much to test in pre-order.
74     if (isa<ModuleOp>(op) || isa<ModuleOp>(op->getParentOp()))
75       return WalkResult::advance();
76 
77     llvm::outs() << "Erasing ";
78     printOperation(op);
79     llvm::outs() << "\n";
80     op->dropAllUses();
81     op->erase();
82     return WalkResult::skip();
83   };
84   auto skipBlockErasure = [](Block *block) {
85     // Do not erase module and module children blocks. Otherwise there wouldn't
86     // be too much to test in pre-order.
87     Operation *parentOp = block->getParentOp();
88     if (isa<ModuleOp>(parentOp) || isa<ModuleOp>(parentOp->getParentOp()))
89       return WalkResult::advance();
90 
91     llvm::outs() << "Erasing ";
92     printBlock(block);
93     llvm::outs() << "\n";
94     block->erase();
95     return WalkResult::skip();
96   };
97 
98   llvm::outs() << "Op pre-order erasures (skip)"
99                << "\n";
100   Operation *cloned = op->clone();
101   cloned->walk<WalkOrder::PreOrder>(skipOpErasure);
102   cloned->erase();
103 
104   llvm::outs() << "Block pre-order erasures (skip)"
105                << "\n";
106   cloned = op->clone();
107   cloned->walk<WalkOrder::PreOrder>(skipBlockErasure);
108   cloned->erase();
109 
110   llvm::outs() << "Op post-order erasures (skip)"
111                << "\n";
112   cloned = op->clone();
113   cloned->walk<WalkOrder::PostOrder>(skipOpErasure);
114   cloned->erase();
115 
116   llvm::outs() << "Block post-order erasures (skip)"
117                << "\n";
118   cloned = op->clone();
119   cloned->walk<WalkOrder::PostOrder>(skipBlockErasure);
120   cloned->erase();
121 }
122 
123 /// Tests callbacks that erase the op or block but don't return 'Skip'. This
124 /// callbacks are only valid in post-order.
125 static void testNoSkipErasureCallbacks(Operation *op) {
126   auto noSkipOpErasure = [](Operation *op) {
127     llvm::outs() << "Erasing ";
128     printOperation(op);
129     llvm::outs() << "\n";
130     op->dropAllUses();
131     op->erase();
132   };
133   auto noSkipBlockErasure = [](Block *block) {
134     llvm::outs() << "Erasing ";
135     printBlock(block);
136     llvm::outs() << "\n";
137     block->erase();
138   };
139 
140   llvm::outs() << "Op post-order erasures (no skip)"
141                << "\n";
142   Operation *cloned = op->clone();
143   cloned->walk<WalkOrder::PostOrder>(noSkipOpErasure);
144 
145   llvm::outs() << "Block post-order erasures (no skip)"
146                << "\n";
147   cloned = op->clone();
148   cloned->walk<WalkOrder::PostOrder>(noSkipBlockErasure);
149   cloned->erase();
150 }
151 
152 namespace {
153 /// This pass exercises the different configurations of the IR visitors.
154 struct TestIRVisitorsPass
155     : public PassWrapper<TestIRVisitorsPass, OperationPass<>> {
156   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIRVisitorsPass)
157 
158   StringRef getArgument() const final { return "test-ir-visitors"; }
159   StringRef getDescription() const final { return "Test various visitors."; }
160   void runOnOperation() override {
161     Operation *op = getOperation();
162     testPureCallbacks(op);
163     testSkipErasureCallbacks(op);
164     testNoSkipErasureCallbacks(op);
165   }
166 };
167 } // namespace
168 
169 namespace mlir {
170 namespace test {
171 void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); }
172 } // namespace test
173 } // namespace mlir
174