1 //===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinOps.h" 10 #include "mlir/Pass/Pass.h" 11 12 using namespace mlir; 13 14 static void printRegion(Region *region) { 15 llvm::outs() << "region " << region->getRegionNumber() << " from operation '" 16 << region->getParentOp()->getName() << "'"; 17 } 18 19 static void printBlock(Block *block) { 20 llvm::outs() << "block "; 21 block->printAsOperand(llvm::outs(), /*printType=*/false); 22 llvm::outs() << " from "; 23 printRegion(block->getParent()); 24 } 25 26 static void printOperation(Operation *op) { 27 llvm::outs() << "op '" << op->getName() << "'"; 28 } 29 30 /// Tests pure callbacks. 31 static void testPureCallbacks(Operation *op) { 32 auto opPure = [](Operation *op) { 33 llvm::outs() << "Visiting "; 34 printOperation(op); 35 llvm::outs() << "\n"; 36 }; 37 auto blockPure = [](Block *block) { 38 llvm::outs() << "Visiting "; 39 printBlock(block); 40 llvm::outs() << "\n"; 41 }; 42 auto regionPure = [](Region *region) { 43 llvm::outs() << "Visiting "; 44 printRegion(region); 45 llvm::outs() << "\n"; 46 }; 47 48 llvm::outs() << "Op pre-order visits" 49 << "\n"; 50 op->walk<WalkOrder::PreOrder>(opPure); 51 llvm::outs() << "Block pre-order visits" 52 << "\n"; 53 op->walk<WalkOrder::PreOrder>(blockPure); 54 llvm::outs() << "Region pre-order visits" 55 << "\n"; 56 op->walk<WalkOrder::PreOrder>(regionPure); 57 58 llvm::outs() << "Op post-order visits" 59 << "\n"; 60 op->walk<WalkOrder::PostOrder>(opPure); 61 llvm::outs() << "Block post-order visits" 62 << "\n"; 63 op->walk<WalkOrder::PostOrder>(blockPure); 64 llvm::outs() << "Region post-order visits" 65 << "\n"; 66 op->walk<WalkOrder::PostOrder>(regionPure); 67 } 68 69 /// Tests erasure callbacks that skip the walk. 70 static void testSkipErasureCallbacks(Operation *op) { 71 auto skipOpErasure = [](Operation *op) { 72 // Do not erase module and module children operations. Otherwise, there 73 // wouldn't be too much to test in pre-order. 74 if (isa<ModuleOp>(op) || isa<ModuleOp>(op->getParentOp())) 75 return WalkResult::advance(); 76 77 llvm::outs() << "Erasing "; 78 printOperation(op); 79 llvm::outs() << "\n"; 80 op->dropAllUses(); 81 op->erase(); 82 return WalkResult::skip(); 83 }; 84 auto skipBlockErasure = [](Block *block) { 85 // Do not erase module and module children blocks. Otherwise there wouldn't 86 // be too much to test in pre-order. 87 Operation *parentOp = block->getParentOp(); 88 if (isa<ModuleOp>(parentOp) || isa<ModuleOp>(parentOp->getParentOp())) 89 return WalkResult::advance(); 90 91 llvm::outs() << "Erasing "; 92 printBlock(block); 93 llvm::outs() << "\n"; 94 block->erase(); 95 return WalkResult::skip(); 96 }; 97 98 llvm::outs() << "Op pre-order erasures (skip)" 99 << "\n"; 100 Operation *cloned = op->clone(); 101 cloned->walk<WalkOrder::PreOrder>(skipOpErasure); 102 cloned->erase(); 103 104 llvm::outs() << "Block pre-order erasures (skip)" 105 << "\n"; 106 cloned = op->clone(); 107 cloned->walk<WalkOrder::PreOrder>(skipBlockErasure); 108 cloned->erase(); 109 110 llvm::outs() << "Op post-order erasures (skip)" 111 << "\n"; 112 cloned = op->clone(); 113 cloned->walk<WalkOrder::PostOrder>(skipOpErasure); 114 cloned->erase(); 115 116 llvm::outs() << "Block post-order erasures (skip)" 117 << "\n"; 118 cloned = op->clone(); 119 cloned->walk<WalkOrder::PostOrder>(skipBlockErasure); 120 cloned->erase(); 121 } 122 123 /// Tests callbacks that erase the op or block but don't return 'Skip'. This 124 /// callbacks are only valid in post-order. 125 static void testNoSkipErasureCallbacks(Operation *op) { 126 auto noSkipOpErasure = [](Operation *op) { 127 llvm::outs() << "Erasing "; 128 printOperation(op); 129 llvm::outs() << "\n"; 130 op->dropAllUses(); 131 op->erase(); 132 }; 133 auto noSkipBlockErasure = [](Block *block) { 134 llvm::outs() << "Erasing "; 135 printBlock(block); 136 llvm::outs() << "\n"; 137 block->erase(); 138 }; 139 140 llvm::outs() << "Op post-order erasures (no skip)" 141 << "\n"; 142 Operation *cloned = op->clone(); 143 cloned->walk<WalkOrder::PostOrder>(noSkipOpErasure); 144 145 llvm::outs() << "Block post-order erasures (no skip)" 146 << "\n"; 147 cloned = op->clone(); 148 cloned->walk<WalkOrder::PostOrder>(noSkipBlockErasure); 149 cloned->erase(); 150 } 151 152 namespace { 153 /// This pass exercises the different configurations of the IR visitors. 154 struct TestIRVisitorsPass 155 : public PassWrapper<TestIRVisitorsPass, OperationPass<>> { 156 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIRVisitorsPass) 157 158 StringRef getArgument() const final { return "test-ir-visitors"; } 159 StringRef getDescription() const final { return "Test various visitors."; } 160 void runOnOperation() override { 161 Operation *op = getOperation(); 162 testPureCallbacks(op); 163 testSkipErasureCallbacks(op); 164 testNoSkipErasureCallbacks(op); 165 } 166 }; 167 } // namespace 168 169 namespace mlir { 170 namespace test { 171 void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); } 172 } // namespace test 173 } // namespace mlir 174