1 //===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/DataFlowFramework.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Pass/Pass.h"
12
13 using namespace mlir;
14
15 namespace {
16 /// This analysis state represents an integer that is XOR'd with other states.
17 class FooState : public AnalysisState {
18 public:
19 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
20
21 using AnalysisState::AnalysisState;
22
23 /// Default-initialize the state to zero.
defaultInitialize()24 ChangeResult defaultInitialize() override { return join(0); }
25
26 /// Returns true if the state is uninitialized.
isUninitialized() const27 bool isUninitialized() const override { return !state; }
28
29 /// Print the integer value or "none" if uninitialized.
print(raw_ostream & os) const30 void print(raw_ostream &os) const override {
31 if (state)
32 os << *state;
33 else
34 os << "none";
35 }
36
37 /// Join the state with another. If either is unintialized, take the
38 /// initialized value. Otherwise, XOR the integer values.
join(const FooState & rhs)39 ChangeResult join(const FooState &rhs) {
40 if (rhs.isUninitialized())
41 return ChangeResult::NoChange;
42 return join(*rhs.state);
43 }
join(uint64_t value)44 ChangeResult join(uint64_t value) {
45 if (isUninitialized()) {
46 state = value;
47 return ChangeResult::Change;
48 }
49 uint64_t before = *state;
50 state = before ^ value;
51 return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
52 }
53
54 /// Set the value of the state directly.
set(const FooState & rhs)55 ChangeResult set(const FooState &rhs) {
56 if (state == rhs.state)
57 return ChangeResult::NoChange;
58 state = rhs.state;
59 return ChangeResult::Change;
60 }
61
62 /// Returns the integer value of the state.
getValue() const63 uint64_t getValue() const { return *state; }
64
65 private:
66 /// An optional integer value.
67 Optional<uint64_t> state;
68 };
69
70 /// This analysis computes `FooState` across operations and control-flow edges.
71 /// If an op specifies a `foo` integer attribute, the contained value is XOR'd
72 /// with the value before the operation.
73 class FooAnalysis : public DataFlowAnalysis {
74 public:
75 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
76
77 using DataFlowAnalysis::DataFlowAnalysis;
78
79 LogicalResult initialize(Operation *top) override;
80 LogicalResult visit(ProgramPoint point) override;
81
82 private:
83 void visitBlock(Block *block);
84 void visitOperation(Operation *op);
85 };
86
87 struct TestFooAnalysisPass
88 : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon09f8d1fb0111::TestFooAnalysisPass89 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
90
91 StringRef getArgument() const override { return "test-foo-analysis"; }
92
93 void runOnOperation() override;
94 };
95 } // namespace
96
initialize(Operation * top)97 LogicalResult FooAnalysis::initialize(Operation *top) {
98 if (top->getNumRegions() != 1)
99 return top->emitError("expected a single region top-level op");
100
101 // Initialize the top-level state.
102 getOrCreate<FooState>(&top->getRegion(0).front())->join(0);
103
104 // Visit all nested blocks and operations.
105 for (Block &block : top->getRegion(0)) {
106 visitBlock(&block);
107 for (Operation &op : block) {
108 if (op.getNumRegions())
109 return op.emitError("unexpected op with regions");
110 visitOperation(&op);
111 }
112 }
113 return success();
114 }
115
visit(ProgramPoint point)116 LogicalResult FooAnalysis::visit(ProgramPoint point) {
117 if (auto *op = point.dyn_cast<Operation *>()) {
118 visitOperation(op);
119 return success();
120 }
121 if (auto *block = point.dyn_cast<Block *>()) {
122 visitBlock(block);
123 return success();
124 }
125 return emitError(point.getLoc(), "unknown point kind");
126 }
127
visitBlock(Block * block)128 void FooAnalysis::visitBlock(Block *block) {
129 if (block->isEntryBlock()) {
130 // This is the initial state. Let the framework default-initialize it.
131 return;
132 }
133 FooState *state = getOrCreate<FooState>(block);
134 ChangeResult result = ChangeResult::NoChange;
135 for (Block *pred : block->getPredecessors()) {
136 // Join the state at the terminators of all predecessors.
137 const FooState *predState =
138 getOrCreateFor<FooState>(block, pred->getTerminator());
139 result |= state->join(*predState);
140 }
141 propagateIfChanged(state, result);
142 }
143
visitOperation(Operation * op)144 void FooAnalysis::visitOperation(Operation *op) {
145 FooState *state = getOrCreate<FooState>(op);
146 ChangeResult result = ChangeResult::NoChange;
147
148 // Copy the state across the operation.
149 const FooState *prevState;
150 if (Operation *prev = op->getPrevNode())
151 prevState = getOrCreateFor<FooState>(op, prev);
152 else
153 prevState = getOrCreateFor<FooState>(op, op->getBlock());
154 result |= state->set(*prevState);
155
156 // Modify the state with the attribute, if specified.
157 if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
158 uint64_t value = attr.getUInt();
159 result |= state->join(value);
160 }
161 propagateIfChanged(state, result);
162 }
163
runOnOperation()164 void TestFooAnalysisPass::runOnOperation() {
165 func::FuncOp func = getOperation();
166 DataFlowSolver solver;
167 solver.load<FooAnalysis>();
168 if (failed(solver.initializeAndRun(func)))
169 return signalPassFailure();
170
171 raw_ostream &os = llvm::errs();
172 os << "function: @" << func.getSymName() << "\n";
173
174 func.walk([&](Operation *op) {
175 auto tag = op->getAttrOfType<StringAttr>("tag");
176 if (!tag)
177 return;
178 const FooState *state = solver.lookupState<FooState>(op);
179 assert(state && !state->isUninitialized());
180 os << tag.getValue() << " -> " << state->getValue() << "\n";
181 });
182 }
183
184 namespace mlir {
185 namespace test {
registerTestFooAnalysisPass()186 void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
187 } // namespace test
188 } // namespace mlir
189