1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "mlir/Analysis/DataFlowAnalysis.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/Interfaces/ControlFlowInterfaces.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/FoldUtils.h"
25 #include "mlir/Transforms/Passes.h"
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // SCCP Analysis
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 struct SCCPLatticeValue {
35   SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr)
36       : constant(constant), constantDialect(dialect) {}
37 
38   /// The pessimistic state of SCCP is non-constant.
39   static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) {
40     return SCCPLatticeValue();
41   }
42   static SCCPLatticeValue getPessimisticValueState(Value value) {
43     return SCCPLatticeValue();
44   }
45 
46   /// Equivalence for SCCP only accounts for the constant, not the originating
47   /// dialect.
48   bool operator==(const SCCPLatticeValue &rhs) const {
49     return constant == rhs.constant;
50   }
51 
52   /// To join the state of two values, we simply check for equivalence.
53   static SCCPLatticeValue join(const SCCPLatticeValue &lhs,
54                                const SCCPLatticeValue &rhs) {
55     return lhs == rhs ? lhs : SCCPLatticeValue();
56   }
57 
58   /// The constant attribute value.
59   Attribute constant;
60 
61   /// The dialect the constant originated from. This is not used as part of the
62   /// key, and is only needed to materialize the held constant if necessary.
63   Dialect *constantDialect;
64 };
65 
66 struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
67   using ForwardDataFlowAnalysis<SCCPLatticeValue>::ForwardDataFlowAnalysis;
68   ~SCCPAnalysis() override = default;
69 
70   ChangeResult
71   visitOperation(Operation *op,
72                  ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final {
73     // Don't try to simulate the results of a region operation as we can't
74     // guarantee that folding will be out-of-place. We don't allow in-place
75     // folds as the desire here is for simulated execution, and not general
76     // folding.
77     if (op->getNumRegions())
78       return markAllPessimisticFixpoint(op->getResults());
79 
80     SmallVector<Attribute> constantOperands(
81         llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
82           return value->getValue().constant;
83         }));
84 
85     // Save the original operands and attributes just in case the operation
86     // folds in-place. The constant passed in may not correspond to the real
87     // runtime value, so in-place updates are not allowed.
88     SmallVector<Value, 8> originalOperands(op->getOperands());
89     DictionaryAttr originalAttrs = op->getAttrDictionary();
90 
91     // Simulate the result of folding this operation to a constant. If folding
92     // fails or was an in-place fold, mark the results as overdefined.
93     SmallVector<OpFoldResult, 8> foldResults;
94     foldResults.reserve(op->getNumResults());
95     if (failed(op->fold(constantOperands, foldResults)))
96       return markAllPessimisticFixpoint(op->getResults());
97 
98     // If the folding was in-place, mark the results as overdefined and reset
99     // the operation. We don't allow in-place folds as the desire here is for
100     // simulated execution, and not general folding.
101     if (foldResults.empty()) {
102       op->setOperands(originalOperands);
103       op->setAttrs(originalAttrs);
104       return markAllPessimisticFixpoint(op->getResults());
105     }
106 
107     // Merge the fold results into the lattice for this operation.
108     assert(foldResults.size() == op->getNumResults() && "invalid result size");
109     Dialect *dialect = op->getDialect();
110     ChangeResult result = ChangeResult::NoChange;
111     for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
112       LatticeElement<SCCPLatticeValue> &lattice =
113           getLatticeElement(op->getResult(i));
114 
115       // Merge in the result of the fold, either a constant or a value.
116       OpFoldResult foldResult = foldResults[i];
117       if (Attribute attr = foldResult.dyn_cast<Attribute>())
118         result |= lattice.join(SCCPLatticeValue(attr, dialect));
119       else
120         result |= lattice.join(getLatticeElement(foldResult.get<Value>()));
121     }
122     return result;
123   }
124 
125   /// Implementation of `getSuccessorsForOperands` that uses constant operands
126   /// to potentially remove dead successors.
127   LogicalResult getSuccessorsForOperands(
128       BranchOpInterface branch,
129       ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
130       SmallVectorImpl<Block *> &successors) final {
131     SmallVector<Attribute> constantOperands(
132         llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
133           return value->getValue().constant;
134         }));
135     if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
136       successors.push_back(singleSucc);
137       return success();
138     }
139     return failure();
140   }
141 
142   /// Implementation of `getSuccessorsForOperands` that uses constant operands
143   /// to potentially remove dead region successors.
144   void getSuccessorsForOperands(
145       RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
146       ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
147       SmallVectorImpl<RegionSuccessor> &successors) final {
148     SmallVector<Attribute> constantOperands(
149         llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
150           return value->getValue().constant;
151         }));
152     branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
153   }
154 };
155 } // namespace
156 
157 //===----------------------------------------------------------------------===//
158 // SCCP Rewrites
159 //===----------------------------------------------------------------------===//
160 
161 /// Replace the given value with a constant if the corresponding lattice
162 /// represents a constant. Returns success if the value was replaced, failure
163 /// otherwise.
164 static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
165                                          OpBuilder &builder,
166                                          OperationFolder &folder, Value value) {
167   LatticeElement<SCCPLatticeValue> *lattice =
168       analysis.lookupLatticeElement(value);
169   if (!lattice)
170     return failure();
171   SCCPLatticeValue &latticeValue = lattice->getValue();
172   if (!latticeValue.constant)
173     return failure();
174 
175   // Attempt to materialize a constant for the given value.
176   Dialect *dialect = latticeValue.constantDialect;
177   Value constant = folder.getOrCreateConstant(
178       builder, dialect, latticeValue.constant, value.getType(), value.getLoc());
179   if (!constant)
180     return failure();
181 
182   value.replaceAllUsesWith(constant);
183   return success();
184 }
185 
186 /// Rewrite the given regions using the computing analysis. This replaces the
187 /// uses of all values that have been computed to be constant, and erases as
188 /// many newly dead operations.
189 static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
190                     MutableArrayRef<Region> initialRegions) {
191   SmallVector<Block *> worklist;
192   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
193     for (Region &region : regions)
194       for (Block &block : llvm::reverse(region))
195         worklist.push_back(&block);
196   };
197 
198   // An operation folder used to create and unique constants.
199   OperationFolder folder(context);
200   OpBuilder builder(context);
201 
202   addToWorklist(initialRegions);
203   while (!worklist.empty()) {
204     Block *block = worklist.pop_back_val();
205 
206     for (Operation &op : llvm::make_early_inc_range(*block)) {
207       builder.setInsertionPoint(&op);
208 
209       // Replace any result with constants.
210       bool replacedAll = op.getNumResults() != 0;
211       for (Value res : op.getResults())
212         replacedAll &=
213             succeeded(replaceWithConstant(analysis, builder, folder, res));
214 
215       // If all of the results of the operation were replaced, try to erase
216       // the operation completely.
217       if (replacedAll && wouldOpBeTriviallyDead(&op)) {
218         assert(op.use_empty() && "expected all uses to be replaced");
219         op.erase();
220         continue;
221       }
222 
223       // Add any the regions of this operation to the worklist.
224       addToWorklist(op.getRegions());
225     }
226 
227     // Replace any block arguments with constants.
228     builder.setInsertionPointToStart(block);
229     for (BlockArgument arg : block->getArguments())
230       (void)replaceWithConstant(analysis, builder, folder, arg);
231   }
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // SCCP Pass
236 //===----------------------------------------------------------------------===//
237 
238 namespace {
239 struct SCCP : public SCCPBase<SCCP> {
240   void runOnOperation() override;
241 };
242 } // end anonymous namespace
243 
244 void SCCP::runOnOperation() {
245   Operation *op = getOperation();
246 
247   SCCPAnalysis analysis(op->getContext());
248   analysis.run(op);
249   rewrite(analysis, op->getContext(), op->getRegions());
250 }
251 
252 std::unique_ptr<Pass> mlir::createSCCPPass() {
253   return std::make_unique<SCCP>();
254 }
255