1 //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/Pass/Pass.h" 11 12 using namespace mlir; 13 14 static std::string getStageDescription(const WalkStage &stage) { 15 if (stage.isBeforeAllRegions()) 16 return "before all regions"; 17 if (stage.isAfterAllRegions()) 18 return "after all regions"; 19 return "before region #" + std::to_string(stage.getNextRegion()); 20 } 21 22 namespace { 23 /// This pass exercises generic visitor with void callbacks and prints the order 24 /// and stage in which operations are visited. 25 class TestGenericIRVisitorPass 26 : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> { 27 public: 28 StringRef getArgument() const final { return "test-generic-ir-visitors"; } 29 StringRef getDescription() const final { return "Test generic IR visitors."; } 30 void runOnOperation() override { 31 Operation *outerOp = getOperation(); 32 int stepNo = 0; 33 outerOp->walk([&](Operation *op, const WalkStage &stage) { 34 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " 35 << getStageDescription(stage) << "\n"; 36 }); 37 38 // Exercise static inference of operation type. 39 outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { 40 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " 41 << getStageDescription(stage) << "\n"; 42 }); 43 } 44 }; 45 46 /// This pass exercises the generic visitor with non-void callbacks and prints 47 /// the order and stage in which operations are visited. It will interrupt the 48 /// walk based on attributes peesent in the IR. 49 class TestGenericIRVisitorInterruptPass 50 : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> { 51 public: 52 StringRef getArgument() const final { 53 return "test-generic-ir-visitors-interrupt"; 54 } 55 StringRef getDescription() const final { 56 return "Test generic IR visitors with interrupts."; 57 } 58 void runOnOperation() override { 59 Operation *outerOp = getOperation(); 60 int stepNo = 0; 61 62 auto walker = [&](Operation *op, const WalkStage &stage) { 63 if (auto interruptBeforeAall = 64 op->getAttrOfType<BoolAttr>("interrupt_before_all")) 65 if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions()) 66 return WalkResult::interrupt(); 67 68 if (auto interruptAfterAll = 69 op->getAttrOfType<BoolAttr>("interrupt_after_all")) 70 if (interruptAfterAll.getValue() && stage.isAfterAllRegions()) 71 return WalkResult::interrupt(); 72 73 if (auto interruptAfterRegion = 74 op->getAttrOfType<IntegerAttr>("interrupt_after_region")) 75 if (stage.isAfterRegion( 76 static_cast<int>(interruptAfterRegion.getInt()))) 77 return WalkResult::interrupt(); 78 79 if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all")) 80 if (skipBeforeAall.getValue() && stage.isBeforeAllRegions()) 81 return WalkResult::skip(); 82 83 if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all")) 84 if (skipAfterAll.getValue() && stage.isAfterAllRegions()) 85 return WalkResult::skip(); 86 87 if (auto skipAfterRegion = 88 op->getAttrOfType<IntegerAttr>("skip_after_region")) 89 if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt()))) 90 return WalkResult::skip(); 91 92 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " 93 << getStageDescription(stage) << "\n"; 94 return WalkResult::advance(); 95 }; 96 97 // Interrupt the walk based on attributes on the operation. 98 auto result = outerOp->walk(walker); 99 100 if (result.wasInterrupted()) 101 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; 102 103 // Exercise static inference of operation type. 104 result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { 105 return walker(op, stage); 106 }); 107 108 if (result.wasInterrupted()) 109 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; 110 } 111 }; 112 113 } // namespace 114 115 namespace mlir { 116 namespace test { 117 void registerTestGenericIRVisitorsPass() { 118 PassRegistration<TestGenericIRVisitorPass>(); 119 PassRegistration<TestGenericIRVisitorInterruptPass>(); 120 } 121 122 } // namespace test 123 } // namespace mlir 124