1 //===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlowFramework.h" 10 #include "mlir/Dialect/Func/IR/FuncOps.h" 11 #include "mlir/Pass/Pass.h" 12 13 using namespace mlir; 14 15 namespace { 16 /// This analysis state represents an integer that is XOR'd with other states. 17 class FooState : public AnalysisState { 18 public: 19 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState) 20 21 using AnalysisState::AnalysisState; 22 23 /// Default-initialize the state to zero. 24 ChangeResult defaultInitialize() override { return join(0); } 25 26 /// Returns true if the state is uninitialized. 27 bool isUninitialized() const override { return !state; } 28 29 /// Print the integer value or "none" if uninitialized. 30 void print(raw_ostream &os) const override { 31 if (state) 32 os << *state; 33 else 34 os << "none"; 35 } 36 37 /// Join the state with another. If either is unintialized, take the 38 /// initialized value. Otherwise, XOR the integer values. 39 ChangeResult join(const FooState &rhs) { 40 if (rhs.isUninitialized()) 41 return ChangeResult::NoChange; 42 return join(*rhs.state); 43 } 44 ChangeResult join(uint64_t value) { 45 if (isUninitialized()) { 46 state = value; 47 return ChangeResult::Change; 48 } 49 uint64_t before = *state; 50 state = before ^ value; 51 return before == *state ? ChangeResult::NoChange : ChangeResult::Change; 52 } 53 54 /// Set the value of the state directly. 55 ChangeResult set(const FooState &rhs) { 56 if (state == rhs.state) 57 return ChangeResult::NoChange; 58 state = rhs.state; 59 return ChangeResult::Change; 60 } 61 62 /// Returns the integer value of the state. 63 uint64_t getValue() const { return *state; } 64 65 private: 66 /// An optional integer value. 67 Optional<uint64_t> state; 68 }; 69 70 /// This analysis computes `FooState` across operations and control-flow edges. 71 /// If an op specifies a `foo` integer attribute, the contained value is XOR'd 72 /// with the value before the operation. 73 class FooAnalysis : public DataFlowAnalysis { 74 public: 75 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis) 76 77 using DataFlowAnalysis::DataFlowAnalysis; 78 79 LogicalResult initialize(Operation *top) override; 80 LogicalResult visit(ProgramPoint point) override; 81 82 private: 83 void visitBlock(Block *block); 84 void visitOperation(Operation *op); 85 }; 86 87 struct TestFooAnalysisPass 88 : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> { 89 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass) 90 91 StringRef getArgument() const override { return "test-foo-analysis"; } 92 93 void runOnOperation() override; 94 }; 95 } // namespace 96 97 LogicalResult FooAnalysis::initialize(Operation *top) { 98 if (top->getNumRegions() != 1) 99 return top->emitError("expected a single region top-level op"); 100 101 // Initialize the top-level state. 102 getOrCreate<FooState>(&top->getRegion(0).front())->join(0); 103 104 // Visit all nested blocks and operations. 105 for (Block &block : top->getRegion(0)) { 106 visitBlock(&block); 107 for (Operation &op : block) { 108 if (op.getNumRegions()) 109 return op.emitError("unexpected op with regions"); 110 visitOperation(&op); 111 } 112 } 113 return success(); 114 } 115 116 LogicalResult FooAnalysis::visit(ProgramPoint point) { 117 if (auto *op = point.dyn_cast<Operation *>()) { 118 visitOperation(op); 119 return success(); 120 } 121 if (auto *block = point.dyn_cast<Block *>()) { 122 visitBlock(block); 123 return success(); 124 } 125 return emitError(point.getLoc(), "unknown point kind"); 126 } 127 128 void FooAnalysis::visitBlock(Block *block) { 129 if (block->isEntryBlock()) { 130 // This is the initial state. Let the framework default-initialize it. 131 return; 132 } 133 FooState *state = getOrCreate<FooState>(block); 134 ChangeResult result = ChangeResult::NoChange; 135 for (Block *pred : block->getPredecessors()) { 136 // Join the state at the terminators of all predecessors. 137 const FooState *predState = 138 getOrCreateFor<FooState>(block, pred->getTerminator()); 139 result |= state->join(*predState); 140 } 141 propagateIfChanged(state, result); 142 } 143 144 void FooAnalysis::visitOperation(Operation *op) { 145 FooState *state = getOrCreate<FooState>(op); 146 ChangeResult result = ChangeResult::NoChange; 147 148 // Copy the state across the operation. 149 const FooState *prevState; 150 if (Operation *prev = op->getPrevNode()) 151 prevState = getOrCreateFor<FooState>(op, prev); 152 else 153 prevState = getOrCreateFor<FooState>(op, op->getBlock()); 154 result |= state->set(*prevState); 155 156 // Modify the state with the attribute, if specified. 157 if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) { 158 uint64_t value = attr.getUInt(); 159 result |= state->join(value); 160 } 161 propagateIfChanged(state, result); 162 } 163 164 void TestFooAnalysisPass::runOnOperation() { 165 func::FuncOp func = getOperation(); 166 DataFlowSolver solver; 167 solver.load<FooAnalysis>(); 168 if (failed(solver.initializeAndRun(func))) 169 return signalPassFailure(); 170 171 raw_ostream &os = llvm::errs(); 172 os << "function: @" << func.getSymName() << "\n"; 173 174 func.walk([&](Operation *op) { 175 auto tag = op->getAttrOfType<StringAttr>("tag"); 176 if (!tag) 177 return; 178 const FooState *state = solver.lookupState<FooState>(op); 179 assert(state && !state->isUninitialized()); 180 os << tag.getValue() << " -> " << state->getValue() << "\n"; 181 }); 182 } 183 184 namespace mlir { 185 namespace test { 186 void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); } 187 } // namespace test 188 } // namespace mlir 189