1 //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Pass/Pass.h" 11 12 using namespace mlir; 13 14 static std::string getStageDescription(const WalkStage &stage) { 15 if (stage.isBeforeAllRegions()) 16 return "before all regions"; 17 if (stage.isAfterAllRegions()) 18 return "after all regions"; 19 return "before region #" + std::to_string(stage.getNextRegion()); 20 } 21 22 namespace { 23 /// This pass exercises generic visitor with void callbacks and prints the order 24 /// and stage in which operations are visited. 25 struct TestGenericIRVisitorPass 26 : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> { 27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass) 28 29 StringRef getArgument() const final { return "test-generic-ir-visitors"; } 30 StringRef getDescription() const final { return "Test generic IR visitors."; } 31 void runOnOperation() override { 32 Operation *outerOp = getOperation(); 33 int stepNo = 0; 34 outerOp->walk([&](Operation *op, const WalkStage &stage) { 35 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " 36 << getStageDescription(stage) << "\n"; 37 }); 38 39 // Exercise static inference of operation type. 40 outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { 41 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " 42 << getStageDescription(stage) << "\n"; 43 }); 44 } 45 }; 46 47 /// This pass exercises the generic visitor with non-void callbacks and prints 48 /// the order and stage in which operations are visited. It will interrupt the 49 /// walk based on attributes peesent in the IR. 50 struct TestGenericIRVisitorInterruptPass 51 : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> { 52 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 53 TestGenericIRVisitorInterruptPass) 54 55 StringRef getArgument() const final { 56 return "test-generic-ir-visitors-interrupt"; 57 } 58 StringRef getDescription() const final { 59 return "Test generic IR visitors with interrupts."; 60 } 61 void runOnOperation() override { 62 Operation *outerOp = getOperation(); 63 int stepNo = 0; 64 65 auto walker = [&](Operation *op, const WalkStage &stage) { 66 if (auto interruptBeforeAall = 67 op->getAttrOfType<BoolAttr>("interrupt_before_all")) 68 if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions()) 69 return WalkResult::interrupt(); 70 71 if (auto interruptAfterAll = 72 op->getAttrOfType<BoolAttr>("interrupt_after_all")) 73 if (interruptAfterAll.getValue() && stage.isAfterAllRegions()) 74 return WalkResult::interrupt(); 75 76 if (auto interruptAfterRegion = 77 op->getAttrOfType<IntegerAttr>("interrupt_after_region")) 78 if (stage.isAfterRegion( 79 static_cast<int>(interruptAfterRegion.getInt()))) 80 return WalkResult::interrupt(); 81 82 if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all")) 83 if (skipBeforeAall.getValue() && stage.isBeforeAllRegions()) 84 return WalkResult::skip(); 85 86 if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all")) 87 if (skipAfterAll.getValue() && stage.isAfterAllRegions()) 88 return WalkResult::skip(); 89 90 if (auto skipAfterRegion = 91 op->getAttrOfType<IntegerAttr>("skip_after_region")) 92 if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt()))) 93 return WalkResult::skip(); 94 95 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " 96 << getStageDescription(stage) << "\n"; 97 return WalkResult::advance(); 98 }; 99 100 // Interrupt the walk based on attributes on the operation. 101 auto result = outerOp->walk(walker); 102 103 if (result.wasInterrupted()) 104 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; 105 106 // Exercise static inference of operation type. 107 result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { 108 return walker(op, stage); 109 }); 110 111 if (result.wasInterrupted()) 112 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; 113 } 114 }; 115 116 } // namespace 117 118 namespace mlir { 119 namespace test { 120 void registerTestGenericIRVisitorsPass() { 121 PassRegistration<TestGenericIRVisitorPass>(); 122 PassRegistration<TestGenericIRVisitorInterruptPass>(); 123 } 124 125 } // namespace test 126 } // namespace mlir 127