1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimisitic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/Interfaces/ControlFlowInterfaces.h"
21 #include "mlir/Interfaces/SideEffects.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/Passes.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 /// This class represents a single lattice value. A lattive value corresponds to
30 /// the various different states that a value in the SCCP dataflow anaylsis can
31 /// take. See 'Kind' below for more details on the different states a value can
32 /// take.
33 class LatticeValue {
34   enum Kind {
35     /// A value with a yet to be determined value. This state may be changed to
36     /// anything.
37     Unknown,
38 
39     /// A value that is known to be a constant. This state may be changed to
40     /// overdefined.
41     Constant,
42 
43     /// A value that cannot statically be determined to be a constant. This
44     /// state cannot be changed.
45     Overdefined
46   };
47 
48 public:
49   /// Initialize a lattice value with "Unknown".
50   LatticeValue()
51       : constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {}
52   /// Initialize a lattice value with a constant.
53   LatticeValue(Attribute attr, Dialect *dialect)
54       : constantAndTag(attr, Kind::Constant), constantDialect(dialect) {}
55 
56   /// Returns true if this lattice value is unknown.
57   bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; }
58 
59   /// Mark the lattice value as overdefined.
60   void markOverdefined() {
61     constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined);
62     constantDialect = nullptr;
63   }
64 
65   /// Returns true if the lattice is overdefined.
66   bool isOverdefined() const {
67     return constantAndTag.getInt() == Kind::Overdefined;
68   }
69 
70   /// Mark the lattice value as constant.
71   void markConstant(Attribute value, Dialect *dialect) {
72     constantAndTag.setPointerAndInt(value, Kind::Constant);
73     constantDialect = dialect;
74   }
75 
76   /// If this lattice is constant, return the constant. Returns nullptr
77   /// otherwise.
78   Attribute getConstant() const { return constantAndTag.getPointer(); }
79 
80   /// If this lattice is constant, return the dialect to use when materializing
81   /// the constant.
82   Dialect *getConstantDialect() const {
83     assert(getConstant() && "expected valid constant");
84     return constantDialect;
85   }
86 
87   /// Merge in the value of the 'rhs' lattice into this one. Returns true if the
88   /// lattice value changed.
89   bool meet(const LatticeValue &rhs) {
90     // If we are already overdefined, or rhs is unknown, there is nothing to do.
91     if (isOverdefined() || rhs.isUnknown())
92       return false;
93     // If we are unknown, just take the value of rhs.
94     if (isUnknown()) {
95       constantAndTag = rhs.constantAndTag;
96       constantDialect = rhs.constantDialect;
97       return true;
98     }
99 
100     // Otherwise, if this value doesn't match rhs go straight to overdefined.
101     if (constantAndTag != rhs.constantAndTag) {
102       markOverdefined();
103       return true;
104     }
105     return false;
106   }
107 
108 private:
109   /// The attribute value if this is a constant and the tag for the element
110   /// kind.
111   llvm::PointerIntPair<Attribute, 2, Kind> constantAndTag;
112 
113   /// The dialect the constant originated from. This is only valid if the
114   /// lattice is a constant. This is not used as part of the key, and is only
115   /// needed to materialize the held constant if necessary.
116   Dialect *constantDialect;
117 };
118 
119 /// This class represents the solver for the SCCP analysis. This class acts as
120 /// the propagation engine for computing which values form constants.
121 class SCCPSolver {
122 public:
123   /// Initialize the solver with a given set of regions.
124   SCCPSolver(MutableArrayRef<Region> regions);
125 
126   /// Run the solver until it converges.
127   void solve();
128 
129   /// Rewrite the given regions using the computing analysis. This replaces the
130   /// uses of all values that have been computed to be constant, and erases as
131   /// many newly dead operations.
132   void rewrite(MLIRContext *context, MutableArrayRef<Region> regions);
133 
134 private:
135   /// Replace the given value with a constant if the corresponding lattice
136   /// represents a constant. Returns success if the value was replaced, failure
137   /// otherwise.
138   LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder,
139                                     Value value);
140 
141   /// Visit the given operation and compute any necessary lattice state.
142   void visitOperation(Operation *op);
143 
144   /// Visit the given operation, which defines regions, and compute any
145   /// necessary lattice state. This also resolves the lattice state of both the
146   /// operation results and any nested regions.
147   void visitRegionOperation(Operation *op);
148 
149   /// Visit the given terminator operation and compute any necessary lattice
150   /// state.
151   void visitTerminatorOperation(Operation *op,
152                                 ArrayRef<Attribute> constantOperands);
153 
154   /// Visit the given block and compute any necessary lattice state.
155   void visitBlock(Block *block);
156 
157   /// Visit argument #'i' of the given block and compute any necessary lattice
158   /// state.
159   void visitBlockArgument(Block *block, int i);
160 
161   /// Mark the given block as executable. Returns false if the block was already
162   /// marked executable.
163   bool markBlockExecutable(Block *block);
164 
165   /// Returns true if the given block is executable.
166   bool isBlockExecutable(Block *block) const;
167 
168   /// Mark the edge between 'from' and 'to' as executable.
169   void markEdgeExecutable(Block *from, Block *to);
170 
171   /// Return true if the edge between 'from' and 'to' is executable.
172   bool isEdgeExecutable(Block *from, Block *to) const;
173 
174   /// Mark the given value as overdefined. This means that we cannot refine a
175   /// specific constant for this value.
176   void markOverdefined(Value value);
177 
178   /// Mark all of the given values as overdefined.
179   template <typename ValuesT>
180   void markAllOverdefined(ValuesT values) {
181     for (auto value : values)
182       markOverdefined(value);
183   }
184   template <typename ValuesT>
185   void markAllOverdefined(Operation *op, ValuesT values) {
186     markAllOverdefined(values);
187     opWorklist.push_back(op);
188   }
189 
190   /// Returns true if the given value was marked as overdefined.
191   bool isOverdefined(Value value) const;
192 
193   /// Merge in the given lattice 'from' into the lattice 'to'. 'owner'
194   /// corresponds to the parent operation of 'to'.
195   void meet(Operation *owner, LatticeValue &to, const LatticeValue &from);
196 
197   /// The lattice for each SSA value.
198   DenseMap<Value, LatticeValue> latticeValues;
199 
200   /// The set of blocks that are known to execute, or are intrinsically live.
201   SmallPtrSet<Block *, 16> executableBlocks;
202 
203   /// The set of control flow edges that are known to execute.
204   DenseSet<std::pair<Block *, Block *>> executableEdges;
205 
206   /// A worklist containing blocks that need to be processed.
207   SmallVector<Block *, 64> blockWorklist;
208 
209   /// A worklist of operations that need to be processed.
210   SmallVector<Operation *, 64> opWorklist;
211 };
212 } // end anonymous namespace
213 
214 SCCPSolver::SCCPSolver(MutableArrayRef<Region> regions) {
215   for (Region &region : regions) {
216     if (region.empty())
217       continue;
218     Block *entryBlock = &region.front();
219 
220     // Mark the entry block as executable.
221     markBlockExecutable(entryBlock);
222 
223     // The values passed to these regions are invisible, so mark any arguments
224     // as overdefined.
225     markAllOverdefined(entryBlock->getArguments());
226   }
227 }
228 
229 void SCCPSolver::solve() {
230   while (!blockWorklist.empty() || !opWorklist.empty()) {
231     // Process any operations in the op worklist.
232     while (!opWorklist.empty()) {
233       Operation *op = opWorklist.pop_back_val();
234 
235       // Visit all of the live users to propagate changes to this operation.
236       for (Operation *user : op->getUsers()) {
237         if (isBlockExecutable(user->getBlock()))
238           visitOperation(user);
239       }
240     }
241 
242     // Process any blocks in the block worklist.
243     while (!blockWorklist.empty())
244       visitBlock(blockWorklist.pop_back_val());
245   }
246 }
247 
248 void SCCPSolver::rewrite(MLIRContext *context,
249                          MutableArrayRef<Region> initialRegions) {
250   SmallVector<Block *, 8> worklist;
251   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
252     for (Region &region : regions)
253       for (Block &block : region)
254         if (isBlockExecutable(&block))
255           worklist.push_back(&block);
256   };
257 
258   // An operation folder used to create and unique constants.
259   OperationFolder folder(context);
260   OpBuilder builder(context);
261 
262   addToWorklist(initialRegions);
263   while (!worklist.empty()) {
264     Block *block = worklist.pop_back_val();
265 
266     // Replace any block arguments with constants.
267     builder.setInsertionPointToStart(block);
268     for (BlockArgument arg : block->getArguments())
269       replaceWithConstant(builder, folder, arg);
270 
271     for (Operation &op : llvm::make_early_inc_range(*block)) {
272       builder.setInsertionPoint(&op);
273 
274       // Replace any result with constants.
275       bool replacedAll = op.getNumResults() != 0;
276       for (Value res : op.getResults())
277         replacedAll &= succeeded(replaceWithConstant(builder, folder, res));
278 
279       // If all of the results of the operation were replaced, try to erase
280       // the operation completely.
281       if (replacedAll && wouldOpBeTriviallyDead(&op)) {
282         assert(op.use_empty() && "expected all uses to be replaced");
283         op.erase();
284         continue;
285       }
286 
287       // Add any the regions of this operation to the worklist.
288       addToWorklist(op.getRegions());
289     }
290   }
291 }
292 
293 LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder,
294                                               OperationFolder &folder,
295                                               Value value) {
296   auto it = latticeValues.find(value);
297   auto attr = it == latticeValues.end() ? nullptr : it->second.getConstant();
298   if (!attr)
299     return failure();
300 
301   // Attempt to materialize a constant for the given value.
302   Dialect *dialect = it->second.getConstantDialect();
303   Value constant = folder.getOrCreateConstant(builder, dialect, attr,
304                                               value.getType(), value.getLoc());
305   if (!constant)
306     return failure();
307 
308   value.replaceAllUsesWith(constant);
309   latticeValues.erase(it);
310   return success();
311 }
312 
313 void SCCPSolver::visitOperation(Operation *op) {
314   // Collect all of the constant operands feeding into this operation. If any
315   // are not ready to be resolved, bail out and wait for them to resolve.
316   SmallVector<Attribute, 8> operandConstants;
317   operandConstants.reserve(op->getNumOperands());
318   for (Value operand : op->getOperands()) {
319     // Make sure all of the operands are resolved first.
320     auto &operandLattice = latticeValues[operand];
321     if (operandLattice.isUnknown())
322       return;
323     operandConstants.push_back(operandLattice.getConstant());
324   }
325 
326   // If this is a terminator operation, process any control flow lattice state.
327   if (op->isKnownTerminator())
328     visitTerminatorOperation(op, operandConstants);
329 
330   // Process region holding operations. The region visitor processes result
331   // values, so we can exit afterwards.
332   if (op->getNumRegions())
333     return visitRegionOperation(op);
334 
335   // If this op produces no results, it can't produce any constants.
336   if (op->getNumResults() == 0)
337     return;
338 
339   // If all of the results of this operation are already overdefined, bail out
340   // early.
341   auto isOverdefinedFn = [&](Value value) { return isOverdefined(value); };
342   if (llvm::all_of(op->getResults(), isOverdefinedFn))
343     return;
344 
345   // Save the original operands and attributes just in case the operation folds
346   // in-place. The constant passed in may not correspond to the real runtime
347   // value, so in-place updates are not allowed.
348   SmallVector<Value, 8> originalOperands(op->getOperands());
349   NamedAttributeList originalAttrs = op->getAttrList();
350 
351   // Simulate the result of folding this operation to a constant. If folding
352   // fails or was an in-place fold, mark the results as overdefined.
353   SmallVector<OpFoldResult, 8> foldResults;
354   foldResults.reserve(op->getNumResults());
355   if (failed(op->fold(operandConstants, foldResults)))
356     return markAllOverdefined(op, op->getResults());
357 
358   // If the folding was in-place, mark the results as overdefined and reset the
359   // operation. We don't allow in-place folds as the desire here is for
360   // simulated execution, and not general folding.
361   if (foldResults.empty()) {
362     op->setOperands(originalOperands);
363     op->setAttrs(originalAttrs);
364     return markAllOverdefined(op, op->getResults());
365   }
366 
367   // Merge the fold results into the lattice for this operation.
368   assert(foldResults.size() == op->getNumResults() && "invalid result size");
369   Dialect *opDialect = op->getDialect();
370   for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
371     LatticeValue &resultLattice = latticeValues[op->getResult(i)];
372 
373     // Merge in the result of the fold, either a constant or a value.
374     OpFoldResult foldResult = foldResults[i];
375     if (Attribute foldAttr = foldResult.dyn_cast<Attribute>())
376       meet(op, resultLattice, LatticeValue(foldAttr, opDialect));
377     else
378       meet(op, resultLattice, latticeValues[foldResult.get<Value>()]);
379   }
380 }
381 
382 void SCCPSolver::visitRegionOperation(Operation *op) {
383   for (Region &region : op->getRegions()) {
384     if (region.empty())
385       continue;
386     Block *entryBlock = &region.front();
387     markBlockExecutable(entryBlock);
388     markAllOverdefined(entryBlock->getArguments());
389   }
390 
391   // Don't try to simulate the results of a region operation as we can't
392   // guarantee that folding will be out-of-place. We don't allow in-place folds
393   // as the desire here is for simulated execution, and not general folding.
394   return markAllOverdefined(op, op->getResults());
395 }
396 
397 void SCCPSolver::visitTerminatorOperation(
398     Operation *op, ArrayRef<Attribute> constantOperands) {
399   if (op->getNumSuccessors() == 0)
400     return;
401 
402   // Try to resolve to a specific successor with the constant operands.
403   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
404     if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
405       markEdgeExecutable(op->getBlock(), singleSucc);
406       return;
407     }
408   }
409 
410   // Otherwise, conservatively treat all edges as executable.
411   Block *block = op->getBlock();
412   for (Block *succ : op->getSuccessors())
413     markEdgeExecutable(block, succ);
414 }
415 
416 void SCCPSolver::visitBlock(Block *block) {
417   // If the block is not the entry block we need to compute the lattice state
418   // for the block arguments. Entry block argument lattices are computed
419   // elsewhere, such as when visiting the parent operation.
420   if (!block->isEntryBlock()) {
421     for (int i : llvm::seq<int>(0, block->getNumArguments()))
422       visitBlockArgument(block, i);
423   }
424 
425   // Visit all of the operations within the block.
426   for (Operation &op : *block)
427     visitOperation(&op);
428 }
429 
430 void SCCPSolver::visitBlockArgument(Block *block, int i) {
431   BlockArgument arg = block->getArgument(i);
432   LatticeValue &argLattice = latticeValues[arg];
433   if (argLattice.isOverdefined())
434     return;
435 
436   bool updatedLattice = false;
437   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
438     Block *pred = *it;
439 
440     // We only care about this predecessor if it is going to execute.
441     if (!isEdgeExecutable(pred, block))
442       continue;
443 
444     // Try to get the operand forwarded by the predecessor. If we can't reason
445     // about the terminator of the predecessor, mark overdefined.
446     Optional<OperandRange> branchOperands;
447     if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
448       branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
449     if (!branchOperands) {
450       updatedLattice = true;
451       argLattice.markOverdefined();
452       break;
453     }
454 
455     // If the operand hasn't been resolved, it is unknown which can merge with
456     // anything.
457     auto operandLattice = latticeValues.find((*branchOperands)[i]);
458     if (operandLattice == latticeValues.end())
459       continue;
460 
461     // Otherwise, meet the two lattice values.
462     updatedLattice |= argLattice.meet(operandLattice->second);
463     if (argLattice.isOverdefined())
464       break;
465   }
466 
467   // If the lattice was updated, visit any executable users of the argument.
468   if (updatedLattice) {
469     for (Operation *user : arg.getUsers())
470       if (isBlockExecutable(user->getBlock()))
471         visitOperation(user);
472   }
473 }
474 
475 bool SCCPSolver::markBlockExecutable(Block *block) {
476   bool marked = executableBlocks.insert(block).second;
477   if (marked)
478     blockWorklist.push_back(block);
479   return marked;
480 }
481 
482 bool SCCPSolver::isBlockExecutable(Block *block) const {
483   return executableBlocks.count(block);
484 }
485 
486 void SCCPSolver::markEdgeExecutable(Block *from, Block *to) {
487   if (!executableEdges.insert(std::make_pair(from, to)).second)
488     return;
489   // Mark the destination as executable, and reprocess its arguments if it was
490   // already executable.
491   if (!markBlockExecutable(to)) {
492     for (int i : llvm::seq<int>(0, to->getNumArguments()))
493       visitBlockArgument(to, i);
494   }
495 }
496 
497 bool SCCPSolver::isEdgeExecutable(Block *from, Block *to) const {
498   return executableEdges.count(std::make_pair(from, to));
499 }
500 
501 void SCCPSolver::markOverdefined(Value value) {
502   latticeValues[value].markOverdefined();
503 }
504 
505 bool SCCPSolver::isOverdefined(Value value) const {
506   auto it = latticeValues.find(value);
507   return it != latticeValues.end() && it->second.isOverdefined();
508 }
509 
510 void SCCPSolver::meet(Operation *owner, LatticeValue &to,
511                       const LatticeValue &from) {
512   if (to.meet(from))
513     opWorklist.push_back(owner);
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // SCCP Pass
518 //===----------------------------------------------------------------------===//
519 
520 namespace {
521 struct SCCP : public SCCPBase<SCCP> {
522   void runOnOperation() override;
523 };
524 } // end anonymous namespace
525 
526 void SCCP::runOnOperation() {
527   Operation *op = getOperation();
528 
529   // Solve for SCCP constraints within nested regions.
530   SCCPSolver solver(op->getRegions());
531   solver.solve();
532 
533   // Cleanup any operations using the solver analysis.
534   solver.rewrite(&getContext(), op->getRegions());
535 }
536 
537 std::unique_ptr<Pass> mlir::createSCCPPass() {
538   return std::make_unique<SCCP>();
539 }
540