1 //===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Pass/Pass.h" 10 11 using namespace mlir; 12 13 static void printRegion(Region *region) { 14 llvm::outs() << "region " << region->getRegionNumber() << " from operation '" 15 << region->getParentOp()->getName() << "'"; 16 } 17 18 static void printBlock(Block *block) { 19 llvm::outs() << "block "; 20 block->printAsOperand(llvm::outs(), /*printType=*/false); 21 llvm::outs() << " from "; 22 printRegion(block->getParent()); 23 } 24 25 static void printOperation(Operation *op) { 26 llvm::outs() << "op '" << op->getName() << "'"; 27 } 28 29 /// Tests pure callbacks. 30 static void testPureCallbacks(Operation *op) { 31 auto opPure = [](Operation *op) { 32 llvm::outs() << "Visiting "; 33 printOperation(op); 34 llvm::outs() << "\n"; 35 }; 36 auto blockPure = [](Block *block) { 37 llvm::outs() << "Visiting "; 38 printBlock(block); 39 llvm::outs() << "\n"; 40 }; 41 auto regionPure = [](Region *region) { 42 llvm::outs() << "Visiting "; 43 printRegion(region); 44 llvm::outs() << "\n"; 45 }; 46 47 llvm::outs() << "Op pre-order visits" 48 << "\n"; 49 op->walk<WalkOrder::PreOrder>(opPure); 50 llvm::outs() << "Block pre-order visits" 51 << "\n"; 52 op->walk<WalkOrder::PreOrder>(blockPure); 53 llvm::outs() << "Region pre-order visits" 54 << "\n"; 55 op->walk<WalkOrder::PreOrder>(regionPure); 56 57 llvm::outs() << "Op post-order visits" 58 << "\n"; 59 op->walk<WalkOrder::PostOrder>(opPure); 60 llvm::outs() << "Block post-order visits" 61 << "\n"; 62 op->walk<WalkOrder::PostOrder>(blockPure); 63 llvm::outs() << "Region post-order visits" 64 << "\n"; 65 op->walk<WalkOrder::PostOrder>(regionPure); 66 } 67 68 /// Tests erasure callbacks that skip the walk. 69 static void testSkipErasureCallbacks(Operation *op) { 70 auto skipOpErasure = [](Operation *op) { 71 // Do not erase module and function op. Otherwise there wouldn't be too 72 // much to test in pre-order. 73 if (isa<ModuleOp>(op) || isa<FuncOp>(op)) 74 return WalkResult::advance(); 75 76 llvm::outs() << "Erasing "; 77 printOperation(op); 78 llvm::outs() << "\n"; 79 op->dropAllUses(); 80 op->erase(); 81 return WalkResult::skip(); 82 }; 83 auto skipBlockErasure = [](Block *block) { 84 // Do not erase module and function blocks. Otherwise there wouldn't be 85 // too much to test in pre-order. 86 Operation *parentOp = block->getParentOp(); 87 if (isa<ModuleOp>(parentOp) || isa<FuncOp>(parentOp)) 88 return WalkResult::advance(); 89 90 llvm::outs() << "Erasing "; 91 printBlock(block); 92 llvm::outs() << "\n"; 93 block->erase(); 94 return WalkResult::skip(); 95 }; 96 97 llvm::outs() << "Op pre-order erasures (skip)" 98 << "\n"; 99 Operation *cloned = op->clone(); 100 cloned->walk<WalkOrder::PreOrder>(skipOpErasure); 101 cloned->erase(); 102 103 llvm::outs() << "Block pre-order erasures (skip)" 104 << "\n"; 105 cloned = op->clone(); 106 cloned->walk<WalkOrder::PreOrder>(skipBlockErasure); 107 cloned->erase(); 108 109 llvm::outs() << "Op post-order erasures (skip)" 110 << "\n"; 111 cloned = op->clone(); 112 cloned->walk<WalkOrder::PostOrder>(skipOpErasure); 113 cloned->erase(); 114 115 llvm::outs() << "Block post-order erasures (skip)" 116 << "\n"; 117 cloned = op->clone(); 118 cloned->walk<WalkOrder::PostOrder>(skipBlockErasure); 119 cloned->erase(); 120 } 121 122 /// Tests callbacks that erase the op or block but don't return 'Skip'. This 123 /// callbacks are only valid in post-order. 124 static void testNoSkipErasureCallbacks(Operation *op) { 125 auto noSkipOpErasure = [](Operation *op) { 126 llvm::outs() << "Erasing "; 127 printOperation(op); 128 llvm::outs() << "\n"; 129 op->dropAllUses(); 130 op->erase(); 131 }; 132 auto noSkipBlockErasure = [](Block *block) { 133 llvm::outs() << "Erasing "; 134 printBlock(block); 135 llvm::outs() << "\n"; 136 block->erase(); 137 }; 138 139 llvm::outs() << "Op post-order erasures (no skip)" 140 << "\n"; 141 Operation *cloned = op->clone(); 142 cloned->walk<WalkOrder::PostOrder>(noSkipOpErasure); 143 144 llvm::outs() << "Block post-order erasures (no skip)" 145 << "\n"; 146 cloned = op->clone(); 147 cloned->walk<WalkOrder::PostOrder>(noSkipBlockErasure); 148 cloned->erase(); 149 } 150 151 namespace { 152 /// This pass exercises the different configurations of the IR visitors. 153 struct TestIRVisitorsPass 154 : public PassWrapper<TestIRVisitorsPass, OperationPass<>> { 155 StringRef getArgument() const final { return "test-ir-visitors"; } 156 StringRef getDescription() const final { return "Test various visitors."; } 157 void runOnOperation() override { 158 Operation *op = getOperation(); 159 testPureCallbacks(op); 160 testSkipErasureCallbacks(op); 161 testNoSkipErasureCallbacks(op); 162 } 163 }; 164 } // end anonymous namespace 165 166 namespace mlir { 167 namespace test { 168 void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); } 169 } // namespace test 170 } // namespace mlir 171