1 //===- DataFlowAnalysis.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlowAnalysis.h"
10 #include "mlir/IR/Operation.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 
15 #include <queue>
16 
17 using namespace mlir;
18 using namespace mlir::detail;
19 
20 namespace {
21 /// This class contains various state used when computing the lattice elements
22 /// of a callable operation.
23 class CallableLatticeState {
24 public:
25   /// Build a lattice state with a given callable region, and a specified number
26   /// of results to be initialized to the default lattice element.
27   CallableLatticeState(ForwardDataFlowAnalysisBase &analysis,
28                        Region *callableRegion, unsigned numResults)
29       : callableArguments(callableRegion->getArguments()),
30         resultLatticeElements(numResults) {
31     for (AbstractLatticeElement *&it : resultLatticeElements)
32       it = analysis.createLatticeElement();
33   }
34 
35   /// Returns the arguments to the callable region.
36   Block::BlockArgListType getCallableArguments() const {
37     return callableArguments;
38   }
39 
40   /// Returns the lattice element for the results of the callable region.
41   auto getResultLatticeElements() {
42     return llvm::make_pointee_range(resultLatticeElements);
43   }
44 
45   /// Add a call to this callable. This is only used if the callable defines a
46   /// symbol.
47   void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }
48 
49   /// Return the calls that reference this callable. This is only used
50   /// if the callable defines a symbol.
51   ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }
52 
53 private:
54   /// The arguments of the callable region.
55   Block::BlockArgListType callableArguments;
56 
57   /// The lattice state for each of the results of this region. The return
58   /// values of the callable aren't SSA values, so we need to track them
59   /// separately.
60   SmallVector<AbstractLatticeElement *, 4> resultLatticeElements;
61 
62   /// The calls referencing this callable if this callable defines a symbol.
63   /// This removes the need to recompute symbol references during propagation.
64   /// Value based references are trivial to resolve, so they can be done
65   /// in-place.
66   SmallVector<Operation *, 4> symbolCalls;
67 };
68 
69 /// This class represents the solver for a forward dataflow analysis. This class
70 /// acts as the propagation engine for computing which lattice elements.
71 class ForwardDataFlowSolver {
72 public:
73   /// Initialize the solver with the given top-level operation.
74   ForwardDataFlowSolver(ForwardDataFlowAnalysisBase &analysis, Operation *op);
75 
76   /// Run the solver until it converges.
77   void solve();
78 
79 private:
80   /// Initialize the set of symbol defining callables that can have their
81   /// arguments and results tracked. 'op' is the top-level operation that the
82   /// solver is operating on.
83   void initializeSymbolCallables(Operation *op);
84 
85   /// Visit the users of the given IR that reside within executable blocks.
86   template <typename T>
87   void visitUsers(T &value) {
88     for (Operation *user : value.getUsers())
89       if (isBlockExecutable(user->getBlock()))
90         visitOperation(user);
91   }
92 
93   /// Visit the given operation and compute any necessary lattice state.
94   void visitOperation(Operation *op);
95 
96   /// Visit the given call operation and compute any necessary lattice state.
97   void visitCallOperation(CallOpInterface op);
98 
99   /// Visit the given callable operation and compute any necessary lattice
100   /// state.
101   void visitCallableOperation(Operation *op);
102 
103   /// Visit the given region branch operation, which defines regions, and
104   /// compute any necessary lattice state. This also resolves the lattice state
105   /// of both the operation results and any nested regions.
106   void visitRegionBranchOperation(
107       RegionBranchOpInterface branch,
108       ArrayRef<AbstractLatticeElement *> operandLattices);
109 
110   /// Visit the given set of region successors, computing any necessary lattice
111   /// state. The provided function returns the input operands to the region at
112   /// the given index. If the index is 'None', the input operands correspond to
113   /// the parent operation results.
114   void visitRegionSuccessors(
115       Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
116       function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
117 
118   /// Visit the given terminator operation and compute any necessary lattice
119   /// state.
120   void
121   visitTerminatorOperation(Operation *op,
122                            ArrayRef<AbstractLatticeElement *> operandLattices);
123 
124   /// Visit the given terminator operation that exits a callable region. These
125   /// are terminators with no CFG successors.
126   void visitCallableTerminatorOperation(
127       Operation *callable, Operation *terminator,
128       ArrayRef<AbstractLatticeElement *> operandLattices);
129 
130   /// Visit the given block and compute any necessary lattice state.
131   void visitBlock(Block *block);
132 
133   /// Visit argument #'i' of the given block and compute any necessary lattice
134   /// state.
135   void visitBlockArgument(Block *block, int i);
136 
137   /// Mark the entry block of the given region as executable. Returns NoChange
138   /// if the block was already marked executable. If `markPessimisticFixpoint`
139   /// is true, the arguments of the entry block are also marked as having
140   /// reached the pessimistic fixpoint.
141   ChangeResult markEntryBlockExecutable(Region *region,
142                                         bool markPessimisticFixpoint);
143 
144   /// Mark the given block as executable. Returns NoChange if the block was
145   /// already marked executable.
146   ChangeResult markBlockExecutable(Block *block);
147 
148   /// Returns true if the given block is executable.
149   bool isBlockExecutable(Block *block) const;
150 
151   /// Mark the edge between 'from' and 'to' as executable.
152   void markEdgeExecutable(Block *from, Block *to);
153 
154   /// Return true if the edge between 'from' and 'to' is executable.
155   bool isEdgeExecutable(Block *from, Block *to) const;
156 
157   /// Mark the given value as having reached the pessimistic fixpoint. This
158   /// means that we cannot further refine the state of this value.
159   void markPessimisticFixpoint(Value value);
160 
161   /// Mark all of the given values as having reaching the pessimistic fixpoint.
162   template <typename ValuesT>
163   void markAllPessimisticFixpoint(ValuesT values) {
164     for (auto value : values)
165       markPessimisticFixpoint(value);
166   }
167   template <typename ValuesT>
168   void markAllPessimisticFixpoint(Operation *op, ValuesT values) {
169     markAllPessimisticFixpoint(values);
170     opWorklist.push(op);
171   }
172   template <typename ValuesT>
173   void markAllPessimisticFixpointAndVisitUsers(ValuesT values) {
174     for (auto value : values) {
175       AbstractLatticeElement &lattice = analysis.getLatticeElement(value);
176       if (lattice.markPessimisticFixpoint() == ChangeResult::Change)
177         visitUsers(value);
178     }
179   }
180 
181   /// Returns true if the given value was marked as having reached the
182   /// pessimistic fixpoint.
183   bool isAtFixpoint(Value value) const;
184 
185   /// Merge in the given lattice 'from' into the lattice 'to'. 'owner'
186   /// corresponds to the parent operation of the lattice for 'to'.
187   void join(Operation *owner, AbstractLatticeElement &to,
188             const AbstractLatticeElement &from);
189 
190   /// A reference to the dataflow analysis being computed.
191   ForwardDataFlowAnalysisBase &analysis;
192 
193   /// The set of blocks that are known to execute, or are intrinsically live.
194   SmallPtrSet<Block *, 16> executableBlocks;
195 
196   /// The set of control flow edges that are known to execute.
197   DenseSet<std::pair<Block *, Block *>> executableEdges;
198 
199   /// A worklist containing blocks that need to be processed.
200   std::queue<Block *> blockWorklist;
201 
202   /// A worklist of operations that need to be processed.
203   std::queue<Operation *> opWorklist;
204 
205   /// The callable operations that have their argument/result state tracked.
206   DenseMap<Operation *, CallableLatticeState> callableLatticeState;
207 
208   /// A map between a call operation and the resolved symbol callable. This
209   /// avoids re-resolving symbol references during propagation. Value based
210   /// callables are trivial to resolve, so they can be done in-place.
211   DenseMap<Operation *, Operation *> callToSymbolCallable;
212 
213   /// A symbol table used for O(1) symbol lookups during simplification.
214   SymbolTableCollection symbolTable;
215 };
216 } // namespace
217 
218 ForwardDataFlowSolver::ForwardDataFlowSolver(
219     ForwardDataFlowAnalysisBase &analysis, Operation *op)
220     : analysis(analysis) {
221   /// Initialize the solver with the regions within this operation.
222   for (Region &region : op->getRegions()) {
223     // Mark the entry block as executable. The values passed to these regions
224     // are also invisible, so mark any arguments as reaching the pessimistic
225     // fixpoint.
226     markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
227   }
228   initializeSymbolCallables(op);
229 }
230 
231 void ForwardDataFlowSolver::solve() {
232   while (!blockWorklist.empty() || !opWorklist.empty()) {
233     // Process any operations in the op worklist.
234     while (!opWorklist.empty()) {
235       Operation *nextOp = opWorklist.front();
236       opWorklist.pop();
237       visitUsers(*nextOp);
238     }
239 
240     // Process any blocks in the block worklist.
241     while (!blockWorklist.empty()) {
242       Block *nextBlock = blockWorklist.front();
243       blockWorklist.pop();
244       visitBlock(nextBlock);
245     }
246   }
247 }
248 
249 void ForwardDataFlowSolver::initializeSymbolCallables(Operation *op) {
250   // Initialize the set of symbol callables that can have their state tracked.
251   // This tracks which symbol callable operations we can propagate within and
252   // out of.
253   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
254     Region &symbolTableRegion = symTable->getRegion(0);
255     Block *symbolTableBlock = &symbolTableRegion.front();
256     for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
257       // We won't be able to track external callables.
258       Region *callableRegion = callable.getCallableRegion();
259       if (!callableRegion)
260         continue;
261       // We only care about symbol defining callables here.
262       auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
263       if (!symbol)
264         continue;
265       callableLatticeState.try_emplace(callable, analysis, callableRegion,
266                                        callable.getCallableResults().size());
267 
268       // If not all of the uses of this symbol are visible, we can't track the
269       // state of the arguments.
270       if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
271         for (Region &region : callable->getRegions())
272           markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
273       }
274     }
275     if (callableLatticeState.empty())
276       return;
277 
278     // After computing the valid callables, walk any symbol uses to check
279     // for non-call references. We won't be able to track the lattice state
280     // for arguments to these callables, as we can't guarantee that we can see
281     // all of its calls.
282     Optional<SymbolTable::UseRange> uses =
283         SymbolTable::getSymbolUses(&symbolTableRegion);
284     if (!uses) {
285       // If we couldn't gather the symbol uses, conservatively assume that
286       // we can't track information for any nested symbols.
287       op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
288       return;
289     }
290 
291     for (const SymbolTable::SymbolUse &use : *uses) {
292       // If the use is a call, track it to avoid the need to recompute the
293       // reference later.
294       if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
295         Operation *symCallable = callOp.resolveCallable(&symbolTable);
296         auto callableLatticeIt = callableLatticeState.find(symCallable);
297         if (callableLatticeIt != callableLatticeState.end()) {
298           callToSymbolCallable.try_emplace(callOp, symCallable);
299 
300           // We only need to record the call in the lattice if it produces any
301           // values.
302           if (callOp->getNumResults())
303             callableLatticeIt->second.addSymbolCall(callOp);
304         }
305         continue;
306       }
307       // This use isn't a call, so don't we know all of the callers.
308       auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
309       auto it = callableLatticeState.find(symbol);
310       if (it != callableLatticeState.end()) {
311         for (Region &region : it->first->getRegions())
312           markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
313       }
314     }
315   };
316   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
317                                 walkFn);
318 }
319 
320 void ForwardDataFlowSolver::visitOperation(Operation *op) {
321   // Collect all of the lattice elements feeding into this operation. If any are
322   // not yet resolved, bail out and wait for them to resolve.
323   SmallVector<AbstractLatticeElement *, 8> operandLattices;
324   operandLattices.reserve(op->getNumOperands());
325   for (Value operand : op->getOperands()) {
326     AbstractLatticeElement *operandLattice =
327         analysis.lookupLatticeElement(operand);
328     if (!operandLattice || operandLattice->isUninitialized())
329       return;
330     operandLattices.push_back(operandLattice);
331   }
332 
333   // If this is a terminator operation, process any control flow lattice state.
334   if (op->hasTrait<OpTrait::IsTerminator>())
335     visitTerminatorOperation(op, operandLattices);
336 
337   // Process call operations. The call visitor processes result values, so we
338   // can exit afterwards.
339   if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
340     return visitCallOperation(call);
341 
342   // Process callable operations. These are specially handled region operations
343   // that track dataflow via calls.
344   if (isa<CallableOpInterface>(op)) {
345     // If this callable has a tracked lattice state, it will be visited by calls
346     // that reference it instead. This way, we don't assume that it is
347     // executable unless there is a proper reference to it.
348     if (callableLatticeState.count(op))
349       return;
350     return visitCallableOperation(op);
351   }
352 
353   // Process region holding operations.
354   if (op->getNumRegions()) {
355     // Check to see if we can reason about the internal control flow of this
356     // region operation.
357     if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
358       return visitRegionBranchOperation(branch, operandLattices);
359 
360     // If we can't, conservatively mark all regions as executable.
361     // TODO: Let the `visitOperation` method decide how to propagate
362     // information to the block arguments.
363     for (Region &region : op->getRegions())
364       markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
365   }
366 
367   // If this op produces no results, it can't produce any constants.
368   if (op->getNumResults() == 0)
369     return;
370 
371   // If all of the results of this operation are already resolved, bail out
372   // early.
373   auto isAtFixpointFn = [&](Value value) { return isAtFixpoint(value); };
374   if (llvm::all_of(op->getResults(), isAtFixpointFn))
375     return;
376 
377   // Visit the current operation.
378   if (analysis.visitOperation(op, operandLattices) == ChangeResult::Change)
379     opWorklist.push(op);
380 
381   // `visitOperation` is required to define all of the result lattices.
382   assert(llvm::none_of(
383              op->getResults(),
384              [&](Value value) {
385                return analysis.getLatticeElement(value).isUninitialized();
386              }) &&
387          "expected `visitOperation` to define all result lattices");
388 }
389 
390 void ForwardDataFlowSolver::visitCallableOperation(Operation *op) {
391   // Mark the regions as executable. If we aren't tracking lattice state for
392   // this callable, mark all of the region arguments as having reached a
393   // fixpoint.
394   bool isTrackingLatticeState = callableLatticeState.count(op);
395   for (Region &region : op->getRegions())
396     markEntryBlockExecutable(&region, !isTrackingLatticeState);
397 
398   // TODO: Add support for non-symbol callables when necessary. If the callable
399   // has non-call uses we would mark as having reached pessimistic fixpoint,
400   // otherwise allow for propagating the return values out.
401   markAllPessimisticFixpoint(op, op->getResults());
402 }
403 
404 void ForwardDataFlowSolver::visitCallOperation(CallOpInterface op) {
405   ResultRange callResults = op->getResults();
406 
407   // Resolve the callable operation for this call.
408   Operation *callableOp = nullptr;
409   if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
410     callableOp = callableValue.getDefiningOp();
411   else
412     callableOp = callToSymbolCallable.lookup(op);
413 
414   // The callable of this call can't be resolved, mark any results overdefined.
415   if (!callableOp)
416     return markAllPessimisticFixpoint(op, callResults);
417 
418   // If this callable is tracking state, merge the argument operands with the
419   // arguments of the callable.
420   auto callableLatticeIt = callableLatticeState.find(callableOp);
421   if (callableLatticeIt == callableLatticeState.end())
422     return markAllPessimisticFixpoint(op, callResults);
423 
424   OperandRange callOperands = op.getArgOperands();
425   auto callableArgs = callableLatticeIt->second.getCallableArguments();
426   for (auto it : llvm::zip(callOperands, callableArgs)) {
427     BlockArgument callableArg = std::get<1>(it);
428     AbstractLatticeElement &argValue = analysis.getLatticeElement(callableArg);
429     AbstractLatticeElement &operandValue =
430         analysis.getLatticeElement(std::get<0>(it));
431     if (argValue.join(operandValue) == ChangeResult::Change)
432       visitUsers(callableArg);
433   }
434 
435   // Visit the callable.
436   visitCallableOperation(callableOp);
437 
438   // Merge in the lattice state for the callable results as well.
439   auto callableResults = callableLatticeIt->second.getResultLatticeElements();
440   for (auto it : llvm::zip(callResults, callableResults))
441     join(/*owner=*/op,
442          /*to=*/analysis.getLatticeElement(std::get<0>(it)),
443          /*from=*/std::get<1>(it));
444 }
445 
446 void ForwardDataFlowSolver::visitRegionBranchOperation(
447     RegionBranchOpInterface branch,
448     ArrayRef<AbstractLatticeElement *> operandLattices) {
449   // Check to see which regions are executable.
450   SmallVector<RegionSuccessor, 1> successors;
451   analysis.getSuccessorsForOperands(branch, /*sourceIndex=*/llvm::None,
452                                     operandLattices, successors);
453 
454   // If the interface identified that no region will be executed. Mark
455   // any results of this operation as overdefined, as we can't reason about
456   // them.
457   // TODO: If we had an interface to detect pass through operands, we could
458   // resolve some results based on the lattice state of the operands. We could
459   // also allow for the parent operation to have itself as a region successor.
460   if (successors.empty())
461     return markAllPessimisticFixpoint(branch, branch->getResults());
462   return visitRegionSuccessors(
463       branch, successors, [&](Optional<unsigned> index) {
464         assert(index && "expected valid region index");
465         return branch.getSuccessorEntryOperands(*index);
466       });
467 }
468 
469 void ForwardDataFlowSolver::visitRegionSuccessors(
470     Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
471     function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
472   for (const RegionSuccessor &it : regionSuccessors) {
473     Region *region = it.getSuccessor();
474     ValueRange succArgs = it.getSuccessorInputs();
475 
476     // Check to see if this is the parent operation.
477     if (!region) {
478       ResultRange results = parentOp->getResults();
479       if (llvm::all_of(results, [&](Value res) { return isAtFixpoint(res); }))
480         continue;
481 
482       // Mark the results outside of the input range as having reached the
483       // pessimistic fixpoint.
484       // TODO: This isn't exactly ideal. There may be situations in which a
485       // region operation can provide information for certain results that
486       // aren't part of the control flow.
487       if (succArgs.size() != results.size()) {
488         opWorklist.push(parentOp);
489         if (succArgs.empty()) {
490           markAllPessimisticFixpoint(results);
491           continue;
492         }
493 
494         unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber();
495         markAllPessimisticFixpoint(results.take_front(firstResIdx));
496         markAllPessimisticFixpoint(
497             results.drop_front(firstResIdx + succArgs.size()));
498       }
499 
500       // Update the lattice for any operation results.
501       OperandRange operands = getInputsForRegion(/*index=*/llvm::None);
502       for (auto it : llvm::zip(succArgs, operands))
503         join(parentOp, analysis.getLatticeElement(std::get<0>(it)),
504              analysis.getLatticeElement(std::get<1>(it)));
505       continue;
506     }
507     assert(!region->empty() && "expected region to be non-empty");
508     Block *entryBlock = &region->front();
509     markBlockExecutable(entryBlock);
510 
511     // If all of the arguments have already reached a fixpoint, the arguments
512     // have already been fully resolved.
513     Block::BlockArgListType arguments = entryBlock->getArguments();
514     if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); }))
515       continue;
516 
517     // Mark any arguments that do not receive inputs as having reached a
518     // pessimistic fixpoint, we won't be able to discern if they are constant.
519     // TODO: This isn't exactly ideal. There may be situations in which a
520     // region operation can provide information for certain results that
521     // aren't part of the control flow.
522     if (succArgs.size() != arguments.size()) {
523       if (succArgs.empty()) {
524         markAllPessimisticFixpoint(arguments);
525         continue;
526       }
527 
528       unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
529       markAllPessimisticFixpointAndVisitUsers(
530           arguments.take_front(firstArgIdx));
531       markAllPessimisticFixpointAndVisitUsers(
532           arguments.drop_front(firstArgIdx + succArgs.size()));
533     }
534 
535     // Update the lattice of arguments that have inputs from the predecessor.
536     OperandRange succOperands = getInputsForRegion(region->getRegionNumber());
537     for (auto it : llvm::zip(succArgs, succOperands)) {
538       AbstractLatticeElement &argValue =
539           analysis.getLatticeElement(std::get<0>(it));
540       AbstractLatticeElement &operandValue =
541           analysis.getLatticeElement(std::get<1>(it));
542       if (argValue.join(operandValue) == ChangeResult::Change)
543         visitUsers(std::get<0>(it));
544     }
545   }
546 }
547 
548 void ForwardDataFlowSolver::visitTerminatorOperation(
549     Operation *op, ArrayRef<AbstractLatticeElement *> operandLattices) {
550   // If this operation has no successors, we treat it as an exiting terminator.
551   if (op->getNumSuccessors() == 0) {
552     Region *parentRegion = op->getParentRegion();
553     Operation *parentOp = parentRegion->getParentOp();
554 
555     // Check to see if this is a terminator for a callable region.
556     if (isa<CallableOpInterface>(parentOp))
557       return visitCallableTerminatorOperation(parentOp, op, operandLattices);
558 
559     // Otherwise, check to see if the parent tracks region control flow.
560     auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
561     if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
562       return;
563 
564     // Query the set of successors of the current region using the current
565     // optimistic lattice state.
566     SmallVector<RegionSuccessor, 1> regionSuccessors;
567     analysis.getSuccessorsForOperands(regionInterface,
568                                       parentRegion->getRegionNumber(),
569                                       operandLattices, regionSuccessors);
570     if (regionSuccessors.empty())
571       return;
572 
573     // Try to get "region-like" successor operands if possible in order to
574     // propagate the operand states to the successors.
575     if (isRegionReturnLike(op)) {
576       return visitRegionSuccessors(
577           parentOp, regionSuccessors, [&](Optional<unsigned> regionIndex) {
578             // Determine the individual region successor operands for the given
579             // region index (if any).
580             return *getRegionBranchSuccessorOperands(op, regionIndex);
581           });
582     }
583 
584     // If this terminator is not "region-like", conservatively mark all of the
585     // successor values as having reached the pessimistic fixpoint.
586     for (auto &it : regionSuccessors) {
587       // If the successor is a region, mark the entry block as executable so
588       // that we visit operations defined within. If the successor is the
589       // parent operation, we simply mark the control flow results as having
590       // reached the pessimistic state.
591       if (Region *region = it.getSuccessor())
592         markEntryBlockExecutable(region, /*markPessimisticFixpoint=*/true);
593       else
594         markAllPessimisticFixpointAndVisitUsers(it.getSuccessorInputs());
595     }
596   }
597 
598   // Try to resolve to a specific set of successors with the current optimistic
599   // lattice state.
600   Block *block = op->getBlock();
601   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
602     SmallVector<Block *> successors;
603     if (succeeded(analysis.getSuccessorsForOperands(branch, operandLattices,
604                                                     successors))) {
605       for (Block *succ : successors)
606         markEdgeExecutable(block, succ);
607       return;
608     }
609   }
610 
611   // Otherwise, conservatively treat all edges as executable.
612   for (Block *succ : op->getSuccessors())
613     markEdgeExecutable(block, succ);
614 }
615 
616 void ForwardDataFlowSolver::visitCallableTerminatorOperation(
617     Operation *callable, Operation *terminator,
618     ArrayRef<AbstractLatticeElement *> operandLattices) {
619   // If there are no exiting values, we have nothing to track.
620   if (terminator->getNumOperands() == 0)
621     return;
622 
623   // If this callable isn't tracking any lattice state there is nothing to do.
624   auto latticeIt = callableLatticeState.find(callable);
625   if (latticeIt == callableLatticeState.end())
626     return;
627   assert(callable->getNumResults() == 0 && "expected symbol callable");
628 
629   // If this terminator is not "return-like", conservatively mark all of the
630   // call-site results as having reached the pessimistic fixpoint.
631   auto callableResultLattices = latticeIt->second.getResultLatticeElements();
632   if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
633     for (auto &it : callableResultLattices)
634       it.markPessimisticFixpoint();
635     for (Operation *call : latticeIt->second.getSymbolCalls())
636       markAllPessimisticFixpoint(call, call->getResults());
637     return;
638   }
639 
640   // Merge the lattice state for terminator operands into the results.
641   ChangeResult result = ChangeResult::NoChange;
642   for (auto it : llvm::zip(operandLattices, callableResultLattices))
643     result |= std::get<1>(it).join(*std::get<0>(it));
644   if (result == ChangeResult::NoChange)
645     return;
646 
647   // If any of the result lattices changed, update the callers.
648   for (Operation *call : latticeIt->second.getSymbolCalls())
649     for (auto it : llvm::zip(call->getResults(), callableResultLattices))
650       join(call, analysis.getLatticeElement(std::get<0>(it)), std::get<1>(it));
651 }
652 
653 void ForwardDataFlowSolver::visitBlock(Block *block) {
654   // If the block is not the entry block we need to compute the lattice state
655   // for the block arguments. Entry block argument lattices are computed
656   // elsewhere, such as when visiting the parent operation.
657   if (!block->isEntryBlock()) {
658     for (int i : llvm::seq<int>(0, block->getNumArguments()))
659       visitBlockArgument(block, i);
660   }
661 
662   // Visit all of the operations within the block.
663   for (Operation &op : *block)
664     visitOperation(&op);
665 }
666 
667 void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
668   BlockArgument arg = block->getArgument(i);
669   AbstractLatticeElement &argLattice = analysis.getLatticeElement(arg);
670   if (argLattice.isAtFixpoint())
671     return;
672 
673   ChangeResult updatedLattice = ChangeResult::NoChange;
674   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
675     Block *pred = *it;
676 
677     // We only care about this predecessor if it is going to execute.
678     if (!isEdgeExecutable(pred, block))
679       continue;
680 
681     // Try to get the operand forwarded by the predecessor. If we can't reason
682     // about the terminator of the predecessor, mark as having reached a
683     // fixpoint.
684     auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
685     if (!branch) {
686       updatedLattice |= argLattice.markPessimisticFixpoint();
687       break;
688     }
689     Value operand = branch.getSuccessorOperands(it.getSuccessorIndex())[i];
690     if (!operand) {
691       updatedLattice |= argLattice.markPessimisticFixpoint();
692       break;
693     }
694 
695     // If the operand hasn't been resolved, it is uninitialized and can merge
696     // with anything.
697     AbstractLatticeElement *operandLattice =
698         analysis.lookupLatticeElement(operand);
699     if (!operandLattice)
700       continue;
701 
702     // Otherwise, join the operand lattice into the argument lattice.
703     updatedLattice |= argLattice.join(*operandLattice);
704     if (argLattice.isAtFixpoint())
705       break;
706   }
707 
708   // If the lattice changed, visit users of the argument.
709   if (updatedLattice == ChangeResult::Change)
710     visitUsers(arg);
711 }
712 
713 ChangeResult
714 ForwardDataFlowSolver::markEntryBlockExecutable(Region *region,
715                                                 bool markPessimisticFixpoint) {
716   if (!region->empty()) {
717     if (markPessimisticFixpoint)
718       markAllPessimisticFixpoint(region->front().getArguments());
719     return markBlockExecutable(&region->front());
720   }
721   return ChangeResult::NoChange;
722 }
723 
724 ChangeResult ForwardDataFlowSolver::markBlockExecutable(Block *block) {
725   bool marked = executableBlocks.insert(block).second;
726   if (marked)
727     blockWorklist.push(block);
728   return marked ? ChangeResult::Change : ChangeResult::NoChange;
729 }
730 
731 bool ForwardDataFlowSolver::isBlockExecutable(Block *block) const {
732   return executableBlocks.count(block);
733 }
734 
735 void ForwardDataFlowSolver::markEdgeExecutable(Block *from, Block *to) {
736   executableEdges.insert(std::make_pair(from, to));
737 
738   // Mark the destination as executable, and reprocess its arguments if it was
739   // already executable.
740   if (markBlockExecutable(to) == ChangeResult::NoChange) {
741     for (int i : llvm::seq<int>(0, to->getNumArguments()))
742       visitBlockArgument(to, i);
743   }
744 }
745 
746 bool ForwardDataFlowSolver::isEdgeExecutable(Block *from, Block *to) const {
747   return executableEdges.count(std::make_pair(from, to));
748 }
749 
750 void ForwardDataFlowSolver::markPessimisticFixpoint(Value value) {
751   analysis.getLatticeElement(value).markPessimisticFixpoint();
752 }
753 
754 bool ForwardDataFlowSolver::isAtFixpoint(Value value) const {
755   if (auto *lattice = analysis.lookupLatticeElement(value))
756     return lattice->isAtFixpoint();
757   return false;
758 }
759 
760 void ForwardDataFlowSolver::join(Operation *owner, AbstractLatticeElement &to,
761                                  const AbstractLatticeElement &from) {
762   if (to.join(from) == ChangeResult::Change)
763     opWorklist.push(owner);
764 }
765 
766 //===----------------------------------------------------------------------===//
767 // AbstractLatticeElement
768 //===----------------------------------------------------------------------===//
769 
770 AbstractLatticeElement::~AbstractLatticeElement() = default;
771 
772 //===----------------------------------------------------------------------===//
773 // ForwardDataFlowAnalysisBase
774 //===----------------------------------------------------------------------===//
775 
776 ForwardDataFlowAnalysisBase::~ForwardDataFlowAnalysisBase() = default;
777 
778 AbstractLatticeElement &
779 ForwardDataFlowAnalysisBase::getLatticeElement(Value value) {
780   AbstractLatticeElement *&latticeValue = latticeValues[value];
781   if (!latticeValue)
782     latticeValue = createLatticeElement(value);
783   return *latticeValue;
784 }
785 
786 AbstractLatticeElement *
787 ForwardDataFlowAnalysisBase::lookupLatticeElement(Value value) {
788   return latticeValues.lookup(value);
789 }
790 
791 void ForwardDataFlowAnalysisBase::run(Operation *topLevelOp) {
792   // Run the main dataflow solver.
793   ForwardDataFlowSolver solver(*this, topLevelOp);
794   solver.solve();
795 
796   // Any values that are still uninitialized now go to a pessimistic fixpoint,
797   // otherwise we assume an optimistic fixpoint has been reached.
798   for (auto &it : latticeValues)
799     if (it.second->isUninitialized())
800       it.second->markPessimisticFixpoint();
801     else
802       it.second->markOptimisticFixpoint();
803 }
804