1 //===- DataFlowFramework.cpp - A generic framework for data-flow analysis -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/DataFlowFramework.h"
10 #include "llvm/Support/Debug.h"
11
12 #define DEBUG_TYPE "dataflow"
13 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
14 #define DATAFLOW_DEBUG(X) LLVM_DEBUG(X)
15 #else
16 #define DATAFLOW_DEBUG(X)
17 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
18
19 using namespace mlir;
20
21 //===----------------------------------------------------------------------===//
22 // GenericProgramPoint
23 //===----------------------------------------------------------------------===//
24
25 GenericProgramPoint::~GenericProgramPoint() = default;
26
27 //===----------------------------------------------------------------------===//
28 // AnalysisState
29 //===----------------------------------------------------------------------===//
30
31 AnalysisState::~AnalysisState() = default;
32
33 //===----------------------------------------------------------------------===//
34 // ProgramPoint
35 //===----------------------------------------------------------------------===//
36
print(raw_ostream & os) const37 void ProgramPoint::print(raw_ostream &os) const {
38 if (isNull()) {
39 os << "<NULL POINT>";
40 return;
41 }
42 if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
43 return programPoint->print(os);
44 if (auto *op = dyn_cast<Operation *>())
45 return op->print(os);
46 if (auto value = dyn_cast<Value>())
47 return value.print(os);
48 return get<Block *>()->print(os);
49 }
50
getLoc() const51 Location ProgramPoint::getLoc() const {
52 if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
53 return programPoint->getLoc();
54 if (auto *op = dyn_cast<Operation *>())
55 return op->getLoc();
56 if (auto value = dyn_cast<Value>())
57 return value.getLoc();
58 return get<Block *>()->getParent()->getLoc();
59 }
60
61 //===----------------------------------------------------------------------===//
62 // DataFlowSolver
63 //===----------------------------------------------------------------------===//
64
initializeAndRun(Operation * top)65 LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
66 // Initialize the analyses.
67 for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
68 DATAFLOW_DEBUG(llvm::dbgs()
69 << "Priming analysis: " << analysis.debugName << "\n");
70 if (failed(analysis.initialize(top)))
71 return failure();
72 }
73
74 // Run the analysis until fixpoint.
75 ProgramPoint point;
76 DataFlowAnalysis *analysis;
77
78 do {
79 // Exhaust the worklist.
80 while (!worklist.empty()) {
81 std::tie(point, analysis) = worklist.front();
82 worklist.pop();
83
84 DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
85 << "' on: " << point << "\n");
86 if (failed(analysis->visit(point)))
87 return failure();
88 }
89
90 // Iterate until all states are in some initialized state and the worklist
91 // is exhausted.
92 } while (!worklist.empty());
93
94 return success();
95 }
96
propagateIfChanged(AnalysisState * state,ChangeResult changed)97 void DataFlowSolver::propagateIfChanged(AnalysisState *state,
98 ChangeResult changed) {
99 if (changed == ChangeResult::Change) {
100 DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
101 << " of " << state->point << "\n"
102 << "Value: " << *state << "\n");
103 for (const WorkItem &item : state->dependents)
104 enqueue(item);
105 state->onUpdate(this);
106 }
107 }
108
addDependency(AnalysisState * state,DataFlowAnalysis * analysis,ProgramPoint point)109 void DataFlowSolver::addDependency(AnalysisState *state,
110 DataFlowAnalysis *analysis,
111 ProgramPoint point) {
112 auto inserted = state->dependents.insert({point, analysis});
113 (void)inserted;
114 DATAFLOW_DEBUG({
115 if (inserted) {
116 llvm::dbgs() << "Creating dependency between " << state->debugName
117 << " of " << state->point << "\nand " << analysis->debugName
118 << " on " << point << "\n";
119 }
120 });
121 }
122
123 //===----------------------------------------------------------------------===//
124 // DataFlowAnalysis
125 //===----------------------------------------------------------------------===//
126
127 DataFlowAnalysis::~DataFlowAnalysis() = default;
128
DataFlowAnalysis(DataFlowSolver & solver)129 DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
130
addDependency(AnalysisState * state,ProgramPoint point)131 void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
132 solver.addDependency(state, this, point);
133 }
134
propagateIfChanged(AnalysisState * state,ChangeResult changed)135 void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,
136 ChangeResult changed) {
137 solver.propagateIfChanged(state, changed);
138 }
139