1 //===- DeadCodeAnalysis.cpp - Dead code analysis --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 
14 using namespace mlir;
15 using namespace mlir::dataflow;
16 
17 //===----------------------------------------------------------------------===//
18 // Executable
19 //===----------------------------------------------------------------------===//
20 
21 ChangeResult Executable::setToLive() {
22   if (live)
23     return ChangeResult::NoChange;
24   live = true;
25   return ChangeResult::Change;
26 }
27 
28 void Executable::print(raw_ostream &os) const {
29   os << (live ? "live" : "dead");
30 }
31 
32 void Executable::onUpdate(DataFlowSolver *solver) const {
33   if (auto *block = point.dyn_cast<Block *>()) {
34     // Re-invoke the analyses on the block itself.
35     for (DataFlowAnalysis *analysis : subscribers)
36       solver->enqueue({block, analysis});
37     // Re-invoke the analyses on all operations in the block.
38     for (DataFlowAnalysis *analysis : subscribers)
39       for (Operation &op : *block)
40         solver->enqueue({&op, analysis});
41   } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
42     // Re-invoke the analysis on the successor block.
43     if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
44       for (DataFlowAnalysis *analysis : subscribers)
45         solver->enqueue({edge->getTo(), analysis});
46     }
47   }
48 }
49 
50 //===----------------------------------------------------------------------===//
51 // PredecessorState
52 //===----------------------------------------------------------------------===//
53 
54 void PredecessorState::print(raw_ostream &os) const {
55   if (allPredecessorsKnown())
56     os << "(all) ";
57   os << "predecessors:\n";
58   for (Operation *op : getKnownPredecessors())
59     os << "  " << *op << "\n";
60 }
61 
62 ChangeResult PredecessorState::join(Operation *predecessor) {
63   return knownPredecessors.insert(predecessor) ? ChangeResult::Change
64                                                : ChangeResult::NoChange;
65 }
66 
67 ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) {
68   ChangeResult result = join(predecessor);
69   if (!inputs.empty()) {
70     ValueRange &curInputs = successorInputs[predecessor];
71     if (curInputs != inputs) {
72       curInputs = inputs;
73       result |= ChangeResult::Change;
74     }
75   }
76   return result;
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // CFGEdge
81 //===----------------------------------------------------------------------===//
82 
83 Location CFGEdge::getLoc() const {
84   return FusedLoc::get(
85       getFrom()->getParent()->getContext(),
86       {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()});
87 }
88 
89 void CFGEdge::print(raw_ostream &os) const {
90   getFrom()->print(os);
91   os << "\n -> \n";
92   getTo()->print(os);
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // DeadCodeAnalysis
97 //===----------------------------------------------------------------------===//
98 
99 DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
100     : DataFlowAnalysis(solver) {
101   registerPointKind<CFGEdge>();
102 }
103 
104 LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
105   // Mark the top-level blocks as executable.
106   for (Region &region : top->getRegions()) {
107     if (region.empty())
108       continue;
109     auto *state = getOrCreate<Executable>(&region.front());
110     propagateIfChanged(state, state->setToLive());
111   }
112 
113   // Mark as overdefined the predecessors of symbol callables with potentially
114   // unknown predecessors.
115   initializeSymbolCallables(top);
116 
117   return initializeRecursively(top);
118 }
119 
120 void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
121   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
122     Region &symbolTableRegion = symTable->getRegion(0);
123     Block *symbolTableBlock = &symbolTableRegion.front();
124 
125     bool foundSymbolCallable = false;
126     for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
127       Region *callableRegion = callable.getCallableRegion();
128       if (!callableRegion)
129         continue;
130       auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
131       if (!symbol)
132         continue;
133 
134       // Public symbol callables or those for which we can't see all uses have
135       // potentially unknown callsites.
136       if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
137         auto *state = getOrCreate<PredecessorState>(callable);
138         propagateIfChanged(state, state->setHasUnknownPredecessors());
139       }
140       foundSymbolCallable = true;
141     }
142 
143     // Exit early if no eligible symbol callables were found in the table.
144     if (!foundSymbolCallable)
145       return;
146 
147     // Walk the symbol table to check for non-call uses of symbols.
148     Optional<SymbolTable::UseRange> uses =
149         SymbolTable::getSymbolUses(&symbolTableRegion);
150     if (!uses) {
151       // If we couldn't gather the symbol uses, conservatively assume that
152       // we can't track information for any nested symbols.
153       return top->walk([&](CallableOpInterface callable) {
154         auto *state = getOrCreate<PredecessorState>(callable);
155         propagateIfChanged(state, state->setHasUnknownPredecessors());
156       });
157     }
158 
159     for (const SymbolTable::SymbolUse &use : *uses) {
160       if (isa<CallOpInterface>(use.getUser()))
161         continue;
162       // If a callable symbol has a non-call use, then we can't be guaranteed to
163       // know all callsites.
164       Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
165       auto *state = getOrCreate<PredecessorState>(symbol);
166       propagateIfChanged(state, state->setHasUnknownPredecessors());
167     }
168   };
169   SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
170                                 walkFn);
171 }
172 
173 LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
174   // Initialize the analysis by visiting every op with control-flow semantics.
175   if (op->getNumRegions() || op->getNumSuccessors() ||
176       op->hasTrait<OpTrait::IsTerminator>() || isa<CallOpInterface>(op)) {
177     // When the liveness of the parent block changes, make sure to re-invoke the
178     // analysis on the op.
179     if (op->getBlock())
180       getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this);
181     // Visit the op.
182     if (failed(visit(op)))
183       return failure();
184   }
185   // Recurse on nested operations.
186   for (Region &region : op->getRegions())
187     for (Operation &op : region.getOps())
188       if (failed(initializeRecursively(&op)))
189         return failure();
190   return success();
191 }
192 
193 void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
194   auto *state = getOrCreate<Executable>(to);
195   propagateIfChanged(state, state->setToLive());
196   auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
197   propagateIfChanged(edgeState, edgeState->setToLive());
198 }
199 
200 void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
201   for (Region &region : op->getRegions()) {
202     if (region.empty())
203       continue;
204     auto *state = getOrCreate<Executable>(&region.front());
205     propagateIfChanged(state, state->setToLive());
206   }
207 }
208 
209 LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
210   if (point.is<Block *>())
211     return success();
212   auto *op = point.dyn_cast<Operation *>();
213   if (!op)
214     return emitError(point.getLoc(), "unknown program point kind");
215 
216   // If the parent block is not executable, there is nothing to do.
217   if (!getOrCreate<Executable>(op->getBlock())->isLive())
218     return success();
219 
220   // We have a live call op. Add this as a live predecessor of the callee.
221   if (auto call = dyn_cast<CallOpInterface>(op))
222     visitCallOperation(call);
223 
224   // Visit the regions.
225   if (op->getNumRegions()) {
226     // Check if we can reason about the region control-flow.
227     if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
228       visitRegionBranchOperation(branch);
229 
230       // Check if this is a callable operation.
231     } else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
232       const auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
233 
234       // If the callsites could not be resolved or are known to be non-empty,
235       // mark the callable as executable.
236       if (!callsites->allPredecessorsKnown() ||
237           !callsites->getKnownPredecessors().empty())
238         markEntryBlocksLive(callable);
239 
240       // Otherwise, conservatively mark all entry blocks as executable.
241     } else {
242       markEntryBlocksLive(op);
243     }
244   }
245 
246   if (op->hasTrait<OpTrait::IsTerminator>() && !op->getNumSuccessors()) {
247     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
248       // Visit the exiting terminator of a region.
249       visitRegionTerminator(op, branch);
250     } else if (auto callable =
251                    dyn_cast<CallableOpInterface>(op->getParentOp())) {
252       // Visit the exiting terminator of a callable.
253       visitCallableTerminator(op, callable);
254     }
255   }
256   // Visit the successors.
257   if (op->getNumSuccessors()) {
258     // Check if we can reason about the control-flow.
259     if (auto branch = dyn_cast<BranchOpInterface>(op)) {
260       visitBranchOperation(branch);
261 
262       // Otherwise, conservatively mark all successors as exectuable.
263     } else {
264       for (Block *successor : op->getSuccessors())
265         markEdgeLive(op->getBlock(), successor);
266     }
267   }
268 
269   return success();
270 }
271 
272 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
273   Operation *callableOp = nullptr;
274   if (Value callableValue = call.getCallableForCallee().dyn_cast<Value>())
275     callableOp = callableValue.getDefiningOp();
276   else
277     callableOp = call.resolveCallable(&symbolTable);
278 
279   // A call to a externally-defined callable has unknown predecessors.
280   const auto isExternalCallable = [](Operation *op) {
281     if (auto callable = dyn_cast<CallableOpInterface>(op))
282       return !callable.getCallableRegion();
283     return false;
284   };
285 
286   // TODO: Add support for non-symbol callables when necessary. If the
287   // callable has non-call uses we would mark as having reached pessimistic
288   // fixpoint, otherwise allow for propagating the return values out.
289   if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
290       !isExternalCallable(callableOp)) {
291     // Add the live callsite.
292     auto *callsites = getOrCreate<PredecessorState>(callableOp);
293     propagateIfChanged(callsites, callsites->join(call));
294   } else {
295     // Mark this call op's predecessors as overdefined.
296     auto *predecessors = getOrCreate<PredecessorState>(call);
297     propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
298   }
299 }
300 
301 /// Get the constant values of the operands of an operation. If any of the
302 /// constant value lattices are uninitialized, return none to indicate the
303 /// analysis should bail out.
304 static Optional<SmallVector<Attribute>> getOperandValuesImpl(
305     Operation *op,
306     function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
307   SmallVector<Attribute> operands;
308   operands.reserve(op->getNumOperands());
309   for (Value operand : op->getOperands()) {
310     const Lattice<ConstantValue> *cv = getLattice(operand);
311     // If any of the operands' values are uninitialized, bail out.
312     if (cv->isUninitialized())
313       return {};
314     operands.push_back(cv->getValue().getConstantValue());
315   }
316   return operands;
317 }
318 
319 Optional<SmallVector<Attribute>>
320 DeadCodeAnalysis::getOperandValues(Operation *op) {
321   return getOperandValuesImpl(op, [&](Value value) {
322     auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
323     lattice->useDefSubscribe(this);
324     return lattice;
325   });
326 }
327 
328 void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
329   // Try to deduce a single successor for the branch.
330   Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
331   if (!operands)
332     return;
333 
334   if (Block *successor = branch.getSuccessorForOperands(*operands)) {
335     markEdgeLive(branch->getBlock(), successor);
336   } else {
337     // Otherwise, mark all successors as executable and outgoing edges.
338     for (Block *successor : branch->getSuccessors())
339       markEdgeLive(branch->getBlock(), successor);
340   }
341 }
342 
343 void DeadCodeAnalysis::visitRegionBranchOperation(
344     RegionBranchOpInterface branch) {
345   // Try to deduce which regions are executable.
346   Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
347   if (!operands)
348     return;
349 
350   SmallVector<RegionSuccessor> successors;
351   branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
352   for (const RegionSuccessor &successor : successors) {
353     // The successor can be either an entry block or the parent operation.
354     ProgramPoint point = successor.getSuccessor()
355                              ? &successor.getSuccessor()->front()
356                              : ProgramPoint(branch);
357     // Mark the entry block as executable.
358     auto *state = getOrCreate<Executable>(point);
359     propagateIfChanged(state, state->setToLive());
360     // Add the parent op as a predecessor.
361     auto *predecessors = getOrCreate<PredecessorState>(point);
362     propagateIfChanged(
363         predecessors,
364         predecessors->join(branch, successor.getSuccessorInputs()));
365   }
366 }
367 
368 void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
369                                              RegionBranchOpInterface branch) {
370   Optional<SmallVector<Attribute>> operands = getOperandValues(op);
371   if (!operands)
372     return;
373 
374   SmallVector<RegionSuccessor> successors;
375   branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
376                              *operands, successors);
377 
378   // Mark successor region entry blocks as executable and add this op to the
379   // list of predecessors.
380   for (const RegionSuccessor &successor : successors) {
381     PredecessorState *predecessors;
382     if (Region *region = successor.getSuccessor()) {
383       auto *state = getOrCreate<Executable>(&region->front());
384       propagateIfChanged(state, state->setToLive());
385       predecessors = getOrCreate<PredecessorState>(&region->front());
386     } else {
387       // Add this terminator as a predecessor to the parent op.
388       predecessors = getOrCreate<PredecessorState>(branch);
389     }
390     propagateIfChanged(predecessors,
391                        predecessors->join(op, successor.getSuccessorInputs()));
392   }
393 }
394 
395 void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
396                                                CallableOpInterface callable) {
397   // If there are no exiting values, we have nothing to do.
398   if (op->getNumOperands() == 0)
399     return;
400 
401   // Add as predecessors to all callsites this return op.
402   auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
403   bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
404   for (Operation *predecessor : callsites->getKnownPredecessors()) {
405     assert(isa<CallOpInterface>(predecessor));
406     auto *predecessors = getOrCreate<PredecessorState>(predecessor);
407     if (canResolve) {
408       propagateIfChanged(predecessors, predecessors->join(op));
409     } else {
410       // If the terminator is not a return-like, then conservatively assume we
411       // can't resolve the predecessor.
412       propagateIfChanged(predecessors,
413                          predecessors->setHasUnknownPredecessors());
414     }
415   }
416 }
417