1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 
14 using namespace mlir;
15 using namespace mlir::dataflow;
16 
17 //===----------------------------------------------------------------------===//
18 // AbstractSparseLattice
19 //===----------------------------------------------------------------------===//
20 
21 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
22   // Push all users of the value to the queue.
23   for (Operation *user : point.get<Value>().getUsers())
24     for (DataFlowAnalysis *analysis : useDefSubscribers)
25       solver->enqueue({user, analysis});
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // AbstractSparseDataFlowAnalysis
30 //===----------------------------------------------------------------------===//
31 
32 AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis(
33     DataFlowSolver &solver)
34     : DataFlowAnalysis(solver) {
35   registerPointKind<CFGEdge>();
36 }
37 
38 LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
39   // Mark the entry block arguments as having reached their pessimistic
40   // fixpoints.
41   for (Region &region : top->getRegions()) {
42     if (region.empty())
43       continue;
44     for (Value argument : region.front().getArguments())
45       markAllPessimisticFixpoint(getLatticeElement(argument));
46   }
47 
48   return initializeRecursively(top);
49 }
50 
51 LogicalResult
52 AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
53   // Initialize the analysis by visiting every owner of an SSA value (all
54   // operations and blocks).
55   visitOperation(op);
56   for (Region &region : op->getRegions()) {
57     for (Block &block : region) {
58       getOrCreate<Executable>(&block)->blockContentSubscribe(this);
59       visitBlock(&block);
60       for (Operation &op : block)
61         if (failed(initializeRecursively(&op)))
62           return failure();
63     }
64   }
65 
66   return success();
67 }
68 
69 LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
70   if (Operation *op = point.dyn_cast<Operation *>())
71     visitOperation(op);
72   else if (Block *block = point.dyn_cast<Block *>())
73     visitBlock(block);
74   else
75     return failure();
76   return success();
77 }
78 
79 void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
80   // Exit early on operations with no results.
81   if (op->getNumResults() == 0)
82     return;
83 
84   // If the containing block is not executable, bail out.
85   if (!getOrCreate<Executable>(op->getBlock())->isLive())
86     return;
87 
88   // Get the result lattices.
89   SmallVector<AbstractSparseLattice *> resultLattices;
90   resultLattices.reserve(op->getNumResults());
91   // Track whether all results have reached their fixpoint.
92   bool allAtFixpoint = true;
93   for (Value result : op->getResults()) {
94     AbstractSparseLattice *resultLattice = getLatticeElement(result);
95     allAtFixpoint &= resultLattice->isAtFixpoint();
96     resultLattices.push_back(resultLattice);
97   }
98   // If all result lattices have reached a fixpoint, there is nothing to do.
99   if (allAtFixpoint)
100     return;
101 
102   // The results of a region branch operation are determined by control-flow.
103   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
104     return visitRegionSuccessors({branch}, branch,
105                                  /*successorIndex=*/llvm::None, resultLattices);
106   }
107 
108   // The results of a call operation are determined by the callgraph.
109   if (auto call = dyn_cast<CallOpInterface>(op)) {
110     const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
111     // If not all return sites are known, then conservatively assume we can't
112     // reason about the data-flow.
113     if (!predecessors->allPredecessorsKnown())
114       return markAllPessimisticFixpoint(resultLattices);
115     for (Operation *predecessor : predecessors->getKnownPredecessors())
116       for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
117         join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
118     return;
119   }
120 
121   // Grab the lattice elements of the operands.
122   SmallVector<const AbstractSparseLattice *> operandLattices;
123   operandLattices.reserve(op->getNumOperands());
124   for (Value operand : op->getOperands()) {
125     AbstractSparseLattice *operandLattice = getLatticeElement(operand);
126     operandLattice->useDefSubscribe(this);
127     // If any of the operand states are not initialized, bail out.
128     if (operandLattice->isUninitialized())
129       return;
130     operandLattices.push_back(operandLattice);
131   }
132 
133   // Invoke the operation transfer function.
134   visitOperationImpl(op, operandLattices, resultLattices);
135 }
136 
137 void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
138   // Exit early on blocks with no arguments.
139   if (block->getNumArguments() == 0)
140     return;
141 
142   // If the block is not executable, bail out.
143   if (!getOrCreate<Executable>(block)->isLive())
144     return;
145 
146   // Get the argument lattices.
147   SmallVector<AbstractSparseLattice *> argLattices;
148   argLattices.reserve(block->getNumArguments());
149   bool allAtFixpoint = true;
150   for (BlockArgument argument : block->getArguments()) {
151     AbstractSparseLattice *argLattice = getLatticeElement(argument);
152     allAtFixpoint &= argLattice->isAtFixpoint();
153     argLattices.push_back(argLattice);
154   }
155   // If all argument lattices have reached their fixpoints, then there is
156   // nothing to do.
157   if (allAtFixpoint)
158     return;
159 
160   // The argument lattices of entry blocks are set by region control-flow or the
161   // callgraph.
162   if (block->isEntryBlock()) {
163     // Check if this block is the entry block of a callable region.
164     auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
165     if (callable && callable.getCallableRegion() == block->getParent()) {
166       const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
167       // If not all callsites are known, conservatively mark all lattices as
168       // having reached their pessimistic fixpoints.
169       if (!callsites->allPredecessorsKnown())
170         return markAllPessimisticFixpoint(argLattices);
171       for (Operation *callsite : callsites->getKnownPredecessors()) {
172         auto call = cast<CallOpInterface>(callsite);
173         for (auto it : llvm::zip(call.getArgOperands(), argLattices))
174           join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
175       }
176       return;
177     }
178 
179     // Check if the lattices can be determined from region control flow.
180     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
181       return visitRegionSuccessors(
182           block, branch, block->getParent()->getRegionNumber(), argLattices);
183     }
184 
185     // Otherwise, we can't reason about the data-flow.
186     return markAllPessimisticFixpoint(argLattices);
187   }
188 
189   // Iterate over the predecessors of the non-entry block.
190   for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
191        it != e; ++it) {
192     Block *predecessor = *it;
193 
194     // If the edge from the predecessor block to the current block is not live,
195     // bail out.
196     auto *edgeExecutable =
197         getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
198     edgeExecutable->blockContentSubscribe(this);
199     if (!edgeExecutable->isLive())
200       continue;
201 
202     // Check if we can reason about the data-flow from the predecessor.
203     if (auto branch =
204             dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
205       SuccessorOperands operands =
206           branch.getSuccessorOperands(it.getSuccessorIndex());
207       for (auto &it : llvm::enumerate(argLattices)) {
208         if (Value operand = operands[it.index()]) {
209           join(it.value(), *getLatticeElementFor(block, operand));
210         } else {
211           // Conservatively mark internally produced arguments as having reached
212           // their pessimistic fixpoint.
213           markAllPessimisticFixpoint(it.value());
214         }
215       }
216     } else {
217       return markAllPessimisticFixpoint(argLattices);
218     }
219   }
220 }
221 
222 void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
223     ProgramPoint point, RegionBranchOpInterface branch,
224     Optional<unsigned> successorIndex,
225     ArrayRef<AbstractSparseLattice *> lattices) {
226   const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
227   assert(predecessors->allPredecessorsKnown() &&
228          "unexpected unresolved region successors");
229 
230   for (Operation *op : predecessors->getKnownPredecessors()) {
231     // Get the incoming successor operands.
232     Optional<OperandRange> operands;
233 
234     // Check if the predecessor is the parent op.
235     if (op == branch) {
236       operands = branch.getSuccessorEntryOperands(successorIndex);
237       // Otherwise, try to deduce the operands from a region return-like op.
238     } else {
239       assert(op->hasTrait<OpTrait::IsTerminator>() && "expected a terminator");
240       if (isRegionReturnLike(op))
241         operands = getRegionBranchSuccessorOperands(op, successorIndex);
242     }
243 
244     if (!operands) {
245       // We can't reason about the data-flow.
246       return markAllPessimisticFixpoint(lattices);
247     }
248 
249     ValueRange inputs = predecessors->getSuccessorInputs(op);
250     assert(inputs.size() == operands->size() &&
251            "expected the same number of successor inputs as operands");
252 
253     // TODO: This was updated to be exposed upstream.
254     unsigned firstIndex = 0;
255     if (inputs.size() != lattices.size()) {
256       if (inputs.empty()) {
257         markAllPessimisticFixpoint(lattices);
258         return;
259       }
260       firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
261       markAllPessimisticFixpoint(lattices.take_front(firstIndex));
262       markAllPessimisticFixpoint(
263           lattices.drop_front(firstIndex + inputs.size()));
264     }
265 
266     for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
267       join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
268   }
269 }
270 
271 const AbstractSparseLattice *
272 AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
273                                                      Value value) {
274   AbstractSparseLattice *state = getLatticeElement(value);
275   addDependency(state, point);
276   return state;
277 }
278 
279 void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
280     ArrayRef<AbstractSparseLattice *> lattices) {
281   for (AbstractSparseLattice *lattice : lattices)
282     propagateIfChanged(lattice, lattice->markPessimisticFixpoint());
283 }
284 
285 void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
286                                           const AbstractSparseLattice &rhs) {
287   propagateIfChanged(lhs, lhs->join(rhs));
288 }
289