1 //===- DataFlowAnalysis.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlowAnalysis.h"
10 #include "mlir/IR/Operation.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "llvm/ADT/ArrayRef.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 
16 #include <queue>
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
21 namespace {
22 /// This class contains various state used when computing the lattice elements
23 /// of a callable operation.
24 class CallableLatticeState {
25 public:
26   /// Build a lattice state with a given callable region, and a specified number
27   /// of results to be initialized to the default lattice element.
28   CallableLatticeState(ForwardDataFlowAnalysisBase &analysis,
29                        Region *callableRegion, unsigned numResults)
30       : callableArguments(callableRegion->getArguments()),
31         resultLatticeElements(numResults) {
32     for (AbstractLatticeElement *&it : resultLatticeElements)
33       it = analysis.createLatticeElement();
34   }
35 
36   /// Returns the arguments to the callable region.
37   Block::BlockArgListType getCallableArguments() const {
38     return callableArguments;
39   }
40 
41   /// Returns the lattice element for the results of the callable region.
42   auto getResultLatticeElements() {
43     return llvm::make_pointee_range(resultLatticeElements);
44   }
45 
46   /// Add a call to this callable. This is only used if the callable defines a
47   /// symbol.
48   void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }
49 
50   /// Return the calls that reference this callable. This is only used
51   /// if the callable defines a symbol.
52   ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }
53 
54 private:
55   /// The arguments of the callable region.
56   Block::BlockArgListType callableArguments;
57 
58   /// The lattice state for each of the results of this region. The return
59   /// values of the callable aren't SSA values, so we need to track them
60   /// separately.
61   SmallVector<AbstractLatticeElement *, 4> resultLatticeElements;
62 
63   /// The calls referencing this callable if this callable defines a symbol.
64   /// This removes the need to recompute symbol references during propagation.
65   /// Value based references are trivial to resolve, so they can be done
66   /// in-place.
67   SmallVector<Operation *, 4> symbolCalls;
68 };
69 
70 /// This class represents the solver for a forward dataflow analysis. This class
71 /// acts as the propagation engine for computing which lattice elements.
72 class ForwardDataFlowSolver {
73 public:
74   /// Initialize the solver with the given top-level operation.
75   ForwardDataFlowSolver(ForwardDataFlowAnalysisBase &analysis, Operation *op);
76 
77   /// Run the solver until it converges.
78   void solve();
79 
80 private:
81   /// Initialize the set of symbol defining callables that can have their
82   /// arguments and results tracked. 'op' is the top-level operation that the
83   /// solver is operating on.
84   void initializeSymbolCallables(Operation *op);
85 
86   /// Visit the users of the given IR that reside within executable blocks.
87   template <typename T>
88   void visitUsers(T &value) {
89     for (Operation *user : value.getUsers())
90       if (isBlockExecutable(user->getBlock()))
91         visitOperation(user);
92   }
93 
94   /// Visit the given operation and compute any necessary lattice state.
95   void visitOperation(Operation *op);
96 
97   /// Visit the given call operation and compute any necessary lattice state.
98   void visitCallOperation(CallOpInterface op);
99 
100   /// Visit the given callable operation and compute any necessary lattice
101   /// state.
102   void visitCallableOperation(Operation *op);
103 
104   /// Visit the given region branch operation, which defines regions, and
105   /// compute any necessary lattice state. This also resolves the lattice state
106   /// of both the operation results and any nested regions.
107   void visitRegionBranchOperation(
108       RegionBranchOpInterface branch,
109       ArrayRef<AbstractLatticeElement *> operandLattices);
110 
111   /// Visit the given set of region successors, computing any necessary lattice
112   /// state. The provided function returns the input operands to the region at
113   /// the given index. If the index is 'None', the input operands correspond to
114   /// the parent operation results.
115   void visitRegionSuccessors(
116       Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
117       ArrayRef<AbstractLatticeElement *> operandLattices,
118       function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
119 
120   /// Visit the given terminator operation and compute any necessary lattice
121   /// state.
122   void
123   visitTerminatorOperation(Operation *op,
124                            ArrayRef<AbstractLatticeElement *> operandLattices);
125 
126   /// Visit the given terminator operation that exits a callable region. These
127   /// are terminators with no CFG successors.
128   void visitCallableTerminatorOperation(
129       Operation *callable, Operation *terminator,
130       ArrayRef<AbstractLatticeElement *> operandLattices);
131 
132   /// Visit the given block and compute any necessary lattice state.
133   void visitBlock(Block *block);
134 
135   /// Visit argument #'i' of the given block and compute any necessary lattice
136   /// state.
137   void visitBlockArgument(Block *block, int i);
138 
139   /// Mark the entry block of the given region as executable. Returns NoChange
140   /// if the block was already marked executable. If `markPessimisticFixpoint`
141   /// is true, the arguments of the entry block are also marked as having
142   /// reached the pessimistic fixpoint.
143   ChangeResult markEntryBlockExecutable(Region *region,
144                                         bool markPessimisticFixpoint);
145 
146   /// Mark the given block as executable. Returns NoChange if the block was
147   /// already marked executable.
148   ChangeResult markBlockExecutable(Block *block);
149 
150   /// Returns true if the given block is executable.
151   bool isBlockExecutable(Block *block) const;
152 
153   /// Mark the edge between 'from' and 'to' as executable.
154   void markEdgeExecutable(Block *from, Block *to);
155 
156   /// Return true if the edge between 'from' and 'to' is executable.
157   bool isEdgeExecutable(Block *from, Block *to) const;
158 
159   /// Mark the given value as having reached the pessimistic fixpoint. This
160   /// means that we cannot further refine the state of this value.
161   void markPessimisticFixpoint(Value value);
162 
163   /// Mark all of the given values as having reaching the pessimistic fixpoint.
164   template <typename ValuesT>
165   void markAllPessimisticFixpoint(ValuesT values) {
166     for (auto value : values)
167       markPessimisticFixpoint(value);
168   }
169   template <typename ValuesT>
170   void markAllPessimisticFixpoint(Operation *op, ValuesT values) {
171     markAllPessimisticFixpoint(values);
172     opWorklist.push(op);
173   }
174   template <typename ValuesT>
175   void markAllPessimisticFixpointAndVisitUsers(ValuesT values) {
176     for (auto value : values) {
177       AbstractLatticeElement &lattice = analysis.getLatticeElement(value);
178       if (lattice.markPessimisticFixpoint() == ChangeResult::Change)
179         visitUsers(value);
180     }
181   }
182 
183   /// Returns true if the given value was marked as having reached the
184   /// pessimistic fixpoint.
185   bool isAtFixpoint(Value value) const;
186 
187   /// Merge in the given lattice 'from' into the lattice 'to'. 'owner'
188   /// corresponds to the parent operation of the lattice for 'to'.
189   void join(Operation *owner, AbstractLatticeElement &to,
190             const AbstractLatticeElement &from);
191 
192   /// A reference to the dataflow analysis being computed.
193   ForwardDataFlowAnalysisBase &analysis;
194 
195   /// The set of blocks that are known to execute, or are intrinsically live.
196   SmallPtrSet<Block *, 16> executableBlocks;
197 
198   /// The set of control flow edges that are known to execute.
199   DenseSet<std::pair<Block *, Block *>> executableEdges;
200 
201   /// A worklist containing blocks that need to be processed.
202   std::queue<Block *> blockWorklist;
203 
204   /// A worklist of operations that need to be processed.
205   std::queue<Operation *> opWorklist;
206 
207   /// The callable operations that have their argument/result state tracked.
208   DenseMap<Operation *, CallableLatticeState> callableLatticeState;
209 
210   /// A map between a call operation and the resolved symbol callable. This
211   /// avoids re-resolving symbol references during propagation. Value based
212   /// callables are trivial to resolve, so they can be done in-place.
213   DenseMap<Operation *, Operation *> callToSymbolCallable;
214 
215   /// A symbol table used for O(1) symbol lookups during simplification.
216   SymbolTableCollection symbolTable;
217 };
218 } // namespace
219 
220 ForwardDataFlowSolver::ForwardDataFlowSolver(
221     ForwardDataFlowAnalysisBase &analysis, Operation *op)
222     : analysis(analysis) {
223   /// Initialize the solver with the regions within this operation.
224   for (Region &region : op->getRegions()) {
225     // Mark the entry block as executable. The values passed to these regions
226     // are also invisible, so mark any arguments as reaching the pessimistic
227     // fixpoint.
228     markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
229   }
230   initializeSymbolCallables(op);
231 }
232 
233 void ForwardDataFlowSolver::solve() {
234   while (!blockWorklist.empty() || !opWorklist.empty()) {
235     // Process any operations in the op worklist.
236     while (!opWorklist.empty()) {
237       Operation *nextOp = opWorklist.front();
238       opWorklist.pop();
239       visitUsers(*nextOp);
240     }
241 
242     // Process any blocks in the block worklist.
243     while (!blockWorklist.empty()) {
244       Block *nextBlock = blockWorklist.front();
245       blockWorklist.pop();
246       visitBlock(nextBlock);
247     }
248   }
249 }
250 
251 void ForwardDataFlowSolver::initializeSymbolCallables(Operation *op) {
252   // Initialize the set of symbol callables that can have their state tracked.
253   // This tracks which symbol callable operations we can propagate within and
254   // out of.
255   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
256     Region &symbolTableRegion = symTable->getRegion(0);
257     Block *symbolTableBlock = &symbolTableRegion.front();
258     for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
259       // We won't be able to track external callables.
260       Region *callableRegion = callable.getCallableRegion();
261       if (!callableRegion)
262         continue;
263       // We only care about symbol defining callables here.
264       auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
265       if (!symbol)
266         continue;
267       callableLatticeState.try_emplace(callable, analysis, callableRegion,
268                                        callable.getCallableResults().size());
269 
270       // If not all of the uses of this symbol are visible, we can't track the
271       // state of the arguments.
272       if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
273         for (Region &region : callable->getRegions())
274           markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
275       }
276     }
277     if (callableLatticeState.empty())
278       return;
279 
280     // After computing the valid callables, walk any symbol uses to check
281     // for non-call references. We won't be able to track the lattice state
282     // for arguments to these callables, as we can't guarantee that we can see
283     // all of its calls.
284     Optional<SymbolTable::UseRange> uses =
285         SymbolTable::getSymbolUses(&symbolTableRegion);
286     if (!uses) {
287       // If we couldn't gather the symbol uses, conservatively assume that
288       // we can't track information for any nested symbols.
289       op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
290       return;
291     }
292 
293     for (const SymbolTable::SymbolUse &use : *uses) {
294       // If the use is a call, track it to avoid the need to recompute the
295       // reference later.
296       if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
297         Operation *symCallable = callOp.resolveCallable(&symbolTable);
298         auto callableLatticeIt = callableLatticeState.find(symCallable);
299         if (callableLatticeIt != callableLatticeState.end()) {
300           callToSymbolCallable.try_emplace(callOp, symCallable);
301 
302           // We only need to record the call in the lattice if it produces any
303           // values.
304           if (callOp->getNumResults())
305             callableLatticeIt->second.addSymbolCall(callOp);
306         }
307         continue;
308       }
309       // This use isn't a call, so don't we know all of the callers.
310       auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
311       auto it = callableLatticeState.find(symbol);
312       if (it != callableLatticeState.end()) {
313         for (Region &region : it->first->getRegions())
314           markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
315       }
316     }
317   };
318   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
319                                 walkFn);
320 }
321 
322 void ForwardDataFlowSolver::visitOperation(Operation *op) {
323   // Collect all of the lattice elements feeding into this operation. If any are
324   // not yet resolved, bail out and wait for them to resolve.
325   SmallVector<AbstractLatticeElement *, 8> operandLattices;
326   operandLattices.reserve(op->getNumOperands());
327   for (Value operand : op->getOperands()) {
328     AbstractLatticeElement *operandLattice =
329         analysis.lookupLatticeElement(operand);
330     if (!operandLattice || operandLattice->isUninitialized())
331       return;
332     operandLattices.push_back(operandLattice);
333   }
334 
335   // If this is a terminator operation, process any control flow lattice state.
336   if (op->hasTrait<OpTrait::IsTerminator>())
337     visitTerminatorOperation(op, operandLattices);
338 
339   // Process call operations. The call visitor processes result values, so we
340   // can exit afterwards.
341   if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
342     return visitCallOperation(call);
343 
344   // Process callable operations. These are specially handled region operations
345   // that track dataflow via calls.
346   if (isa<CallableOpInterface>(op)) {
347     // If this callable has a tracked lattice state, it will be visited by calls
348     // that reference it instead. This way, we don't assume that it is
349     // executable unless there is a proper reference to it.
350     if (callableLatticeState.count(op))
351       return;
352     return visitCallableOperation(op);
353   }
354 
355   // Process region holding operations.
356   if (op->getNumRegions()) {
357     // Check to see if we can reason about the internal control flow of this
358     // region operation.
359     if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
360       return visitRegionBranchOperation(branch, operandLattices);
361 
362     for (Region &region : op->getRegions()) {
363       analysis.visitNonControlFlowArguments(op, RegionSuccessor(&region),
364                                             operandLattices);
365       // `visitNonControlFlowArguments` is required to define all of the region
366       // argument lattices.
367       assert(llvm::none_of(
368                  region.getArguments(),
369                  [&](Value value) {
370                    return analysis.getLatticeElement(value).isUninitialized();
371                  }) &&
372              "expected `visitNonControlFlowArguments` to define all argument "
373              "lattices");
374       markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/false);
375     }
376   }
377 
378   // If this op produces no results, it can't produce any constants.
379   if (op->getNumResults() == 0)
380     return;
381 
382   // If all of the results of this operation are already resolved, bail out
383   // early.
384   auto isAtFixpointFn = [&](Value value) { return isAtFixpoint(value); };
385   if (llvm::all_of(op->getResults(), isAtFixpointFn))
386     return;
387 
388   // Visit the current operation.
389   if (analysis.visitOperation(op, operandLattices) == ChangeResult::Change)
390     opWorklist.push(op);
391 
392   // `visitOperation` is required to define all of the result lattices.
393   assert(llvm::none_of(
394              op->getResults(),
395              [&](Value value) {
396                return analysis.getLatticeElement(value).isUninitialized();
397              }) &&
398          "expected `visitOperation` to define all result lattices");
399 }
400 
401 void ForwardDataFlowSolver::visitCallableOperation(Operation *op) {
402   // Mark the regions as executable. If we aren't tracking lattice state for
403   // this callable, mark all of the region arguments as having reached a
404   // fixpoint.
405   bool isTrackingLatticeState = callableLatticeState.count(op);
406   for (Region &region : op->getRegions())
407     markEntryBlockExecutable(&region, !isTrackingLatticeState);
408 
409   // TODO: Add support for non-symbol callables when necessary. If the callable
410   // has non-call uses we would mark as having reached pessimistic fixpoint,
411   // otherwise allow for propagating the return values out.
412   markAllPessimisticFixpoint(op, op->getResults());
413 }
414 
415 void ForwardDataFlowSolver::visitCallOperation(CallOpInterface op) {
416   ResultRange callResults = op->getResults();
417 
418   // Resolve the callable operation for this call.
419   Operation *callableOp = nullptr;
420   if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
421     callableOp = callableValue.getDefiningOp();
422   else
423     callableOp = callToSymbolCallable.lookup(op);
424 
425   // The callable of this call can't be resolved, mark any results overdefined.
426   if (!callableOp)
427     return markAllPessimisticFixpoint(op, callResults);
428 
429   // If this callable is tracking state, merge the argument operands with the
430   // arguments of the callable.
431   auto callableLatticeIt = callableLatticeState.find(callableOp);
432   if (callableLatticeIt == callableLatticeState.end())
433     return markAllPessimisticFixpoint(op, callResults);
434 
435   OperandRange callOperands = op.getArgOperands();
436   auto callableArgs = callableLatticeIt->second.getCallableArguments();
437   for (auto it : llvm::zip(callOperands, callableArgs)) {
438     BlockArgument callableArg = std::get<1>(it);
439     AbstractLatticeElement &argValue = analysis.getLatticeElement(callableArg);
440     AbstractLatticeElement &operandValue =
441         analysis.getLatticeElement(std::get<0>(it));
442     if (argValue.join(operandValue) == ChangeResult::Change)
443       visitUsers(callableArg);
444   }
445 
446   // Visit the callable.
447   visitCallableOperation(callableOp);
448 
449   // Merge in the lattice state for the callable results as well.
450   auto callableResults = callableLatticeIt->second.getResultLatticeElements();
451   for (auto it : llvm::zip(callResults, callableResults))
452     join(/*owner=*/op,
453          /*to=*/analysis.getLatticeElement(std::get<0>(it)),
454          /*from=*/std::get<1>(it));
455 }
456 
457 void ForwardDataFlowSolver::visitRegionBranchOperation(
458     RegionBranchOpInterface branch,
459     ArrayRef<AbstractLatticeElement *> operandLattices) {
460   // Check to see which regions are executable.
461   SmallVector<RegionSuccessor, 1> successors;
462   analysis.getSuccessorsForOperands(branch, /*sourceIndex=*/llvm::None,
463                                     operandLattices, successors);
464 
465   // If the interface identified that no region will be executed. Mark
466   // any results of this operation as overdefined, as we can't reason about
467   // them.
468   // TODO: If we had an interface to detect pass through operands, we could
469   // resolve some results based on the lattice state of the operands. We could
470   // also allow for the parent operation to have itself as a region successor.
471   if (successors.empty())
472     return markAllPessimisticFixpoint(branch, branch->getResults());
473   return visitRegionSuccessors(
474       branch, successors, operandLattices, [&](Optional<unsigned> index) {
475         assert(index && "expected valid region index");
476         return branch.getSuccessorEntryOperands(*index);
477       });
478 }
479 
480 void ForwardDataFlowSolver::visitRegionSuccessors(
481     Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
482     ArrayRef<AbstractLatticeElement *> operandLattices,
483     function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
484   for (const RegionSuccessor &it : regionSuccessors) {
485     Region *region = it.getSuccessor();
486     ValueRange succArgs = it.getSuccessorInputs();
487 
488     // Check to see if this is the parent operation.
489     if (!region) {
490       ResultRange results = parentOp->getResults();
491       if (llvm::all_of(results, [&](Value res) { return isAtFixpoint(res); }))
492         continue;
493 
494       // Mark the results outside of the input range as having reached the
495       // pessimistic fixpoint.
496       // TODO: This isn't exactly ideal. There may be situations in which a
497       // region operation can provide information for certain results that
498       // aren't part of the control flow.
499       if (succArgs.size() != results.size()) {
500         opWorklist.push(parentOp);
501         if (succArgs.empty()) {
502           markAllPessimisticFixpoint(results);
503           continue;
504         }
505 
506         unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber();
507         markAllPessimisticFixpoint(results.take_front(firstResIdx));
508         markAllPessimisticFixpoint(
509             results.drop_front(firstResIdx + succArgs.size()));
510       }
511 
512       // Update the lattice for any operation results.
513       OperandRange operands = getInputsForRegion(/*index=*/llvm::None);
514       for (auto it : llvm::zip(succArgs, operands))
515         join(parentOp, analysis.getLatticeElement(std::get<0>(it)),
516              analysis.getLatticeElement(std::get<1>(it)));
517       continue;
518     }
519     assert(!region->empty() && "expected region to be non-empty");
520     Block *entryBlock = &region->front();
521     markBlockExecutable(entryBlock);
522 
523     // If all of the arguments have already reached a fixpoint, the arguments
524     // have already been fully resolved.
525     Block::BlockArgListType arguments = entryBlock->getArguments();
526     if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); }))
527       continue;
528 
529     if (succArgs.size() != arguments.size()) {
530       if (analysis.visitNonControlFlowArguments(
531               parentOp, it, operandLattices) == ChangeResult::Change) {
532         unsigned firstArgIdx =
533             succArgs.empty() ? 0
534                              : succArgs[0].cast<BlockArgument>().getArgNumber();
535         for (Value v : arguments.take_front(firstArgIdx)) {
536           assert(!analysis.getLatticeElement(v).isUninitialized() &&
537                  "Non-control flow block arg has no lattice value after "
538                  "analysis callback");
539           visitUsers(v);
540         }
541         for (Value v : arguments.drop_front(firstArgIdx + succArgs.size())) {
542           assert(!analysis.getLatticeElement(v).isUninitialized() &&
543                  "Non-control flow block arg has no lattice value after "
544                  "analysis callback");
545           visitUsers(v);
546         }
547       }
548     }
549 
550     // Update the lattice of arguments that have inputs from the predecessor.
551     OperandRange succOperands = getInputsForRegion(region->getRegionNumber());
552     for (auto it : llvm::zip(succArgs, succOperands)) {
553       AbstractLatticeElement &argValue =
554           analysis.getLatticeElement(std::get<0>(it));
555       AbstractLatticeElement &operandValue =
556           analysis.getLatticeElement(std::get<1>(it));
557       if (argValue.join(operandValue) == ChangeResult::Change)
558         visitUsers(std::get<0>(it));
559     }
560   }
561 }
562 
563 void ForwardDataFlowSolver::visitTerminatorOperation(
564     Operation *op, ArrayRef<AbstractLatticeElement *> operandLattices) {
565   // If this operation has no successors, we treat it as an exiting terminator.
566   if (op->getNumSuccessors() == 0) {
567     Region *parentRegion = op->getParentRegion();
568     Operation *parentOp = parentRegion->getParentOp();
569 
570     // Check to see if this is a terminator for a callable region.
571     if (isa<CallableOpInterface>(parentOp))
572       return visitCallableTerminatorOperation(parentOp, op, operandLattices);
573 
574     // Otherwise, check to see if the parent tracks region control flow.
575     auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
576     if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
577       return;
578 
579     // Query the set of successors of the current region using the current
580     // optimistic lattice state.
581     SmallVector<RegionSuccessor, 1> regionSuccessors;
582     analysis.getSuccessorsForOperands(regionInterface,
583                                       parentRegion->getRegionNumber(),
584                                       operandLattices, regionSuccessors);
585     if (regionSuccessors.empty())
586       return;
587 
588     // Try to get "region-like" successor operands if possible in order to
589     // propagate the operand states to the successors.
590     if (isRegionReturnLike(op)) {
591       auto getOperands = [&](Optional<unsigned> regionIndex) {
592         // Determine the individual region successor operands for the given
593         // region index (if any).
594         return *getRegionBranchSuccessorOperands(op, regionIndex);
595       };
596       return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices,
597                                    getOperands);
598     }
599 
600     // If this terminator is not "region-like", conservatively mark all of the
601     // successor values as having reached the pessimistic fixpoint.
602     for (auto &it : regionSuccessors) {
603       // If the successor is a region, mark the entry block as executable so
604       // that we visit operations defined within. If the successor is the
605       // parent operation, we simply mark the control flow results as having
606       // reached the pessimistic state.
607       if (Region *region = it.getSuccessor())
608         markEntryBlockExecutable(region, /*markPessimisticFixpoint=*/true);
609       else
610         markAllPessimisticFixpointAndVisitUsers(it.getSuccessorInputs());
611     }
612   }
613 
614   // Try to resolve to a specific set of successors with the current optimistic
615   // lattice state.
616   Block *block = op->getBlock();
617   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
618     SmallVector<Block *> successors;
619     if (succeeded(analysis.getSuccessorsForOperands(branch, operandLattices,
620                                                     successors))) {
621       for (Block *succ : successors)
622         markEdgeExecutable(block, succ);
623       return;
624     }
625   }
626 
627   // Otherwise, conservatively treat all edges as executable.
628   for (Block *succ : op->getSuccessors())
629     markEdgeExecutable(block, succ);
630 }
631 
632 void ForwardDataFlowSolver::visitCallableTerminatorOperation(
633     Operation *callable, Operation *terminator,
634     ArrayRef<AbstractLatticeElement *> operandLattices) {
635   // If there are no exiting values, we have nothing to track.
636   if (terminator->getNumOperands() == 0)
637     return;
638 
639   // If this callable isn't tracking any lattice state there is nothing to do.
640   auto latticeIt = callableLatticeState.find(callable);
641   if (latticeIt == callableLatticeState.end())
642     return;
643   assert(callable->getNumResults() == 0 && "expected symbol callable");
644 
645   // If this terminator is not "return-like", conservatively mark all of the
646   // call-site results as having reached the pessimistic fixpoint.
647   auto callableResultLattices = latticeIt->second.getResultLatticeElements();
648   if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
649     for (auto &it : callableResultLattices)
650       it.markPessimisticFixpoint();
651     for (Operation *call : latticeIt->second.getSymbolCalls())
652       markAllPessimisticFixpoint(call, call->getResults());
653     return;
654   }
655 
656   // Merge the lattice state for terminator operands into the results.
657   ChangeResult result = ChangeResult::NoChange;
658   for (auto it : llvm::zip(operandLattices, callableResultLattices))
659     result |= std::get<1>(it).join(*std::get<0>(it));
660   if (result == ChangeResult::NoChange)
661     return;
662 
663   // If any of the result lattices changed, update the callers.
664   for (Operation *call : latticeIt->second.getSymbolCalls())
665     for (auto it : llvm::zip(call->getResults(), callableResultLattices))
666       join(call, analysis.getLatticeElement(std::get<0>(it)), std::get<1>(it));
667 }
668 
669 void ForwardDataFlowSolver::visitBlock(Block *block) {
670   // If the block is not the entry block we need to compute the lattice state
671   // for the block arguments. Entry block argument lattices are computed
672   // elsewhere, such as when visiting the parent operation.
673   if (!block->isEntryBlock()) {
674     for (int i : llvm::seq<int>(0, block->getNumArguments()))
675       visitBlockArgument(block, i);
676   }
677 
678   // Visit all of the operations within the block.
679   for (Operation &op : *block)
680     visitOperation(&op);
681 }
682 
683 void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
684   BlockArgument arg = block->getArgument(i);
685   AbstractLatticeElement &argLattice = analysis.getLatticeElement(arg);
686   if (argLattice.isAtFixpoint())
687     return;
688 
689   ChangeResult updatedLattice = ChangeResult::NoChange;
690   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
691     Block *pred = *it;
692 
693     // We only care about this predecessor if it is going to execute.
694     if (!isEdgeExecutable(pred, block))
695       continue;
696 
697     // Try to get the operand forwarded by the predecessor. If we can't reason
698     // about the terminator of the predecessor, mark as having reached a
699     // fixpoint.
700     auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
701     if (!branch) {
702       updatedLattice |= argLattice.markPessimisticFixpoint();
703       break;
704     }
705     Value operand = branch.getSuccessorOperands(it.getSuccessorIndex())[i];
706     if (!operand) {
707       updatedLattice |= argLattice.markPessimisticFixpoint();
708       break;
709     }
710 
711     // If the operand hasn't been resolved, it is uninitialized and can merge
712     // with anything.
713     AbstractLatticeElement *operandLattice =
714         analysis.lookupLatticeElement(operand);
715     if (!operandLattice)
716       continue;
717 
718     // Otherwise, join the operand lattice into the argument lattice.
719     updatedLattice |= argLattice.join(*operandLattice);
720     if (argLattice.isAtFixpoint())
721       break;
722   }
723 
724   // If the lattice changed, visit users of the argument.
725   if (updatedLattice == ChangeResult::Change)
726     visitUsers(arg);
727 }
728 
729 ChangeResult
730 ForwardDataFlowSolver::markEntryBlockExecutable(Region *region,
731                                                 bool markPessimisticFixpoint) {
732   if (!region->empty()) {
733     if (markPessimisticFixpoint)
734       markAllPessimisticFixpoint(region->front().getArguments());
735     return markBlockExecutable(&region->front());
736   }
737   return ChangeResult::NoChange;
738 }
739 
740 ChangeResult ForwardDataFlowSolver::markBlockExecutable(Block *block) {
741   bool marked = executableBlocks.insert(block).second;
742   if (marked)
743     blockWorklist.push(block);
744   return marked ? ChangeResult::Change : ChangeResult::NoChange;
745 }
746 
747 bool ForwardDataFlowSolver::isBlockExecutable(Block *block) const {
748   return executableBlocks.count(block);
749 }
750 
751 void ForwardDataFlowSolver::markEdgeExecutable(Block *from, Block *to) {
752   executableEdges.insert(std::make_pair(from, to));
753 
754   // Mark the destination as executable, and reprocess its arguments if it was
755   // already executable.
756   if (markBlockExecutable(to) == ChangeResult::NoChange) {
757     for (int i : llvm::seq<int>(0, to->getNumArguments()))
758       visitBlockArgument(to, i);
759   }
760 }
761 
762 bool ForwardDataFlowSolver::isEdgeExecutable(Block *from, Block *to) const {
763   return executableEdges.count(std::make_pair(from, to));
764 }
765 
766 void ForwardDataFlowSolver::markPessimisticFixpoint(Value value) {
767   analysis.getLatticeElement(value).markPessimisticFixpoint();
768 }
769 
770 bool ForwardDataFlowSolver::isAtFixpoint(Value value) const {
771   if (auto *lattice = analysis.lookupLatticeElement(value))
772     return lattice->isAtFixpoint();
773   return false;
774 }
775 
776 void ForwardDataFlowSolver::join(Operation *owner, AbstractLatticeElement &to,
777                                  const AbstractLatticeElement &from) {
778   if (to.join(from) == ChangeResult::Change)
779     opWorklist.push(owner);
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // AbstractLatticeElement
784 //===----------------------------------------------------------------------===//
785 
786 AbstractLatticeElement::~AbstractLatticeElement() = default;
787 
788 //===----------------------------------------------------------------------===//
789 // ForwardDataFlowAnalysisBase
790 //===----------------------------------------------------------------------===//
791 
792 ForwardDataFlowAnalysisBase::~ForwardDataFlowAnalysisBase() = default;
793 
794 AbstractLatticeElement &
795 ForwardDataFlowAnalysisBase::getLatticeElement(Value value) {
796   AbstractLatticeElement *&latticeValue = latticeValues[value];
797   if (!latticeValue)
798     latticeValue = createLatticeElement(value);
799   return *latticeValue;
800 }
801 
802 AbstractLatticeElement *
803 ForwardDataFlowAnalysisBase::lookupLatticeElement(Value value) {
804   return latticeValues.lookup(value);
805 }
806 
807 void ForwardDataFlowAnalysisBase::run(Operation *topLevelOp) {
808   // Run the main dataflow solver.
809   ForwardDataFlowSolver solver(*this, topLevelOp);
810   solver.solve();
811 
812   // Any values that are still uninitialized now go to a pessimistic fixpoint,
813   // otherwise we assume an optimistic fixpoint has been reached.
814   for (auto &it : latticeValues)
815     if (it.second->isUninitialized())
816       it.second->markPessimisticFixpoint();
817     else
818       it.second->markOptimisticFixpoint();
819 }
820