1 //===- SparseAnalysis.h - Sparse data-flow analysis -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements sparse data-flow analysis using the data-flow analysis
10 // framework. The analysis is forward and conditional and uses the results of
11 // dead code analysis to prune dead code during the analysis.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
16 #define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
17 
18 #include "mlir/Analysis/DataFlowFramework.h"
19 #include "mlir/Interfaces/ControlFlowInterfaces.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 
22 namespace mlir {
23 namespace dataflow {
24 
25 //===----------------------------------------------------------------------===//
26 // AbstractSparseLattice
27 //===----------------------------------------------------------------------===//
28 
29 /// This class represents an abstract lattice. A lattice contains information
30 /// about an SSA value and is what's propagated across the IR by sparse
31 /// data-flow analysis.
32 class AbstractSparseLattice : public AnalysisState {
33 public:
34   /// Lattices can only be created for values.
AbstractSparseLattice(Value value)35   AbstractSparseLattice(Value value) : AnalysisState(value) {}
36 
37   /// Join the information contained in 'rhs' into this lattice. Returns
38   /// if the value of the lattice changed.
39   virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0;
40 
41   /// Returns true if the lattice element is at fixpoint and further calls to
42   /// `join` will not update the value of the element.
43   virtual bool isAtFixpoint() const = 0;
44 
45   /// Mark the lattice element as having reached a pessimistic fixpoint. This
46   /// means that the lattice may potentially have conflicting value states, and
47   /// only the most conservative value should be relied on.
48   virtual ChangeResult markPessimisticFixpoint() = 0;
49 
50   /// When the lattice gets updated, propagate an update to users of the value
51   /// using its use-def chain to subscribed analyses.
52   void onUpdate(DataFlowSolver *solver) const override;
53 
54   /// Subscribe an analysis to updates of the lattice. When the lattice changes,
55   /// subscribed analyses are re-invoked on all users of the value. This is
56   /// more efficient than relying on the dependency map.
useDefSubscribe(DataFlowAnalysis * analysis)57   void useDefSubscribe(DataFlowAnalysis *analysis) {
58     useDefSubscribers.insert(analysis);
59   }
60 
61 private:
62   /// A set of analyses that should be updated when this lattice changes.
63   SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
64             SmallPtrSet<DataFlowAnalysis *, 4>>
65       useDefSubscribers;
66 };
67 
68 //===----------------------------------------------------------------------===//
69 // Lattice
70 //===----------------------------------------------------------------------===//
71 
72 /// This class represents a lattice holding a specific value of type `ValueT`.
73 /// Lattice values (`ValueT`) are required to adhere to the following:
74 ///
75 ///   * static ValueT join(const ValueT &lhs, const ValueT &rhs);
76 ///     - This method conservatively joins the information held by `lhs`
77 ///       and `rhs` into a new value. This method is required to be monotonic.
78 ///   * bool operator==(const ValueT &rhs) const;
79 ///
80 template <typename ValueT>
81 class Lattice : public AbstractSparseLattice {
82 public:
83   /// Construct a lattice with a known value.
Lattice(Value value)84   explicit Lattice(Value value)
85       : AbstractSparseLattice(value),
86         knownValue(ValueT::getPessimisticValueState(value)) {}
87 
88   /// Return the value held by this lattice. This requires that the value is
89   /// initialized.
getValue()90   ValueT &getValue() {
91     assert(!isUninitialized() && "expected known lattice element");
92     return *optimisticValue;
93   }
getValue()94   const ValueT &getValue() const {
95     return const_cast<Lattice<ValueT> *>(this)->getValue();
96   }
97 
98   /// Returns true if the value of this lattice hasn't yet been initialized.
isUninitialized()99   bool isUninitialized() const override { return !optimisticValue.has_value(); }
100   /// Force the initialization of the element by setting it to its pessimistic
101   /// fixpoint.
defaultInitialize()102   ChangeResult defaultInitialize() override {
103     return markPessimisticFixpoint();
104   }
105 
106   /// Returns true if the lattice has reached a fixpoint. A fixpoint is when
107   /// the information optimistically assumed to be true is the same as the
108   /// information known to be true.
isAtFixpoint()109   bool isAtFixpoint() const override { return optimisticValue == knownValue; }
110 
111   /// Join the information contained in the 'rhs' lattice into this
112   /// lattice. Returns if the state of the current lattice changed.
join(const AbstractSparseLattice & rhs)113   ChangeResult join(const AbstractSparseLattice &rhs) override {
114     const Lattice<ValueT> &rhsLattice =
115         static_cast<const Lattice<ValueT> &>(rhs);
116 
117     // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do.
118     if (isAtFixpoint() || rhsLattice.isUninitialized())
119       return ChangeResult::NoChange;
120 
121     // Join the rhs value into this lattice.
122     return join(rhsLattice.getValue());
123   }
124 
125   /// Join the information contained in the 'rhs' value into this
126   /// lattice. Returns if the state of the current lattice changed.
join(const ValueT & rhs)127   ChangeResult join(const ValueT &rhs) {
128     // If the current lattice is uninitialized, copy the rhs value.
129     if (isUninitialized()) {
130       optimisticValue = rhs;
131       return ChangeResult::Change;
132     }
133 
134     // Otherwise, join rhs with the current optimistic value.
135     ValueT newValue = ValueT::join(*optimisticValue, rhs);
136     assert(ValueT::join(newValue, *optimisticValue) == newValue &&
137            "expected `join` to be monotonic");
138     assert(ValueT::join(newValue, rhs) == newValue &&
139            "expected `join` to be monotonic");
140 
141     // Update the current optimistic value if something changed.
142     if (newValue == optimisticValue)
143       return ChangeResult::NoChange;
144 
145     optimisticValue = newValue;
146     return ChangeResult::Change;
147   }
148 
149   /// Mark the lattice element as having reached a pessimistic fixpoint. This
150   /// means that the lattice may potentially have conflicting value states,
151   /// and only the conservatively known value state should be relied on.
markPessimisticFixpoint()152   ChangeResult markPessimisticFixpoint() override {
153     if (isAtFixpoint())
154       return ChangeResult::NoChange;
155 
156     // For this fixed point, we take whatever we knew to be true and set that
157     // to our optimistic value.
158     optimisticValue = knownValue;
159     return ChangeResult::Change;
160   }
161 
162   /// Print the lattice element.
print(raw_ostream & os)163   void print(raw_ostream &os) const override {
164     os << "[";
165     knownValue.print(os);
166     os << ", ";
167     if (optimisticValue)
168       optimisticValue->print(os);
169     else
170       os << "<NULL>";
171     os << "]";
172   }
173 
174 private:
175   /// The value that is conservatively known to be true.
176   ValueT knownValue;
177   /// The currently computed value that is optimistically assumed to be true,
178   /// or None if the lattice element is uninitialized.
179   Optional<ValueT> optimisticValue;
180 };
181 
182 //===----------------------------------------------------------------------===//
183 // AbstractSparseDataFlowAnalysis
184 //===----------------------------------------------------------------------===//
185 
186 /// Base class for sparse (forward) data-flow analyses. A sparse analysis
187 /// implements a transfer function on operations from the lattices of the
188 /// operands to the lattices of the results. This analysis will propagate
189 /// lattices across control-flow edges and the callgraph using liveness
190 /// information.
191 class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
192 public:
193   /// Initialize the analysis by visiting every owner of an SSA value: all
194   /// operations and blocks.
195   LogicalResult initialize(Operation *top) override;
196 
197   /// Visit a program point. If this is a block and all control-flow
198   /// predecessors or callsites are known, then the arguments lattices are
199   /// propagated from them. If this is a call operation or an operation with
200   /// region control-flow, then its result lattices are set accordingly.
201   /// Otherwise, the operation transfer function is invoked.
202   LogicalResult visit(ProgramPoint point) override;
203 
204 protected:
205   explicit AbstractSparseDataFlowAnalysis(DataFlowSolver &solver);
206 
207   /// The operation transfer function. Given the operand lattices, this
208   /// function is expected to set the result lattices.
209   virtual void
210   visitOperationImpl(Operation *op,
211                      ArrayRef<const AbstractSparseLattice *> operandLattices,
212                      ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
213 
214   /// Given an operation with region control-flow, the lattices of the operands,
215   /// and a region successor, compute the lattice values for block arguments
216   /// that are not accounted for by the branching control flow (ex. the bounds
217   /// of loops).
218   virtual void visitNonControlFlowArgumentsImpl(
219       Operation *op, const RegionSuccessor &successor,
220       ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;
221 
222   /// Get the lattice element of a value.
223   virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
224 
225   /// Get a read-only lattice element for a value and add it as a dependency to
226   /// a program point.
227   const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
228                                                     Value value);
229 
230   /// Mark the given lattice elements as having reached their pessimistic
231   /// fixpoints and propagate an update if any changed.
232   void markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice *> lattices);
233 
234   /// Join the lattice element and propagate and update if it changed.
235   void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
236 
237 private:
238   /// Recursively initialize the analysis on nested operations and blocks.
239   LogicalResult initializeRecursively(Operation *op);
240 
241   /// Visit an operation. If this is a call operation or an operation with
242   /// region control-flow, then its result lattices are set accordingly.
243   /// Otherwise, the operation transfer function is invoked.
244   void visitOperation(Operation *op);
245 
246   /// Visit a block to compute the lattice values of its arguments. If this is
247   /// an entry block, then the argument values are determined from the block's
248   /// "predecessors" as set by `PredecessorState`. The predecessors can be
249   /// region terminators or callable callsites. Otherwise, the values are
250   /// determined from block predecessors.
251   void visitBlock(Block *block);
252 
253   /// Visit a program point `point` with predecessors within a region branch
254   /// operation `branch`, which can either be the entry block of one of the
255   /// regions or the parent operation itself, and set either the argument or
256   /// parent result lattices.
257   void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
258                              Optional<unsigned> successorIndex,
259                              ArrayRef<AbstractSparseLattice *> lattices);
260 };
261 
262 //===----------------------------------------------------------------------===//
263 // SparseDataFlowAnalysis
264 //===----------------------------------------------------------------------===//
265 
266 /// A sparse (forward) data-flow analysis for propagating SSA value lattices
267 /// across the IR by implementing transfer functions for operations.
268 ///
269 /// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
270 template <typename StateT>
271 class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
272   static_assert(
273       std::is_base_of<AbstractSparseLattice, StateT>::value,
274       "analysis state class expected to subclass AbstractSparseLattice");
275 
276 public:
SparseDataFlowAnalysis(DataFlowSolver & solver)277   explicit SparseDataFlowAnalysis(DataFlowSolver &solver)
278       : AbstractSparseDataFlowAnalysis(solver) {}
279 
280   /// Visit an operation with the lattices of its operands. This function is
281   /// expected to set the lattices of the operation's results.
282   virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
283                               ArrayRef<StateT *> results) = 0;
284 
285   /// Given an operation with possible region control-flow, the lattices of the
286   /// operands, and a region successor, compute the lattice values for block
287   /// arguments that are not accounted for by the branching control flow (ex.
288   /// the bounds of loops). By default, this method marks all such lattice
289   /// elements as having reached a pessimistic fixpoint. `firstIndex` is the
290   /// index of the first element of `argLattices` that is set by control-flow.
visitNonControlFlowArguments(Operation * op,const RegionSuccessor & successor,ArrayRef<StateT * > argLattices,unsigned firstIndex)291   virtual void visitNonControlFlowArguments(Operation *op,
292                                             const RegionSuccessor &successor,
293                                             ArrayRef<StateT *> argLattices,
294                                             unsigned firstIndex) {
295     markAllPessimisticFixpoint(argLattices.take_front(firstIndex));
296     markAllPessimisticFixpoint(argLattices.drop_front(
297         firstIndex + successor.getSuccessorInputs().size()));
298   }
299 
300 protected:
301   /// Get the lattice element for a value.
getLatticeElement(Value value)302   StateT *getLatticeElement(Value value) override {
303     return getOrCreate<StateT>(value);
304   }
305 
306   /// Get the lattice element for a value and create a dependency on the
307   /// provided program point.
getLatticeElementFor(ProgramPoint point,Value value)308   const StateT *getLatticeElementFor(ProgramPoint point, Value value) {
309     return static_cast<const StateT *>(
310         AbstractSparseDataFlowAnalysis::getLatticeElementFor(point, value));
311   }
312 
313   /// Mark the lattice elements of a range of values as having reached their
314   /// pessimistic fixpoint.
markAllPessimisticFixpoint(ArrayRef<StateT * > lattices)315   void markAllPessimisticFixpoint(ArrayRef<StateT *> lattices) {
316     AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
317         {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
318          lattices.size()});
319   }
320 
321 private:
322   /// Type-erased wrappers that convert the abstract lattice operands to derived
323   /// lattices and invoke the virtual hooks operating on the derived lattices.
visitOperationImpl(Operation * op,ArrayRef<const AbstractSparseLattice * > operandLattices,ArrayRef<AbstractSparseLattice * > resultLattices)324   void visitOperationImpl(
325       Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
326       ArrayRef<AbstractSparseLattice *> resultLattices) override {
327     visitOperation(
328         op,
329         {reinterpret_cast<const StateT *const *>(operandLattices.begin()),
330          operandLattices.size()},
331         {reinterpret_cast<StateT *const *>(resultLattices.begin()),
332          resultLattices.size()});
333   }
visitNonControlFlowArgumentsImpl(Operation * op,const RegionSuccessor & successor,ArrayRef<AbstractSparseLattice * > argLattices,unsigned firstIndex)334   void visitNonControlFlowArgumentsImpl(
335       Operation *op, const RegionSuccessor &successor,
336       ArrayRef<AbstractSparseLattice *> argLattices,
337       unsigned firstIndex) override {
338     visitNonControlFlowArguments(
339         op, successor,
340         {reinterpret_cast<StateT *const *>(argLattices.begin()),
341          argLattices.size()},
342         firstIndex);
343   }
344 };
345 
346 } // end namespace dataflow
347 } // end namespace mlir
348 
349 #endif // MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
350