1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass performs a sparse conditional constant propagation 10 // in MLIR. It identifies values known to be constant, propagates that 11 // information throughout the IR, and replaces them. This is done with an 12 // optimistic dataflow analysis that assumes that all values are constant until 13 // proven otherwise. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "mlir/Analysis/DataFlowAnalysis.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/Interfaces/ControlFlowInterfaces.h" 22 #include "mlir/Interfaces/SideEffectInterfaces.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 #include "mlir/Transforms/Passes.h" 26 27 using namespace mlir; 28 29 //===----------------------------------------------------------------------===// 30 // SCCP Analysis 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 struct SCCPLatticeValue { 35 SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr) 36 : constant(constant), constantDialect(dialect) {} 37 38 /// The pessimistic state of SCCP is non-constant. 39 static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) { 40 return SCCPLatticeValue(); 41 } 42 static SCCPLatticeValue getPessimisticValueState(Value value) { 43 return SCCPLatticeValue(); 44 } 45 46 /// Equivalence for SCCP only accounts for the constant, not the originating 47 /// dialect. 48 bool operator==(const SCCPLatticeValue &rhs) const { 49 return constant == rhs.constant; 50 } 51 52 /// To join the state of two values, we simply check for equivalence. 53 static SCCPLatticeValue join(const SCCPLatticeValue &lhs, 54 const SCCPLatticeValue &rhs) { 55 return lhs == rhs ? lhs : SCCPLatticeValue(); 56 } 57 58 /// The constant attribute value. 59 Attribute constant; 60 61 /// The dialect the constant originated from. This is not used as part of the 62 /// key, and is only needed to materialize the held constant if necessary. 63 Dialect *constantDialect; 64 }; 65 66 struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> { 67 using ForwardDataFlowAnalysis<SCCPLatticeValue>::ForwardDataFlowAnalysis; 68 ~SCCPAnalysis() override = default; 69 70 ChangeResult 71 visitOperation(Operation *op, 72 ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final { 73 // Don't try to simulate the results of a region operation as we can't 74 // guarantee that folding will be out-of-place. We don't allow in-place 75 // folds as the desire here is for simulated execution, and not general 76 // folding. 77 if (op->getNumRegions()) 78 return markAllPessimisticFixpoint(op->getResults()); 79 80 SmallVector<Attribute> constantOperands( 81 llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) { 82 return value->getValue().constant; 83 })); 84 85 // Save the original operands and attributes just in case the operation 86 // folds in-place. The constant passed in may not correspond to the real 87 // runtime value, so in-place updates are not allowed. 88 SmallVector<Value, 8> originalOperands(op->getOperands()); 89 DictionaryAttr originalAttrs = op->getAttrDictionary(); 90 91 // Simulate the result of folding this operation to a constant. If folding 92 // fails or was an in-place fold, mark the results as overdefined. 93 SmallVector<OpFoldResult, 8> foldResults; 94 foldResults.reserve(op->getNumResults()); 95 if (failed(op->fold(constantOperands, foldResults))) 96 return markAllPessimisticFixpoint(op->getResults()); 97 98 // If the folding was in-place, mark the results as overdefined and reset 99 // the operation. We don't allow in-place folds as the desire here is for 100 // simulated execution, and not general folding. 101 if (foldResults.empty()) { 102 op->setOperands(originalOperands); 103 op->setAttrs(originalAttrs); 104 return markAllPessimisticFixpoint(op->getResults()); 105 } 106 107 // Merge the fold results into the lattice for this operation. 108 assert(foldResults.size() == op->getNumResults() && "invalid result size"); 109 Dialect *dialect = op->getDialect(); 110 ChangeResult result = ChangeResult::NoChange; 111 for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { 112 LatticeElement<SCCPLatticeValue> &lattice = 113 getLatticeElement(op->getResult(i)); 114 115 // Merge in the result of the fold, either a constant or a value. 116 OpFoldResult foldResult = foldResults[i]; 117 if (Attribute attr = foldResult.dyn_cast<Attribute>()) 118 result |= lattice.join(SCCPLatticeValue(attr, dialect)); 119 else 120 result |= lattice.join(getLatticeElement(foldResult.get<Value>())); 121 } 122 return result; 123 } 124 125 /// Implementation of `getSuccessorsForOperands` that uses constant operands 126 /// to potentially remove dead successors. 127 LogicalResult getSuccessorsForOperands( 128 BranchOpInterface branch, 129 ArrayRef<LatticeElement<SCCPLatticeValue> *> operands, 130 SmallVectorImpl<Block *> &successors) final { 131 SmallVector<Attribute> constantOperands( 132 llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) { 133 return value->getValue().constant; 134 })); 135 if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) { 136 successors.push_back(singleSucc); 137 return success(); 138 } 139 return failure(); 140 } 141 142 /// Implementation of `getSuccessorsForOperands` that uses constant operands 143 /// to potentially remove dead region successors. 144 void getSuccessorsForOperands( 145 RegionBranchOpInterface branch, Optional<unsigned> sourceIndex, 146 ArrayRef<LatticeElement<SCCPLatticeValue> *> operands, 147 SmallVectorImpl<RegionSuccessor> &successors) final { 148 SmallVector<Attribute> constantOperands( 149 llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) { 150 return value->getValue().constant; 151 })); 152 branch.getSuccessorRegions(sourceIndex, constantOperands, successors); 153 } 154 }; 155 } // namespace 156 157 //===----------------------------------------------------------------------===// 158 // SCCP Rewrites 159 //===----------------------------------------------------------------------===// 160 161 /// Replace the given value with a constant if the corresponding lattice 162 /// represents a constant. Returns success if the value was replaced, failure 163 /// otherwise. 164 static LogicalResult replaceWithConstant(SCCPAnalysis &analysis, 165 OpBuilder &builder, 166 OperationFolder &folder, Value value) { 167 LatticeElement<SCCPLatticeValue> *lattice = 168 analysis.lookupLatticeElement(value); 169 if (!lattice) 170 return failure(); 171 SCCPLatticeValue &latticeValue = lattice->getValue(); 172 if (!latticeValue.constant) 173 return failure(); 174 175 // Attempt to materialize a constant for the given value. 176 Dialect *dialect = latticeValue.constantDialect; 177 Value constant = folder.getOrCreateConstant( 178 builder, dialect, latticeValue.constant, value.getType(), value.getLoc()); 179 if (!constant) 180 return failure(); 181 182 value.replaceAllUsesWith(constant); 183 return success(); 184 } 185 186 /// Rewrite the given regions using the computing analysis. This replaces the 187 /// uses of all values that have been computed to be constant, and erases as 188 /// many newly dead operations. 189 static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, 190 MutableArrayRef<Region> initialRegions) { 191 SmallVector<Block *> worklist; 192 auto addToWorklist = [&](MutableArrayRef<Region> regions) { 193 for (Region ®ion : regions) 194 for (Block &block : llvm::reverse(region)) 195 worklist.push_back(&block); 196 }; 197 198 // An operation folder used to create and unique constants. 199 OperationFolder folder(context); 200 OpBuilder builder(context); 201 202 addToWorklist(initialRegions); 203 while (!worklist.empty()) { 204 Block *block = worklist.pop_back_val(); 205 206 for (Operation &op : llvm::make_early_inc_range(*block)) { 207 builder.setInsertionPoint(&op); 208 209 // Replace any result with constants. 210 bool replacedAll = op.getNumResults() != 0; 211 for (Value res : op.getResults()) 212 replacedAll &= 213 succeeded(replaceWithConstant(analysis, builder, folder, res)); 214 215 // If all of the results of the operation were replaced, try to erase 216 // the operation completely. 217 if (replacedAll && wouldOpBeTriviallyDead(&op)) { 218 assert(op.use_empty() && "expected all uses to be replaced"); 219 op.erase(); 220 continue; 221 } 222 223 // Add any the regions of this operation to the worklist. 224 addToWorklist(op.getRegions()); 225 } 226 227 // Replace any block arguments with constants. 228 builder.setInsertionPointToStart(block); 229 for (BlockArgument arg : block->getArguments()) 230 (void)replaceWithConstant(analysis, builder, folder, arg); 231 } 232 } 233 234 //===----------------------------------------------------------------------===// 235 // SCCP Pass 236 //===----------------------------------------------------------------------===// 237 238 namespace { 239 struct SCCP : public SCCPBase<SCCP> { 240 void runOnOperation() override; 241 }; 242 } // end anonymous namespace 243 244 void SCCP::runOnOperation() { 245 Operation *op = getOperation(); 246 247 SCCPAnalysis analysis(op->getContext()); 248 analysis.run(op); 249 rewrite(analysis, op->getContext(), op->getRegions()); 250 } 251 252 std::unique_ptr<Pass> mlir::createSCCPPass() { 253 return std::make_unique<SCCP>(); 254 } 255