1 //===- DataFlowFramework.cpp - A generic framework for data-flow analysis -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlowFramework.h"
10 #include "llvm/Support/Debug.h"
11 
12 #define DEBUG_TYPE "dataflow"
13 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
14 #define DATAFLOW_DEBUG(X) LLVM_DEBUG(X)
15 #else
16 #define DATAFLOW_DEBUG(X)
17 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
18 
19 using namespace mlir;
20 
21 //===----------------------------------------------------------------------===//
22 // GenericProgramPoint
23 //===----------------------------------------------------------------------===//
24 
25 GenericProgramPoint::~GenericProgramPoint() = default;
26 
27 //===----------------------------------------------------------------------===//
28 // AnalysisState
29 //===----------------------------------------------------------------------===//
30 
31 AnalysisState::~AnalysisState() = default;
32 
33 //===----------------------------------------------------------------------===//
34 // ProgramPoint
35 //===----------------------------------------------------------------------===//
36 
37 void ProgramPoint::print(raw_ostream &os) const {
38   if (isNull()) {
39     os << "<NULL POINT>";
40     return;
41   }
42   if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
43     return programPoint->print(os);
44   if (auto *op = dyn_cast<Operation *>())
45     return op->print(os);
46   if (auto value = dyn_cast<Value>())
47     return value.print(os);
48   return get<Block *>()->print(os);
49 }
50 
51 Location ProgramPoint::getLoc() const {
52   if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
53     return programPoint->getLoc();
54   if (auto *op = dyn_cast<Operation *>())
55     return op->getLoc();
56   if (auto value = dyn_cast<Value>())
57     return value.getLoc();
58   return get<Block *>()->getParent()->getLoc();
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // DataFlowSolver
63 //===----------------------------------------------------------------------===//
64 
65 LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
66   // Initialize the analyses.
67   for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
68     DATAFLOW_DEBUG(llvm::dbgs()
69                    << "Priming analysis: " << analysis.debugName << "\n");
70     if (failed(analysis.initialize(top)))
71       return failure();
72   }
73 
74   // Run the analysis until fixpoint.
75   ProgramPoint point;
76   DataFlowAnalysis *analysis;
77 
78   do {
79     // Exhaust the worklist.
80     while (!worklist.empty()) {
81       std::tie(point, analysis) = worklist.front();
82       worklist.pop();
83 
84       DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
85                                   << "' on: " << point << "\n");
86       if (failed(analysis->visit(point)))
87         return failure();
88     }
89 
90     // Iterate until all states are in some initialized state and the worklist
91     // is exhausted.
92   } while (!worklist.empty());
93 
94   return success();
95 }
96 
97 void DataFlowSolver::propagateIfChanged(AnalysisState *state,
98                                         ChangeResult changed) {
99   if (changed == ChangeResult::Change) {
100     DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
101                                 << " of " << state->point << "\n"
102                                 << "Value: " << *state << "\n");
103     for (const WorkItem &item : state->dependents)
104       enqueue(item);
105     state->onUpdate(this);
106   }
107 }
108 
109 void DataFlowSolver::addDependency(AnalysisState *state,
110                                    DataFlowAnalysis *analysis,
111                                    ProgramPoint point) {
112   auto inserted = state->dependents.insert({point, analysis});
113   (void)inserted;
114   DATAFLOW_DEBUG({
115     if (inserted) {
116       llvm::dbgs() << "Creating dependency between " << state->debugName
117                    << " of " << state->point << "\nand " << analysis->debugName
118                    << " on " << point << "\n";
119     }
120   });
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // DataFlowAnalysis
125 //===----------------------------------------------------------------------===//
126 
127 DataFlowAnalysis::~DataFlowAnalysis() = default;
128 
129 DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
130 
131 void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
132   solver.addDependency(state, this, point);
133 }
134 
135 void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,
136                                           ChangeResult changed) {
137   solver.propagateIfChanged(state, changed);
138 }
139