1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h" 10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 11 #include "mlir/Interfaces/CallInterfaces.h" 12 #include "mlir/Interfaces/ControlFlowInterfaces.h" 13 14 using namespace mlir; 15 using namespace mlir::dataflow; 16 17 //===----------------------------------------------------------------------===// 18 // AbstractSparseLattice 19 //===----------------------------------------------------------------------===// 20 21 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { 22 // Push all users of the value to the queue. 23 for (Operation *user : point.get<Value>().getUsers()) 24 for (DataFlowAnalysis *analysis : useDefSubscribers) 25 solver->enqueue({user, analysis}); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // AbstractSparseDataFlowAnalysis 30 //===----------------------------------------------------------------------===// 31 32 AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis( 33 DataFlowSolver &solver) 34 : DataFlowAnalysis(solver) { 35 registerPointKind<CFGEdge>(); 36 } 37 38 LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) { 39 // Mark the entry block arguments as having reached their pessimistic 40 // fixpoints. 41 for (Region ®ion : top->getRegions()) { 42 if (region.empty()) 43 continue; 44 for (Value argument : region.front().getArguments()) 45 markAllPessimisticFixpoint(getLatticeElement(argument)); 46 } 47 48 return initializeRecursively(top); 49 } 50 51 LogicalResult 52 AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) { 53 // Initialize the analysis by visiting every owner of an SSA value (all 54 // operations and blocks). 55 visitOperation(op); 56 for (Region ®ion : op->getRegions()) { 57 for (Block &block : region) { 58 getOrCreate<Executable>(&block)->blockContentSubscribe(this); 59 visitBlock(&block); 60 for (Operation &op : block) 61 if (failed(initializeRecursively(&op))) 62 return failure(); 63 } 64 } 65 66 return success(); 67 } 68 69 LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) { 70 if (Operation *op = point.dyn_cast<Operation *>()) 71 visitOperation(op); 72 else if (Block *block = point.dyn_cast<Block *>()) 73 visitBlock(block); 74 else 75 return failure(); 76 return success(); 77 } 78 79 void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) { 80 // Exit early on operations with no results. 81 if (op->getNumResults() == 0) 82 return; 83 84 // If the containing block is not executable, bail out. 85 if (!getOrCreate<Executable>(op->getBlock())->isLive()) 86 return; 87 88 // Get the result lattices. 89 SmallVector<AbstractSparseLattice *> resultLattices; 90 resultLattices.reserve(op->getNumResults()); 91 // Track whether all results have reached their fixpoint. 92 bool allAtFixpoint = true; 93 for (Value result : op->getResults()) { 94 AbstractSparseLattice *resultLattice = getLatticeElement(result); 95 allAtFixpoint &= resultLattice->isAtFixpoint(); 96 resultLattices.push_back(resultLattice); 97 } 98 // If all result lattices have reached a fixpoint, there is nothing to do. 99 if (allAtFixpoint) 100 return; 101 102 // The results of a region branch operation are determined by control-flow. 103 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { 104 return visitRegionSuccessors({branch}, branch, 105 /*successorIndex=*/llvm::None, resultLattices); 106 } 107 108 // The results of a call operation are determined by the callgraph. 109 if (auto call = dyn_cast<CallOpInterface>(op)) { 110 const auto *predecessors = getOrCreateFor<PredecessorState>(op, call); 111 // If not all return sites are known, then conservatively assume we can't 112 // reason about the data-flow. 113 if (!predecessors->allPredecessorsKnown()) 114 return markAllPessimisticFixpoint(resultLattices); 115 for (Operation *predecessor : predecessors->getKnownPredecessors()) 116 for (auto it : llvm::zip(predecessor->getOperands(), resultLattices)) 117 join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it))); 118 return; 119 } 120 121 // Grab the lattice elements of the operands. 122 SmallVector<const AbstractSparseLattice *> operandLattices; 123 operandLattices.reserve(op->getNumOperands()); 124 for (Value operand : op->getOperands()) { 125 AbstractSparseLattice *operandLattice = getLatticeElement(operand); 126 operandLattice->useDefSubscribe(this); 127 // If any of the operand states are not initialized, bail out. 128 if (operandLattice->isUninitialized()) 129 return; 130 operandLattices.push_back(operandLattice); 131 } 132 133 // Invoke the operation transfer function. 134 visitOperationImpl(op, operandLattices, resultLattices); 135 } 136 137 void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) { 138 // Exit early on blocks with no arguments. 139 if (block->getNumArguments() == 0) 140 return; 141 142 // If the block is not executable, bail out. 143 if (!getOrCreate<Executable>(block)->isLive()) 144 return; 145 146 // Get the argument lattices. 147 SmallVector<AbstractSparseLattice *> argLattices; 148 argLattices.reserve(block->getNumArguments()); 149 bool allAtFixpoint = true; 150 for (BlockArgument argument : block->getArguments()) { 151 AbstractSparseLattice *argLattice = getLatticeElement(argument); 152 allAtFixpoint &= argLattice->isAtFixpoint(); 153 argLattices.push_back(argLattice); 154 } 155 // If all argument lattices have reached their fixpoints, then there is 156 // nothing to do. 157 if (allAtFixpoint) 158 return; 159 160 // The argument lattices of entry blocks are set by region control-flow or the 161 // callgraph. 162 if (block->isEntryBlock()) { 163 // Check if this block is the entry block of a callable region. 164 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); 165 if (callable && callable.getCallableRegion() == block->getParent()) { 166 const auto *callsites = getOrCreateFor<PredecessorState>(block, callable); 167 // If not all callsites are known, conservatively mark all lattices as 168 // having reached their pessimistic fixpoints. 169 if (!callsites->allPredecessorsKnown()) 170 return markAllPessimisticFixpoint(argLattices); 171 for (Operation *callsite : callsites->getKnownPredecessors()) { 172 auto call = cast<CallOpInterface>(callsite); 173 for (auto it : llvm::zip(call.getArgOperands(), argLattices)) 174 join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it))); 175 } 176 return; 177 } 178 179 // Check if the lattices can be determined from region control flow. 180 if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { 181 return visitRegionSuccessors( 182 block, branch, block->getParent()->getRegionNumber(), argLattices); 183 } 184 185 // Otherwise, we can't reason about the data-flow. 186 return markAllPessimisticFixpoint(argLattices); 187 } 188 189 // Iterate over the predecessors of the non-entry block. 190 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); 191 it != e; ++it) { 192 Block *predecessor = *it; 193 194 // If the edge from the predecessor block to the current block is not live, 195 // bail out. 196 auto *edgeExecutable = 197 getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block)); 198 edgeExecutable->blockContentSubscribe(this); 199 if (!edgeExecutable->isLive()) 200 continue; 201 202 // Check if we can reason about the data-flow from the predecessor. 203 if (auto branch = 204 dyn_cast<BranchOpInterface>(predecessor->getTerminator())) { 205 SuccessorOperands operands = 206 branch.getSuccessorOperands(it.getSuccessorIndex()); 207 for (auto &it : llvm::enumerate(argLattices)) { 208 if (Value operand = operands[it.index()]) { 209 join(it.value(), *getLatticeElementFor(block, operand)); 210 } else { 211 // Conservatively mark internally produced arguments as having reached 212 // their pessimistic fixpoint. 213 markAllPessimisticFixpoint(it.value()); 214 } 215 } 216 } else { 217 return markAllPessimisticFixpoint(argLattices); 218 } 219 } 220 } 221 222 void AbstractSparseDataFlowAnalysis::visitRegionSuccessors( 223 ProgramPoint point, RegionBranchOpInterface branch, 224 Optional<unsigned> successorIndex, 225 ArrayRef<AbstractSparseLattice *> lattices) { 226 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point); 227 assert(predecessors->allPredecessorsKnown() && 228 "unexpected unresolved region successors"); 229 230 for (Operation *op : predecessors->getKnownPredecessors()) { 231 // Get the incoming successor operands. 232 Optional<OperandRange> operands; 233 234 // Check if the predecessor is the parent op. 235 if (op == branch) { 236 operands = branch.getSuccessorEntryOperands(successorIndex); 237 // Otherwise, try to deduce the operands from a region return-like op. 238 } else { 239 assert(op->hasTrait<OpTrait::IsTerminator>() && "expected a terminator"); 240 if (isRegionReturnLike(op)) 241 operands = getRegionBranchSuccessorOperands(op, successorIndex); 242 } 243 244 if (!operands) { 245 // We can't reason about the data-flow. 246 return markAllPessimisticFixpoint(lattices); 247 } 248 249 ValueRange inputs = predecessors->getSuccessorInputs(op); 250 assert(inputs.size() == operands->size() && 251 "expected the same number of successor inputs as operands"); 252 253 // TODO: This was updated to be exposed upstream. 254 unsigned firstIndex = 0; 255 if (inputs.size() != lattices.size()) { 256 if (inputs.empty()) { 257 markAllPessimisticFixpoint(lattices); 258 return; 259 } 260 firstIndex = inputs.front().cast<BlockArgument>().getArgNumber(); 261 markAllPessimisticFixpoint(lattices.take_front(firstIndex)); 262 markAllPessimisticFixpoint( 263 lattices.drop_front(firstIndex + inputs.size())); 264 } 265 266 for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) 267 join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it))); 268 } 269 } 270 271 const AbstractSparseLattice * 272 AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, 273 Value value) { 274 AbstractSparseLattice *state = getLatticeElement(value); 275 addDependency(state, point); 276 return state; 277 } 278 279 void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( 280 ArrayRef<AbstractSparseLattice *> lattices) { 281 for (AbstractSparseLattice *lattice : lattices) 282 propagateIfChanged(lattice, lattice->markPessimisticFixpoint()); 283 } 284 285 void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs, 286 const AbstractSparseLattice &rhs) { 287 propagateIfChanged(lhs, lhs->join(rhs)); 288 } 289