1 //===- DeadCodeAnalysis.cpp - Dead code analysis --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13
14 using namespace mlir;
15 using namespace mlir::dataflow;
16
17 //===----------------------------------------------------------------------===//
18 // Executable
19 //===----------------------------------------------------------------------===//
20
setToLive()21 ChangeResult Executable::setToLive() {
22 if (live)
23 return ChangeResult::NoChange;
24 live = true;
25 return ChangeResult::Change;
26 }
27
print(raw_ostream & os) const28 void Executable::print(raw_ostream &os) const {
29 os << (live ? "live" : "dead");
30 }
31
onUpdate(DataFlowSolver * solver) const32 void Executable::onUpdate(DataFlowSolver *solver) const {
33 if (auto *block = point.dyn_cast<Block *>()) {
34 // Re-invoke the analyses on the block itself.
35 for (DataFlowAnalysis *analysis : subscribers)
36 solver->enqueue({block, analysis});
37 // Re-invoke the analyses on all operations in the block.
38 for (DataFlowAnalysis *analysis : subscribers)
39 for (Operation &op : *block)
40 solver->enqueue({&op, analysis});
41 } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
42 // Re-invoke the analysis on the successor block.
43 if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
44 for (DataFlowAnalysis *analysis : subscribers)
45 solver->enqueue({edge->getTo(), analysis});
46 }
47 }
48 }
49
50 //===----------------------------------------------------------------------===//
51 // PredecessorState
52 //===----------------------------------------------------------------------===//
53
print(raw_ostream & os) const54 void PredecessorState::print(raw_ostream &os) const {
55 if (allPredecessorsKnown())
56 os << "(all) ";
57 os << "predecessors:\n";
58 for (Operation *op : getKnownPredecessors())
59 os << " " << *op << "\n";
60 }
61
join(Operation * predecessor)62 ChangeResult PredecessorState::join(Operation *predecessor) {
63 return knownPredecessors.insert(predecessor) ? ChangeResult::Change
64 : ChangeResult::NoChange;
65 }
66
join(Operation * predecessor,ValueRange inputs)67 ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) {
68 ChangeResult result = join(predecessor);
69 if (!inputs.empty()) {
70 ValueRange &curInputs = successorInputs[predecessor];
71 if (curInputs != inputs) {
72 curInputs = inputs;
73 result |= ChangeResult::Change;
74 }
75 }
76 return result;
77 }
78
79 //===----------------------------------------------------------------------===//
80 // CFGEdge
81 //===----------------------------------------------------------------------===//
82
getLoc() const83 Location CFGEdge::getLoc() const {
84 return FusedLoc::get(
85 getFrom()->getParent()->getContext(),
86 {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()});
87 }
88
print(raw_ostream & os) const89 void CFGEdge::print(raw_ostream &os) const {
90 getFrom()->print(os);
91 os << "\n -> \n";
92 getTo()->print(os);
93 }
94
95 //===----------------------------------------------------------------------===//
96 // DeadCodeAnalysis
97 //===----------------------------------------------------------------------===//
98
DeadCodeAnalysis(DataFlowSolver & solver)99 DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
100 : DataFlowAnalysis(solver) {
101 registerPointKind<CFGEdge>();
102 }
103
initialize(Operation * top)104 LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
105 // Mark the top-level blocks as executable.
106 for (Region ®ion : top->getRegions()) {
107 if (region.empty())
108 continue;
109 auto *state = getOrCreate<Executable>(®ion.front());
110 propagateIfChanged(state, state->setToLive());
111 }
112
113 // Mark as overdefined the predecessors of symbol callables with potentially
114 // unknown predecessors.
115 initializeSymbolCallables(top);
116
117 return initializeRecursively(top);
118 }
119
initializeSymbolCallables(Operation * top)120 void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
121 auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
122 Region &symbolTableRegion = symTable->getRegion(0);
123 Block *symbolTableBlock = &symbolTableRegion.front();
124
125 bool foundSymbolCallable = false;
126 for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
127 Region *callableRegion = callable.getCallableRegion();
128 if (!callableRegion)
129 continue;
130 auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
131 if (!symbol)
132 continue;
133
134 // Public symbol callables or those for which we can't see all uses have
135 // potentially unknown callsites.
136 if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
137 auto *state = getOrCreate<PredecessorState>(callable);
138 propagateIfChanged(state, state->setHasUnknownPredecessors());
139 }
140 foundSymbolCallable = true;
141 }
142
143 // Exit early if no eligible symbol callables were found in the table.
144 if (!foundSymbolCallable)
145 return;
146
147 // Walk the symbol table to check for non-call uses of symbols.
148 Optional<SymbolTable::UseRange> uses =
149 SymbolTable::getSymbolUses(&symbolTableRegion);
150 if (!uses) {
151 // If we couldn't gather the symbol uses, conservatively assume that
152 // we can't track information for any nested symbols.
153 return top->walk([&](CallableOpInterface callable) {
154 auto *state = getOrCreate<PredecessorState>(callable);
155 propagateIfChanged(state, state->setHasUnknownPredecessors());
156 });
157 }
158
159 for (const SymbolTable::SymbolUse &use : *uses) {
160 if (isa<CallOpInterface>(use.getUser()))
161 continue;
162 // If a callable symbol has a non-call use, then we can't be guaranteed to
163 // know all callsites.
164 Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
165 auto *state = getOrCreate<PredecessorState>(symbol);
166 propagateIfChanged(state, state->setHasUnknownPredecessors());
167 }
168 };
169 SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
170 walkFn);
171 }
172
173 /// Returns true if the operation is a returning terminator in region
174 /// control-flow or the terminator of a callable region.
isRegionOrCallableReturn(Operation * op)175 static bool isRegionOrCallableReturn(Operation *op) {
176 return !op->getNumSuccessors() &&
177 isa<RegionBranchOpInterface, CallableOpInterface>(op->getParentOp()) &&
178 op->getBlock()->getTerminator() == op;
179 }
180
initializeRecursively(Operation * op)181 LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
182 // Initialize the analysis by visiting every op with control-flow semantics.
183 if (op->getNumRegions() || op->getNumSuccessors() ||
184 isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) {
185 // When the liveness of the parent block changes, make sure to re-invoke the
186 // analysis on the op.
187 if (op->getBlock())
188 getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this);
189 // Visit the op.
190 if (failed(visit(op)))
191 return failure();
192 }
193 // Recurse on nested operations.
194 for (Region ®ion : op->getRegions())
195 for (Operation &op : region.getOps())
196 if (failed(initializeRecursively(&op)))
197 return failure();
198 return success();
199 }
200
markEdgeLive(Block * from,Block * to)201 void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
202 auto *state = getOrCreate<Executable>(to);
203 propagateIfChanged(state, state->setToLive());
204 auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
205 propagateIfChanged(edgeState, edgeState->setToLive());
206 }
207
markEntryBlocksLive(Operation * op)208 void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
209 for (Region ®ion : op->getRegions()) {
210 if (region.empty())
211 continue;
212 auto *state = getOrCreate<Executable>(®ion.front());
213 propagateIfChanged(state, state->setToLive());
214 }
215 }
216
visit(ProgramPoint point)217 LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
218 if (point.is<Block *>())
219 return success();
220 auto *op = point.dyn_cast<Operation *>();
221 if (!op)
222 return emitError(point.getLoc(), "unknown program point kind");
223
224 // If the parent block is not executable, there is nothing to do.
225 if (!getOrCreate<Executable>(op->getBlock())->isLive())
226 return success();
227
228 // We have a live call op. Add this as a live predecessor of the callee.
229 if (auto call = dyn_cast<CallOpInterface>(op))
230 visitCallOperation(call);
231
232 // Visit the regions.
233 if (op->getNumRegions()) {
234 // Check if we can reason about the region control-flow.
235 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
236 visitRegionBranchOperation(branch);
237
238 // Check if this is a callable operation.
239 } else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
240 const auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
241
242 // If the callsites could not be resolved or are known to be non-empty,
243 // mark the callable as executable.
244 if (!callsites->allPredecessorsKnown() ||
245 !callsites->getKnownPredecessors().empty())
246 markEntryBlocksLive(callable);
247
248 // Otherwise, conservatively mark all entry blocks as executable.
249 } else {
250 markEntryBlocksLive(op);
251 }
252 }
253
254 if (isRegionOrCallableReturn(op)) {
255 if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
256 // Visit the exiting terminator of a region.
257 visitRegionTerminator(op, branch);
258 } else if (auto callable =
259 dyn_cast<CallableOpInterface>(op->getParentOp())) {
260 // Visit the exiting terminator of a callable.
261 visitCallableTerminator(op, callable);
262 }
263 }
264 // Visit the successors.
265 if (op->getNumSuccessors()) {
266 // Check if we can reason about the control-flow.
267 if (auto branch = dyn_cast<BranchOpInterface>(op)) {
268 visitBranchOperation(branch);
269
270 // Otherwise, conservatively mark all successors as exectuable.
271 } else {
272 for (Block *successor : op->getSuccessors())
273 markEdgeLive(op->getBlock(), successor);
274 }
275 }
276
277 return success();
278 }
279
visitCallOperation(CallOpInterface call)280 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
281 Operation *callableOp = nullptr;
282 if (Value callableValue = call.getCallableForCallee().dyn_cast<Value>())
283 callableOp = callableValue.getDefiningOp();
284 else
285 callableOp = call.resolveCallable(&symbolTable);
286
287 // A call to a externally-defined callable has unknown predecessors.
288 const auto isExternalCallable = [](Operation *op) {
289 if (auto callable = dyn_cast<CallableOpInterface>(op))
290 return !callable.getCallableRegion();
291 return false;
292 };
293
294 // TODO: Add support for non-symbol callables when necessary. If the
295 // callable has non-call uses we would mark as having reached pessimistic
296 // fixpoint, otherwise allow for propagating the return values out.
297 if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
298 !isExternalCallable(callableOp)) {
299 // Add the live callsite.
300 auto *callsites = getOrCreate<PredecessorState>(callableOp);
301 propagateIfChanged(callsites, callsites->join(call));
302 } else {
303 // Mark this call op's predecessors as overdefined.
304 auto *predecessors = getOrCreate<PredecessorState>(call);
305 propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
306 }
307 }
308
309 /// Get the constant values of the operands of an operation. If any of the
310 /// constant value lattices are uninitialized, return none to indicate the
311 /// analysis should bail out.
getOperandValuesImpl(Operation * op,function_ref<const Lattice<ConstantValue> * (Value)> getLattice)312 static Optional<SmallVector<Attribute>> getOperandValuesImpl(
313 Operation *op,
314 function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
315 SmallVector<Attribute> operands;
316 operands.reserve(op->getNumOperands());
317 for (Value operand : op->getOperands()) {
318 const Lattice<ConstantValue> *cv = getLattice(operand);
319 // If any of the operands' values are uninitialized, bail out.
320 if (cv->isUninitialized())
321 return {};
322 operands.push_back(cv->getValue().getConstantValue());
323 }
324 return operands;
325 }
326
327 Optional<SmallVector<Attribute>>
getOperandValues(Operation * op)328 DeadCodeAnalysis::getOperandValues(Operation *op) {
329 return getOperandValuesImpl(op, [&](Value value) {
330 auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
331 lattice->useDefSubscribe(this);
332 return lattice;
333 });
334 }
335
visitBranchOperation(BranchOpInterface branch)336 void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
337 // Try to deduce a single successor for the branch.
338 Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
339 if (!operands)
340 return;
341
342 if (Block *successor = branch.getSuccessorForOperands(*operands)) {
343 markEdgeLive(branch->getBlock(), successor);
344 } else {
345 // Otherwise, mark all successors as executable and outgoing edges.
346 for (Block *successor : branch->getSuccessors())
347 markEdgeLive(branch->getBlock(), successor);
348 }
349 }
350
visitRegionBranchOperation(RegionBranchOpInterface branch)351 void DeadCodeAnalysis::visitRegionBranchOperation(
352 RegionBranchOpInterface branch) {
353 // Try to deduce which regions are executable.
354 Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
355 if (!operands)
356 return;
357
358 SmallVector<RegionSuccessor> successors;
359 branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
360 for (const RegionSuccessor &successor : successors) {
361 // The successor can be either an entry block or the parent operation.
362 ProgramPoint point = successor.getSuccessor()
363 ? &successor.getSuccessor()->front()
364 : ProgramPoint(branch);
365 // Mark the entry block as executable.
366 auto *state = getOrCreate<Executable>(point);
367 propagateIfChanged(state, state->setToLive());
368 // Add the parent op as a predecessor.
369 auto *predecessors = getOrCreate<PredecessorState>(point);
370 propagateIfChanged(
371 predecessors,
372 predecessors->join(branch, successor.getSuccessorInputs()));
373 }
374 }
375
visitRegionTerminator(Operation * op,RegionBranchOpInterface branch)376 void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
377 RegionBranchOpInterface branch) {
378 Optional<SmallVector<Attribute>> operands = getOperandValues(op);
379 if (!operands)
380 return;
381
382 SmallVector<RegionSuccessor> successors;
383 branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
384 *operands, successors);
385
386 // Mark successor region entry blocks as executable and add this op to the
387 // list of predecessors.
388 for (const RegionSuccessor &successor : successors) {
389 PredecessorState *predecessors;
390 if (Region *region = successor.getSuccessor()) {
391 auto *state = getOrCreate<Executable>(®ion->front());
392 propagateIfChanged(state, state->setToLive());
393 predecessors = getOrCreate<PredecessorState>(®ion->front());
394 } else {
395 // Add this terminator as a predecessor to the parent op.
396 predecessors = getOrCreate<PredecessorState>(branch);
397 }
398 propagateIfChanged(predecessors,
399 predecessors->join(op, successor.getSuccessorInputs()));
400 }
401 }
402
visitCallableTerminator(Operation * op,CallableOpInterface callable)403 void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
404 CallableOpInterface callable) {
405 // If there are no exiting values, we have nothing to do.
406 if (op->getNumOperands() == 0)
407 return;
408
409 // Add as predecessors to all callsites this return op.
410 auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
411 bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
412 for (Operation *predecessor : callsites->getKnownPredecessors()) {
413 assert(isa<CallOpInterface>(predecessor));
414 auto *predecessors = getOrCreate<PredecessorState>(predecessor);
415 if (canResolve) {
416 propagateIfChanged(predecessors, predecessors->join(op));
417 } else {
418 // If the terminator is not a return-like, then conservatively assume we
419 // can't resolve the predecessor.
420 propagateIfChanged(predecessors,
421 predecessors->setHasUnknownPredecessors());
422 }
423 }
424 }
425