1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/Interfaces/ControlFlowInterfaces.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/Passes.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 /// This class represents a single lattice value. A lattive value corresponds to
30 /// the various different states that a value in the SCCP dataflow analysis can
31 /// take. See 'Kind' below for more details on the different states a value can
32 /// take.
33 class LatticeValue {
34   enum Kind {
35     /// A value with a yet to be determined value. This state may be changed to
36     /// anything.
37     Unknown,
38 
39     /// A value that is known to be a constant. This state may be changed to
40     /// overdefined.
41     Constant,
42 
43     /// A value that cannot statically be determined to be a constant. This
44     /// state cannot be changed.
45     Overdefined
46   };
47 
48 public:
49   /// Initialize a lattice value with "Unknown".
50   LatticeValue()
51       : constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {}
52   /// Initialize a lattice value with a constant.
53   LatticeValue(Attribute attr, Dialect *dialect)
54       : constantAndTag(attr, Kind::Constant), constantDialect(dialect) {}
55 
56   /// Returns true if this lattice value is unknown.
57   bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; }
58 
59   /// Mark the lattice value as overdefined.
60   void markOverdefined() {
61     constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined);
62     constantDialect = nullptr;
63   }
64 
65   /// Returns true if the lattice is overdefined.
66   bool isOverdefined() const {
67     return constantAndTag.getInt() == Kind::Overdefined;
68   }
69 
70   /// Mark the lattice value as constant.
71   void markConstant(Attribute value, Dialect *dialect) {
72     constantAndTag.setPointerAndInt(value, Kind::Constant);
73     constantDialect = dialect;
74   }
75 
76   /// If this lattice is constant, return the constant. Returns nullptr
77   /// otherwise.
78   Attribute getConstant() const { return constantAndTag.getPointer(); }
79 
80   /// If this lattice is constant, return the dialect to use when materializing
81   /// the constant.
82   Dialect *getConstantDialect() const {
83     assert(getConstant() && "expected valid constant");
84     return constantDialect;
85   }
86 
87   /// Merge in the value of the 'rhs' lattice into this one. Returns true if the
88   /// lattice value changed.
89   bool meet(const LatticeValue &rhs) {
90     // If we are already overdefined, or rhs is unknown, there is nothing to do.
91     if (isOverdefined() || rhs.isUnknown())
92       return false;
93     // If we are unknown, just take the value of rhs.
94     if (isUnknown()) {
95       constantAndTag = rhs.constantAndTag;
96       constantDialect = rhs.constantDialect;
97       return true;
98     }
99 
100     // Otherwise, if this value doesn't match rhs go straight to overdefined.
101     if (constantAndTag != rhs.constantAndTag) {
102       markOverdefined();
103       return true;
104     }
105     return false;
106   }
107 
108 private:
109   /// The attribute value if this is a constant and the tag for the element
110   /// kind.
111   llvm::PointerIntPair<Attribute, 2, Kind> constantAndTag;
112 
113   /// The dialect the constant originated from. This is only valid if the
114   /// lattice is a constant. This is not used as part of the key, and is only
115   /// needed to materialize the held constant if necessary.
116   Dialect *constantDialect;
117 };
118 
119 /// This class contains various state used when computing the lattice of a
120 /// callable operation.
121 class CallableLatticeState {
122 public:
123   /// Build a lattice state with a given callable region, and a specified number
124   /// of results to be initialized to the default lattice value (Unknown).
125   CallableLatticeState(Region *callableRegion, unsigned numResults)
126       : callableArguments(callableRegion->getArguments()),
127         resultLatticeValues(numResults) {}
128 
129   /// Returns the arguments to the callable region.
130   Block::BlockArgListType getCallableArguments() const {
131     return callableArguments;
132   }
133 
134   /// Returns the lattice value for the results of the callable region.
135   MutableArrayRef<LatticeValue> getResultLatticeValues() {
136     return resultLatticeValues;
137   }
138 
139   /// Add a call to this callable. This is only used if the callable defines a
140   /// symbol.
141   void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }
142 
143   /// Return the calls that reference this callable. This is only used
144   /// if the callable defines a symbol.
145   ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }
146 
147 private:
148   /// The arguments of the callable region.
149   Block::BlockArgListType callableArguments;
150 
151   /// The lattice state for each of the results of this region. The return
152   /// values of the callable aren't SSA values, so we need to track them
153   /// separately.
154   SmallVector<LatticeValue, 4> resultLatticeValues;
155 
156   /// The calls referencing this callable if this callable defines a symbol.
157   /// This removes the need to recompute symbol references during propagation.
158   /// Value based references are trivial to resolve, so they can be done
159   /// in-place.
160   SmallVector<Operation *, 4> symbolCalls;
161 };
162 
163 /// This class represents the solver for the SCCP analysis. This class acts as
164 /// the propagation engine for computing which values form constants.
165 class SCCPSolver {
166 public:
167   /// Initialize the solver with the given top-level operation.
168   SCCPSolver(Operation *op);
169 
170   /// Run the solver until it converges.
171   void solve();
172 
173   /// Rewrite the given regions using the computing analysis. This replaces the
174   /// uses of all values that have been computed to be constant, and erases as
175   /// many newly dead operations.
176   void rewrite(MLIRContext *context, MutableArrayRef<Region> regions);
177 
178 private:
179   /// Initialize the set of symbol defining callables that can have their
180   /// arguments and results tracked. 'op' is the top-level operation that SCCP
181   /// is operating on.
182   void initializeSymbolCallables(Operation *op);
183 
184   /// Replace the given value with a constant if the corresponding lattice
185   /// represents a constant. Returns success if the value was replaced, failure
186   /// otherwise.
187   LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder,
188                                     Value value);
189 
190   /// Visit the users of the given IR that reside within executable blocks.
191   template <typename T>
192   void visitUsers(T &value) {
193     for (Operation *user : value.getUsers())
194       if (isBlockExecutable(user->getBlock()))
195         visitOperation(user);
196   }
197 
198   /// Visit the given operation and compute any necessary lattice state.
199   void visitOperation(Operation *op);
200 
201   /// Visit the given call operation and compute any necessary lattice state.
202   void visitCallOperation(CallOpInterface op);
203 
204   /// Visit the given callable operation and compute any necessary lattice
205   /// state.
206   void visitCallableOperation(Operation *op);
207 
208   /// Visit the given operation, which defines regions, and compute any
209   /// necessary lattice state. This also resolves the lattice state of both the
210   /// operation results and any nested regions.
211   void visitRegionOperation(Operation *op,
212                             ArrayRef<Attribute> constantOperands);
213 
214   /// Visit the given set of region successors, computing any necessary lattice
215   /// state. The provided function returns the input operands to the region at
216   /// the given index. If the index is 'None', the input operands correspond to
217   /// the parent operation results.
218   void visitRegionSuccessors(
219       Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
220       function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
221 
222   /// Visit the given terminator operation and compute any necessary lattice
223   /// state.
224   void visitTerminatorOperation(Operation *op,
225                                 ArrayRef<Attribute> constantOperands);
226 
227   /// Visit the given terminator operation that exits a callable region. These
228   /// are terminators with no CFG successors.
229   void visitCallableTerminatorOperation(Operation *callable,
230                                         Operation *terminator);
231 
232   /// Visit the given block and compute any necessary lattice state.
233   void visitBlock(Block *block);
234 
235   /// Visit argument #'i' of the given block and compute any necessary lattice
236   /// state.
237   void visitBlockArgument(Block *block, int i);
238 
239   /// Mark the entry block of the given region as executable. Returns false if
240   /// the block was already marked executable. If `markArgsOverdefined` is true,
241   /// the arguments of the entry block are also set to overdefined.
242   bool markEntryBlockExecutable(Region *region, bool markArgsOverdefined);
243 
244   /// Mark the given block as executable. Returns false if the block was already
245   /// marked executable.
246   bool markBlockExecutable(Block *block);
247 
248   /// Returns true if the given block is executable.
249   bool isBlockExecutable(Block *block) const;
250 
251   /// Mark the edge between 'from' and 'to' as executable.
252   void markEdgeExecutable(Block *from, Block *to);
253 
254   /// Return true if the edge between 'from' and 'to' is executable.
255   bool isEdgeExecutable(Block *from, Block *to) const;
256 
257   /// Mark the given value as overdefined. This means that we cannot refine a
258   /// specific constant for this value.
259   void markOverdefined(Value value);
260 
261   /// Mark all of the given values as overdefined.
262   template <typename ValuesT>
263   void markAllOverdefined(ValuesT values) {
264     for (auto value : values)
265       markOverdefined(value);
266   }
267   template <typename ValuesT>
268   void markAllOverdefined(Operation *op, ValuesT values) {
269     markAllOverdefined(values);
270     opWorklist.push_back(op);
271   }
272   template <typename ValuesT>
273   void markAllOverdefinedAndVisitUsers(ValuesT values) {
274     for (auto value : values) {
275       auto &lattice = latticeValues[value];
276       if (!lattice.isOverdefined()) {
277         lattice.markOverdefined();
278         visitUsers(value);
279       }
280     }
281   }
282 
283   /// Returns true if the given value was marked as overdefined.
284   bool isOverdefined(Value value) const;
285 
286   /// Merge in the given lattice 'from' into the lattice 'to'. 'owner'
287   /// corresponds to the parent operation of 'to'.
288   void meet(Operation *owner, LatticeValue &to, const LatticeValue &from);
289 
290   /// The lattice for each SSA value.
291   DenseMap<Value, LatticeValue> latticeValues;
292 
293   /// The set of blocks that are known to execute, or are intrinsically live.
294   SmallPtrSet<Block *, 16> executableBlocks;
295 
296   /// The set of control flow edges that are known to execute.
297   DenseSet<std::pair<Block *, Block *>> executableEdges;
298 
299   /// A worklist containing blocks that need to be processed.
300   SmallVector<Block *, 64> blockWorklist;
301 
302   /// A worklist of operations that need to be processed.
303   SmallVector<Operation *, 64> opWorklist;
304 
305   /// The callable operations that have their argument/result state tracked.
306   DenseMap<Operation *, CallableLatticeState> callableLatticeState;
307 
308   /// A map between a call operation and the resolved symbol callable. This
309   /// avoids re-resolving symbol references during propagation. Value based
310   /// callables are trivial to resolve, so they can be done in-place.
311   DenseMap<Operation *, Operation *> callToSymbolCallable;
312 
313   /// A symbol table used for O(1) symbol lookups during simplification.
314   SymbolTableCollection symbolTable;
315 };
316 } // end anonymous namespace
317 
318 SCCPSolver::SCCPSolver(Operation *op) {
319   /// Initialize the solver with the regions within this operation.
320   for (Region &region : op->getRegions()) {
321     // Mark the entry block as executable. The values passed to these regions
322     // are also invisible, so mark any arguments as overdefined.
323     markEntryBlockExecutable(&region, /*markArgsOverdefined=*/true);
324   }
325   initializeSymbolCallables(op);
326 }
327 
328 void SCCPSolver::solve() {
329   while (!blockWorklist.empty() || !opWorklist.empty()) {
330     // Process any operations in the op worklist.
331     while (!opWorklist.empty())
332       visitUsers(*opWorklist.pop_back_val());
333 
334     // Process any blocks in the block worklist.
335     while (!blockWorklist.empty())
336       visitBlock(blockWorklist.pop_back_val());
337   }
338 }
339 
340 void SCCPSolver::rewrite(MLIRContext *context,
341                          MutableArrayRef<Region> initialRegions) {
342   SmallVector<Block *, 8> worklist;
343   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
344     for (Region &region : regions)
345       for (Block &block : region)
346         if (isBlockExecutable(&block))
347           worklist.push_back(&block);
348   };
349 
350   // An operation folder used to create and unique constants.
351   OperationFolder folder(context);
352   OpBuilder builder(context);
353 
354   addToWorklist(initialRegions);
355   while (!worklist.empty()) {
356     Block *block = worklist.pop_back_val();
357 
358     // Replace any block arguments with constants.
359     builder.setInsertionPointToStart(block);
360     for (BlockArgument arg : block->getArguments())
361       (void)replaceWithConstant(builder, folder, arg);
362 
363     for (Operation &op : llvm::make_early_inc_range(*block)) {
364       builder.setInsertionPoint(&op);
365 
366       // Replace any result with constants.
367       bool replacedAll = op.getNumResults() != 0;
368       for (Value res : op.getResults())
369         replacedAll &= succeeded(replaceWithConstant(builder, folder, res));
370 
371       // If all of the results of the operation were replaced, try to erase
372       // the operation completely.
373       if (replacedAll && wouldOpBeTriviallyDead(&op)) {
374         assert(op.use_empty() && "expected all uses to be replaced");
375         op.erase();
376         continue;
377       }
378 
379       // Add any the regions of this operation to the worklist.
380       addToWorklist(op.getRegions());
381     }
382   }
383 }
384 
385 void SCCPSolver::initializeSymbolCallables(Operation *op) {
386   // Initialize the set of symbol callables that can have their state tracked.
387   // This tracks which symbol callable operations we can propagate within and
388   // out of.
389   auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
390     Region &symbolTableRegion = symTable->getRegion(0);
391     Block *symbolTableBlock = &symbolTableRegion.front();
392     for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
393       // We won't be able to track external callables.
394       Region *callableRegion = callable.getCallableRegion();
395       if (!callableRegion)
396         continue;
397       // We only care about symbol defining callables here.
398       auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
399       if (!symbol)
400         continue;
401       callableLatticeState.try_emplace(callable, callableRegion,
402                                        callable.getCallableResults().size());
403 
404       // If not all of the uses of this symbol are visible, we can't track the
405       // state of the arguments.
406       if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
407         for (Region &region : callable->getRegions())
408           markEntryBlockExecutable(&region, /*markArgsOverdefined=*/true);
409       }
410     }
411     if (callableLatticeState.empty())
412       return;
413 
414     // After computing the valid callables, walk any symbol uses to check
415     // for non-call references. We won't be able to track the lattice state
416     // for arguments to these callables, as we can't guarantee that we can see
417     // all of its calls.
418     Optional<SymbolTable::UseRange> uses =
419         SymbolTable::getSymbolUses(&symbolTableRegion);
420     if (!uses) {
421       // If we couldn't gather the symbol uses, conservatively assume that
422       // we can't track information for any nested symbols.
423       op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
424       return;
425     }
426 
427     for (const SymbolTable::SymbolUse &use : *uses) {
428       // If the use is a call, track it to avoid the need to recompute the
429       // reference later.
430       if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
431         Operation *symCallable = callOp.resolveCallable(&symbolTable);
432         auto callableLatticeIt = callableLatticeState.find(symCallable);
433         if (callableLatticeIt != callableLatticeState.end()) {
434           callToSymbolCallable.try_emplace(callOp, symCallable);
435 
436           // We only need to record the call in the lattice if it produces any
437           // values.
438           if (callOp->getNumResults())
439             callableLatticeIt->second.addSymbolCall(callOp);
440         }
441         continue;
442       }
443       // This use isn't a call, so don't we know all of the callers.
444       auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
445       auto it = callableLatticeState.find(symbol);
446       if (it != callableLatticeState.end()) {
447         for (Region &region : it->first->getRegions())
448           markEntryBlockExecutable(&region, /*markArgsOverdefined=*/true);
449       }
450     }
451   };
452   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
453                                 walkFn);
454 }
455 
456 LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder,
457                                               OperationFolder &folder,
458                                               Value value) {
459   auto it = latticeValues.find(value);
460   auto attr = it == latticeValues.end() ? nullptr : it->second.getConstant();
461   if (!attr)
462     return failure();
463 
464   // Attempt to materialize a constant for the given value.
465   Dialect *dialect = it->second.getConstantDialect();
466   Value constant = folder.getOrCreateConstant(builder, dialect, attr,
467                                               value.getType(), value.getLoc());
468   if (!constant)
469     return failure();
470 
471   value.replaceAllUsesWith(constant);
472   latticeValues.erase(it);
473   return success();
474 }
475 
476 void SCCPSolver::visitOperation(Operation *op) {
477   // Collect all of the constant operands feeding into this operation. If any
478   // are not ready to be resolved, bail out and wait for them to resolve.
479   SmallVector<Attribute, 8> operandConstants;
480   operandConstants.reserve(op->getNumOperands());
481   for (Value operand : op->getOperands()) {
482     // Make sure all of the operands are resolved first.
483     auto &operandLattice = latticeValues[operand];
484     if (operandLattice.isUnknown())
485       return;
486     operandConstants.push_back(operandLattice.getConstant());
487   }
488 
489   // If this is a terminator operation, process any control flow lattice state.
490   if (op->hasTrait<OpTrait::IsTerminator>())
491     visitTerminatorOperation(op, operandConstants);
492 
493   // Process call operations. The call visitor processes result values, so we
494   // can exit afterwards.
495   if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
496     return visitCallOperation(call);
497 
498   // Process callable operations. These are specially handled region operations
499   // that track dataflow via calls.
500   if (isa<CallableOpInterface>(op)) {
501     // If this callable has a tracked lattice state, it will be visited by calls
502     // that reference it instead. This way, we don't assume that it is
503     // executable unless there is a proper reference to it.
504     if (callableLatticeState.count(op))
505       return;
506     return visitCallableOperation(op);
507   }
508 
509   // Process region holding operations. The region visitor processes result
510   // values, so we can exit afterwards.
511   if (op->getNumRegions())
512     return visitRegionOperation(op, operandConstants);
513 
514   // If this op produces no results, it can't produce any constants.
515   if (op->getNumResults() == 0)
516     return;
517 
518   // If all of the results of this operation are already overdefined, bail out
519   // early.
520   auto isOverdefinedFn = [&](Value value) { return isOverdefined(value); };
521   if (llvm::all_of(op->getResults(), isOverdefinedFn))
522     return;
523 
524   // Save the original operands and attributes just in case the operation folds
525   // in-place. The constant passed in may not correspond to the real runtime
526   // value, so in-place updates are not allowed.
527   SmallVector<Value, 8> originalOperands(op->getOperands());
528   DictionaryAttr originalAttrs = op->getAttrDictionary();
529 
530   // Simulate the result of folding this operation to a constant. If folding
531   // fails or was an in-place fold, mark the results as overdefined.
532   SmallVector<OpFoldResult, 8> foldResults;
533   foldResults.reserve(op->getNumResults());
534   if (failed(op->fold(operandConstants, foldResults)))
535     return markAllOverdefined(op, op->getResults());
536 
537   // If the folding was in-place, mark the results as overdefined and reset the
538   // operation. We don't allow in-place folds as the desire here is for
539   // simulated execution, and not general folding.
540   if (foldResults.empty()) {
541     op->setOperands(originalOperands);
542     op->setAttrs(originalAttrs);
543     return markAllOverdefined(op, op->getResults());
544   }
545 
546   // Merge the fold results into the lattice for this operation.
547   assert(foldResults.size() == op->getNumResults() && "invalid result size");
548   Dialect *opDialect = op->getDialect();
549   for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
550     LatticeValue &resultLattice = latticeValues[op->getResult(i)];
551 
552     // Merge in the result of the fold, either a constant or a value.
553     OpFoldResult foldResult = foldResults[i];
554     if (Attribute foldAttr = foldResult.dyn_cast<Attribute>())
555       meet(op, resultLattice, LatticeValue(foldAttr, opDialect));
556     else
557       meet(op, resultLattice, latticeValues[foldResult.get<Value>()]);
558   }
559 }
560 
561 void SCCPSolver::visitCallableOperation(Operation *op) {
562   // Mark the regions as executable. If we aren't tracking lattice state for
563   // this callable, mark all of the region arguments as overdefined.
564   bool isTrackingLatticeState = callableLatticeState.count(op);
565   for (Region &region : op->getRegions())
566     markEntryBlockExecutable(&region, !isTrackingLatticeState);
567 
568   // TODO: Add support for non-symbol callables when necessary. If the callable
569   // has non-call uses we would mark overdefined, otherwise allow for
570   // propagating the return values out.
571   markAllOverdefined(op, op->getResults());
572 }
573 
574 void SCCPSolver::visitCallOperation(CallOpInterface op) {
575   ResultRange callResults = op->getResults();
576 
577   // Resolve the callable operation for this call.
578   Operation *callableOp = nullptr;
579   if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
580     callableOp = callableValue.getDefiningOp();
581   else
582     callableOp = callToSymbolCallable.lookup(op);
583 
584   // The callable of this call can't be resolved, mark any results overdefined.
585   if (!callableOp)
586     return markAllOverdefined(op, callResults);
587 
588   // If this callable is tracking state, merge the argument operands with the
589   // arguments of the callable.
590   auto callableLatticeIt = callableLatticeState.find(callableOp);
591   if (callableLatticeIt == callableLatticeState.end())
592     return markAllOverdefined(op, callResults);
593 
594   OperandRange callOperands = op.getArgOperands();
595   auto callableArgs = callableLatticeIt->second.getCallableArguments();
596   for (auto it : llvm::zip(callOperands, callableArgs)) {
597     BlockArgument callableArg = std::get<1>(it);
598     if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)]))
599       visitUsers(callableArg);
600   }
601 
602   // Visit the callable.
603   visitCallableOperation(callableOp);
604 
605   // Merge in the lattice state for the callable results as well.
606   auto callableResults = callableLatticeIt->second.getResultLatticeValues();
607   for (auto it : llvm::zip(callResults, callableResults))
608     meet(/*owner=*/op, /*to=*/latticeValues[std::get<0>(it)],
609          /*from=*/std::get<1>(it));
610 }
611 
612 void SCCPSolver::visitRegionOperation(Operation *op,
613                                       ArrayRef<Attribute> constantOperands) {
614   // Check to see if we can reason about the internal control flow of this
615   // region operation.
616   auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
617   if (!regionInterface) {
618     // If we can't, conservatively mark all regions as executable.
619     for (Region &region : op->getRegions())
620       markEntryBlockExecutable(&region, /*markArgsOverdefined=*/true);
621 
622     // Don't try to simulate the results of a region operation as we can't
623     // guarantee that folding will be out-of-place. We don't allow in-place
624     // folds as the desire here is for simulated execution, and not general
625     // folding.
626     return markAllOverdefined(op, op->getResults());
627   }
628 
629   // Check to see which regions are executable.
630   SmallVector<RegionSuccessor, 1> successors;
631   regionInterface.getSuccessorRegions(/*index=*/llvm::None, constantOperands,
632                                       successors);
633 
634   // If the interface identified that no region will be executed. Mark
635   // any results of this operation as overdefined, as we can't reason about
636   // them.
637   // TODO: If we had an interface to detect pass through operands, we could
638   // resolve some results based on the lattice state of the operands. We could
639   // also allow for the parent operation to have itself as a region successor.
640   if (successors.empty())
641     return markAllOverdefined(op, op->getResults());
642   return visitRegionSuccessors(op, successors, [&](Optional<unsigned> index) {
643     assert(index && "expected valid region index");
644     return regionInterface.getSuccessorEntryOperands(*index);
645   });
646 }
647 
648 void SCCPSolver::visitRegionSuccessors(
649     Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
650     function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
651   for (const RegionSuccessor &it : regionSuccessors) {
652     Region *region = it.getSuccessor();
653     ValueRange succArgs = it.getSuccessorInputs();
654 
655     // Check to see if this is the parent operation.
656     if (!region) {
657       ResultRange results = parentOp->getResults();
658       if (llvm::all_of(results, [&](Value res) { return isOverdefined(res); }))
659         continue;
660 
661       // Mark the results outside of the input range as overdefined.
662       if (succArgs.size() != results.size()) {
663         opWorklist.push_back(parentOp);
664         if (succArgs.empty())
665           return markAllOverdefined(results);
666 
667         unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber();
668         markAllOverdefined(results.take_front(firstResIdx));
669         markAllOverdefined(results.drop_front(firstResIdx + succArgs.size()));
670       }
671 
672       // Update the lattice for any operation results.
673       OperandRange operands = getInputsForRegion(/*index=*/llvm::None);
674       for (auto it : llvm::zip(succArgs, operands))
675         meet(parentOp, latticeValues[std::get<0>(it)],
676              latticeValues[std::get<1>(it)]);
677       return;
678     }
679     assert(!region->empty() && "expected region to be non-empty");
680     Block *entryBlock = &region->front();
681     markBlockExecutable(entryBlock);
682 
683     // If all of the arguments are already overdefined, the arguments have
684     // already been fully resolved.
685     auto arguments = entryBlock->getArguments();
686     if (llvm::all_of(arguments, [&](Value arg) { return isOverdefined(arg); }))
687       continue;
688 
689     // Mark any arguments that do not receive inputs as overdefined, we won't be
690     // able to discern if they are constant.
691     if (succArgs.size() != arguments.size()) {
692       if (succArgs.empty()) {
693         markAllOverdefined(arguments);
694         continue;
695       }
696 
697       unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
698       markAllOverdefinedAndVisitUsers(arguments.take_front(firstArgIdx));
699       markAllOverdefinedAndVisitUsers(
700           arguments.drop_front(firstArgIdx + succArgs.size()));
701     }
702 
703     // Update the lattice for arguments that have inputs from the predecessor.
704     OperandRange succOperands = getInputsForRegion(region->getRegionNumber());
705     for (auto it : llvm::zip(succArgs, succOperands)) {
706       LatticeValue &argLattice = latticeValues[std::get<0>(it)];
707       if (argLattice.meet(latticeValues[std::get<1>(it)]))
708         visitUsers(std::get<0>(it));
709     }
710   }
711 }
712 
713 void SCCPSolver::visitTerminatorOperation(
714     Operation *op, ArrayRef<Attribute> constantOperands) {
715   // If this operation has no successors, we treat it as an exiting terminator.
716   if (op->getNumSuccessors() == 0) {
717     Region *parentRegion = op->getParentRegion();
718     Operation *parentOp = parentRegion->getParentOp();
719 
720     // Check to see if this is a terminator for a callable region.
721     if (isa<CallableOpInterface>(parentOp))
722       return visitCallableTerminatorOperation(parentOp, op);
723 
724     // Otherwise, check to see if the parent tracks region control flow.
725     auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
726     if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
727       return;
728 
729     // Query the set of successors from the current region.
730     SmallVector<RegionSuccessor, 1> regionSuccessors;
731     regionInterface.getSuccessorRegions(parentRegion->getRegionNumber(),
732                                         constantOperands, regionSuccessors);
733     if (regionSuccessors.empty())
734       return;
735 
736     // If this terminator is not "region-like", conservatively mark all of the
737     // successor values as overdefined.
738     if (!op->hasTrait<OpTrait::ReturnLike>()) {
739       for (auto &it : regionSuccessors)
740         markAllOverdefinedAndVisitUsers(it.getSuccessorInputs());
741       return;
742     }
743 
744     // Otherwise, propagate the operand lattice states to each of the
745     // successors.
746     OperandRange operands = op->getOperands();
747     return visitRegionSuccessors(parentOp, regionSuccessors,
748                                  [&](Optional<unsigned>) { return operands; });
749   }
750 
751   // Try to resolve to a specific successor with the constant operands.
752   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
753     if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
754       markEdgeExecutable(op->getBlock(), singleSucc);
755       return;
756     }
757   }
758 
759   // Otherwise, conservatively treat all edges as executable.
760   Block *block = op->getBlock();
761   for (Block *succ : op->getSuccessors())
762     markEdgeExecutable(block, succ);
763 }
764 
765 void SCCPSolver::visitCallableTerminatorOperation(Operation *callable,
766                                                   Operation *terminator) {
767   // If there are no exiting values, we have nothing to track.
768   if (terminator->getNumOperands() == 0)
769     return;
770 
771   // If this callable isn't tracking any lattice state there is nothing to do.
772   auto latticeIt = callableLatticeState.find(callable);
773   if (latticeIt == callableLatticeState.end())
774     return;
775   assert(callable->getNumResults() == 0 && "expected symbol callable");
776 
777   // If this terminator is not "return-like", conservatively mark all of the
778   // call-site results as overdefined.
779   auto callableResultLattices = latticeIt->second.getResultLatticeValues();
780   if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
781     for (auto &it : callableResultLattices)
782       it.markOverdefined();
783     for (Operation *call : latticeIt->second.getSymbolCalls())
784       markAllOverdefined(call, call->getResults());
785     return;
786   }
787 
788   // Merge the terminator operands into the results.
789   bool anyChanged = false;
790   for (auto it : llvm::zip(terminator->getOperands(), callableResultLattices))
791     anyChanged |= std::get<1>(it).meet(latticeValues[std::get<0>(it)]);
792   if (!anyChanged)
793     return;
794 
795   // If any of the result lattices changed, update the callers.
796   for (Operation *call : latticeIt->second.getSymbolCalls())
797     for (auto it : llvm::zip(call->getResults(), callableResultLattices))
798       meet(call, latticeValues[std::get<0>(it)], std::get<1>(it));
799 }
800 
801 void SCCPSolver::visitBlock(Block *block) {
802   // If the block is not the entry block we need to compute the lattice state
803   // for the block arguments. Entry block argument lattices are computed
804   // elsewhere, such as when visiting the parent operation.
805   if (!block->isEntryBlock()) {
806     for (int i : llvm::seq<int>(0, block->getNumArguments()))
807       visitBlockArgument(block, i);
808   }
809 
810   // Visit all of the operations within the block.
811   for (Operation &op : *block)
812     visitOperation(&op);
813 }
814 
815 void SCCPSolver::visitBlockArgument(Block *block, int i) {
816   BlockArgument arg = block->getArgument(i);
817   LatticeValue &argLattice = latticeValues[arg];
818   if (argLattice.isOverdefined())
819     return;
820 
821   bool updatedLattice = false;
822   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
823     Block *pred = *it;
824 
825     // We only care about this predecessor if it is going to execute.
826     if (!isEdgeExecutable(pred, block))
827       continue;
828 
829     // Try to get the operand forwarded by the predecessor. If we can't reason
830     // about the terminator of the predecessor, mark overdefined.
831     Optional<OperandRange> branchOperands;
832     if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
833       branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
834     if (!branchOperands) {
835       updatedLattice = true;
836       argLattice.markOverdefined();
837       break;
838     }
839 
840     // If the operand hasn't been resolved, it is unknown which can merge with
841     // anything.
842     auto operandLattice = latticeValues.find((*branchOperands)[i]);
843     if (operandLattice == latticeValues.end())
844       continue;
845 
846     // Otherwise, meet the two lattice values.
847     updatedLattice |= argLattice.meet(operandLattice->second);
848     if (argLattice.isOverdefined())
849       break;
850   }
851 
852   // If the lattice was updated, visit any executable users of the argument.
853   if (updatedLattice)
854     visitUsers(arg);
855 }
856 
857 bool SCCPSolver::markEntryBlockExecutable(Region *region,
858                                           bool markArgsOverdefined) {
859   if (!region->empty()) {
860     if (markArgsOverdefined)
861       markAllOverdefined(region->front().getArguments());
862     return markBlockExecutable(&region->front());
863   }
864   return false;
865 }
866 
867 bool SCCPSolver::markBlockExecutable(Block *block) {
868   bool marked = executableBlocks.insert(block).second;
869   if (marked)
870     blockWorklist.push_back(block);
871   return marked;
872 }
873 
874 bool SCCPSolver::isBlockExecutable(Block *block) const {
875   return executableBlocks.count(block);
876 }
877 
878 void SCCPSolver::markEdgeExecutable(Block *from, Block *to) {
879   if (!executableEdges.insert(std::make_pair(from, to)).second)
880     return;
881   // Mark the destination as executable, and reprocess its arguments if it was
882   // already executable.
883   if (!markBlockExecutable(to)) {
884     for (int i : llvm::seq<int>(0, to->getNumArguments()))
885       visitBlockArgument(to, i);
886   }
887 }
888 
889 bool SCCPSolver::isEdgeExecutable(Block *from, Block *to) const {
890   return executableEdges.count(std::make_pair(from, to));
891 }
892 
893 void SCCPSolver::markOverdefined(Value value) {
894   latticeValues[value].markOverdefined();
895 }
896 
897 bool SCCPSolver::isOverdefined(Value value) const {
898   auto it = latticeValues.find(value);
899   return it != latticeValues.end() && it->second.isOverdefined();
900 }
901 
902 void SCCPSolver::meet(Operation *owner, LatticeValue &to,
903                       const LatticeValue &from) {
904   if (to.meet(from))
905     opWorklist.push_back(owner);
906 }
907 
908 //===----------------------------------------------------------------------===//
909 // SCCP Pass
910 //===----------------------------------------------------------------------===//
911 
912 namespace {
913 struct SCCP : public SCCPBase<SCCP> {
914   void runOnOperation() override;
915 };
916 } // end anonymous namespace
917 
918 void SCCP::runOnOperation() {
919   Operation *op = getOperation();
920 
921   // Solve for SCCP constraints within nested regions.
922   SCCPSolver solver(op);
923   solver.solve();
924 
925   // Cleanup any operations using the solver analysis.
926   solver.rewrite(&getContext(), op->getRegions());
927 }
928 
929 std::unique_ptr<Pass> mlir::createSCCPPass() {
930   return std::make_unique<SCCP>();
931 }
932