1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h" 10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 11 #include "mlir/Interfaces/CallInterfaces.h" 12 13 using namespace mlir; 14 using namespace mlir::dataflow; 15 16 //===----------------------------------------------------------------------===// 17 // AbstractSparseLattice 18 //===----------------------------------------------------------------------===// 19 20 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { 21 // Push all users of the value to the queue. 22 for (Operation *user : point.get<Value>().getUsers()) 23 for (DataFlowAnalysis *analysis : useDefSubscribers) 24 solver->enqueue({user, analysis}); 25 } 26 27 //===----------------------------------------------------------------------===// 28 // AbstractSparseDataFlowAnalysis 29 //===----------------------------------------------------------------------===// 30 31 AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis( 32 DataFlowSolver &solver) 33 : DataFlowAnalysis(solver) { 34 registerPointKind<CFGEdge>(); 35 } 36 37 LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) { 38 // Mark the entry block arguments as having reached their pessimistic 39 // fixpoints. 40 for (Region ®ion : top->getRegions()) { 41 if (region.empty()) 42 continue; 43 for (Value argument : region.front().getArguments()) 44 markAllPessimisticFixpoint(getLatticeElement(argument)); 45 } 46 47 return initializeRecursively(top); 48 } 49 50 LogicalResult 51 AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) { 52 // Initialize the analysis by visiting every owner of an SSA value (all 53 // operations and blocks). 54 visitOperation(op); 55 for (Region ®ion : op->getRegions()) { 56 for (Block &block : region) { 57 getOrCreate<Executable>(&block)->blockContentSubscribe(this); 58 visitBlock(&block); 59 for (Operation &op : block) 60 if (failed(initializeRecursively(&op))) 61 return failure(); 62 } 63 } 64 65 return success(); 66 } 67 68 LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) { 69 if (Operation *op = point.dyn_cast<Operation *>()) 70 visitOperation(op); 71 else if (Block *block = point.dyn_cast<Block *>()) 72 visitBlock(block); 73 else 74 return failure(); 75 return success(); 76 } 77 78 void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) { 79 // Exit early on operations with no results. 80 if (op->getNumResults() == 0) 81 return; 82 83 // If the containing block is not executable, bail out. 84 if (!getOrCreate<Executable>(op->getBlock())->isLive()) 85 return; 86 87 // Get the result lattices. 88 SmallVector<AbstractSparseLattice *> resultLattices; 89 resultLattices.reserve(op->getNumResults()); 90 // Track whether all results have reached their fixpoint. 91 bool allAtFixpoint = true; 92 for (Value result : op->getResults()) { 93 AbstractSparseLattice *resultLattice = getLatticeElement(result); 94 allAtFixpoint &= resultLattice->isAtFixpoint(); 95 resultLattices.push_back(resultLattice); 96 } 97 // If all result lattices have reached a fixpoint, there is nothing to do. 98 if (allAtFixpoint) 99 return; 100 101 // The results of a region branch operation are determined by control-flow. 102 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { 103 return visitRegionSuccessors({branch}, branch, 104 /*successorIndex=*/llvm::None, resultLattices); 105 } 106 107 // The results of a call operation are determined by the callgraph. 108 if (auto call = dyn_cast<CallOpInterface>(op)) { 109 const auto *predecessors = getOrCreateFor<PredecessorState>(op, call); 110 // If not all return sites are known, then conservatively assume we can't 111 // reason about the data-flow. 112 if (!predecessors->allPredecessorsKnown()) 113 return markAllPessimisticFixpoint(resultLattices); 114 for (Operation *predecessor : predecessors->getKnownPredecessors()) 115 for (auto it : llvm::zip(predecessor->getOperands(), resultLattices)) 116 join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it))); 117 return; 118 } 119 120 // Grab the lattice elements of the operands. 121 SmallVector<const AbstractSparseLattice *> operandLattices; 122 operandLattices.reserve(op->getNumOperands()); 123 for (Value operand : op->getOperands()) { 124 AbstractSparseLattice *operandLattice = getLatticeElement(operand); 125 operandLattice->useDefSubscribe(this); 126 // If any of the operand states are not initialized, bail out. 127 if (operandLattice->isUninitialized()) 128 return; 129 operandLattices.push_back(operandLattice); 130 } 131 132 // Invoke the operation transfer function. 133 visitOperationImpl(op, operandLattices, resultLattices); 134 } 135 136 void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) { 137 // Exit early on blocks with no arguments. 138 if (block->getNumArguments() == 0) 139 return; 140 141 // If the block is not executable, bail out. 142 if (!getOrCreate<Executable>(block)->isLive()) 143 return; 144 145 // Get the argument lattices. 146 SmallVector<AbstractSparseLattice *> argLattices; 147 argLattices.reserve(block->getNumArguments()); 148 bool allAtFixpoint = true; 149 for (BlockArgument argument : block->getArguments()) { 150 AbstractSparseLattice *argLattice = getLatticeElement(argument); 151 allAtFixpoint &= argLattice->isAtFixpoint(); 152 argLattices.push_back(argLattice); 153 } 154 // If all argument lattices have reached their fixpoints, then there is 155 // nothing to do. 156 if (allAtFixpoint) 157 return; 158 159 // The argument lattices of entry blocks are set by region control-flow or the 160 // callgraph. 161 if (block->isEntryBlock()) { 162 // Check if this block is the entry block of a callable region. 163 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); 164 if (callable && callable.getCallableRegion() == block->getParent()) { 165 const auto *callsites = getOrCreateFor<PredecessorState>(block, callable); 166 // If not all callsites are known, conservatively mark all lattices as 167 // having reached their pessimistic fixpoints. 168 if (!callsites->allPredecessorsKnown()) 169 return markAllPessimisticFixpoint(argLattices); 170 for (Operation *callsite : callsites->getKnownPredecessors()) { 171 auto call = cast<CallOpInterface>(callsite); 172 for (auto it : llvm::zip(call.getArgOperands(), argLattices)) 173 join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it))); 174 } 175 return; 176 } 177 178 // Check if the lattices can be determined from region control flow. 179 if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { 180 return visitRegionSuccessors( 181 block, branch, block->getParent()->getRegionNumber(), argLattices); 182 } 183 184 // Otherwise, we can't reason about the data-flow. 185 return visitNonControlFlowArgumentsImpl(block->getParentOp(), 186 RegionSuccessor(block->getParent()), 187 argLattices, /*firstIndex=*/0); 188 } 189 190 // Iterate over the predecessors of the non-entry block. 191 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); 192 it != e; ++it) { 193 Block *predecessor = *it; 194 195 // If the edge from the predecessor block to the current block is not live, 196 // bail out. 197 auto *edgeExecutable = 198 getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block)); 199 edgeExecutable->blockContentSubscribe(this); 200 if (!edgeExecutable->isLive()) 201 continue; 202 203 // Check if we can reason about the data-flow from the predecessor. 204 if (auto branch = 205 dyn_cast<BranchOpInterface>(predecessor->getTerminator())) { 206 SuccessorOperands operands = 207 branch.getSuccessorOperands(it.getSuccessorIndex()); 208 for (auto &it : llvm::enumerate(argLattices)) { 209 if (Value operand = operands[it.index()]) { 210 join(it.value(), *getLatticeElementFor(block, operand)); 211 } else { 212 // Conservatively mark internally produced arguments as having reached 213 // their pessimistic fixpoint. 214 markAllPessimisticFixpoint(it.value()); 215 } 216 } 217 } else { 218 return markAllPessimisticFixpoint(argLattices); 219 } 220 } 221 } 222 223 void AbstractSparseDataFlowAnalysis::visitRegionSuccessors( 224 ProgramPoint point, RegionBranchOpInterface branch, 225 Optional<unsigned> successorIndex, 226 ArrayRef<AbstractSparseLattice *> lattices) { 227 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point); 228 assert(predecessors->allPredecessorsKnown() && 229 "unexpected unresolved region successors"); 230 231 for (Operation *op : predecessors->getKnownPredecessors()) { 232 // Get the incoming successor operands. 233 Optional<OperandRange> operands; 234 235 // Check if the predecessor is the parent op. 236 if (op == branch) { 237 operands = branch.getSuccessorEntryOperands(successorIndex); 238 // Otherwise, try to deduce the operands from a region return-like op. 239 } else { 240 if (isRegionReturnLike(op)) 241 operands = getRegionBranchSuccessorOperands(op, successorIndex); 242 } 243 244 if (!operands) { 245 // We can't reason about the data-flow. 246 return markAllPessimisticFixpoint(lattices); 247 } 248 249 ValueRange inputs = predecessors->getSuccessorInputs(op); 250 assert(inputs.size() == operands->size() && 251 "expected the same number of successor inputs as operands"); 252 253 unsigned firstIndex = 0; 254 if (inputs.size() != lattices.size()) { 255 if (auto *op = point.dyn_cast<Operation *>()) { 256 if (!inputs.empty()) 257 firstIndex = inputs.front().cast<OpResult>().getResultNumber(); 258 visitNonControlFlowArgumentsImpl( 259 branch, 260 RegionSuccessor( 261 branch->getResults().slice(firstIndex, inputs.size())), 262 lattices, firstIndex); 263 } else { 264 if (!inputs.empty()) 265 firstIndex = inputs.front().cast<BlockArgument>().getArgNumber(); 266 Region *region = point.get<Block *>()->getParent(); 267 visitNonControlFlowArgumentsImpl( 268 branch, 269 RegionSuccessor(region, region->getArguments().slice( 270 firstIndex, inputs.size())), 271 lattices, firstIndex); 272 } 273 } 274 275 for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) 276 join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it))); 277 } 278 } 279 280 const AbstractSparseLattice * 281 AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, 282 Value value) { 283 AbstractSparseLattice *state = getLatticeElement(value); 284 addDependency(state, point); 285 return state; 286 } 287 288 void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( 289 ArrayRef<AbstractSparseLattice *> lattices) { 290 for (AbstractSparseLattice *lattice : lattices) 291 propagateIfChanged(lattice, lattice->markPessimisticFixpoint()); 292 } 293 294 void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs, 295 const AbstractSparseLattice &rhs) { 296 propagateIfChanged(lhs, lhs->join(rhs)); 297 } 298