1 //===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlowFramework.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Pass/Pass.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 /// This analysis state represents an integer that is XOR'd with other states.
17 class FooState : public AnalysisState {
18 public:
19   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
20 
21   using AnalysisState::AnalysisState;
22 
23   /// Default-initialize the state to zero.
defaultInitialize()24   ChangeResult defaultInitialize() override { return join(0); }
25 
26   /// Returns true if the state is uninitialized.
isUninitialized() const27   bool isUninitialized() const override { return !state; }
28 
29   /// Print the integer value or "none" if uninitialized.
print(raw_ostream & os) const30   void print(raw_ostream &os) const override {
31     if (state)
32       os << *state;
33     else
34       os << "none";
35   }
36 
37   /// Join the state with another. If either is unintialized, take the
38   /// initialized value. Otherwise, XOR the integer values.
join(const FooState & rhs)39   ChangeResult join(const FooState &rhs) {
40     if (rhs.isUninitialized())
41       return ChangeResult::NoChange;
42     return join(*rhs.state);
43   }
join(uint64_t value)44   ChangeResult join(uint64_t value) {
45     if (isUninitialized()) {
46       state = value;
47       return ChangeResult::Change;
48     }
49     uint64_t before = *state;
50     state = before ^ value;
51     return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
52   }
53 
54   /// Set the value of the state directly.
set(const FooState & rhs)55   ChangeResult set(const FooState &rhs) {
56     if (state == rhs.state)
57       return ChangeResult::NoChange;
58     state = rhs.state;
59     return ChangeResult::Change;
60   }
61 
62   /// Returns the integer value of the state.
getValue() const63   uint64_t getValue() const { return *state; }
64 
65 private:
66   /// An optional integer value.
67   Optional<uint64_t> state;
68 };
69 
70 /// This analysis computes `FooState` across operations and control-flow edges.
71 /// If an op specifies a `foo` integer attribute, the contained value is XOR'd
72 /// with the value before the operation.
73 class FooAnalysis : public DataFlowAnalysis {
74 public:
75   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
76 
77   using DataFlowAnalysis::DataFlowAnalysis;
78 
79   LogicalResult initialize(Operation *top) override;
80   LogicalResult visit(ProgramPoint point) override;
81 
82 private:
83   void visitBlock(Block *block);
84   void visitOperation(Operation *op);
85 };
86 
87 struct TestFooAnalysisPass
88     : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon09f8d1fb0111::TestFooAnalysisPass89   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
90 
91   StringRef getArgument() const override { return "test-foo-analysis"; }
92 
93   void runOnOperation() override;
94 };
95 } // namespace
96 
initialize(Operation * top)97 LogicalResult FooAnalysis::initialize(Operation *top) {
98   if (top->getNumRegions() != 1)
99     return top->emitError("expected a single region top-level op");
100 
101   // Initialize the top-level state.
102   getOrCreate<FooState>(&top->getRegion(0).front())->join(0);
103 
104   // Visit all nested blocks and operations.
105   for (Block &block : top->getRegion(0)) {
106     visitBlock(&block);
107     for (Operation &op : block) {
108       if (op.getNumRegions())
109         return op.emitError("unexpected op with regions");
110       visitOperation(&op);
111     }
112   }
113   return success();
114 }
115 
visit(ProgramPoint point)116 LogicalResult FooAnalysis::visit(ProgramPoint point) {
117   if (auto *op = point.dyn_cast<Operation *>()) {
118     visitOperation(op);
119     return success();
120   }
121   if (auto *block = point.dyn_cast<Block *>()) {
122     visitBlock(block);
123     return success();
124   }
125   return emitError(point.getLoc(), "unknown point kind");
126 }
127 
visitBlock(Block * block)128 void FooAnalysis::visitBlock(Block *block) {
129   if (block->isEntryBlock()) {
130     // This is the initial state. Let the framework default-initialize it.
131     return;
132   }
133   FooState *state = getOrCreate<FooState>(block);
134   ChangeResult result = ChangeResult::NoChange;
135   for (Block *pred : block->getPredecessors()) {
136     // Join the state at the terminators of all predecessors.
137     const FooState *predState =
138         getOrCreateFor<FooState>(block, pred->getTerminator());
139     result |= state->join(*predState);
140   }
141   propagateIfChanged(state, result);
142 }
143 
visitOperation(Operation * op)144 void FooAnalysis::visitOperation(Operation *op) {
145   FooState *state = getOrCreate<FooState>(op);
146   ChangeResult result = ChangeResult::NoChange;
147 
148   // Copy the state across the operation.
149   const FooState *prevState;
150   if (Operation *prev = op->getPrevNode())
151     prevState = getOrCreateFor<FooState>(op, prev);
152   else
153     prevState = getOrCreateFor<FooState>(op, op->getBlock());
154   result |= state->set(*prevState);
155 
156   // Modify the state with the attribute, if specified.
157   if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
158     uint64_t value = attr.getUInt();
159     result |= state->join(value);
160   }
161   propagateIfChanged(state, result);
162 }
163 
runOnOperation()164 void TestFooAnalysisPass::runOnOperation() {
165   func::FuncOp func = getOperation();
166   DataFlowSolver solver;
167   solver.load<FooAnalysis>();
168   if (failed(solver.initializeAndRun(func)))
169     return signalPassFailure();
170 
171   raw_ostream &os = llvm::errs();
172   os << "function: @" << func.getSymName() << "\n";
173 
174   func.walk([&](Operation *op) {
175     auto tag = op->getAttrOfType<StringAttr>("tag");
176     if (!tag)
177       return;
178     const FooState *state = solver.lookupState<FooState>(op);
179     assert(state && !state->isUninitialized());
180     os << tag.getValue() << " -> " << state->getValue() << "\n";
181   });
182 }
183 
184 namespace mlir {
185 namespace test {
registerTestFooAnalysisPass()186 void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
187 } // namespace test
188 } // namespace mlir
189