1 //===- DeadCodeAnalysis.cpp - Dead code analysis --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 
14 using namespace mlir;
15 using namespace mlir::dataflow;
16 
17 //===----------------------------------------------------------------------===//
18 // Executable
19 //===----------------------------------------------------------------------===//
20 
21 ChangeResult Executable::setToLive() {
22   if (live)
23     return ChangeResult::NoChange;
24   live = true;
25   return ChangeResult::Change;
26 }
27 
28 void Executable::print(raw_ostream &os) const {
29   os << (live ? "live" : "dead");
30 }
31 
32 void Executable::onUpdate(DataFlowSolver *solver) const {
33   if (auto *block = point.dyn_cast<Block *>()) {
34     // Re-invoke the analyses on the block itself.
35     for (DataFlowAnalysis *analysis : subscribers)
36       solver->enqueue({block, analysis});
37     // Re-invoke the analyses on all operations in the block.
38     for (DataFlowAnalysis *analysis : subscribers)
39       for (Operation &op : *block)
40         solver->enqueue({&op, analysis});
41   } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
42     // Re-invoke the analysis on the successor block.
43     if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
44       for (DataFlowAnalysis *analysis : subscribers)
45         solver->enqueue({edge->getTo(), analysis});
46     }
47   }
48 }
49 
50 //===----------------------------------------------------------------------===//
51 // PredecessorState
52 //===----------------------------------------------------------------------===//
53 
54 void PredecessorState::print(raw_ostream &os) const {
55   if (allPredecessorsKnown())
56     os << "(all) ";
57   os << "predecessors:\n";
58   for (Operation *op : getKnownPredecessors())
59     os << "  " << *op << "\n";
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // CFGEdge
64 //===----------------------------------------------------------------------===//
65 
66 Location CFGEdge::getLoc() const {
67   return FusedLoc::get(
68       getFrom()->getParent()->getContext(),
69       {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()});
70 }
71 
72 void CFGEdge::print(raw_ostream &os) const {
73   getFrom()->print(os);
74   os << "\n -> \n";
75   getTo()->print(os);
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // DeadCodeAnalysis
80 //===----------------------------------------------------------------------===//
81 
82 DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
83     : DataFlowAnalysis(solver) {
84   registerPointKind<CFGEdge>();
85 }
86 
87 LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
88   // Mark the top-level blocks as executable.
89   for (Region &region : top->getRegions()) {
90     if (region.empty())
91       continue;
92     auto *state = getOrCreate<Executable>(&region.front());
93     propagateIfChanged(state, state->setToLive());
94   }
95 
96   // Mark as overdefined the predecessors of symbol callables with potentially
97   // unknown predecessors.
98   initializeSymbolCallables(top);
99 
100   return initializeRecursively(top);
101 }
102 
103 void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
104   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
105     Region &symbolTableRegion = symTable->getRegion(0);
106     Block *symbolTableBlock = &symbolTableRegion.front();
107 
108     bool foundSymbolCallable = false;
109     for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
110       Region *callableRegion = callable.getCallableRegion();
111       if (!callableRegion)
112         continue;
113       auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
114       if (!symbol)
115         continue;
116 
117       // Public symbol callables or those for which we can't see all uses have
118       // potentially unknown callsites.
119       if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
120         auto *state = getOrCreate<PredecessorState>(callable);
121         propagateIfChanged(state, state->setHasUnknownPredecessors());
122       }
123       foundSymbolCallable = true;
124     }
125 
126     // Exit early if no eligible symbol callables were found in the table.
127     if (!foundSymbolCallable)
128       return;
129 
130     // Walk the symbol table to check for non-call uses of symbols.
131     Optional<SymbolTable::UseRange> uses =
132         SymbolTable::getSymbolUses(&symbolTableRegion);
133     if (!uses) {
134       // If we couldn't gather the symbol uses, conservatively assume that
135       // we can't track information for any nested symbols.
136       return top->walk([&](CallableOpInterface callable) {
137         auto *state = getOrCreate<PredecessorState>(callable);
138         propagateIfChanged(state, state->setHasUnknownPredecessors());
139       });
140     }
141 
142     for (const SymbolTable::SymbolUse &use : *uses) {
143       if (isa<CallOpInterface>(use.getUser()))
144         continue;
145       // If a callable symbol has a non-call use, then we can't be guaranteed to
146       // know all callsites.
147       Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
148       auto *state = getOrCreate<PredecessorState>(symbol);
149       propagateIfChanged(state, state->setHasUnknownPredecessors());
150     }
151   };
152   SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
153                                 walkFn);
154 }
155 
156 LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
157   // Initialize the analysis by visiting every op with control-flow semantics.
158   if (op->getNumRegions() || op->getNumSuccessors() ||
159       op->hasTrait<OpTrait::IsTerminator>() || isa<CallOpInterface>(op)) {
160     // When the liveness of the parent block changes, make sure to re-invoke the
161     // analysis on the op.
162     if (op->getBlock())
163       getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this);
164     // Visit the op.
165     if (failed(visit(op)))
166       return failure();
167   }
168   // Recurse on nested operations.
169   for (Region &region : op->getRegions())
170     for (Operation &op : region.getOps())
171       if (failed(initializeRecursively(&op)))
172         return failure();
173   return success();
174 }
175 
176 void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
177   auto *state = getOrCreate<Executable>(to);
178   propagateIfChanged(state, state->setToLive());
179   auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
180   propagateIfChanged(edgeState, edgeState->setToLive());
181 }
182 
183 void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
184   for (Region &region : op->getRegions()) {
185     if (region.empty())
186       continue;
187     auto *state = getOrCreate<Executable>(&region.front());
188     propagateIfChanged(state, state->setToLive());
189   }
190 }
191 
192 LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
193   if (point.is<Block *>())
194     return success();
195   auto *op = point.dyn_cast<Operation *>();
196   if (!op)
197     return emitError(point.getLoc(), "unknown program point kind");
198 
199   // If the parent block is not executable, there is nothing to do.
200   if (!getOrCreate<Executable>(op->getBlock())->isLive())
201     return success();
202 
203   // We have a live call op. Add this as a live predecessor of the callee.
204   if (auto call = dyn_cast<CallOpInterface>(op))
205     visitCallOperation(call);
206 
207   // Visit the regions.
208   if (op->getNumRegions()) {
209     // Check if we can reason about the region control-flow.
210     if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
211       visitRegionBranchOperation(branch);
212 
213       // Check if this is a callable operation.
214     } else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
215       const auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
216 
217       // If the callsites could not be resolved or are known to be non-empty,
218       // mark the callable as executable.
219       if (!callsites->allPredecessorsKnown() ||
220           !callsites->getKnownPredecessors().empty())
221         markEntryBlocksLive(callable);
222 
223       // Otherwise, conservatively mark all entry blocks as executable.
224     } else {
225       markEntryBlocksLive(op);
226     }
227   }
228 
229   if (op->hasTrait<OpTrait::IsTerminator>() && !op->getNumSuccessors()) {
230     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
231       // Visit the exiting terminator of a region.
232       visitRegionTerminator(op, branch);
233     } else if (auto callable =
234                    dyn_cast<CallableOpInterface>(op->getParentOp())) {
235       // Visit the exiting terminator of a callable.
236       visitCallableTerminator(op, callable);
237     }
238   }
239   // Visit the successors.
240   if (op->getNumSuccessors()) {
241     // Check if we can reason about the control-flow.
242     if (auto branch = dyn_cast<BranchOpInterface>(op)) {
243       visitBranchOperation(branch);
244 
245       // Otherwise, conservatively mark all successors as exectuable.
246     } else {
247       for (Block *successor : op->getSuccessors())
248         markEdgeLive(op->getBlock(), successor);
249     }
250   }
251 
252   return success();
253 }
254 
255 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
256   Operation *callableOp = nullptr;
257   if (Value callableValue = call.getCallableForCallee().dyn_cast<Value>())
258     callableOp = callableValue.getDefiningOp();
259   else
260     callableOp = call.resolveCallable(&symbolTable);
261 
262   // A call to a externally-defined callable has unknown predecessors.
263   const auto isExternalCallable = [](Operation *op) {
264     if (auto callable = dyn_cast<CallableOpInterface>(op))
265       return !callable.getCallableRegion();
266     return false;
267   };
268 
269   // TODO: Add support for non-symbol callables when necessary. If the
270   // callable has non-call uses we would mark as having reached pessimistic
271   // fixpoint, otherwise allow for propagating the return values out.
272   if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
273       !isExternalCallable(callableOp)) {
274     // Add the live callsite.
275     auto *callsites = getOrCreate<PredecessorState>(callableOp);
276     propagateIfChanged(callsites, callsites->join(call));
277   } else {
278     // Mark this call op's predecessors as overdefined.
279     auto *predecessors = getOrCreate<PredecessorState>(call);
280     propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
281   }
282 }
283 
284 /// Get the constant values of the operands of an operation. If any of the
285 /// constant value lattices are uninitialized, return none to indicate the
286 /// analysis should bail out.
287 static Optional<SmallVector<Attribute>> getOperandValuesImpl(
288     Operation *op,
289     function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
290   SmallVector<Attribute> operands;
291   operands.reserve(op->getNumOperands());
292   for (Value operand : op->getOperands()) {
293     const Lattice<ConstantValue> *cv = getLattice(operand);
294     // If any of the operands' values are uninitialized, bail out.
295     if (cv->isUninitialized())
296       return {};
297     operands.push_back(cv->getValue().getConstantValue());
298   }
299   return operands;
300 }
301 
302 Optional<SmallVector<Attribute>>
303 DeadCodeAnalysis::getOperandValues(Operation *op) {
304   return getOperandValuesImpl(op, [&](Value value) {
305     auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
306     lattice->useDefSubscribe(this);
307     return lattice;
308   });
309 }
310 
311 void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
312   // Try to deduce a single successor for the branch.
313   Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
314   if (!operands)
315     return;
316 
317   if (Block *successor = branch.getSuccessorForOperands(*operands)) {
318     markEdgeLive(branch->getBlock(), successor);
319   } else {
320     // Otherwise, mark all successors as executable and outgoing edges.
321     for (Block *successor : branch->getSuccessors())
322       markEdgeLive(branch->getBlock(), successor);
323   }
324 }
325 
326 void DeadCodeAnalysis::visitRegionBranchOperation(
327     RegionBranchOpInterface branch) {
328   // Try to deduce which regions are executable.
329   Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
330   if (!operands)
331     return;
332 
333   SmallVector<RegionSuccessor> successors;
334   branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
335   for (const RegionSuccessor &successor : successors) {
336     // Mark the entry block as executable.
337     Region *region = successor.getSuccessor();
338     assert(region && "expected a region successor");
339     auto *state = getOrCreate<Executable>(&region->front());
340     propagateIfChanged(state, state->setToLive());
341     // Add the parent op as a predecessor.
342     auto *predecessors = getOrCreate<PredecessorState>(&region->front());
343     propagateIfChanged(predecessors, predecessors->join(branch));
344   }
345 }
346 
347 void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
348                                              RegionBranchOpInterface branch) {
349   Optional<SmallVector<Attribute>> operands = getOperandValues(op);
350   if (!operands)
351     return;
352 
353   SmallVector<RegionSuccessor> successors;
354   branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
355                              *operands, successors);
356 
357   // Mark successor region entry blocks as executable and add this op to the
358   // list of predecessors.
359   for (const RegionSuccessor &successor : successors) {
360     PredecessorState *predecessors;
361     if (Region *region = successor.getSuccessor()) {
362       auto *state = getOrCreate<Executable>(&region->front());
363       propagateIfChanged(state, state->setToLive());
364       predecessors = getOrCreate<PredecessorState>(&region->front());
365     } else {
366       // Add this terminator as a predecessor to the parent op.
367       predecessors = getOrCreate<PredecessorState>(branch);
368     }
369     propagateIfChanged(predecessors, predecessors->join(op));
370   }
371 }
372 
373 void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
374                                                CallableOpInterface callable) {
375   // If there are no exiting values, we have nothing to do.
376   if (op->getNumOperands() == 0)
377     return;
378 
379   // Add as predecessors to all callsites this return op.
380   auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
381   bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
382   for (Operation *predecessor : callsites->getKnownPredecessors()) {
383     assert(isa<CallOpInterface>(predecessor));
384     auto *predecessors = getOrCreate<PredecessorState>(predecessor);
385     if (canResolve) {
386       propagateIfChanged(predecessors, predecessors->join(op));
387     } else {
388       // If the terminator is not a return-like, then conservatively assume we
389       // can't resolve the predecessor.
390       propagateIfChanged(predecessors,
391                          predecessors->setHasUnknownPredecessors());
392     }
393   }
394 }
395