1 //===- DataFlowFramework.cpp - A generic framework for data-flow analysis -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlowFramework.h"
10 #include "llvm/Support/Debug.h"
11 
12 #define DEBUG_TYPE "dataflow"
13 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
14 #define DATAFLOW_DEBUG(X) LLVM_DEBUG(X)
15 #else
16 #define DATAFLOW_DEBUG(X)
17 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
18 
19 using namespace mlir;
20 
21 //===----------------------------------------------------------------------===//
22 // GenericProgramPoint
23 //===----------------------------------------------------------------------===//
24 
25 GenericProgramPoint::~GenericProgramPoint() = default;
26 
27 //===----------------------------------------------------------------------===//
28 // AnalysisState
29 //===----------------------------------------------------------------------===//
30 
31 AnalysisState::~AnalysisState() = default;
32 
33 //===----------------------------------------------------------------------===//
34 // ProgramPoint
35 //===----------------------------------------------------------------------===//
36 
37 void ProgramPoint::print(raw_ostream &os) const {
38   if (isNull()) {
39     os << "<NULL POINT>";
40     return;
41   }
42   if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
43     return programPoint->print(os);
44   if (auto *op = dyn_cast<Operation *>())
45     return op->print(os);
46   if (auto value = dyn_cast<Value>())
47     return value.print(os);
48   return get<Block *>()->print(os);
49 }
50 
51 Location ProgramPoint::getLoc() const {
52   if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
53     return programPoint->getLoc();
54   if (auto *op = dyn_cast<Operation *>())
55     return op->getLoc();
56   if (auto value = dyn_cast<Value>())
57     return value.getLoc();
58   return get<Block *>()->getParent()->getLoc();
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // DataFlowSolver
63 //===----------------------------------------------------------------------===//
64 
65 LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
66   // Initialize the analyses.
67   for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
68     DATAFLOW_DEBUG(llvm::dbgs()
69                    << "Priming analysis: " << analysis.debugName << "\n");
70     if (failed(analysis.initialize(top)))
71       return failure();
72   }
73 
74   // Run the analysis until fixpoint.
75   ProgramPoint point;
76   DataFlowAnalysis *analysis;
77 
78   do {
79     // Exhaust the worklist.
80     while (!worklist.empty()) {
81       std::tie(point, analysis) = worklist.front();
82       worklist.pop();
83 
84       DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
85                                   << "' on: " << point << "\n");
86       if (failed(analysis->visit(point)))
87         return failure();
88     }
89 
90     // "Nudge" the state of the analysis by forcefully initializing states that
91     // are still uninitialized. All uninitialized states in the graph can be
92     // initialized in any order because the analysis reached fixpoint, meaning
93     // that there are no work items that would have further nudged the analysis.
94     for (AnalysisState &state :
95          llvm::make_pointee_range(llvm::make_second_range(analysisStates))) {
96       if (!state.isUninitialized())
97         continue;
98       DATAFLOW_DEBUG(llvm::dbgs() << "Default initializing " << state.debugName
99                                   << " of " << state.point << "\n");
100       propagateIfChanged(&state, state.defaultInitialize());
101     }
102 
103     // Iterate until all states are in some initialized state and the worklist
104     // is exhausted.
105   } while (!worklist.empty());
106 
107   return success();
108 }
109 
110 void DataFlowSolver::propagateIfChanged(AnalysisState *state,
111                                         ChangeResult changed) {
112   if (changed == ChangeResult::Change) {
113     DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
114                                 << " of " << state->point << "\n"
115                                 << "Value: " << *state << "\n");
116     for (const WorkItem &item : state->dependents)
117       enqueue(item);
118     state->onUpdate(this);
119   }
120 }
121 
122 void DataFlowSolver::addDependency(AnalysisState *state,
123                                    DataFlowAnalysis *analysis,
124                                    ProgramPoint point) {
125   auto inserted = state->dependents.insert({point, analysis});
126   (void)inserted;
127   DATAFLOW_DEBUG({
128     if (inserted) {
129       llvm::dbgs() << "Creating dependency between " << state->debugName
130                    << " of " << state->point << "\nand " << analysis->debugName
131                    << " on " << point << "\n";
132     }
133   });
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // DataFlowAnalysis
138 //===----------------------------------------------------------------------===//
139 
140 DataFlowAnalysis::~DataFlowAnalysis() = default;
141 
142 DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
143 
144 void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
145   solver.addDependency(state, this, point);
146 }
147 
148 void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,
149                                           ChangeResult changed) {
150   solver.propagateIfChanged(state, changed);
151 }
152