1 //===- DeadCodeAnalysis.cpp - Dead code analysis --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 
14 using namespace mlir;
15 using namespace mlir::dataflow;
16 
17 //===----------------------------------------------------------------------===//
18 // Executable
19 //===----------------------------------------------------------------------===//
20 
setToLive()21 ChangeResult Executable::setToLive() {
22   if (live)
23     return ChangeResult::NoChange;
24   live = true;
25   return ChangeResult::Change;
26 }
27 
print(raw_ostream & os) const28 void Executable::print(raw_ostream &os) const {
29   os << (live ? "live" : "dead");
30 }
31 
onUpdate(DataFlowSolver * solver) const32 void Executable::onUpdate(DataFlowSolver *solver) const {
33   if (auto *block = point.dyn_cast<Block *>()) {
34     // Re-invoke the analyses on the block itself.
35     for (DataFlowAnalysis *analysis : subscribers)
36       solver->enqueue({block, analysis});
37     // Re-invoke the analyses on all operations in the block.
38     for (DataFlowAnalysis *analysis : subscribers)
39       for (Operation &op : *block)
40         solver->enqueue({&op, analysis});
41   } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
42     // Re-invoke the analysis on the successor block.
43     if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
44       for (DataFlowAnalysis *analysis : subscribers)
45         solver->enqueue({edge->getTo(), analysis});
46     }
47   }
48 }
49 
50 //===----------------------------------------------------------------------===//
51 // PredecessorState
52 //===----------------------------------------------------------------------===//
53 
print(raw_ostream & os) const54 void PredecessorState::print(raw_ostream &os) const {
55   if (allPredecessorsKnown())
56     os << "(all) ";
57   os << "predecessors:\n";
58   for (Operation *op : getKnownPredecessors())
59     os << "  " << *op << "\n";
60 }
61 
join(Operation * predecessor)62 ChangeResult PredecessorState::join(Operation *predecessor) {
63   return knownPredecessors.insert(predecessor) ? ChangeResult::Change
64                                                : ChangeResult::NoChange;
65 }
66 
join(Operation * predecessor,ValueRange inputs)67 ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) {
68   ChangeResult result = join(predecessor);
69   if (!inputs.empty()) {
70     ValueRange &curInputs = successorInputs[predecessor];
71     if (curInputs != inputs) {
72       curInputs = inputs;
73       result |= ChangeResult::Change;
74     }
75   }
76   return result;
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // CFGEdge
81 //===----------------------------------------------------------------------===//
82 
getLoc() const83 Location CFGEdge::getLoc() const {
84   return FusedLoc::get(
85       getFrom()->getParent()->getContext(),
86       {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()});
87 }
88 
print(raw_ostream & os) const89 void CFGEdge::print(raw_ostream &os) const {
90   getFrom()->print(os);
91   os << "\n -> \n";
92   getTo()->print(os);
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // DeadCodeAnalysis
97 //===----------------------------------------------------------------------===//
98 
DeadCodeAnalysis(DataFlowSolver & solver)99 DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
100     : DataFlowAnalysis(solver) {
101   registerPointKind<CFGEdge>();
102 }
103 
initialize(Operation * top)104 LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
105   // Mark the top-level blocks as executable.
106   for (Region &region : top->getRegions()) {
107     if (region.empty())
108       continue;
109     auto *state = getOrCreate<Executable>(&region.front());
110     propagateIfChanged(state, state->setToLive());
111   }
112 
113   // Mark as overdefined the predecessors of symbol callables with potentially
114   // unknown predecessors.
115   initializeSymbolCallables(top);
116 
117   return initializeRecursively(top);
118 }
119 
initializeSymbolCallables(Operation * top)120 void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
121   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
122     Region &symbolTableRegion = symTable->getRegion(0);
123     Block *symbolTableBlock = &symbolTableRegion.front();
124 
125     bool foundSymbolCallable = false;
126     for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
127       Region *callableRegion = callable.getCallableRegion();
128       if (!callableRegion)
129         continue;
130       auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
131       if (!symbol)
132         continue;
133 
134       // Public symbol callables or those for which we can't see all uses have
135       // potentially unknown callsites.
136       if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
137         auto *state = getOrCreate<PredecessorState>(callable);
138         propagateIfChanged(state, state->setHasUnknownPredecessors());
139       }
140       foundSymbolCallable = true;
141     }
142 
143     // Exit early if no eligible symbol callables were found in the table.
144     if (!foundSymbolCallable)
145       return;
146 
147     // Walk the symbol table to check for non-call uses of symbols.
148     Optional<SymbolTable::UseRange> uses =
149         SymbolTable::getSymbolUses(&symbolTableRegion);
150     if (!uses) {
151       // If we couldn't gather the symbol uses, conservatively assume that
152       // we can't track information for any nested symbols.
153       return top->walk([&](CallableOpInterface callable) {
154         auto *state = getOrCreate<PredecessorState>(callable);
155         propagateIfChanged(state, state->setHasUnknownPredecessors());
156       });
157     }
158 
159     for (const SymbolTable::SymbolUse &use : *uses) {
160       if (isa<CallOpInterface>(use.getUser()))
161         continue;
162       // If a callable symbol has a non-call use, then we can't be guaranteed to
163       // know all callsites.
164       Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
165       auto *state = getOrCreate<PredecessorState>(symbol);
166       propagateIfChanged(state, state->setHasUnknownPredecessors());
167     }
168   };
169   SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
170                                 walkFn);
171 }
172 
173 /// Returns true if the operation is a returning terminator in region
174 /// control-flow or the terminator of a callable region.
isRegionOrCallableReturn(Operation * op)175 static bool isRegionOrCallableReturn(Operation *op) {
176   return !op->getNumSuccessors() &&
177          isa<RegionBranchOpInterface, CallableOpInterface>(op->getParentOp()) &&
178          op->getBlock()->getTerminator() == op;
179 }
180 
initializeRecursively(Operation * op)181 LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
182   // Initialize the analysis by visiting every op with control-flow semantics.
183   if (op->getNumRegions() || op->getNumSuccessors() ||
184       isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) {
185     // When the liveness of the parent block changes, make sure to re-invoke the
186     // analysis on the op.
187     if (op->getBlock())
188       getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this);
189     // Visit the op.
190     if (failed(visit(op)))
191       return failure();
192   }
193   // Recurse on nested operations.
194   for (Region &region : op->getRegions())
195     for (Operation &op : region.getOps())
196       if (failed(initializeRecursively(&op)))
197         return failure();
198   return success();
199 }
200 
markEdgeLive(Block * from,Block * to)201 void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
202   auto *state = getOrCreate<Executable>(to);
203   propagateIfChanged(state, state->setToLive());
204   auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
205   propagateIfChanged(edgeState, edgeState->setToLive());
206 }
207 
markEntryBlocksLive(Operation * op)208 void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
209   for (Region &region : op->getRegions()) {
210     if (region.empty())
211       continue;
212     auto *state = getOrCreate<Executable>(&region.front());
213     propagateIfChanged(state, state->setToLive());
214   }
215 }
216 
visit(ProgramPoint point)217 LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
218   if (point.is<Block *>())
219     return success();
220   auto *op = point.dyn_cast<Operation *>();
221   if (!op)
222     return emitError(point.getLoc(), "unknown program point kind");
223 
224   // If the parent block is not executable, there is nothing to do.
225   if (!getOrCreate<Executable>(op->getBlock())->isLive())
226     return success();
227 
228   // We have a live call op. Add this as a live predecessor of the callee.
229   if (auto call = dyn_cast<CallOpInterface>(op))
230     visitCallOperation(call);
231 
232   // Visit the regions.
233   if (op->getNumRegions()) {
234     // Check if we can reason about the region control-flow.
235     if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
236       visitRegionBranchOperation(branch);
237 
238       // Check if this is a callable operation.
239     } else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
240       const auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
241 
242       // If the callsites could not be resolved or are known to be non-empty,
243       // mark the callable as executable.
244       if (!callsites->allPredecessorsKnown() ||
245           !callsites->getKnownPredecessors().empty())
246         markEntryBlocksLive(callable);
247 
248       // Otherwise, conservatively mark all entry blocks as executable.
249     } else {
250       markEntryBlocksLive(op);
251     }
252   }
253 
254   if (isRegionOrCallableReturn(op)) {
255     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
256       // Visit the exiting terminator of a region.
257       visitRegionTerminator(op, branch);
258     } else if (auto callable =
259                    dyn_cast<CallableOpInterface>(op->getParentOp())) {
260       // Visit the exiting terminator of a callable.
261       visitCallableTerminator(op, callable);
262     }
263   }
264   // Visit the successors.
265   if (op->getNumSuccessors()) {
266     // Check if we can reason about the control-flow.
267     if (auto branch = dyn_cast<BranchOpInterface>(op)) {
268       visitBranchOperation(branch);
269 
270       // Otherwise, conservatively mark all successors as exectuable.
271     } else {
272       for (Block *successor : op->getSuccessors())
273         markEdgeLive(op->getBlock(), successor);
274     }
275   }
276 
277   return success();
278 }
279 
visitCallOperation(CallOpInterface call)280 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
281   Operation *callableOp = nullptr;
282   if (Value callableValue = call.getCallableForCallee().dyn_cast<Value>())
283     callableOp = callableValue.getDefiningOp();
284   else
285     callableOp = call.resolveCallable(&symbolTable);
286 
287   // A call to a externally-defined callable has unknown predecessors.
288   const auto isExternalCallable = [](Operation *op) {
289     if (auto callable = dyn_cast<CallableOpInterface>(op))
290       return !callable.getCallableRegion();
291     return false;
292   };
293 
294   // TODO: Add support for non-symbol callables when necessary. If the
295   // callable has non-call uses we would mark as having reached pessimistic
296   // fixpoint, otherwise allow for propagating the return values out.
297   if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
298       !isExternalCallable(callableOp)) {
299     // Add the live callsite.
300     auto *callsites = getOrCreate<PredecessorState>(callableOp);
301     propagateIfChanged(callsites, callsites->join(call));
302   } else {
303     // Mark this call op's predecessors as overdefined.
304     auto *predecessors = getOrCreate<PredecessorState>(call);
305     propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
306   }
307 }
308 
309 /// Get the constant values of the operands of an operation. If any of the
310 /// constant value lattices are uninitialized, return none to indicate the
311 /// analysis should bail out.
getOperandValuesImpl(Operation * op,function_ref<const Lattice<ConstantValue> * (Value)> getLattice)312 static Optional<SmallVector<Attribute>> getOperandValuesImpl(
313     Operation *op,
314     function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
315   SmallVector<Attribute> operands;
316   operands.reserve(op->getNumOperands());
317   for (Value operand : op->getOperands()) {
318     const Lattice<ConstantValue> *cv = getLattice(operand);
319     // If any of the operands' values are uninitialized, bail out.
320     if (cv->isUninitialized())
321       return {};
322     operands.push_back(cv->getValue().getConstantValue());
323   }
324   return operands;
325 }
326 
327 Optional<SmallVector<Attribute>>
getOperandValues(Operation * op)328 DeadCodeAnalysis::getOperandValues(Operation *op) {
329   return getOperandValuesImpl(op, [&](Value value) {
330     auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
331     lattice->useDefSubscribe(this);
332     return lattice;
333   });
334 }
335 
visitBranchOperation(BranchOpInterface branch)336 void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
337   // Try to deduce a single successor for the branch.
338   Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
339   if (!operands)
340     return;
341 
342   if (Block *successor = branch.getSuccessorForOperands(*operands)) {
343     markEdgeLive(branch->getBlock(), successor);
344   } else {
345     // Otherwise, mark all successors as executable and outgoing edges.
346     for (Block *successor : branch->getSuccessors())
347       markEdgeLive(branch->getBlock(), successor);
348   }
349 }
350 
visitRegionBranchOperation(RegionBranchOpInterface branch)351 void DeadCodeAnalysis::visitRegionBranchOperation(
352     RegionBranchOpInterface branch) {
353   // Try to deduce which regions are executable.
354   Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
355   if (!operands)
356     return;
357 
358   SmallVector<RegionSuccessor> successors;
359   branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
360   for (const RegionSuccessor &successor : successors) {
361     // The successor can be either an entry block or the parent operation.
362     ProgramPoint point = successor.getSuccessor()
363                              ? &successor.getSuccessor()->front()
364                              : ProgramPoint(branch);
365     // Mark the entry block as executable.
366     auto *state = getOrCreate<Executable>(point);
367     propagateIfChanged(state, state->setToLive());
368     // Add the parent op as a predecessor.
369     auto *predecessors = getOrCreate<PredecessorState>(point);
370     propagateIfChanged(
371         predecessors,
372         predecessors->join(branch, successor.getSuccessorInputs()));
373   }
374 }
375 
visitRegionTerminator(Operation * op,RegionBranchOpInterface branch)376 void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
377                                              RegionBranchOpInterface branch) {
378   Optional<SmallVector<Attribute>> operands = getOperandValues(op);
379   if (!operands)
380     return;
381 
382   SmallVector<RegionSuccessor> successors;
383   branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
384                              *operands, successors);
385 
386   // Mark successor region entry blocks as executable and add this op to the
387   // list of predecessors.
388   for (const RegionSuccessor &successor : successors) {
389     PredecessorState *predecessors;
390     if (Region *region = successor.getSuccessor()) {
391       auto *state = getOrCreate<Executable>(&region->front());
392       propagateIfChanged(state, state->setToLive());
393       predecessors = getOrCreate<PredecessorState>(&region->front());
394     } else {
395       // Add this terminator as a predecessor to the parent op.
396       predecessors = getOrCreate<PredecessorState>(branch);
397     }
398     propagateIfChanged(predecessors,
399                        predecessors->join(op, successor.getSuccessorInputs()));
400   }
401 }
402 
visitCallableTerminator(Operation * op,CallableOpInterface callable)403 void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
404                                                CallableOpInterface callable) {
405   // If there are no exiting values, we have nothing to do.
406   if (op->getNumOperands() == 0)
407     return;
408 
409   // Add as predecessors to all callsites this return op.
410   auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
411   bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
412   for (Operation *predecessor : callsites->getKnownPredecessors()) {
413     assert(isa<CallOpInterface>(predecessor));
414     auto *predecessors = getOrCreate<PredecessorState>(predecessor);
415     if (canResolve) {
416       propagateIfChanged(predecessors, predecessors->join(op));
417     } else {
418       // If the terminator is not a return-like, then conservatively assume we
419       // can't resolve the predecessor.
420       propagateIfChanged(predecessors,
421                          predecessors->setHasUnknownPredecessors());
422     }
423   }
424 }
425