1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 
13 using namespace mlir;
14 using namespace mlir::dataflow;
15 
16 //===----------------------------------------------------------------------===//
17 // AbstractSparseLattice
18 //===----------------------------------------------------------------------===//
19 
onUpdate(DataFlowSolver * solver) const20 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
21   // Push all users of the value to the queue.
22   for (Operation *user : point.get<Value>().getUsers())
23     for (DataFlowAnalysis *analysis : useDefSubscribers)
24       solver->enqueue({user, analysis});
25 }
26 
27 //===----------------------------------------------------------------------===//
28 // AbstractSparseDataFlowAnalysis
29 //===----------------------------------------------------------------------===//
30 
AbstractSparseDataFlowAnalysis(DataFlowSolver & solver)31 AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis(
32     DataFlowSolver &solver)
33     : DataFlowAnalysis(solver) {
34   registerPointKind<CFGEdge>();
35 }
36 
initialize(Operation * top)37 LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
38   // Mark the entry block arguments as having reached their pessimistic
39   // fixpoints.
40   for (Region &region : top->getRegions()) {
41     if (region.empty())
42       continue;
43     for (Value argument : region.front().getArguments())
44       markAllPessimisticFixpoint(getLatticeElement(argument));
45   }
46 
47   return initializeRecursively(top);
48 }
49 
50 LogicalResult
initializeRecursively(Operation * op)51 AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
52   // Initialize the analysis by visiting every owner of an SSA value (all
53   // operations and blocks).
54   visitOperation(op);
55   for (Region &region : op->getRegions()) {
56     for (Block &block : region) {
57       getOrCreate<Executable>(&block)->blockContentSubscribe(this);
58       visitBlock(&block);
59       for (Operation &op : block)
60         if (failed(initializeRecursively(&op)))
61           return failure();
62     }
63   }
64 
65   return success();
66 }
67 
visit(ProgramPoint point)68 LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
69   if (Operation *op = point.dyn_cast<Operation *>())
70     visitOperation(op);
71   else if (Block *block = point.dyn_cast<Block *>())
72     visitBlock(block);
73   else
74     return failure();
75   return success();
76 }
77 
visitOperation(Operation * op)78 void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
79   // Exit early on operations with no results.
80   if (op->getNumResults() == 0)
81     return;
82 
83   // If the containing block is not executable, bail out.
84   if (!getOrCreate<Executable>(op->getBlock())->isLive())
85     return;
86 
87   // Get the result lattices.
88   SmallVector<AbstractSparseLattice *> resultLattices;
89   resultLattices.reserve(op->getNumResults());
90   // Track whether all results have reached their fixpoint.
91   bool allAtFixpoint = true;
92   for (Value result : op->getResults()) {
93     AbstractSparseLattice *resultLattice = getLatticeElement(result);
94     allAtFixpoint &= resultLattice->isAtFixpoint();
95     resultLattices.push_back(resultLattice);
96   }
97   // If all result lattices have reached a fixpoint, there is nothing to do.
98   if (allAtFixpoint)
99     return;
100 
101   // The results of a region branch operation are determined by control-flow.
102   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
103     return visitRegionSuccessors({branch}, branch,
104                                  /*successorIndex=*/llvm::None, resultLattices);
105   }
106 
107   // The results of a call operation are determined by the callgraph.
108   if (auto call = dyn_cast<CallOpInterface>(op)) {
109     const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
110     // If not all return sites are known, then conservatively assume we can't
111     // reason about the data-flow.
112     if (!predecessors->allPredecessorsKnown())
113       return markAllPessimisticFixpoint(resultLattices);
114     for (Operation *predecessor : predecessors->getKnownPredecessors())
115       for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
116         join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
117     return;
118   }
119 
120   // Grab the lattice elements of the operands.
121   SmallVector<const AbstractSparseLattice *> operandLattices;
122   operandLattices.reserve(op->getNumOperands());
123   for (Value operand : op->getOperands()) {
124     AbstractSparseLattice *operandLattice = getLatticeElement(operand);
125     operandLattice->useDefSubscribe(this);
126     // If any of the operand states are not initialized, bail out.
127     if (operandLattice->isUninitialized())
128       return;
129     operandLattices.push_back(operandLattice);
130   }
131 
132   // Invoke the operation transfer function.
133   visitOperationImpl(op, operandLattices, resultLattices);
134 }
135 
visitBlock(Block * block)136 void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
137   // Exit early on blocks with no arguments.
138   if (block->getNumArguments() == 0)
139     return;
140 
141   // If the block is not executable, bail out.
142   if (!getOrCreate<Executable>(block)->isLive())
143     return;
144 
145   // Get the argument lattices.
146   SmallVector<AbstractSparseLattice *> argLattices;
147   argLattices.reserve(block->getNumArguments());
148   bool allAtFixpoint = true;
149   for (BlockArgument argument : block->getArguments()) {
150     AbstractSparseLattice *argLattice = getLatticeElement(argument);
151     allAtFixpoint &= argLattice->isAtFixpoint();
152     argLattices.push_back(argLattice);
153   }
154   // If all argument lattices have reached their fixpoints, then there is
155   // nothing to do.
156   if (allAtFixpoint)
157     return;
158 
159   // The argument lattices of entry blocks are set by region control-flow or the
160   // callgraph.
161   if (block->isEntryBlock()) {
162     // Check if this block is the entry block of a callable region.
163     auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
164     if (callable && callable.getCallableRegion() == block->getParent()) {
165       const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
166       // If not all callsites are known, conservatively mark all lattices as
167       // having reached their pessimistic fixpoints.
168       if (!callsites->allPredecessorsKnown())
169         return markAllPessimisticFixpoint(argLattices);
170       for (Operation *callsite : callsites->getKnownPredecessors()) {
171         auto call = cast<CallOpInterface>(callsite);
172         for (auto it : llvm::zip(call.getArgOperands(), argLattices))
173           join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
174       }
175       return;
176     }
177 
178     // Check if the lattices can be determined from region control flow.
179     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
180       return visitRegionSuccessors(
181           block, branch, block->getParent()->getRegionNumber(), argLattices);
182     }
183 
184     // Otherwise, we can't reason about the data-flow.
185     return visitNonControlFlowArgumentsImpl(block->getParentOp(),
186                                             RegionSuccessor(block->getParent()),
187                                             argLattices, /*firstIndex=*/0);
188   }
189 
190   // Iterate over the predecessors of the non-entry block.
191   for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
192        it != e; ++it) {
193     Block *predecessor = *it;
194 
195     // If the edge from the predecessor block to the current block is not live,
196     // bail out.
197     auto *edgeExecutable =
198         getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
199     edgeExecutable->blockContentSubscribe(this);
200     if (!edgeExecutable->isLive())
201       continue;
202 
203     // Check if we can reason about the data-flow from the predecessor.
204     if (auto branch =
205             dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
206       SuccessorOperands operands =
207           branch.getSuccessorOperands(it.getSuccessorIndex());
208       for (auto &it : llvm::enumerate(argLattices)) {
209         if (Value operand = operands[it.index()]) {
210           join(it.value(), *getLatticeElementFor(block, operand));
211         } else {
212           // Conservatively mark internally produced arguments as having reached
213           // their pessimistic fixpoint.
214           markAllPessimisticFixpoint(it.value());
215         }
216       }
217     } else {
218       return markAllPessimisticFixpoint(argLattices);
219     }
220   }
221 }
222 
visitRegionSuccessors(ProgramPoint point,RegionBranchOpInterface branch,Optional<unsigned> successorIndex,ArrayRef<AbstractSparseLattice * > lattices)223 void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
224     ProgramPoint point, RegionBranchOpInterface branch,
225     Optional<unsigned> successorIndex,
226     ArrayRef<AbstractSparseLattice *> lattices) {
227   const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
228   assert(predecessors->allPredecessorsKnown() &&
229          "unexpected unresolved region successors");
230 
231   for (Operation *op : predecessors->getKnownPredecessors()) {
232     // Get the incoming successor operands.
233     Optional<OperandRange> operands;
234 
235     // Check if the predecessor is the parent op.
236     if (op == branch) {
237       operands = branch.getSuccessorEntryOperands(successorIndex);
238       // Otherwise, try to deduce the operands from a region return-like op.
239     } else {
240       if (isRegionReturnLike(op))
241         operands = getRegionBranchSuccessorOperands(op, successorIndex);
242     }
243 
244     if (!operands) {
245       // We can't reason about the data-flow.
246       return markAllPessimisticFixpoint(lattices);
247     }
248 
249     ValueRange inputs = predecessors->getSuccessorInputs(op);
250     assert(inputs.size() == operands->size() &&
251            "expected the same number of successor inputs as operands");
252 
253     unsigned firstIndex = 0;
254     if (inputs.size() != lattices.size()) {
255       if (auto *op = point.dyn_cast<Operation *>()) {
256         if (!inputs.empty())
257           firstIndex = inputs.front().cast<OpResult>().getResultNumber();
258         visitNonControlFlowArgumentsImpl(
259             branch,
260             RegionSuccessor(
261                 branch->getResults().slice(firstIndex, inputs.size())),
262             lattices, firstIndex);
263       } else {
264         if (!inputs.empty())
265           firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
266         Region *region = point.get<Block *>()->getParent();
267         visitNonControlFlowArgumentsImpl(
268             branch,
269             RegionSuccessor(region, region->getArguments().slice(
270                                         firstIndex, inputs.size())),
271             lattices, firstIndex);
272       }
273     }
274 
275     for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
276       join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
277   }
278 }
279 
280 const AbstractSparseLattice *
getLatticeElementFor(ProgramPoint point,Value value)281 AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
282                                                      Value value) {
283   AbstractSparseLattice *state = getLatticeElement(value);
284   addDependency(state, point);
285   return state;
286 }
287 
markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice * > lattices)288 void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
289     ArrayRef<AbstractSparseLattice *> lattices) {
290   for (AbstractSparseLattice *lattice : lattices)
291     propagateIfChanged(lattice, lattice->markPessimisticFixpoint());
292 }
293 
join(AbstractSparseLattice * lhs,const AbstractSparseLattice & rhs)294 void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
295                                           const AbstractSparseLattice &rhs) {
296   propagateIfChanged(lhs, lhs->join(rhs));
297 }
298