1 //===- DataFlowFramework.h - A generic framework for data-flow analysis ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a generic framework for writing data-flow analysis in MLIR.
10 // The framework consists of a solver, which runs the fixed-point iteration and
11 // manages analysis dependencies, and a data-flow analysis class used to
12 // implement specific analyses.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
17 #define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
18 
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Support/StorageUniquer.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/TypeName.h"
23 #include <queue>
24 
25 namespace mlir {
26 
27 //===----------------------------------------------------------------------===//
28 // ChangeResult
29 //===----------------------------------------------------------------------===//
30 
31 /// A result type used to indicate if a change happened. Boolean operations on
32 /// ChangeResult behave as though `Change` is truthy.
33 enum class ChangeResult {
34   NoChange,
35   Change,
36 };
37 inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
38   return lhs == ChangeResult::Change ? lhs : rhs;
39 }
40 inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) {
41   lhs = lhs | rhs;
42   return lhs;
43 }
44 inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
45   return lhs == ChangeResult::NoChange ? lhs : rhs;
46 }
47 
48 /// Forward declare the analysis state class.
49 class AnalysisState;
50 
51 //===----------------------------------------------------------------------===//
52 // GenericProgramPoint
53 //===----------------------------------------------------------------------===//
54 
55 /// Abstract class for generic program points. In classical data-flow analysis,
56 /// programs points represent positions in a program to which lattice elements
57 /// are attached. In sparse data-flow analysis, these can be SSA values, and in
58 /// dense data-flow analysis, these are the program points before and after
59 /// every operation.
60 ///
61 /// In the general MLIR data-flow analysis framework, program points are an
62 /// extensible concept. Program points are uniquely identifiable objects to
63 /// which analysis states can be attached. The semantics of program points are
64 /// defined by the analyses that specify their transfer functions.
65 ///
66 /// Program points are implemented using MLIR's storage uniquer framework and
67 /// type ID system to provide RTTI.
68 class GenericProgramPoint : public StorageUniquer::BaseStorage {
69 public:
70   virtual ~GenericProgramPoint();
71 
72   /// Get the abstract program point's type identifier.
getTypeID()73   TypeID getTypeID() const { return typeID; }
74 
75   /// Get a derived source location for the program point.
76   virtual Location getLoc() const = 0;
77 
78   /// Print the program point.
79   virtual void print(raw_ostream &os) const = 0;
80 
81 protected:
82   /// Create an abstract program point with type identifier.
GenericProgramPoint(TypeID typeID)83   explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {}
84 
85 private:
86   /// The type identifier of the program point.
87   TypeID typeID;
88 };
89 
90 //===----------------------------------------------------------------------===//
91 // GenericProgramPointBase
92 //===----------------------------------------------------------------------===//
93 
94 /// Base class for generic program points based on a concrete program point
95 /// type and a content key. This class defines the common methods required for
96 /// operability with the storage uniquer framework.
97 ///
98 /// The provided key type uniquely identifies the concrete program point
99 /// instance and are the data members of the class.
100 template <typename ConcreteT, typename Value>
101 class GenericProgramPointBase : public GenericProgramPoint {
102 public:
103   /// The concrete key type used by the storage uniquer. This class is uniqued
104   /// by its contents.
105   using KeyTy = Value;
106   /// Alias for the base class.
107   using Base = GenericProgramPointBase<ConcreteT, Value>;
108 
109   /// Construct an instance of the program point using the provided value and
110   /// the type ID of the concrete type.
111   template <typename ValueT>
GenericProgramPointBase(ValueT && value)112   explicit GenericProgramPointBase(ValueT &&value)
113       : GenericProgramPoint(TypeID::get<ConcreteT>()),
114         value(std::forward<ValueT>(value)) {}
115 
116   /// Get a uniqued instance of this program point class with the given
117   /// arguments.
118   template <typename... Args>
get(StorageUniquer & uniquer,Args &&...args)119   static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
120     return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
121   }
122 
123   /// Allocate space for a program point and construct it in-place.
124   template <typename ValueT>
construct(StorageUniquer::StorageAllocator & alloc,ValueT && value)125   static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
126                               ValueT &&value) {
127     return new (alloc.allocate<ConcreteT>())
128         ConcreteT(std::forward<ValueT>(value));
129   }
130 
131   /// Two program points are equal if their values are equal.
132   bool operator==(const Value &value) const { return this->value == value; }
133 
134   /// Provide LLVM-style RTTI using type IDs.
classof(const GenericProgramPoint * point)135   static bool classof(const GenericProgramPoint *point) {
136     return point->getTypeID() == TypeID::get<ConcreteT>();
137   }
138 
139   /// Get the contents of the program point.
getValue()140   const Value &getValue() const { return value; }
141 
142 private:
143   /// The program point value.
144   Value value;
145 };
146 
147 //===----------------------------------------------------------------------===//
148 // ProgramPoint
149 //===----------------------------------------------------------------------===//
150 
151 /// Fundamental IR components are supported as first-class program points.
152 struct ProgramPoint
153     : public PointerUnion<GenericProgramPoint *, Operation *, Value, Block *> {
154   using ParentTy =
155       PointerUnion<GenericProgramPoint *, Operation *, Value, Block *>;
156   /// Inherit constructors.
157   using ParentTy::PointerUnion;
158   /// Allow implicit conversion from the parent type.
ParentTyProgramPoint159   ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
160   /// Allow implicit conversions from operation wrappers.
161   /// TODO: For Windows only. Find a better solution.
162   template <typename OpT, typename = typename std::enable_if_t<
163                               std::is_convertible<OpT, Operation *>::value &&
164                               !std::is_same<OpT, Operation *>::value>>
ProgramPointProgramPoint165   ProgramPoint(OpT op) : ParentTy(op) {}
166 
167   /// Print the program point.
168   void print(raw_ostream &os) const;
169 
170   /// Get the source location of the program point.
171   Location getLoc() const;
172 };
173 
174 /// Forward declaration of the data-flow analysis class.
175 class DataFlowAnalysis;
176 
177 //===----------------------------------------------------------------------===//
178 // DataFlowSolver
179 //===----------------------------------------------------------------------===//
180 
181 /// The general data-flow analysis solver. This class is responsible for
182 /// orchestrating child data-flow analyses, running the fixed-point iteration
183 /// algorithm, managing analysis state and program point memory, and tracking
184 /// dependencies beteen analyses, program points, and analysis states.
185 ///
186 /// Steps to run a data-flow analysis:
187 ///
188 /// 1. Load and initialize children analyses. Children analyses are instantiated
189 ///    in the solver and initialized, building their dependency relations.
190 /// 2. Configure and run the analysis. The solver invokes the children analyses
191 ///    according to their dependency relations until a fixed point is reached.
192 /// 3. Query analysis state results from the solver.
193 ///
194 /// TODO: Optimize the internal implementation of the solver.
195 class DataFlowSolver {
196 public:
197   /// Load an analysis into the solver. Return the analysis instance.
198   template <typename AnalysisT, typename... Args>
199   AnalysisT *load(Args &&...args);
200 
201   /// Initialize the children analyses starting from the provided top-level
202   /// operation and run the analysis until fixpoint.
203   LogicalResult initializeAndRun(Operation *top);
204 
205   /// Lookup an analysis state for the given program point. Returns null if one
206   /// does not exist.
207   template <typename StateT, typename PointT>
lookupState(PointT point)208   const StateT *lookupState(PointT point) const {
209     auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
210     if (it == analysisStates.end())
211       return nullptr;
212     return static_cast<const StateT *>(it->second.get());
213   }
214 
215   /// Get a uniqued program point instance. If one is not present, it is
216   /// created with the provided arguments.
217   template <typename PointT, typename... Args>
getProgramPoint(Args &&...args)218   PointT *getProgramPoint(Args &&...args) {
219     return PointT::get(uniquer, std::forward<Args>(args)...);
220   }
221 
222   /// A work item on the solver queue is a program point, child analysis pair.
223   /// Each item is processed by invoking the child analysis at the program
224   /// point.
225   using WorkItem = std::pair<ProgramPoint, DataFlowAnalysis *>;
226   /// Push a work item onto the worklist.
enqueue(WorkItem item)227   void enqueue(WorkItem item) { worklist.push(std::move(item)); }
228 
229   /// Get the state associated with the given program point. If it does not
230   /// exist, create an uninitialized state.
231   template <typename StateT, typename PointT>
232   StateT *getOrCreateState(PointT point);
233 
234   /// Propagate an update to an analysis state if it changed by pushing
235   /// dependent work items to the back of the queue.
236   void propagateIfChanged(AnalysisState *state, ChangeResult changed);
237 
238   /// Add a dependency to an analysis state on a child analysis and program
239   /// point. If the state is updated, the child analysis must be invoked on the
240   /// given program point again.
241   void addDependency(AnalysisState *state, DataFlowAnalysis *analysis,
242                      ProgramPoint point);
243 
244 private:
245   /// The solver's work queue. Work items can be inserted to the front of the
246   /// queue to be processed greedily, speeding up computations that otherwise
247   /// quickly degenerate to quadratic due to propagation of state updates.
248   std::queue<WorkItem> worklist;
249 
250   /// Type-erased instances of the children analyses.
251   SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
252 
253   /// The storage uniquer instance that owns the memory of the allocated program
254   /// points.
255   StorageUniquer uniquer;
256 
257   /// A type-erased map of program points to associated analysis states for
258   /// first-class program points.
259   DenseMap<std::pair<ProgramPoint, TypeID>, std::unique_ptr<AnalysisState>>
260       analysisStates;
261 
262   /// Allow the base child analysis class to access the internals of the solver.
263   friend class DataFlowAnalysis;
264 };
265 
266 //===----------------------------------------------------------------------===//
267 // AnalysisState
268 //===----------------------------------------------------------------------===//
269 
270 /// Base class for generic analysis states. Analysis states contain data-flow
271 /// information that are attached to program points and which evolve as the
272 /// analysis iterates.
273 ///
274 /// This class places no restrictions on the semantics of analysis states beyond
275 /// these requirements.
276 ///
277 /// 1. Querying the state of a program point prior to visiting that point
278 ///    results in uninitialized state. Analyses must be aware of unintialized
279 ///    states.
280 /// 2. Analysis states can reach fixpoints, where subsequent updates will never
281 ///    trigger a change in the state.
282 /// 3. Analysis states that are uninitialized can be forcefully initialized to a
283 ///    default value.
284 class AnalysisState {
285 public:
286   virtual ~AnalysisState();
287 
288   /// Create the analysis state at the given program point.
AnalysisState(ProgramPoint point)289   AnalysisState(ProgramPoint point) : point(point) {}
290 
291   /// Returns true if the analysis state is uninitialized.
292   virtual bool isUninitialized() const = 0;
293 
294   /// Force an uninitialized analysis state to initialize itself with a default
295   /// value.
296   virtual ChangeResult defaultInitialize() = 0;
297 
298   /// Print the contents of the analysis state.
299   virtual void print(raw_ostream &os) const = 0;
300 
301 protected:
302   /// This function is called by the solver when the analysis state is updated
303   /// to optionally enqueue more work items. For example, if a state tracks
304   /// dependents through the IR (e.g. use-def chains), this function can be
305   /// implemented to push those dependents on the worklist.
onUpdate(DataFlowSolver * solver)306   virtual void onUpdate(DataFlowSolver *solver) const {}
307 
308   /// The dependency relations originating from this analysis state. An entry
309   /// `state -> (analysis, point)` is created when `analysis` queries `state`
310   /// when updating `point`.
311   ///
312   /// When this state is updated, all dependent child analysis invocations are
313   /// pushed to the back of the queue. Use a `SetVector` to keep the analysis
314   /// deterministic.
315   ///
316   /// Store the dependents on the analysis state for efficiency.
317   SetVector<DataFlowSolver::WorkItem> dependents;
318 
319   /// The program point to which the state belongs.
320   ProgramPoint point;
321 
322 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
323   /// When compiling with debugging, keep a name for the analysis state.
324   StringRef debugName;
325 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
326 
327   /// Allow the framework to access the dependents.
328   friend class DataFlowSolver;
329 };
330 
331 //===----------------------------------------------------------------------===//
332 // DataFlowAnalysis
333 //===----------------------------------------------------------------------===//
334 
335 /// Base class for all data-flow analyses. A child analysis is expected to build
336 /// an initial dependency graph (and optionally provide an initial state) when
337 /// initialized and define transfer functions when visiting program points.
338 ///
339 /// In classical data-flow analysis, the dependency graph is fixed and analyses
340 /// define explicit transfer functions between input states and output states.
341 /// In this framework, however, the dependency graph can change during the
342 /// analysis, and transfer functions are opaque such that the solver doesn't
343 /// know what states calling `visit` on an analysis will be updated. This allows
344 /// multiple analyses to plug in and provide values for the same state.
345 ///
346 /// Generally, when an analysis queries an uninitialized state, it is expected
347 /// to "bail out", i.e., not provide any updates. When the value is initialized,
348 /// the solver will re-invoke the analysis. If the solver exhausts its worklist,
349 /// however, and there are still uninitialized states, the solver "nudges" the
350 /// analyses by default-initializing those states.
351 class DataFlowAnalysis {
352 public:
353   virtual ~DataFlowAnalysis();
354 
355   /// Create an analysis with a reference to the parent solver.
356   explicit DataFlowAnalysis(DataFlowSolver &solver);
357 
358   /// Initialize the analysis from the provided top-level operation by building
359   /// an initial dependency graph between all program points of interest. This
360   /// can be implemented by calling `visit` on all program points of interest
361   /// below the top-level operation.
362   ///
363   /// An analysis can optionally provide initial values to certain analysis
364   /// states to influence the evolution of the analysis.
365   virtual LogicalResult initialize(Operation *top) = 0;
366 
367   /// Visit the given program point. This function is invoked by the solver on
368   /// this analysis with a given program point when a dependent analysis state
369   /// is updated. The function is similar to a transfer function; it queries
370   /// certain analysis states and sets other states.
371   ///
372   /// The function is expected to create dependencies on queried states and
373   /// propagate updates on changed states. A dependency can be created by
374   /// calling `addDependency` between the input state and a program point,
375   /// indicating that, if the state is updated, the solver should invoke `solve`
376   /// on the program point. The dependent point does not have to be the same as
377   /// the provided point. An update to a state is propagated by calling
378   /// `propagateIfChange` on the state. If the state has changed, then all its
379   /// dependents are placed on the worklist.
380   ///
381   /// The dependency graph does not need to be static. Each invocation of
382   /// `visit` can add new dependencies, but these dependecies will not be
383   /// dynamically added to the worklist because the solver doesn't know what
384   /// will provide a value for then.
385   virtual LogicalResult visit(ProgramPoint point) = 0;
386 
387 protected:
388   /// Create a dependency between the given analysis state and program point
389   /// on this analysis.
390   void addDependency(AnalysisState *state, ProgramPoint point);
391 
392   /// Propagate an update to a state if it changed.
393   void propagateIfChanged(AnalysisState *state, ChangeResult changed);
394 
395   /// Register a custom program point class.
396   template <typename PointT>
registerPointKind()397   void registerPointKind() {
398     solver.uniquer.registerParametricStorageType<PointT>();
399   }
400 
401   /// Get or create a custom program point.
402   template <typename PointT, typename... Args>
getProgramPoint(Args &&...args)403   PointT *getProgramPoint(Args &&...args) {
404     return solver.getProgramPoint<PointT>(std::forward<Args>(args)...);
405   }
406 
407   /// Get the analysis state assiocated with the program point. The returned
408   /// state is expected to be "write-only", and any updates need to be
409   /// propagated by `propagateIfChanged`.
410   template <typename StateT, typename PointT>
getOrCreate(PointT point)411   StateT *getOrCreate(PointT point) {
412     return solver.getOrCreateState<StateT>(point);
413   }
414 
415   /// Get a read-only analysis state for the given point and create a dependency
416   /// on `dependent`. If the return state is updated elsewhere, this analysis is
417   /// re-invoked on the dependent.
418   template <typename StateT, typename PointT>
getOrCreateFor(ProgramPoint dependent,PointT point)419   const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) {
420     StateT *state = getOrCreate<StateT>(point);
421     addDependency(state, dependent);
422     return state;
423   }
424 
425 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
426   /// When compiling with debugging, keep a name for the analyis.
427   StringRef debugName;
428 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
429 
430 private:
431   /// The parent data-flow solver.
432   DataFlowSolver &solver;
433 
434   /// Allow the data-flow solver to access the internals of this class.
435   friend class DataFlowSolver;
436 };
437 
438 template <typename AnalysisT, typename... Args>
load(Args &&...args)439 AnalysisT *DataFlowSolver::load(Args &&...args) {
440   childAnalyses.emplace_back(new AnalysisT(*this, std::forward<Args>(args)...));
441 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
442   childAnalyses.back().get()->debugName = llvm::getTypeName<AnalysisT>();
443 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
444   return static_cast<AnalysisT *>(childAnalyses.back().get());
445 }
446 
447 template <typename StateT, typename PointT>
getOrCreateState(PointT point)448 StateT *DataFlowSolver::getOrCreateState(PointT point) {
449   std::unique_ptr<AnalysisState> &state =
450       analysisStates[{ProgramPoint(point), TypeID::get<StateT>()}];
451   if (!state) {
452     state = std::unique_ptr<StateT>(new StateT(point));
453 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
454     state->debugName = llvm::getTypeName<StateT>();
455 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
456   }
457   return static_cast<StateT *>(state.get());
458 }
459 
460 inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
461   state.print(os);
462   return os;
463 }
464 
465 inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) {
466   point.print(os);
467   return os;
468 }
469 
470 } // end namespace mlir
471 
472 namespace llvm {
473 /// Allow hashing of program points.
474 template <>
475 struct DenseMapInfo<mlir::ProgramPoint>
476     : public DenseMapInfo<mlir::ProgramPoint::ParentTy> {};
477 } // end namespace llvm
478 
479 #endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
480