1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12
13 using namespace mlir;
14 using namespace mlir::dataflow;
15
16 //===----------------------------------------------------------------------===//
17 // AbstractSparseLattice
18 //===----------------------------------------------------------------------===//
19
onUpdate(DataFlowSolver * solver) const20 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
21 // Push all users of the value to the queue.
22 for (Operation *user : point.get<Value>().getUsers())
23 for (DataFlowAnalysis *analysis : useDefSubscribers)
24 solver->enqueue({user, analysis});
25 }
26
27 //===----------------------------------------------------------------------===//
28 // AbstractSparseDataFlowAnalysis
29 //===----------------------------------------------------------------------===//
30
AbstractSparseDataFlowAnalysis(DataFlowSolver & solver)31 AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis(
32 DataFlowSolver &solver)
33 : DataFlowAnalysis(solver) {
34 registerPointKind<CFGEdge>();
35 }
36
initialize(Operation * top)37 LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
38 // Mark the entry block arguments as having reached their pessimistic
39 // fixpoints.
40 for (Region ®ion : top->getRegions()) {
41 if (region.empty())
42 continue;
43 for (Value argument : region.front().getArguments())
44 markAllPessimisticFixpoint(getLatticeElement(argument));
45 }
46
47 return initializeRecursively(top);
48 }
49
50 LogicalResult
initializeRecursively(Operation * op)51 AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
52 // Initialize the analysis by visiting every owner of an SSA value (all
53 // operations and blocks).
54 visitOperation(op);
55 for (Region ®ion : op->getRegions()) {
56 for (Block &block : region) {
57 getOrCreate<Executable>(&block)->blockContentSubscribe(this);
58 visitBlock(&block);
59 for (Operation &op : block)
60 if (failed(initializeRecursively(&op)))
61 return failure();
62 }
63 }
64
65 return success();
66 }
67
visit(ProgramPoint point)68 LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
69 if (Operation *op = point.dyn_cast<Operation *>())
70 visitOperation(op);
71 else if (Block *block = point.dyn_cast<Block *>())
72 visitBlock(block);
73 else
74 return failure();
75 return success();
76 }
77
visitOperation(Operation * op)78 void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
79 // Exit early on operations with no results.
80 if (op->getNumResults() == 0)
81 return;
82
83 // If the containing block is not executable, bail out.
84 if (!getOrCreate<Executable>(op->getBlock())->isLive())
85 return;
86
87 // Get the result lattices.
88 SmallVector<AbstractSparseLattice *> resultLattices;
89 resultLattices.reserve(op->getNumResults());
90 // Track whether all results have reached their fixpoint.
91 bool allAtFixpoint = true;
92 for (Value result : op->getResults()) {
93 AbstractSparseLattice *resultLattice = getLatticeElement(result);
94 allAtFixpoint &= resultLattice->isAtFixpoint();
95 resultLattices.push_back(resultLattice);
96 }
97 // If all result lattices have reached a fixpoint, there is nothing to do.
98 if (allAtFixpoint)
99 return;
100
101 // The results of a region branch operation are determined by control-flow.
102 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
103 return visitRegionSuccessors({branch}, branch,
104 /*successorIndex=*/llvm::None, resultLattices);
105 }
106
107 // The results of a call operation are determined by the callgraph.
108 if (auto call = dyn_cast<CallOpInterface>(op)) {
109 const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
110 // If not all return sites are known, then conservatively assume we can't
111 // reason about the data-flow.
112 if (!predecessors->allPredecessorsKnown())
113 return markAllPessimisticFixpoint(resultLattices);
114 for (Operation *predecessor : predecessors->getKnownPredecessors())
115 for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
116 join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
117 return;
118 }
119
120 // Grab the lattice elements of the operands.
121 SmallVector<const AbstractSparseLattice *> operandLattices;
122 operandLattices.reserve(op->getNumOperands());
123 for (Value operand : op->getOperands()) {
124 AbstractSparseLattice *operandLattice = getLatticeElement(operand);
125 operandLattice->useDefSubscribe(this);
126 // If any of the operand states are not initialized, bail out.
127 if (operandLattice->isUninitialized())
128 return;
129 operandLattices.push_back(operandLattice);
130 }
131
132 // Invoke the operation transfer function.
133 visitOperationImpl(op, operandLattices, resultLattices);
134 }
135
visitBlock(Block * block)136 void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
137 // Exit early on blocks with no arguments.
138 if (block->getNumArguments() == 0)
139 return;
140
141 // If the block is not executable, bail out.
142 if (!getOrCreate<Executable>(block)->isLive())
143 return;
144
145 // Get the argument lattices.
146 SmallVector<AbstractSparseLattice *> argLattices;
147 argLattices.reserve(block->getNumArguments());
148 bool allAtFixpoint = true;
149 for (BlockArgument argument : block->getArguments()) {
150 AbstractSparseLattice *argLattice = getLatticeElement(argument);
151 allAtFixpoint &= argLattice->isAtFixpoint();
152 argLattices.push_back(argLattice);
153 }
154 // If all argument lattices have reached their fixpoints, then there is
155 // nothing to do.
156 if (allAtFixpoint)
157 return;
158
159 // The argument lattices of entry blocks are set by region control-flow or the
160 // callgraph.
161 if (block->isEntryBlock()) {
162 // Check if this block is the entry block of a callable region.
163 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
164 if (callable && callable.getCallableRegion() == block->getParent()) {
165 const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
166 // If not all callsites are known, conservatively mark all lattices as
167 // having reached their pessimistic fixpoints.
168 if (!callsites->allPredecessorsKnown())
169 return markAllPessimisticFixpoint(argLattices);
170 for (Operation *callsite : callsites->getKnownPredecessors()) {
171 auto call = cast<CallOpInterface>(callsite);
172 for (auto it : llvm::zip(call.getArgOperands(), argLattices))
173 join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
174 }
175 return;
176 }
177
178 // Check if the lattices can be determined from region control flow.
179 if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
180 return visitRegionSuccessors(
181 block, branch, block->getParent()->getRegionNumber(), argLattices);
182 }
183
184 // Otherwise, we can't reason about the data-flow.
185 return visitNonControlFlowArgumentsImpl(block->getParentOp(),
186 RegionSuccessor(block->getParent()),
187 argLattices, /*firstIndex=*/0);
188 }
189
190 // Iterate over the predecessors of the non-entry block.
191 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
192 it != e; ++it) {
193 Block *predecessor = *it;
194
195 // If the edge from the predecessor block to the current block is not live,
196 // bail out.
197 auto *edgeExecutable =
198 getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
199 edgeExecutable->blockContentSubscribe(this);
200 if (!edgeExecutable->isLive())
201 continue;
202
203 // Check if we can reason about the data-flow from the predecessor.
204 if (auto branch =
205 dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
206 SuccessorOperands operands =
207 branch.getSuccessorOperands(it.getSuccessorIndex());
208 for (auto &it : llvm::enumerate(argLattices)) {
209 if (Value operand = operands[it.index()]) {
210 join(it.value(), *getLatticeElementFor(block, operand));
211 } else {
212 // Conservatively mark internally produced arguments as having reached
213 // their pessimistic fixpoint.
214 markAllPessimisticFixpoint(it.value());
215 }
216 }
217 } else {
218 return markAllPessimisticFixpoint(argLattices);
219 }
220 }
221 }
222
visitRegionSuccessors(ProgramPoint point,RegionBranchOpInterface branch,Optional<unsigned> successorIndex,ArrayRef<AbstractSparseLattice * > lattices)223 void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
224 ProgramPoint point, RegionBranchOpInterface branch,
225 Optional<unsigned> successorIndex,
226 ArrayRef<AbstractSparseLattice *> lattices) {
227 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
228 assert(predecessors->allPredecessorsKnown() &&
229 "unexpected unresolved region successors");
230
231 for (Operation *op : predecessors->getKnownPredecessors()) {
232 // Get the incoming successor operands.
233 Optional<OperandRange> operands;
234
235 // Check if the predecessor is the parent op.
236 if (op == branch) {
237 operands = branch.getSuccessorEntryOperands(successorIndex);
238 // Otherwise, try to deduce the operands from a region return-like op.
239 } else {
240 if (isRegionReturnLike(op))
241 operands = getRegionBranchSuccessorOperands(op, successorIndex);
242 }
243
244 if (!operands) {
245 // We can't reason about the data-flow.
246 return markAllPessimisticFixpoint(lattices);
247 }
248
249 ValueRange inputs = predecessors->getSuccessorInputs(op);
250 assert(inputs.size() == operands->size() &&
251 "expected the same number of successor inputs as operands");
252
253 unsigned firstIndex = 0;
254 if (inputs.size() != lattices.size()) {
255 if (auto *op = point.dyn_cast<Operation *>()) {
256 if (!inputs.empty())
257 firstIndex = inputs.front().cast<OpResult>().getResultNumber();
258 visitNonControlFlowArgumentsImpl(
259 branch,
260 RegionSuccessor(
261 branch->getResults().slice(firstIndex, inputs.size())),
262 lattices, firstIndex);
263 } else {
264 if (!inputs.empty())
265 firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
266 Region *region = point.get<Block *>()->getParent();
267 visitNonControlFlowArgumentsImpl(
268 branch,
269 RegionSuccessor(region, region->getArguments().slice(
270 firstIndex, inputs.size())),
271 lattices, firstIndex);
272 }
273 }
274
275 for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
276 join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
277 }
278 }
279
280 const AbstractSparseLattice *
getLatticeElementFor(ProgramPoint point,Value value)281 AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
282 Value value) {
283 AbstractSparseLattice *state = getLatticeElement(value);
284 addDependency(state, point);
285 return state;
286 }
287
markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice * > lattices)288 void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
289 ArrayRef<AbstractSparseLattice *> lattices) {
290 for (AbstractSparseLattice *lattice : lattices)
291 propagateIfChanged(lattice, lattice->markPessimisticFixpoint());
292 }
293
join(AbstractSparseLattice * lhs,const AbstractSparseLattice & rhs)294 void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
295 const AbstractSparseLattice &rhs) {
296 propagateIfChanged(lhs, lhs->join(rhs));
297 }
298