1 //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Pass/Pass.h" 11 12 using namespace mlir; 13 14 static std::string getStageDescription(const WalkStage &stage) { 15 if (stage.isBeforeAllRegions()) 16 return "before all regions"; 17 if (stage.isAfterAllRegions()) 18 return "after all regions"; 19 return "before region #" + std::to_string(stage.getNextRegion()); 20 } 21 22 namespace { 23 /// This pass exercises generic visitor with void callbacks and prints the order 24 /// and stage in which operations are visited. 25 struct TestGenericIRVisitorPass 26 : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> { 27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass) 28 29 StringRef getArgument() const final { return "test-generic-ir-visitors"; } 30 StringRef getDescription() const final { return "Test generic IR visitors."; } 31 void runOnOperation() override { 32 Operation *outerOp = getOperation(); 33 int stepNo = 0; 34 outerOp->walk([&](Operation *op, const WalkStage &stage) { 35 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " 36 << getStageDescription(stage) << "\n"; 37 }); 38 39 // Exercise static inference of operation type. 40 outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { 41 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " 42 << getStageDescription(stage) << "\n"; 43 }); 44 } 45 }; 46 47 /// This pass exercises the generic visitor with non-void callbacks and prints 48 /// the order and stage in which operations are visited. It will interrupt the 49 /// walk based on attributes peesent in the IR. 50 struct TestGenericIRVisitorInterruptPass 51 : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> { 52 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 53 TestGenericIRVisitorInterruptPass) 54 55 StringRef getArgument() const final { 56 return "test-generic-ir-visitors-interrupt"; 57 } 58 StringRef getDescription() const final { 59 return "Test generic IR visitors with interrupts."; 60 } 61 void runOnOperation() override { 62 Operation *outerOp = getOperation(); 63 int stepNo = 0; 64 65 auto walker = [&](Operation *op, const WalkStage &stage) { 66 if (auto interruptBeforeAall = 67 op->getAttrOfType<BoolAttr>("interrupt_before_all")) 68 if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions()) 69 return WalkResult::interrupt(); 70 71 if (auto interruptAfterAll = 72 op->getAttrOfType<BoolAttr>("interrupt_after_all")) 73 if (interruptAfterAll.getValue() && stage.isAfterAllRegions()) 74 return WalkResult::interrupt(); 75 76 if (auto interruptAfterRegion = 77 op->getAttrOfType<IntegerAttr>("interrupt_after_region")) 78 if (stage.isAfterRegion( 79 static_cast<int>(interruptAfterRegion.getInt()))) 80 return WalkResult::interrupt(); 81 82 if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all")) 83 if (skipBeforeAall.getValue() && stage.isBeforeAllRegions()) 84 return WalkResult::skip(); 85 86 if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all")) 87 if (skipAfterAll.getValue() && stage.isAfterAllRegions()) 88 return WalkResult::skip(); 89 90 if (auto skipAfterRegion = 91 op->getAttrOfType<IntegerAttr>("skip_after_region")) 92 if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt()))) 93 return WalkResult::skip(); 94 95 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " 96 << getStageDescription(stage) << "\n"; 97 return WalkResult::advance(); 98 }; 99 100 // Interrupt the walk based on attributes on the operation. 101 auto result = outerOp->walk(walker); 102 103 if (result.wasInterrupted()) 104 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; 105 106 // Exercise static inference of operation type. 107 result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { 108 return walker(op, stage); 109 }); 110 111 if (result.wasInterrupted()) 112 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; 113 } 114 }; 115 116 struct TestGenericIRBlockVisitorInterruptPass 117 : public PassWrapper<TestGenericIRBlockVisitorInterruptPass, 118 OperationPass<ModuleOp>> { 119 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 120 TestGenericIRBlockVisitorInterruptPass) 121 122 StringRef getArgument() const final { 123 return "test-generic-ir-block-visitors-interrupt"; 124 } 125 StringRef getDescription() const final { 126 return "Test generic IR visitors with interrupts, starting with Blocks."; 127 } 128 129 void runOnOperation() override { 130 int stepNo = 0; 131 132 auto walker = [&](Block *block) { 133 for (Operation &op : *block) 134 if (op.getAttrOfType<BoolAttr>("interrupt")) 135 return WalkResult::interrupt(); 136 137 llvm::outs() << "step " << stepNo++ << "\n"; 138 return WalkResult::advance(); 139 }; 140 141 auto result = getOperation()->walk(walker); 142 if (result.wasInterrupted()) 143 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; 144 } 145 }; 146 147 struct TestGenericIRRegionVisitorInterruptPass 148 : public PassWrapper<TestGenericIRRegionVisitorInterruptPass, 149 OperationPass<ModuleOp>> { 150 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 151 TestGenericIRRegionVisitorInterruptPass) 152 153 StringRef getArgument() const final { 154 return "test-generic-ir-region-visitors-interrupt"; 155 } 156 StringRef getDescription() const final { 157 return "Test generic IR visitors with interrupts, starting with Regions."; 158 } 159 160 void runOnOperation() override { 161 int stepNo = 0; 162 163 auto walker = [&](Region *region) { 164 for (Operation &op : region->getOps()) 165 if (op.getAttrOfType<BoolAttr>("interrupt")) 166 return WalkResult::interrupt(); 167 168 llvm::outs() << "step " << stepNo++ << "\n"; 169 return WalkResult::advance(); 170 }; 171 172 auto result = getOperation()->walk(walker); 173 if (result.wasInterrupted()) 174 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; 175 } 176 }; 177 178 } // namespace 179 180 namespace mlir { 181 namespace test { 182 void registerTestGenericIRVisitorsPass() { 183 PassRegistration<TestGenericIRVisitorPass>(); 184 PassRegistration<TestGenericIRVisitorInterruptPass>(); 185 PassRegistration<TestGenericIRBlockVisitorInterruptPass>(); 186 PassRegistration<TestGenericIRRegionVisitorInterruptPass>(); 187 } 188 189 } // namespace test 190 } // namespace mlir 191