1*ead75d94SMogball //===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
2*ead75d94SMogball //
3*ead75d94SMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*ead75d94SMogball // See https://llvm.org/LICENSE.txt for license information.
5*ead75d94SMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*ead75d94SMogball //
7*ead75d94SMogball //===----------------------------------------------------------------------===//
8*ead75d94SMogball 
9*ead75d94SMogball #include "mlir/Analysis/DataFlowFramework.h"
10*ead75d94SMogball #include "mlir/Dialect/Func/IR/FuncOps.h"
11*ead75d94SMogball #include "mlir/Pass/Pass.h"
12*ead75d94SMogball 
13*ead75d94SMogball using namespace mlir;
14*ead75d94SMogball 
15*ead75d94SMogball namespace {
16*ead75d94SMogball /// This analysis state represents an integer that is XOR'd with other states.
17*ead75d94SMogball class FooState : public AnalysisState {
18*ead75d94SMogball public:
19*ead75d94SMogball   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
20*ead75d94SMogball 
21*ead75d94SMogball   using AnalysisState::AnalysisState;
22*ead75d94SMogball 
23*ead75d94SMogball   /// Default-initialize the state to zero.
defaultInitialize()24*ead75d94SMogball   ChangeResult defaultInitialize() override { return join(0); }
25*ead75d94SMogball 
26*ead75d94SMogball   /// Returns true if the state is uninitialized.
isUninitialized() const27*ead75d94SMogball   bool isUninitialized() const override { return !state; }
28*ead75d94SMogball 
29*ead75d94SMogball   /// Print the integer value or "none" if uninitialized.
print(raw_ostream & os) const30*ead75d94SMogball   void print(raw_ostream &os) const override {
31*ead75d94SMogball     if (state)
32*ead75d94SMogball       os << *state;
33*ead75d94SMogball     else
34*ead75d94SMogball       os << "none";
35*ead75d94SMogball   }
36*ead75d94SMogball 
37*ead75d94SMogball   /// Join the state with another. If either is unintialized, take the
38*ead75d94SMogball   /// initialized value. Otherwise, XOR the integer values.
join(const FooState & rhs)39*ead75d94SMogball   ChangeResult join(const FooState &rhs) {
40*ead75d94SMogball     if (rhs.isUninitialized())
41*ead75d94SMogball       return ChangeResult::NoChange;
42*ead75d94SMogball     return join(*rhs.state);
43*ead75d94SMogball   }
join(uint64_t value)44*ead75d94SMogball   ChangeResult join(uint64_t value) {
45*ead75d94SMogball     if (isUninitialized()) {
46*ead75d94SMogball       state = value;
47*ead75d94SMogball       return ChangeResult::Change;
48*ead75d94SMogball     }
49*ead75d94SMogball     uint64_t before = *state;
50*ead75d94SMogball     state = before ^ value;
51*ead75d94SMogball     return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
52*ead75d94SMogball   }
53*ead75d94SMogball 
54*ead75d94SMogball   /// Set the value of the state directly.
set(const FooState & rhs)55*ead75d94SMogball   ChangeResult set(const FooState &rhs) {
56*ead75d94SMogball     if (state == rhs.state)
57*ead75d94SMogball       return ChangeResult::NoChange;
58*ead75d94SMogball     state = rhs.state;
59*ead75d94SMogball     return ChangeResult::Change;
60*ead75d94SMogball   }
61*ead75d94SMogball 
62*ead75d94SMogball   /// Returns the integer value of the state.
getValue() const63*ead75d94SMogball   uint64_t getValue() const { return *state; }
64*ead75d94SMogball 
65*ead75d94SMogball private:
66*ead75d94SMogball   /// An optional integer value.
67*ead75d94SMogball   Optional<uint64_t> state;
68*ead75d94SMogball };
69*ead75d94SMogball 
70*ead75d94SMogball /// This analysis computes `FooState` across operations and control-flow edges.
71*ead75d94SMogball /// If an op specifies a `foo` integer attribute, the contained value is XOR'd
72*ead75d94SMogball /// with the value before the operation.
73*ead75d94SMogball class FooAnalysis : public DataFlowAnalysis {
74*ead75d94SMogball public:
75*ead75d94SMogball   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
76*ead75d94SMogball 
77*ead75d94SMogball   using DataFlowAnalysis::DataFlowAnalysis;
78*ead75d94SMogball 
79*ead75d94SMogball   LogicalResult initialize(Operation *top) override;
80*ead75d94SMogball   LogicalResult visit(ProgramPoint point) override;
81*ead75d94SMogball 
82*ead75d94SMogball private:
83*ead75d94SMogball   void visitBlock(Block *block);
84*ead75d94SMogball   void visitOperation(Operation *op);
85*ead75d94SMogball };
86*ead75d94SMogball 
87*ead75d94SMogball struct TestFooAnalysisPass
88*ead75d94SMogball     : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon09f8d1fb0111::TestFooAnalysisPass89*ead75d94SMogball   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
90*ead75d94SMogball 
91*ead75d94SMogball   StringRef getArgument() const override { return "test-foo-analysis"; }
92*ead75d94SMogball 
93*ead75d94SMogball   void runOnOperation() override;
94*ead75d94SMogball };
95*ead75d94SMogball } // namespace
96*ead75d94SMogball 
initialize(Operation * top)97*ead75d94SMogball LogicalResult FooAnalysis::initialize(Operation *top) {
98*ead75d94SMogball   if (top->getNumRegions() != 1)
99*ead75d94SMogball     return top->emitError("expected a single region top-level op");
100*ead75d94SMogball 
101*ead75d94SMogball   // Initialize the top-level state.
102*ead75d94SMogball   getOrCreate<FooState>(&top->getRegion(0).front())->join(0);
103*ead75d94SMogball 
104*ead75d94SMogball   // Visit all nested blocks and operations.
105*ead75d94SMogball   for (Block &block : top->getRegion(0)) {
106*ead75d94SMogball     visitBlock(&block);
107*ead75d94SMogball     for (Operation &op : block) {
108*ead75d94SMogball       if (op.getNumRegions())
109*ead75d94SMogball         return op.emitError("unexpected op with regions");
110*ead75d94SMogball       visitOperation(&op);
111*ead75d94SMogball     }
112*ead75d94SMogball   }
113*ead75d94SMogball   return success();
114*ead75d94SMogball }
115*ead75d94SMogball 
visit(ProgramPoint point)116*ead75d94SMogball LogicalResult FooAnalysis::visit(ProgramPoint point) {
117*ead75d94SMogball   if (auto *op = point.dyn_cast<Operation *>()) {
118*ead75d94SMogball     visitOperation(op);
119*ead75d94SMogball     return success();
120*ead75d94SMogball   }
121*ead75d94SMogball   if (auto *block = point.dyn_cast<Block *>()) {
122*ead75d94SMogball     visitBlock(block);
123*ead75d94SMogball     return success();
124*ead75d94SMogball   }
125*ead75d94SMogball   return emitError(point.getLoc(), "unknown point kind");
126*ead75d94SMogball }
127*ead75d94SMogball 
visitBlock(Block * block)128*ead75d94SMogball void FooAnalysis::visitBlock(Block *block) {
129*ead75d94SMogball   if (block->isEntryBlock()) {
130*ead75d94SMogball     // This is the initial state. Let the framework default-initialize it.
131*ead75d94SMogball     return;
132*ead75d94SMogball   }
133*ead75d94SMogball   FooState *state = getOrCreate<FooState>(block);
134*ead75d94SMogball   ChangeResult result = ChangeResult::NoChange;
135*ead75d94SMogball   for (Block *pred : block->getPredecessors()) {
136*ead75d94SMogball     // Join the state at the terminators of all predecessors.
137*ead75d94SMogball     const FooState *predState =
138*ead75d94SMogball         getOrCreateFor<FooState>(block, pred->getTerminator());
139*ead75d94SMogball     result |= state->join(*predState);
140*ead75d94SMogball   }
141*ead75d94SMogball   propagateIfChanged(state, result);
142*ead75d94SMogball }
143*ead75d94SMogball 
visitOperation(Operation * op)144*ead75d94SMogball void FooAnalysis::visitOperation(Operation *op) {
145*ead75d94SMogball   FooState *state = getOrCreate<FooState>(op);
146*ead75d94SMogball   ChangeResult result = ChangeResult::NoChange;
147*ead75d94SMogball 
148*ead75d94SMogball   // Copy the state across the operation.
149*ead75d94SMogball   const FooState *prevState;
150*ead75d94SMogball   if (Operation *prev = op->getPrevNode())
151*ead75d94SMogball     prevState = getOrCreateFor<FooState>(op, prev);
152*ead75d94SMogball   else
153*ead75d94SMogball     prevState = getOrCreateFor<FooState>(op, op->getBlock());
154*ead75d94SMogball   result |= state->set(*prevState);
155*ead75d94SMogball 
156*ead75d94SMogball   // Modify the state with the attribute, if specified.
157*ead75d94SMogball   if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
158*ead75d94SMogball     uint64_t value = attr.getUInt();
159*ead75d94SMogball     result |= state->join(value);
160*ead75d94SMogball   }
161*ead75d94SMogball   propagateIfChanged(state, result);
162*ead75d94SMogball }
163*ead75d94SMogball 
runOnOperation()164*ead75d94SMogball void TestFooAnalysisPass::runOnOperation() {
165*ead75d94SMogball   func::FuncOp func = getOperation();
166*ead75d94SMogball   DataFlowSolver solver;
167*ead75d94SMogball   solver.load<FooAnalysis>();
168*ead75d94SMogball   if (failed(solver.initializeAndRun(func)))
169*ead75d94SMogball     return signalPassFailure();
170*ead75d94SMogball 
171*ead75d94SMogball   raw_ostream &os = llvm::errs();
172*ead75d94SMogball   os << "function: @" << func.getSymName() << "\n";
173*ead75d94SMogball 
174*ead75d94SMogball   func.walk([&](Operation *op) {
175*ead75d94SMogball     auto tag = op->getAttrOfType<StringAttr>("tag");
176*ead75d94SMogball     if (!tag)
177*ead75d94SMogball       return;
178*ead75d94SMogball     const FooState *state = solver.lookupState<FooState>(op);
179*ead75d94SMogball     assert(state && !state->isUninitialized());
180*ead75d94SMogball     os << tag.getValue() << " -> " << state->getValue() << "\n";
181*ead75d94SMogball   });
182*ead75d94SMogball }
183*ead75d94SMogball 
184*ead75d94SMogball namespace mlir {
185*ead75d94SMogball namespace test {
registerTestFooAnalysisPass()186*ead75d94SMogball void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
187*ead75d94SMogball } // namespace test
188*ead75d94SMogball } // namespace mlir
189