1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "mlir/Analysis/DataFlowAnalysis.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/Interfaces/ControlFlowInterfaces.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/FoldUtils.h"
25 #include "mlir/Transforms/Passes.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "sccp"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // SCCP Analysis
34 //===----------------------------------------------------------------------===//
35 
36 namespace {
37 struct SCCPLatticeValue {
38   SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr)
39       : constant(constant), constantDialect(dialect) {}
40 
41   /// The pessimistic state of SCCP is non-constant.
42   static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) {
43     return SCCPLatticeValue();
44   }
45   static SCCPLatticeValue getPessimisticValueState(Value value) {
46     return SCCPLatticeValue();
47   }
48 
49   /// Equivalence for SCCP only accounts for the constant, not the originating
50   /// dialect.
51   bool operator==(const SCCPLatticeValue &rhs) const {
52     return constant == rhs.constant;
53   }
54 
55   /// To join the state of two values, we simply check for equivalence.
56   static SCCPLatticeValue join(const SCCPLatticeValue &lhs,
57                                const SCCPLatticeValue &rhs) {
58     return lhs == rhs ? lhs : SCCPLatticeValue();
59   }
60 
61   /// The constant attribute value.
62   Attribute constant;
63 
64   /// The dialect the constant originated from. This is not used as part of the
65   /// key, and is only needed to materialize the held constant if necessary.
66   Dialect *constantDialect;
67 };
68 
69 struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
70   using ForwardDataFlowAnalysis<SCCPLatticeValue>::ForwardDataFlowAnalysis;
71   ~SCCPAnalysis() override = default;
72 
73   ChangeResult
74   visitOperation(Operation *op,
75                  ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final {
76 
77     LLVM_DEBUG(llvm::dbgs() << "SCCP: Visiting operation: " << *op << "\n");
78 
79     // Don't try to simulate the results of a region operation as we can't
80     // guarantee that folding will be out-of-place. We don't allow in-place
81     // folds as the desire here is for simulated execution, and not general
82     // folding.
83     if (op->getNumRegions())
84       return markAllPessimisticFixpoint(op->getResults());
85 
86     SmallVector<Attribute> constantOperands(
87         llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
88           return value->getValue().constant;
89         }));
90 
91     // Save the original operands and attributes just in case the operation
92     // folds in-place. The constant passed in may not correspond to the real
93     // runtime value, so in-place updates are not allowed.
94     SmallVector<Value, 8> originalOperands(op->getOperands());
95     DictionaryAttr originalAttrs = op->getAttrDictionary();
96 
97     // Simulate the result of folding this operation to a constant. If folding
98     // fails or was an in-place fold, mark the results as overdefined.
99     SmallVector<OpFoldResult, 8> foldResults;
100     foldResults.reserve(op->getNumResults());
101     if (failed(op->fold(constantOperands, foldResults)))
102       return markAllPessimisticFixpoint(op->getResults());
103 
104     // If the folding was in-place, mark the results as overdefined and reset
105     // the operation. We don't allow in-place folds as the desire here is for
106     // simulated execution, and not general folding.
107     if (foldResults.empty()) {
108       op->setOperands(originalOperands);
109       op->setAttrs(originalAttrs);
110       return markAllPessimisticFixpoint(op->getResults());
111     }
112 
113     // Merge the fold results into the lattice for this operation.
114     assert(foldResults.size() == op->getNumResults() && "invalid result size");
115     Dialect *dialect = op->getDialect();
116     ChangeResult result = ChangeResult::NoChange;
117     for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
118       LatticeElement<SCCPLatticeValue> &lattice =
119           getLatticeElement(op->getResult(i));
120 
121       // Merge in the result of the fold, either a constant or a value.
122       OpFoldResult foldResult = foldResults[i];
123       if (Attribute attr = foldResult.dyn_cast<Attribute>())
124         result |= lattice.join(SCCPLatticeValue(attr, dialect));
125       else
126         result |= lattice.join(getLatticeElement(foldResult.get<Value>()));
127     }
128     return result;
129   }
130 
131   /// Implementation of `getSuccessorsForOperands` that uses constant operands
132   /// to potentially remove dead successors.
133   LogicalResult getSuccessorsForOperands(
134       BranchOpInterface branch,
135       ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
136       SmallVectorImpl<Block *> &successors) final {
137     SmallVector<Attribute> constantOperands(
138         llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
139           return value->getValue().constant;
140         }));
141     if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
142       successors.push_back(singleSucc);
143       return success();
144     }
145     return failure();
146   }
147 
148   /// Implementation of `getSuccessorsForOperands` that uses constant operands
149   /// to potentially remove dead region successors.
150   void getSuccessorsForOperands(
151       RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
152       ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
153       SmallVectorImpl<RegionSuccessor> &successors) final {
154     SmallVector<Attribute> constantOperands(
155         llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
156           return value->getValue().constant;
157         }));
158     branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
159   }
160 };
161 } // namespace
162 
163 //===----------------------------------------------------------------------===//
164 // SCCP Rewrites
165 //===----------------------------------------------------------------------===//
166 
167 /// Replace the given value with a constant if the corresponding lattice
168 /// represents a constant. Returns success if the value was replaced, failure
169 /// otherwise.
170 static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
171                                          OpBuilder &builder,
172                                          OperationFolder &folder, Value value) {
173   LatticeElement<SCCPLatticeValue> *lattice =
174       analysis.lookupLatticeElement(value);
175   if (!lattice)
176     return failure();
177   SCCPLatticeValue &latticeValue = lattice->getValue();
178   if (!latticeValue.constant)
179     return failure();
180 
181   // Attempt to materialize a constant for the given value.
182   Dialect *dialect = latticeValue.constantDialect;
183   Value constant = folder.getOrCreateConstant(
184       builder, dialect, latticeValue.constant, value.getType(), value.getLoc());
185   if (!constant)
186     return failure();
187 
188   value.replaceAllUsesWith(constant);
189   return success();
190 }
191 
192 /// Rewrite the given regions using the computing analysis. This replaces the
193 /// uses of all values that have been computed to be constant, and erases as
194 /// many newly dead operations.
195 static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
196                     MutableArrayRef<Region> initialRegions) {
197   SmallVector<Block *> worklist;
198   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
199     for (Region &region : regions)
200       for (Block &block : llvm::reverse(region))
201         worklist.push_back(&block);
202   };
203 
204   // An operation folder used to create and unique constants.
205   OperationFolder folder(context);
206   OpBuilder builder(context);
207 
208   addToWorklist(initialRegions);
209   while (!worklist.empty()) {
210     Block *block = worklist.pop_back_val();
211 
212     for (Operation &op : llvm::make_early_inc_range(*block)) {
213       builder.setInsertionPoint(&op);
214 
215       // Replace any result with constants.
216       bool replacedAll = op.getNumResults() != 0;
217       for (Value res : op.getResults())
218         replacedAll &=
219             succeeded(replaceWithConstant(analysis, builder, folder, res));
220 
221       // If all of the results of the operation were replaced, try to erase
222       // the operation completely.
223       if (replacedAll && wouldOpBeTriviallyDead(&op)) {
224         assert(op.use_empty() && "expected all uses to be replaced");
225         op.erase();
226         continue;
227       }
228 
229       // Add any the regions of this operation to the worklist.
230       addToWorklist(op.getRegions());
231     }
232 
233     // Replace any block arguments with constants.
234     builder.setInsertionPointToStart(block);
235     for (BlockArgument arg : block->getArguments())
236       (void)replaceWithConstant(analysis, builder, folder, arg);
237   }
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // SCCP Pass
242 //===----------------------------------------------------------------------===//
243 
244 namespace {
245 struct SCCP : public SCCPBase<SCCP> {
246   void runOnOperation() override;
247 };
248 } // namespace
249 
250 void SCCP::runOnOperation() {
251   Operation *op = getOperation();
252 
253   SCCPAnalysis analysis(op->getContext());
254   analysis.run(op);
255   rewrite(analysis, op->getContext(), op->getRegions());
256 }
257 
258 std::unique_ptr<Pass> mlir::createSCCPPass() {
259   return std::make_unique<SCCP>();
260 }
261