1 //===- SparseAnalysis.h - Sparse data-flow analysis -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements sparse data-flow analysis using the data-flow analysis 10 // framework. The analysis is forward and conditional and uses the results of 11 // dead code analysis to prune dead code during the analysis. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H 16 #define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H 17 18 #include "mlir/Analysis/DataFlowFramework.h" 19 #include "mlir/Interfaces/ControlFlowInterfaces.h" 20 #include "llvm/ADT/SmallPtrSet.h" 21 22 namespace mlir { 23 namespace dataflow { 24 25 //===----------------------------------------------------------------------===// 26 // AbstractSparseLattice 27 //===----------------------------------------------------------------------===// 28 29 /// This class represents an abstract lattice. A lattice contains information 30 /// about an SSA value and is what's propagated across the IR by sparse 31 /// data-flow analysis. 32 class AbstractSparseLattice : public AnalysisState { 33 public: 34 /// Lattices can only be created for values. AbstractSparseLattice(Value value)35 AbstractSparseLattice(Value value) : AnalysisState(value) {} 36 37 /// Join the information contained in 'rhs' into this lattice. Returns 38 /// if the value of the lattice changed. 39 virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0; 40 41 /// Returns true if the lattice element is at fixpoint and further calls to 42 /// `join` will not update the value of the element. 43 virtual bool isAtFixpoint() const = 0; 44 45 /// Mark the lattice element as having reached a pessimistic fixpoint. This 46 /// means that the lattice may potentially have conflicting value states, and 47 /// only the most conservative value should be relied on. 48 virtual ChangeResult markPessimisticFixpoint() = 0; 49 50 /// When the lattice gets updated, propagate an update to users of the value 51 /// using its use-def chain to subscribed analyses. 52 void onUpdate(DataFlowSolver *solver) const override; 53 54 /// Subscribe an analysis to updates of the lattice. When the lattice changes, 55 /// subscribed analyses are re-invoked on all users of the value. This is 56 /// more efficient than relying on the dependency map. useDefSubscribe(DataFlowAnalysis * analysis)57 void useDefSubscribe(DataFlowAnalysis *analysis) { 58 useDefSubscribers.insert(analysis); 59 } 60 61 private: 62 /// A set of analyses that should be updated when this lattice changes. 63 SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>, 64 SmallPtrSet<DataFlowAnalysis *, 4>> 65 useDefSubscribers; 66 }; 67 68 //===----------------------------------------------------------------------===// 69 // Lattice 70 //===----------------------------------------------------------------------===// 71 72 /// This class represents a lattice holding a specific value of type `ValueT`. 73 /// Lattice values (`ValueT`) are required to adhere to the following: 74 /// 75 /// * static ValueT join(const ValueT &lhs, const ValueT &rhs); 76 /// - This method conservatively joins the information held by `lhs` 77 /// and `rhs` into a new value. This method is required to be monotonic. 78 /// * bool operator==(const ValueT &rhs) const; 79 /// 80 template <typename ValueT> 81 class Lattice : public AbstractSparseLattice { 82 public: 83 /// Construct a lattice with a known value. Lattice(Value value)84 explicit Lattice(Value value) 85 : AbstractSparseLattice(value), 86 knownValue(ValueT::getPessimisticValueState(value)) {} 87 88 /// Return the value held by this lattice. This requires that the value is 89 /// initialized. getValue()90 ValueT &getValue() { 91 assert(!isUninitialized() && "expected known lattice element"); 92 return *optimisticValue; 93 } getValue()94 const ValueT &getValue() const { 95 return const_cast<Lattice<ValueT> *>(this)->getValue(); 96 } 97 98 /// Returns true if the value of this lattice hasn't yet been initialized. isUninitialized()99 bool isUninitialized() const override { return !optimisticValue.has_value(); } 100 /// Force the initialization of the element by setting it to its pessimistic 101 /// fixpoint. defaultInitialize()102 ChangeResult defaultInitialize() override { 103 return markPessimisticFixpoint(); 104 } 105 106 /// Returns true if the lattice has reached a fixpoint. A fixpoint is when 107 /// the information optimistically assumed to be true is the same as the 108 /// information known to be true. isAtFixpoint()109 bool isAtFixpoint() const override { return optimisticValue == knownValue; } 110 111 /// Join the information contained in the 'rhs' lattice into this 112 /// lattice. Returns if the state of the current lattice changed. join(const AbstractSparseLattice & rhs)113 ChangeResult join(const AbstractSparseLattice &rhs) override { 114 const Lattice<ValueT> &rhsLattice = 115 static_cast<const Lattice<ValueT> &>(rhs); 116 117 // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do. 118 if (isAtFixpoint() || rhsLattice.isUninitialized()) 119 return ChangeResult::NoChange; 120 121 // Join the rhs value into this lattice. 122 return join(rhsLattice.getValue()); 123 } 124 125 /// Join the information contained in the 'rhs' value into this 126 /// lattice. Returns if the state of the current lattice changed. join(const ValueT & rhs)127 ChangeResult join(const ValueT &rhs) { 128 // If the current lattice is uninitialized, copy the rhs value. 129 if (isUninitialized()) { 130 optimisticValue = rhs; 131 return ChangeResult::Change; 132 } 133 134 // Otherwise, join rhs with the current optimistic value. 135 ValueT newValue = ValueT::join(*optimisticValue, rhs); 136 assert(ValueT::join(newValue, *optimisticValue) == newValue && 137 "expected `join` to be monotonic"); 138 assert(ValueT::join(newValue, rhs) == newValue && 139 "expected `join` to be monotonic"); 140 141 // Update the current optimistic value if something changed. 142 if (newValue == optimisticValue) 143 return ChangeResult::NoChange; 144 145 optimisticValue = newValue; 146 return ChangeResult::Change; 147 } 148 149 /// Mark the lattice element as having reached a pessimistic fixpoint. This 150 /// means that the lattice may potentially have conflicting value states, 151 /// and only the conservatively known value state should be relied on. markPessimisticFixpoint()152 ChangeResult markPessimisticFixpoint() override { 153 if (isAtFixpoint()) 154 return ChangeResult::NoChange; 155 156 // For this fixed point, we take whatever we knew to be true and set that 157 // to our optimistic value. 158 optimisticValue = knownValue; 159 return ChangeResult::Change; 160 } 161 162 /// Print the lattice element. print(raw_ostream & os)163 void print(raw_ostream &os) const override { 164 os << "["; 165 knownValue.print(os); 166 os << ", "; 167 if (optimisticValue) 168 optimisticValue->print(os); 169 else 170 os << "<NULL>"; 171 os << "]"; 172 } 173 174 private: 175 /// The value that is conservatively known to be true. 176 ValueT knownValue; 177 /// The currently computed value that is optimistically assumed to be true, 178 /// or None if the lattice element is uninitialized. 179 Optional<ValueT> optimisticValue; 180 }; 181 182 //===----------------------------------------------------------------------===// 183 // AbstractSparseDataFlowAnalysis 184 //===----------------------------------------------------------------------===// 185 186 /// Base class for sparse (forward) data-flow analyses. A sparse analysis 187 /// implements a transfer function on operations from the lattices of the 188 /// operands to the lattices of the results. This analysis will propagate 189 /// lattices across control-flow edges and the callgraph using liveness 190 /// information. 191 class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis { 192 public: 193 /// Initialize the analysis by visiting every owner of an SSA value: all 194 /// operations and blocks. 195 LogicalResult initialize(Operation *top) override; 196 197 /// Visit a program point. If this is a block and all control-flow 198 /// predecessors or callsites are known, then the arguments lattices are 199 /// propagated from them. If this is a call operation or an operation with 200 /// region control-flow, then its result lattices are set accordingly. 201 /// Otherwise, the operation transfer function is invoked. 202 LogicalResult visit(ProgramPoint point) override; 203 204 protected: 205 explicit AbstractSparseDataFlowAnalysis(DataFlowSolver &solver); 206 207 /// The operation transfer function. Given the operand lattices, this 208 /// function is expected to set the result lattices. 209 virtual void 210 visitOperationImpl(Operation *op, 211 ArrayRef<const AbstractSparseLattice *> operandLattices, 212 ArrayRef<AbstractSparseLattice *> resultLattices) = 0; 213 214 /// Given an operation with region control-flow, the lattices of the operands, 215 /// and a region successor, compute the lattice values for block arguments 216 /// that are not accounted for by the branching control flow (ex. the bounds 217 /// of loops). 218 virtual void visitNonControlFlowArgumentsImpl( 219 Operation *op, const RegionSuccessor &successor, 220 ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0; 221 222 /// Get the lattice element of a value. 223 virtual AbstractSparseLattice *getLatticeElement(Value value) = 0; 224 225 /// Get a read-only lattice element for a value and add it as a dependency to 226 /// a program point. 227 const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point, 228 Value value); 229 230 /// Mark the given lattice elements as having reached their pessimistic 231 /// fixpoints and propagate an update if any changed. 232 void markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice *> lattices); 233 234 /// Join the lattice element and propagate and update if it changed. 235 void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs); 236 237 private: 238 /// Recursively initialize the analysis on nested operations and blocks. 239 LogicalResult initializeRecursively(Operation *op); 240 241 /// Visit an operation. If this is a call operation or an operation with 242 /// region control-flow, then its result lattices are set accordingly. 243 /// Otherwise, the operation transfer function is invoked. 244 void visitOperation(Operation *op); 245 246 /// Visit a block to compute the lattice values of its arguments. If this is 247 /// an entry block, then the argument values are determined from the block's 248 /// "predecessors" as set by `PredecessorState`. The predecessors can be 249 /// region terminators or callable callsites. Otherwise, the values are 250 /// determined from block predecessors. 251 void visitBlock(Block *block); 252 253 /// Visit a program point `point` with predecessors within a region branch 254 /// operation `branch`, which can either be the entry block of one of the 255 /// regions or the parent operation itself, and set either the argument or 256 /// parent result lattices. 257 void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch, 258 Optional<unsigned> successorIndex, 259 ArrayRef<AbstractSparseLattice *> lattices); 260 }; 261 262 //===----------------------------------------------------------------------===// 263 // SparseDataFlowAnalysis 264 //===----------------------------------------------------------------------===// 265 266 /// A sparse (forward) data-flow analysis for propagating SSA value lattices 267 /// across the IR by implementing transfer functions for operations. 268 /// 269 /// `StateT` is expected to be a subclass of `AbstractSparseLattice`. 270 template <typename StateT> 271 class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis { 272 static_assert( 273 std::is_base_of<AbstractSparseLattice, StateT>::value, 274 "analysis state class expected to subclass AbstractSparseLattice"); 275 276 public: SparseDataFlowAnalysis(DataFlowSolver & solver)277 explicit SparseDataFlowAnalysis(DataFlowSolver &solver) 278 : AbstractSparseDataFlowAnalysis(solver) {} 279 280 /// Visit an operation with the lattices of its operands. This function is 281 /// expected to set the lattices of the operation's results. 282 virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands, 283 ArrayRef<StateT *> results) = 0; 284 285 /// Given an operation with possible region control-flow, the lattices of the 286 /// operands, and a region successor, compute the lattice values for block 287 /// arguments that are not accounted for by the branching control flow (ex. 288 /// the bounds of loops). By default, this method marks all such lattice 289 /// elements as having reached a pessimistic fixpoint. `firstIndex` is the 290 /// index of the first element of `argLattices` that is set by control-flow. visitNonControlFlowArguments(Operation * op,const RegionSuccessor & successor,ArrayRef<StateT * > argLattices,unsigned firstIndex)291 virtual void visitNonControlFlowArguments(Operation *op, 292 const RegionSuccessor &successor, 293 ArrayRef<StateT *> argLattices, 294 unsigned firstIndex) { 295 markAllPessimisticFixpoint(argLattices.take_front(firstIndex)); 296 markAllPessimisticFixpoint(argLattices.drop_front( 297 firstIndex + successor.getSuccessorInputs().size())); 298 } 299 300 protected: 301 /// Get the lattice element for a value. getLatticeElement(Value value)302 StateT *getLatticeElement(Value value) override { 303 return getOrCreate<StateT>(value); 304 } 305 306 /// Get the lattice element for a value and create a dependency on the 307 /// provided program point. getLatticeElementFor(ProgramPoint point,Value value)308 const StateT *getLatticeElementFor(ProgramPoint point, Value value) { 309 return static_cast<const StateT *>( 310 AbstractSparseDataFlowAnalysis::getLatticeElementFor(point, value)); 311 } 312 313 /// Mark the lattice elements of a range of values as having reached their 314 /// pessimistic fixpoint. markAllPessimisticFixpoint(ArrayRef<StateT * > lattices)315 void markAllPessimisticFixpoint(ArrayRef<StateT *> lattices) { 316 AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( 317 {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()), 318 lattices.size()}); 319 } 320 321 private: 322 /// Type-erased wrappers that convert the abstract lattice operands to derived 323 /// lattices and invoke the virtual hooks operating on the derived lattices. visitOperationImpl(Operation * op,ArrayRef<const AbstractSparseLattice * > operandLattices,ArrayRef<AbstractSparseLattice * > resultLattices)324 void visitOperationImpl( 325 Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices, 326 ArrayRef<AbstractSparseLattice *> resultLattices) override { 327 visitOperation( 328 op, 329 {reinterpret_cast<const StateT *const *>(operandLattices.begin()), 330 operandLattices.size()}, 331 {reinterpret_cast<StateT *const *>(resultLattices.begin()), 332 resultLattices.size()}); 333 } visitNonControlFlowArgumentsImpl(Operation * op,const RegionSuccessor & successor,ArrayRef<AbstractSparseLattice * > argLattices,unsigned firstIndex)334 void visitNonControlFlowArgumentsImpl( 335 Operation *op, const RegionSuccessor &successor, 336 ArrayRef<AbstractSparseLattice *> argLattices, 337 unsigned firstIndex) override { 338 visitNonControlFlowArguments( 339 op, successor, 340 {reinterpret_cast<StateT *const *>(argLattices.begin()), 341 argLattices.size()}, 342 firstIndex); 343 } 344 }; 345 346 } // end namespace dataflow 347 } // end namespace mlir 348 349 #endif // MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H 350