1 //===- DataFlowFramework.h - A generic framework for data-flow analysis ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a generic framework for writing data-flow analysis in MLIR.
10 // The framework consists of a solver, which runs the fixed-point iteration and
11 // manages analysis dependencies, and a data-flow analysis class used to
12 // implement specific analyses.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
17 #define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
18
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Support/StorageUniquer.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/TypeName.h"
23 #include <queue>
24
25 namespace mlir {
26
27 //===----------------------------------------------------------------------===//
28 // ChangeResult
29 //===----------------------------------------------------------------------===//
30
31 /// A result type used to indicate if a change happened. Boolean operations on
32 /// ChangeResult behave as though `Change` is truthy.
33 enum class ChangeResult {
34 NoChange,
35 Change,
36 };
37 inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
38 return lhs == ChangeResult::Change ? lhs : rhs;
39 }
40 inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) {
41 lhs = lhs | rhs;
42 return lhs;
43 }
44 inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
45 return lhs == ChangeResult::NoChange ? lhs : rhs;
46 }
47
48 /// Forward declare the analysis state class.
49 class AnalysisState;
50
51 //===----------------------------------------------------------------------===//
52 // GenericProgramPoint
53 //===----------------------------------------------------------------------===//
54
55 /// Abstract class for generic program points. In classical data-flow analysis,
56 /// programs points represent positions in a program to which lattice elements
57 /// are attached. In sparse data-flow analysis, these can be SSA values, and in
58 /// dense data-flow analysis, these are the program points before and after
59 /// every operation.
60 ///
61 /// In the general MLIR data-flow analysis framework, program points are an
62 /// extensible concept. Program points are uniquely identifiable objects to
63 /// which analysis states can be attached. The semantics of program points are
64 /// defined by the analyses that specify their transfer functions.
65 ///
66 /// Program points are implemented using MLIR's storage uniquer framework and
67 /// type ID system to provide RTTI.
68 class GenericProgramPoint : public StorageUniquer::BaseStorage {
69 public:
70 virtual ~GenericProgramPoint();
71
72 /// Get the abstract program point's type identifier.
getTypeID()73 TypeID getTypeID() const { return typeID; }
74
75 /// Get a derived source location for the program point.
76 virtual Location getLoc() const = 0;
77
78 /// Print the program point.
79 virtual void print(raw_ostream &os) const = 0;
80
81 protected:
82 /// Create an abstract program point with type identifier.
GenericProgramPoint(TypeID typeID)83 explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {}
84
85 private:
86 /// The type identifier of the program point.
87 TypeID typeID;
88 };
89
90 //===----------------------------------------------------------------------===//
91 // GenericProgramPointBase
92 //===----------------------------------------------------------------------===//
93
94 /// Base class for generic program points based on a concrete program point
95 /// type and a content key. This class defines the common methods required for
96 /// operability with the storage uniquer framework.
97 ///
98 /// The provided key type uniquely identifies the concrete program point
99 /// instance and are the data members of the class.
100 template <typename ConcreteT, typename Value>
101 class GenericProgramPointBase : public GenericProgramPoint {
102 public:
103 /// The concrete key type used by the storage uniquer. This class is uniqued
104 /// by its contents.
105 using KeyTy = Value;
106 /// Alias for the base class.
107 using Base = GenericProgramPointBase<ConcreteT, Value>;
108
109 /// Construct an instance of the program point using the provided value and
110 /// the type ID of the concrete type.
111 template <typename ValueT>
GenericProgramPointBase(ValueT && value)112 explicit GenericProgramPointBase(ValueT &&value)
113 : GenericProgramPoint(TypeID::get<ConcreteT>()),
114 value(std::forward<ValueT>(value)) {}
115
116 /// Get a uniqued instance of this program point class with the given
117 /// arguments.
118 template <typename... Args>
get(StorageUniquer & uniquer,Args &&...args)119 static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
120 return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
121 }
122
123 /// Allocate space for a program point and construct it in-place.
124 template <typename ValueT>
construct(StorageUniquer::StorageAllocator & alloc,ValueT && value)125 static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
126 ValueT &&value) {
127 return new (alloc.allocate<ConcreteT>())
128 ConcreteT(std::forward<ValueT>(value));
129 }
130
131 /// Two program points are equal if their values are equal.
132 bool operator==(const Value &value) const { return this->value == value; }
133
134 /// Provide LLVM-style RTTI using type IDs.
classof(const GenericProgramPoint * point)135 static bool classof(const GenericProgramPoint *point) {
136 return point->getTypeID() == TypeID::get<ConcreteT>();
137 }
138
139 /// Get the contents of the program point.
getValue()140 const Value &getValue() const { return value; }
141
142 private:
143 /// The program point value.
144 Value value;
145 };
146
147 //===----------------------------------------------------------------------===//
148 // ProgramPoint
149 //===----------------------------------------------------------------------===//
150
151 /// Fundamental IR components are supported as first-class program points.
152 struct ProgramPoint
153 : public PointerUnion<GenericProgramPoint *, Operation *, Value, Block *> {
154 using ParentTy =
155 PointerUnion<GenericProgramPoint *, Operation *, Value, Block *>;
156 /// Inherit constructors.
157 using ParentTy::PointerUnion;
158 /// Allow implicit conversion from the parent type.
ParentTyProgramPoint159 ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
160 /// Allow implicit conversions from operation wrappers.
161 /// TODO: For Windows only. Find a better solution.
162 template <typename OpT, typename = typename std::enable_if_t<
163 std::is_convertible<OpT, Operation *>::value &&
164 !std::is_same<OpT, Operation *>::value>>
ProgramPointProgramPoint165 ProgramPoint(OpT op) : ParentTy(op) {}
166
167 /// Print the program point.
168 void print(raw_ostream &os) const;
169
170 /// Get the source location of the program point.
171 Location getLoc() const;
172 };
173
174 /// Forward declaration of the data-flow analysis class.
175 class DataFlowAnalysis;
176
177 //===----------------------------------------------------------------------===//
178 // DataFlowSolver
179 //===----------------------------------------------------------------------===//
180
181 /// The general data-flow analysis solver. This class is responsible for
182 /// orchestrating child data-flow analyses, running the fixed-point iteration
183 /// algorithm, managing analysis state and program point memory, and tracking
184 /// dependencies beteen analyses, program points, and analysis states.
185 ///
186 /// Steps to run a data-flow analysis:
187 ///
188 /// 1. Load and initialize children analyses. Children analyses are instantiated
189 /// in the solver and initialized, building their dependency relations.
190 /// 2. Configure and run the analysis. The solver invokes the children analyses
191 /// according to their dependency relations until a fixed point is reached.
192 /// 3. Query analysis state results from the solver.
193 ///
194 /// TODO: Optimize the internal implementation of the solver.
195 class DataFlowSolver {
196 public:
197 /// Load an analysis into the solver. Return the analysis instance.
198 template <typename AnalysisT, typename... Args>
199 AnalysisT *load(Args &&...args);
200
201 /// Initialize the children analyses starting from the provided top-level
202 /// operation and run the analysis until fixpoint.
203 LogicalResult initializeAndRun(Operation *top);
204
205 /// Lookup an analysis state for the given program point. Returns null if one
206 /// does not exist.
207 template <typename StateT, typename PointT>
lookupState(PointT point)208 const StateT *lookupState(PointT point) const {
209 auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
210 if (it == analysisStates.end())
211 return nullptr;
212 return static_cast<const StateT *>(it->second.get());
213 }
214
215 /// Get a uniqued program point instance. If one is not present, it is
216 /// created with the provided arguments.
217 template <typename PointT, typename... Args>
getProgramPoint(Args &&...args)218 PointT *getProgramPoint(Args &&...args) {
219 return PointT::get(uniquer, std::forward<Args>(args)...);
220 }
221
222 /// A work item on the solver queue is a program point, child analysis pair.
223 /// Each item is processed by invoking the child analysis at the program
224 /// point.
225 using WorkItem = std::pair<ProgramPoint, DataFlowAnalysis *>;
226 /// Push a work item onto the worklist.
enqueue(WorkItem item)227 void enqueue(WorkItem item) { worklist.push(std::move(item)); }
228
229 /// Get the state associated with the given program point. If it does not
230 /// exist, create an uninitialized state.
231 template <typename StateT, typename PointT>
232 StateT *getOrCreateState(PointT point);
233
234 /// Propagate an update to an analysis state if it changed by pushing
235 /// dependent work items to the back of the queue.
236 void propagateIfChanged(AnalysisState *state, ChangeResult changed);
237
238 /// Add a dependency to an analysis state on a child analysis and program
239 /// point. If the state is updated, the child analysis must be invoked on the
240 /// given program point again.
241 void addDependency(AnalysisState *state, DataFlowAnalysis *analysis,
242 ProgramPoint point);
243
244 private:
245 /// The solver's work queue. Work items can be inserted to the front of the
246 /// queue to be processed greedily, speeding up computations that otherwise
247 /// quickly degenerate to quadratic due to propagation of state updates.
248 std::queue<WorkItem> worklist;
249
250 /// Type-erased instances of the children analyses.
251 SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
252
253 /// The storage uniquer instance that owns the memory of the allocated program
254 /// points.
255 StorageUniquer uniquer;
256
257 /// A type-erased map of program points to associated analysis states for
258 /// first-class program points.
259 DenseMap<std::pair<ProgramPoint, TypeID>, std::unique_ptr<AnalysisState>>
260 analysisStates;
261
262 /// Allow the base child analysis class to access the internals of the solver.
263 friend class DataFlowAnalysis;
264 };
265
266 //===----------------------------------------------------------------------===//
267 // AnalysisState
268 //===----------------------------------------------------------------------===//
269
270 /// Base class for generic analysis states. Analysis states contain data-flow
271 /// information that are attached to program points and which evolve as the
272 /// analysis iterates.
273 ///
274 /// This class places no restrictions on the semantics of analysis states beyond
275 /// these requirements.
276 ///
277 /// 1. Querying the state of a program point prior to visiting that point
278 /// results in uninitialized state. Analyses must be aware of unintialized
279 /// states.
280 /// 2. Analysis states can reach fixpoints, where subsequent updates will never
281 /// trigger a change in the state.
282 /// 3. Analysis states that are uninitialized can be forcefully initialized to a
283 /// default value.
284 class AnalysisState {
285 public:
286 virtual ~AnalysisState();
287
288 /// Create the analysis state at the given program point.
AnalysisState(ProgramPoint point)289 AnalysisState(ProgramPoint point) : point(point) {}
290
291 /// Returns true if the analysis state is uninitialized.
292 virtual bool isUninitialized() const = 0;
293
294 /// Force an uninitialized analysis state to initialize itself with a default
295 /// value.
296 virtual ChangeResult defaultInitialize() = 0;
297
298 /// Print the contents of the analysis state.
299 virtual void print(raw_ostream &os) const = 0;
300
301 protected:
302 /// This function is called by the solver when the analysis state is updated
303 /// to optionally enqueue more work items. For example, if a state tracks
304 /// dependents through the IR (e.g. use-def chains), this function can be
305 /// implemented to push those dependents on the worklist.
onUpdate(DataFlowSolver * solver)306 virtual void onUpdate(DataFlowSolver *solver) const {}
307
308 /// The dependency relations originating from this analysis state. An entry
309 /// `state -> (analysis, point)` is created when `analysis` queries `state`
310 /// when updating `point`.
311 ///
312 /// When this state is updated, all dependent child analysis invocations are
313 /// pushed to the back of the queue. Use a `SetVector` to keep the analysis
314 /// deterministic.
315 ///
316 /// Store the dependents on the analysis state for efficiency.
317 SetVector<DataFlowSolver::WorkItem> dependents;
318
319 /// The program point to which the state belongs.
320 ProgramPoint point;
321
322 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
323 /// When compiling with debugging, keep a name for the analysis state.
324 StringRef debugName;
325 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
326
327 /// Allow the framework to access the dependents.
328 friend class DataFlowSolver;
329 };
330
331 //===----------------------------------------------------------------------===//
332 // DataFlowAnalysis
333 //===----------------------------------------------------------------------===//
334
335 /// Base class for all data-flow analyses. A child analysis is expected to build
336 /// an initial dependency graph (and optionally provide an initial state) when
337 /// initialized and define transfer functions when visiting program points.
338 ///
339 /// In classical data-flow analysis, the dependency graph is fixed and analyses
340 /// define explicit transfer functions between input states and output states.
341 /// In this framework, however, the dependency graph can change during the
342 /// analysis, and transfer functions are opaque such that the solver doesn't
343 /// know what states calling `visit` on an analysis will be updated. This allows
344 /// multiple analyses to plug in and provide values for the same state.
345 ///
346 /// Generally, when an analysis queries an uninitialized state, it is expected
347 /// to "bail out", i.e., not provide any updates. When the value is initialized,
348 /// the solver will re-invoke the analysis. If the solver exhausts its worklist,
349 /// however, and there are still uninitialized states, the solver "nudges" the
350 /// analyses by default-initializing those states.
351 class DataFlowAnalysis {
352 public:
353 virtual ~DataFlowAnalysis();
354
355 /// Create an analysis with a reference to the parent solver.
356 explicit DataFlowAnalysis(DataFlowSolver &solver);
357
358 /// Initialize the analysis from the provided top-level operation by building
359 /// an initial dependency graph between all program points of interest. This
360 /// can be implemented by calling `visit` on all program points of interest
361 /// below the top-level operation.
362 ///
363 /// An analysis can optionally provide initial values to certain analysis
364 /// states to influence the evolution of the analysis.
365 virtual LogicalResult initialize(Operation *top) = 0;
366
367 /// Visit the given program point. This function is invoked by the solver on
368 /// this analysis with a given program point when a dependent analysis state
369 /// is updated. The function is similar to a transfer function; it queries
370 /// certain analysis states and sets other states.
371 ///
372 /// The function is expected to create dependencies on queried states and
373 /// propagate updates on changed states. A dependency can be created by
374 /// calling `addDependency` between the input state and a program point,
375 /// indicating that, if the state is updated, the solver should invoke `solve`
376 /// on the program point. The dependent point does not have to be the same as
377 /// the provided point. An update to a state is propagated by calling
378 /// `propagateIfChange` on the state. If the state has changed, then all its
379 /// dependents are placed on the worklist.
380 ///
381 /// The dependency graph does not need to be static. Each invocation of
382 /// `visit` can add new dependencies, but these dependecies will not be
383 /// dynamically added to the worklist because the solver doesn't know what
384 /// will provide a value for then.
385 virtual LogicalResult visit(ProgramPoint point) = 0;
386
387 protected:
388 /// Create a dependency between the given analysis state and program point
389 /// on this analysis.
390 void addDependency(AnalysisState *state, ProgramPoint point);
391
392 /// Propagate an update to a state if it changed.
393 void propagateIfChanged(AnalysisState *state, ChangeResult changed);
394
395 /// Register a custom program point class.
396 template <typename PointT>
registerPointKind()397 void registerPointKind() {
398 solver.uniquer.registerParametricStorageType<PointT>();
399 }
400
401 /// Get or create a custom program point.
402 template <typename PointT, typename... Args>
getProgramPoint(Args &&...args)403 PointT *getProgramPoint(Args &&...args) {
404 return solver.getProgramPoint<PointT>(std::forward<Args>(args)...);
405 }
406
407 /// Get the analysis state assiocated with the program point. The returned
408 /// state is expected to be "write-only", and any updates need to be
409 /// propagated by `propagateIfChanged`.
410 template <typename StateT, typename PointT>
getOrCreate(PointT point)411 StateT *getOrCreate(PointT point) {
412 return solver.getOrCreateState<StateT>(point);
413 }
414
415 /// Get a read-only analysis state for the given point and create a dependency
416 /// on `dependent`. If the return state is updated elsewhere, this analysis is
417 /// re-invoked on the dependent.
418 template <typename StateT, typename PointT>
getOrCreateFor(ProgramPoint dependent,PointT point)419 const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) {
420 StateT *state = getOrCreate<StateT>(point);
421 addDependency(state, dependent);
422 return state;
423 }
424
425 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
426 /// When compiling with debugging, keep a name for the analyis.
427 StringRef debugName;
428 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
429
430 private:
431 /// The parent data-flow solver.
432 DataFlowSolver &solver;
433
434 /// Allow the data-flow solver to access the internals of this class.
435 friend class DataFlowSolver;
436 };
437
438 template <typename AnalysisT, typename... Args>
load(Args &&...args)439 AnalysisT *DataFlowSolver::load(Args &&...args) {
440 childAnalyses.emplace_back(new AnalysisT(*this, std::forward<Args>(args)...));
441 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
442 childAnalyses.back().get()->debugName = llvm::getTypeName<AnalysisT>();
443 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
444 return static_cast<AnalysisT *>(childAnalyses.back().get());
445 }
446
447 template <typename StateT, typename PointT>
getOrCreateState(PointT point)448 StateT *DataFlowSolver::getOrCreateState(PointT point) {
449 std::unique_ptr<AnalysisState> &state =
450 analysisStates[{ProgramPoint(point), TypeID::get<StateT>()}];
451 if (!state) {
452 state = std::unique_ptr<StateT>(new StateT(point));
453 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
454 state->debugName = llvm::getTypeName<StateT>();
455 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
456 }
457 return static_cast<StateT *>(state.get());
458 }
459
460 inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
461 state.print(os);
462 return os;
463 }
464
465 inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) {
466 point.print(os);
467 return os;
468 }
469
470 } // end namespace mlir
471
472 namespace llvm {
473 /// Allow hashing of program points.
474 template <>
475 struct DenseMapInfo<mlir::ProgramPoint>
476 : public DenseMapInfo<mlir::ProgramPoint::ParentTy> {};
477 } // end namespace llvm
478
479 #endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
480