1 //===- DataFlowAnalysis.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlowAnalysis.h"
10 #include "mlir/IR/Operation.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 
15 using namespace mlir;
16 using namespace mlir::detail;
17 
18 namespace {
19 /// This class contains various state used when computing the lattice elements
20 /// of a callable operation.
21 class CallableLatticeState {
22 public:
23   /// Build a lattice state with a given callable region, and a specified number
24   /// of results to be initialized to the default lattice element.
25   CallableLatticeState(ForwardDataFlowAnalysisBase &analysis,
26                        Region *callableRegion, unsigned numResults)
27       : callableArguments(callableRegion->getArguments()),
28         resultLatticeElements(numResults) {
29     for (AbstractLatticeElement *&it : resultLatticeElements)
30       it = analysis.createLatticeElement();
31   }
32 
33   /// Returns the arguments to the callable region.
34   Block::BlockArgListType getCallableArguments() const {
35     return callableArguments;
36   }
37 
38   /// Returns the lattice element for the results of the callable region.
39   auto getResultLatticeElements() {
40     return llvm::make_pointee_range(resultLatticeElements);
41   }
42 
43   /// Add a call to this callable. This is only used if the callable defines a
44   /// symbol.
45   void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }
46 
47   /// Return the calls that reference this callable. This is only used
48   /// if the callable defines a symbol.
49   ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }
50 
51 private:
52   /// The arguments of the callable region.
53   Block::BlockArgListType callableArguments;
54 
55   /// The lattice state for each of the results of this region. The return
56   /// values of the callable aren't SSA values, so we need to track them
57   /// separately.
58   SmallVector<AbstractLatticeElement *, 4> resultLatticeElements;
59 
60   /// The calls referencing this callable if this callable defines a symbol.
61   /// This removes the need to recompute symbol references during propagation.
62   /// Value based references are trivial to resolve, so they can be done
63   /// in-place.
64   SmallVector<Operation *, 4> symbolCalls;
65 };
66 
67 /// This class represents the solver for a forward dataflow analysis. This class
68 /// acts as the propagation engine for computing which lattice elements.
69 class ForwardDataFlowSolver {
70 public:
71   /// Initialize the solver with the given top-level operation.
72   ForwardDataFlowSolver(ForwardDataFlowAnalysisBase &analysis, Operation *op);
73 
74   /// Run the solver until it converges.
75   void solve();
76 
77 private:
78   /// Initialize the set of symbol defining callables that can have their
79   /// arguments and results tracked. 'op' is the top-level operation that the
80   /// solver is operating on.
81   void initializeSymbolCallables(Operation *op);
82 
83   /// Visit the users of the given IR that reside within executable blocks.
84   template <typename T>
85   void visitUsers(T &value) {
86     for (Operation *user : value.getUsers())
87       if (isBlockExecutable(user->getBlock()))
88         visitOperation(user);
89   }
90 
91   /// Visit the given operation and compute any necessary lattice state.
92   void visitOperation(Operation *op);
93 
94   /// Visit the given call operation and compute any necessary lattice state.
95   void visitCallOperation(CallOpInterface op);
96 
97   /// Visit the given callable operation and compute any necessary lattice
98   /// state.
99   void visitCallableOperation(Operation *op);
100 
101   /// Visit the given region branch operation, which defines regions, and
102   /// compute any necessary lattice state. This also resolves the lattice state
103   /// of both the operation results and any nested regions.
104   void visitRegionBranchOperation(
105       RegionBranchOpInterface branch,
106       ArrayRef<AbstractLatticeElement *> operandLattices);
107 
108   /// Visit the given set of region successors, computing any necessary lattice
109   /// state. The provided function returns the input operands to the region at
110   /// the given index. If the index is 'None', the input operands correspond to
111   /// the parent operation results.
112   void visitRegionSuccessors(
113       Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
114       function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
115 
116   /// Visit the given terminator operation and compute any necessary lattice
117   /// state.
118   void
119   visitTerminatorOperation(Operation *op,
120                            ArrayRef<AbstractLatticeElement *> operandLattices);
121 
122   /// Visit the given terminator operation that exits a callable region. These
123   /// are terminators with no CFG successors.
124   void visitCallableTerminatorOperation(
125       Operation *callable, Operation *terminator,
126       ArrayRef<AbstractLatticeElement *> operandLattices);
127 
128   /// Visit the given block and compute any necessary lattice state.
129   void visitBlock(Block *block);
130 
131   /// Visit argument #'i' of the given block and compute any necessary lattice
132   /// state.
133   void visitBlockArgument(Block *block, int i);
134 
135   /// Mark the entry block of the given region as executable. Returns NoChange
136   /// if the block was already marked executable. If `markPessimisticFixpoint`
137   /// is true, the arguments of the entry block are also marked as having
138   /// reached the pessimistic fixpoint.
139   ChangeResult markEntryBlockExecutable(Region *region,
140                                         bool markPessimisticFixpoint);
141 
142   /// Mark the given block as executable. Returns NoChange if the block was
143   /// already marked executable.
144   ChangeResult markBlockExecutable(Block *block);
145 
146   /// Returns true if the given block is executable.
147   bool isBlockExecutable(Block *block) const;
148 
149   /// Mark the edge between 'from' and 'to' as executable.
150   void markEdgeExecutable(Block *from, Block *to);
151 
152   /// Return true if the edge between 'from' and 'to' is executable.
153   bool isEdgeExecutable(Block *from, Block *to) const;
154 
155   /// Mark the given value as having reached the pessimistic fixpoint. This
156   /// means that we cannot further refine the state of this value.
157   void markPessimisticFixpoint(Value value);
158 
159   /// Mark all of the given values as having reaching the pessimistic fixpoint.
160   template <typename ValuesT>
161   void markAllPessimisticFixpoint(ValuesT values) {
162     for (auto value : values)
163       markPessimisticFixpoint(value);
164   }
165   template <typename ValuesT>
166   void markAllPessimisticFixpoint(Operation *op, ValuesT values) {
167     markAllPessimisticFixpoint(values);
168     opWorklist.push_back(op);
169   }
170   template <typename ValuesT>
171   void markAllPessimisticFixpointAndVisitUsers(ValuesT values) {
172     for (auto value : values) {
173       AbstractLatticeElement &lattice = analysis.getLatticeElement(value);
174       if (lattice.markPessimisticFixpoint() == ChangeResult::Change)
175         visitUsers(value);
176     }
177   }
178 
179   /// Returns true if the given value was marked as having reached the
180   /// pessimistic fixpoint.
181   bool isAtFixpoint(Value value) const;
182 
183   /// Merge in the given lattice 'from' into the lattice 'to'. 'owner'
184   /// corresponds to the parent operation of the lattice for 'to'.
185   void join(Operation *owner, AbstractLatticeElement &to,
186             const AbstractLatticeElement &from);
187 
188   /// A reference to the dataflow analysis being computed.
189   ForwardDataFlowAnalysisBase &analysis;
190 
191   /// The set of blocks that are known to execute, or are intrinsically live.
192   SmallPtrSet<Block *, 16> executableBlocks;
193 
194   /// The set of control flow edges that are known to execute.
195   DenseSet<std::pair<Block *, Block *>> executableEdges;
196 
197   /// A worklist containing blocks that need to be processed.
198   SmallVector<Block *, 64> blockWorklist;
199 
200   /// A worklist of operations that need to be processed.
201   SmallVector<Operation *, 64> opWorklist;
202 
203   /// The callable operations that have their argument/result state tracked.
204   DenseMap<Operation *, CallableLatticeState> callableLatticeState;
205 
206   /// A map between a call operation and the resolved symbol callable. This
207   /// avoids re-resolving symbol references during propagation. Value based
208   /// callables are trivial to resolve, so they can be done in-place.
209   DenseMap<Operation *, Operation *> callToSymbolCallable;
210 
211   /// A symbol table used for O(1) symbol lookups during simplification.
212   SymbolTableCollection symbolTable;
213 };
214 } // namespace
215 
216 ForwardDataFlowSolver::ForwardDataFlowSolver(
217     ForwardDataFlowAnalysisBase &analysis, Operation *op)
218     : analysis(analysis) {
219   /// Initialize the solver with the regions within this operation.
220   for (Region &region : op->getRegions()) {
221     // Mark the entry block as executable. The values passed to these regions
222     // are also invisible, so mark any arguments as reaching the pessimistic
223     // fixpoint.
224     markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
225   }
226   initializeSymbolCallables(op);
227 }
228 
229 void ForwardDataFlowSolver::solve() {
230   while (!blockWorklist.empty() || !opWorklist.empty()) {
231     // Process any operations in the op worklist.
232     while (!opWorklist.empty())
233       visitUsers(*opWorklist.pop_back_val());
234 
235     // Process any blocks in the block worklist.
236     while (!blockWorklist.empty())
237       visitBlock(blockWorklist.pop_back_val());
238   }
239 }
240 
241 void ForwardDataFlowSolver::initializeSymbolCallables(Operation *op) {
242   // Initialize the set of symbol callables that can have their state tracked.
243   // This tracks which symbol callable operations we can propagate within and
244   // out of.
245   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
246     Region &symbolTableRegion = symTable->getRegion(0);
247     Block *symbolTableBlock = &symbolTableRegion.front();
248     for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
249       // We won't be able to track external callables.
250       Region *callableRegion = callable.getCallableRegion();
251       if (!callableRegion)
252         continue;
253       // We only care about symbol defining callables here.
254       auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
255       if (!symbol)
256         continue;
257       callableLatticeState.try_emplace(callable, analysis, callableRegion,
258                                        callable.getCallableResults().size());
259 
260       // If not all of the uses of this symbol are visible, we can't track the
261       // state of the arguments.
262       if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
263         for (Region &region : callable->getRegions())
264           markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
265       }
266     }
267     if (callableLatticeState.empty())
268       return;
269 
270     // After computing the valid callables, walk any symbol uses to check
271     // for non-call references. We won't be able to track the lattice state
272     // for arguments to these callables, as we can't guarantee that we can see
273     // all of its calls.
274     Optional<SymbolTable::UseRange> uses =
275         SymbolTable::getSymbolUses(&symbolTableRegion);
276     if (!uses) {
277       // If we couldn't gather the symbol uses, conservatively assume that
278       // we can't track information for any nested symbols.
279       op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
280       return;
281     }
282 
283     for (const SymbolTable::SymbolUse &use : *uses) {
284       // If the use is a call, track it to avoid the need to recompute the
285       // reference later.
286       if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
287         Operation *symCallable = callOp.resolveCallable(&symbolTable);
288         auto callableLatticeIt = callableLatticeState.find(symCallable);
289         if (callableLatticeIt != callableLatticeState.end()) {
290           callToSymbolCallable.try_emplace(callOp, symCallable);
291 
292           // We only need to record the call in the lattice if it produces any
293           // values.
294           if (callOp->getNumResults())
295             callableLatticeIt->second.addSymbolCall(callOp);
296         }
297         continue;
298       }
299       // This use isn't a call, so don't we know all of the callers.
300       auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
301       auto it = callableLatticeState.find(symbol);
302       if (it != callableLatticeState.end()) {
303         for (Region &region : it->first->getRegions())
304           markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
305       }
306     }
307   };
308   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
309                                 walkFn);
310 }
311 
312 void ForwardDataFlowSolver::visitOperation(Operation *op) {
313   // Collect all of the lattice elements feeding into this operation. If any are
314   // not yet resolved, bail out and wait for them to resolve.
315   SmallVector<AbstractLatticeElement *, 8> operandLattices;
316   operandLattices.reserve(op->getNumOperands());
317   for (Value operand : op->getOperands()) {
318     AbstractLatticeElement *operandLattice =
319         analysis.lookupLatticeElement(operand);
320     if (!operandLattice || operandLattice->isUninitialized())
321       return;
322     operandLattices.push_back(operandLattice);
323   }
324 
325   // If this is a terminator operation, process any control flow lattice state.
326   if (op->hasTrait<OpTrait::IsTerminator>())
327     visitTerminatorOperation(op, operandLattices);
328 
329   // Process call operations. The call visitor processes result values, so we
330   // can exit afterwards.
331   if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
332     return visitCallOperation(call);
333 
334   // Process callable operations. These are specially handled region operations
335   // that track dataflow via calls.
336   if (isa<CallableOpInterface>(op)) {
337     // If this callable has a tracked lattice state, it will be visited by calls
338     // that reference it instead. This way, we don't assume that it is
339     // executable unless there is a proper reference to it.
340     if (callableLatticeState.count(op))
341       return;
342     return visitCallableOperation(op);
343   }
344 
345   // Process region holding operations.
346   if (op->getNumRegions()) {
347     // Check to see if we can reason about the internal control flow of this
348     // region operation.
349     if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
350       return visitRegionBranchOperation(branch, operandLattices);
351 
352     // If we can't, conservatively mark all regions as executable.
353     // TODO: Let the `visitOperation` method decide how to propagate
354     // information to the block arguments.
355     for (Region &region : op->getRegions())
356       markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
357   }
358 
359   // If this op produces no results, it can't produce any constants.
360   if (op->getNumResults() == 0)
361     return;
362 
363   // If all of the results of this operation are already resolved, bail out
364   // early.
365   auto isAtFixpointFn = [&](Value value) { return isAtFixpoint(value); };
366   if (llvm::all_of(op->getResults(), isAtFixpointFn))
367     return;
368 
369   // Visit the current operation.
370   if (analysis.visitOperation(op, operandLattices) == ChangeResult::Change)
371     opWorklist.push_back(op);
372 
373   // `visitOperation` is required to define all of the result lattices.
374   assert(llvm::none_of(
375              op->getResults(),
376              [&](Value value) {
377                return analysis.getLatticeElement(value).isUninitialized();
378              }) &&
379          "expected `visitOperation` to define all result lattices");
380 }
381 
382 void ForwardDataFlowSolver::visitCallableOperation(Operation *op) {
383   // Mark the regions as executable. If we aren't tracking lattice state for
384   // this callable, mark all of the region arguments as having reached a
385   // fixpoint.
386   bool isTrackingLatticeState = callableLatticeState.count(op);
387   for (Region &region : op->getRegions())
388     markEntryBlockExecutable(&region, !isTrackingLatticeState);
389 
390   // TODO: Add support for non-symbol callables when necessary. If the callable
391   // has non-call uses we would mark as having reached pessimistic fixpoint,
392   // otherwise allow for propagating the return values out.
393   markAllPessimisticFixpoint(op, op->getResults());
394 }
395 
396 void ForwardDataFlowSolver::visitCallOperation(CallOpInterface op) {
397   ResultRange callResults = op->getResults();
398 
399   // Resolve the callable operation for this call.
400   Operation *callableOp = nullptr;
401   if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
402     callableOp = callableValue.getDefiningOp();
403   else
404     callableOp = callToSymbolCallable.lookup(op);
405 
406   // The callable of this call can't be resolved, mark any results overdefined.
407   if (!callableOp)
408     return markAllPessimisticFixpoint(op, callResults);
409 
410   // If this callable is tracking state, merge the argument operands with the
411   // arguments of the callable.
412   auto callableLatticeIt = callableLatticeState.find(callableOp);
413   if (callableLatticeIt == callableLatticeState.end())
414     return markAllPessimisticFixpoint(op, callResults);
415 
416   OperandRange callOperands = op.getArgOperands();
417   auto callableArgs = callableLatticeIt->second.getCallableArguments();
418   for (auto it : llvm::zip(callOperands, callableArgs)) {
419     BlockArgument callableArg = std::get<1>(it);
420     AbstractLatticeElement &argValue = analysis.getLatticeElement(callableArg);
421     AbstractLatticeElement &operandValue =
422         analysis.getLatticeElement(std::get<0>(it));
423     if (argValue.join(operandValue) == ChangeResult::Change)
424       visitUsers(callableArg);
425   }
426 
427   // Visit the callable.
428   visitCallableOperation(callableOp);
429 
430   // Merge in the lattice state for the callable results as well.
431   auto callableResults = callableLatticeIt->second.getResultLatticeElements();
432   for (auto it : llvm::zip(callResults, callableResults))
433     join(/*owner=*/op,
434          /*to=*/analysis.getLatticeElement(std::get<0>(it)),
435          /*from=*/std::get<1>(it));
436 }
437 
438 void ForwardDataFlowSolver::visitRegionBranchOperation(
439     RegionBranchOpInterface branch,
440     ArrayRef<AbstractLatticeElement *> operandLattices) {
441   // Check to see which regions are executable.
442   SmallVector<RegionSuccessor, 1> successors;
443   analysis.getSuccessorsForOperands(branch, /*sourceIndex=*/llvm::None,
444                                     operandLattices, successors);
445 
446   // If the interface identified that no region will be executed. Mark
447   // any results of this operation as overdefined, as we can't reason about
448   // them.
449   // TODO: If we had an interface to detect pass through operands, we could
450   // resolve some results based on the lattice state of the operands. We could
451   // also allow for the parent operation to have itself as a region successor.
452   if (successors.empty())
453     return markAllPessimisticFixpoint(branch, branch->getResults());
454   return visitRegionSuccessors(
455       branch, successors, [&](Optional<unsigned> index) {
456         assert(index && "expected valid region index");
457         return branch.getSuccessorEntryOperands(*index);
458       });
459 }
460 
461 void ForwardDataFlowSolver::visitRegionSuccessors(
462     Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
463     function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
464   for (const RegionSuccessor &it : regionSuccessors) {
465     Region *region = it.getSuccessor();
466     ValueRange succArgs = it.getSuccessorInputs();
467 
468     // Check to see if this is the parent operation.
469     if (!region) {
470       ResultRange results = parentOp->getResults();
471       if (llvm::all_of(results, [&](Value res) { return isAtFixpoint(res); }))
472         continue;
473 
474       // Mark the results outside of the input range as having reached the
475       // pessimistic fixpoint.
476       // TODO: This isn't exactly ideal. There may be situations in which a
477       // region operation can provide information for certain results that
478       // aren't part of the control flow.
479       if (succArgs.size() != results.size()) {
480         opWorklist.push_back(parentOp);
481         if (succArgs.empty()) {
482           markAllPessimisticFixpoint(results);
483           continue;
484         }
485 
486         unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber();
487         markAllPessimisticFixpoint(results.take_front(firstResIdx));
488         markAllPessimisticFixpoint(
489             results.drop_front(firstResIdx + succArgs.size()));
490       }
491 
492       // Update the lattice for any operation results.
493       OperandRange operands = getInputsForRegion(/*index=*/llvm::None);
494       for (auto it : llvm::zip(succArgs, operands))
495         join(parentOp, analysis.getLatticeElement(std::get<0>(it)),
496              analysis.getLatticeElement(std::get<1>(it)));
497       continue;
498     }
499     assert(!region->empty() && "expected region to be non-empty");
500     Block *entryBlock = &region->front();
501     markBlockExecutable(entryBlock);
502 
503     // If all of the arguments have already reached a fixpoint, the arguments
504     // have already been fully resolved.
505     Block::BlockArgListType arguments = entryBlock->getArguments();
506     if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); }))
507       continue;
508 
509     // Mark any arguments that do not receive inputs as having reached a
510     // pessimistic fixpoint, we won't be able to discern if they are constant.
511     // TODO: This isn't exactly ideal. There may be situations in which a
512     // region operation can provide information for certain results that
513     // aren't part of the control flow.
514     if (succArgs.size() != arguments.size()) {
515       if (succArgs.empty()) {
516         markAllPessimisticFixpoint(arguments);
517         continue;
518       }
519 
520       unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
521       markAllPessimisticFixpointAndVisitUsers(
522           arguments.take_front(firstArgIdx));
523       markAllPessimisticFixpointAndVisitUsers(
524           arguments.drop_front(firstArgIdx + succArgs.size()));
525     }
526 
527     // Update the lattice of arguments that have inputs from the predecessor.
528     OperandRange succOperands = getInputsForRegion(region->getRegionNumber());
529     for (auto it : llvm::zip(succArgs, succOperands)) {
530       AbstractLatticeElement &argValue =
531           analysis.getLatticeElement(std::get<0>(it));
532       AbstractLatticeElement &operandValue =
533           analysis.getLatticeElement(std::get<1>(it));
534       if (argValue.join(operandValue) == ChangeResult::Change)
535         visitUsers(std::get<0>(it));
536     }
537   }
538 }
539 
540 void ForwardDataFlowSolver::visitTerminatorOperation(
541     Operation *op, ArrayRef<AbstractLatticeElement *> operandLattices) {
542   // If this operation has no successors, we treat it as an exiting terminator.
543   if (op->getNumSuccessors() == 0) {
544     Region *parentRegion = op->getParentRegion();
545     Operation *parentOp = parentRegion->getParentOp();
546 
547     // Check to see if this is a terminator for a callable region.
548     if (isa<CallableOpInterface>(parentOp))
549       return visitCallableTerminatorOperation(parentOp, op, operandLattices);
550 
551     // Otherwise, check to see if the parent tracks region control flow.
552     auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
553     if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
554       return;
555 
556     // Query the set of successors of the current region using the current
557     // optimistic lattice state.
558     SmallVector<RegionSuccessor, 1> regionSuccessors;
559     analysis.getSuccessorsForOperands(regionInterface,
560                                       parentRegion->getRegionNumber(),
561                                       operandLattices, regionSuccessors);
562     if (regionSuccessors.empty())
563       return;
564 
565     // Try to get "region-like" successor operands if possible in order to
566     // propagate the operand states to the successors.
567     if (isRegionReturnLike(op)) {
568       return visitRegionSuccessors(
569           parentOp, regionSuccessors, [&](Optional<unsigned> regionIndex) {
570             // Determine the individual region successor operands for the given
571             // region index (if any).
572             return *getRegionBranchSuccessorOperands(op, regionIndex);
573           });
574     }
575 
576     // If this terminator is not "region-like", conservatively mark all of the
577     // successor values as having reached the pessimistic fixpoint.
578     for (auto &it : regionSuccessors) {
579       // If the successor is a region, mark the entry block as executable so
580       // that we visit operations defined within. If the successor is the
581       // parent operation, we simply mark the control flow results as having
582       // reached the pessimistic state.
583       if (Region *region = it.getSuccessor())
584         markEntryBlockExecutable(region, /*markPessimisticFixpoint=*/true);
585       else
586         markAllPessimisticFixpointAndVisitUsers(it.getSuccessorInputs());
587     }
588   }
589 
590   // Try to resolve to a specific set of successors with the current optimistic
591   // lattice state.
592   Block *block = op->getBlock();
593   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
594     SmallVector<Block *> successors;
595     if (succeeded(analysis.getSuccessorsForOperands(branch, operandLattices,
596                                                     successors))) {
597       for (Block *succ : successors)
598         markEdgeExecutable(block, succ);
599       return;
600     }
601   }
602 
603   // Otherwise, conservatively treat all edges as executable.
604   for (Block *succ : op->getSuccessors())
605     markEdgeExecutable(block, succ);
606 }
607 
608 void ForwardDataFlowSolver::visitCallableTerminatorOperation(
609     Operation *callable, Operation *terminator,
610     ArrayRef<AbstractLatticeElement *> operandLattices) {
611   // If there are no exiting values, we have nothing to track.
612   if (terminator->getNumOperands() == 0)
613     return;
614 
615   // If this callable isn't tracking any lattice state there is nothing to do.
616   auto latticeIt = callableLatticeState.find(callable);
617   if (latticeIt == callableLatticeState.end())
618     return;
619   assert(callable->getNumResults() == 0 && "expected symbol callable");
620 
621   // If this terminator is not "return-like", conservatively mark all of the
622   // call-site results as having reached the pessimistic fixpoint.
623   auto callableResultLattices = latticeIt->second.getResultLatticeElements();
624   if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
625     for (auto &it : callableResultLattices)
626       it.markPessimisticFixpoint();
627     for (Operation *call : latticeIt->second.getSymbolCalls())
628       markAllPessimisticFixpoint(call, call->getResults());
629     return;
630   }
631 
632   // Merge the lattice state for terminator operands into the results.
633   ChangeResult result = ChangeResult::NoChange;
634   for (auto it : llvm::zip(operandLattices, callableResultLattices))
635     result |= std::get<1>(it).join(*std::get<0>(it));
636   if (result == ChangeResult::NoChange)
637     return;
638 
639   // If any of the result lattices changed, update the callers.
640   for (Operation *call : latticeIt->second.getSymbolCalls())
641     for (auto it : llvm::zip(call->getResults(), callableResultLattices))
642       join(call, analysis.getLatticeElement(std::get<0>(it)), std::get<1>(it));
643 }
644 
645 void ForwardDataFlowSolver::visitBlock(Block *block) {
646   // If the block is not the entry block we need to compute the lattice state
647   // for the block arguments. Entry block argument lattices are computed
648   // elsewhere, such as when visiting the parent operation.
649   if (!block->isEntryBlock()) {
650     for (int i : llvm::seq<int>(0, block->getNumArguments()))
651       visitBlockArgument(block, i);
652   }
653 
654   // Visit all of the operations within the block.
655   for (Operation &op : *block)
656     visitOperation(&op);
657 }
658 
659 void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
660   BlockArgument arg = block->getArgument(i);
661   AbstractLatticeElement &argLattice = analysis.getLatticeElement(arg);
662   if (argLattice.isAtFixpoint())
663     return;
664 
665   ChangeResult updatedLattice = ChangeResult::NoChange;
666   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
667     Block *pred = *it;
668 
669     // We only care about this predecessor if it is going to execute.
670     if (!isEdgeExecutable(pred, block))
671       continue;
672 
673     // Try to get the operand forwarded by the predecessor. If we can't reason
674     // about the terminator of the predecessor, mark as having reached a
675     // fixpoint.
676     Optional<OperandRange> branchOperands;
677     if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
678       branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
679     if (!branchOperands) {
680       updatedLattice |= argLattice.markPessimisticFixpoint();
681       break;
682     }
683 
684     // If the operand hasn't been resolved, it is uninitialized and can merge
685     // with anything.
686     AbstractLatticeElement *operandLattice =
687         analysis.lookupLatticeElement((*branchOperands)[i]);
688     if (!operandLattice)
689       continue;
690 
691     // Otherwise, join the operand lattice into the argument lattice.
692     updatedLattice |= argLattice.join(*operandLattice);
693     if (argLattice.isAtFixpoint())
694       break;
695   }
696 
697   // If the lattice changed, visit users of the argument.
698   if (updatedLattice == ChangeResult::Change)
699     visitUsers(arg);
700 }
701 
702 ChangeResult
703 ForwardDataFlowSolver::markEntryBlockExecutable(Region *region,
704                                                 bool markPessimisticFixpoint) {
705   if (!region->empty()) {
706     if (markPessimisticFixpoint)
707       markAllPessimisticFixpoint(region->front().getArguments());
708     return markBlockExecutable(&region->front());
709   }
710   return ChangeResult::NoChange;
711 }
712 
713 ChangeResult ForwardDataFlowSolver::markBlockExecutable(Block *block) {
714   bool marked = executableBlocks.insert(block).second;
715   if (marked)
716     blockWorklist.push_back(block);
717   return marked ? ChangeResult::Change : ChangeResult::NoChange;
718 }
719 
720 bool ForwardDataFlowSolver::isBlockExecutable(Block *block) const {
721   return executableBlocks.count(block);
722 }
723 
724 void ForwardDataFlowSolver::markEdgeExecutable(Block *from, Block *to) {
725   executableEdges.insert(std::make_pair(from, to));
726 
727   // Mark the destination as executable, and reprocess its arguments if it was
728   // already executable.
729   if (markBlockExecutable(to) == ChangeResult::NoChange) {
730     for (int i : llvm::seq<int>(0, to->getNumArguments()))
731       visitBlockArgument(to, i);
732   }
733 }
734 
735 bool ForwardDataFlowSolver::isEdgeExecutable(Block *from, Block *to) const {
736   return executableEdges.count(std::make_pair(from, to));
737 }
738 
739 void ForwardDataFlowSolver::markPessimisticFixpoint(Value value) {
740   analysis.getLatticeElement(value).markPessimisticFixpoint();
741 }
742 
743 bool ForwardDataFlowSolver::isAtFixpoint(Value value) const {
744   if (auto *lattice = analysis.lookupLatticeElement(value))
745     return lattice->isAtFixpoint();
746   return false;
747 }
748 
749 void ForwardDataFlowSolver::join(Operation *owner, AbstractLatticeElement &to,
750                                  const AbstractLatticeElement &from) {
751   if (to.join(from) == ChangeResult::Change)
752     opWorklist.push_back(owner);
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // AbstractLatticeElement
757 //===----------------------------------------------------------------------===//
758 
759 AbstractLatticeElement::~AbstractLatticeElement() {}
760 
761 //===----------------------------------------------------------------------===//
762 // ForwardDataFlowAnalysisBase
763 //===----------------------------------------------------------------------===//
764 
765 ForwardDataFlowAnalysisBase::~ForwardDataFlowAnalysisBase() {}
766 
767 AbstractLatticeElement &
768 ForwardDataFlowAnalysisBase::getLatticeElement(Value value) {
769   AbstractLatticeElement *&latticeValue = latticeValues[value];
770   if (!latticeValue)
771     latticeValue = createLatticeElement(value);
772   return *latticeValue;
773 }
774 
775 AbstractLatticeElement *
776 ForwardDataFlowAnalysisBase::lookupLatticeElement(Value value) {
777   return latticeValues.lookup(value);
778 }
779 
780 void ForwardDataFlowAnalysisBase::run(Operation *topLevelOp) {
781   // Run the main dataflow solver.
782   ForwardDataFlowSolver solver(*this, topLevelOp);
783   solver.solve();
784 
785   // Any values that are still uninitialized now go to a pessimistic fixpoint,
786   // otherwise we assume an optimistic fixpoint has been reached.
787   for (auto &it : latticeValues)
788     if (it.second->isUninitialized())
789       it.second->markPessimisticFixpoint();
790     else
791       it.second->markOptimisticFixpoint();
792 }
793