1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass performs a sparse conditional constant propagation 10 // in MLIR. It identifies values known to be constant, propagates that 11 // information throughout the IR, and replaces them. This is done with an 12 // optimistic dataflow analysis that assumes that all values are constant until 13 // proven otherwise. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "mlir/Analysis/DataFlowAnalysis.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/Interfaces/ControlFlowInterfaces.h" 22 #include "mlir/Interfaces/SideEffectInterfaces.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 #include "mlir/Transforms/Passes.h" 26 #include "llvm/Support/Debug.h" 27 28 #define DEBUG_TYPE "sccp" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // SCCP Analysis 34 //===----------------------------------------------------------------------===// 35 36 namespace { 37 struct SCCPLatticeValue { 38 SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr) 39 : constant(constant), constantDialect(dialect) {} 40 41 /// The pessimistic state of SCCP is non-constant. 42 static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) { 43 return SCCPLatticeValue(); 44 } 45 static SCCPLatticeValue getPessimisticValueState(Value value) { 46 return SCCPLatticeValue(); 47 } 48 49 /// Equivalence for SCCP only accounts for the constant, not the originating 50 /// dialect. 51 bool operator==(const SCCPLatticeValue &rhs) const { 52 return constant == rhs.constant; 53 } 54 55 /// To join the state of two values, we simply check for equivalence. 56 static SCCPLatticeValue join(const SCCPLatticeValue &lhs, 57 const SCCPLatticeValue &rhs) { 58 return lhs == rhs ? lhs : SCCPLatticeValue(); 59 } 60 61 /// The constant attribute value. 62 Attribute constant; 63 64 /// The dialect the constant originated from. This is not used as part of the 65 /// key, and is only needed to materialize the held constant if necessary. 66 Dialect *constantDialect; 67 }; 68 69 struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> { 70 using ForwardDataFlowAnalysis<SCCPLatticeValue>::ForwardDataFlowAnalysis; 71 ~SCCPAnalysis() override = default; 72 73 ChangeResult 74 visitOperation(Operation *op, 75 ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final { 76 77 LLVM_DEBUG(llvm::dbgs() << "SCCP: Visiting operation: " << *op << "\n"); 78 79 // Don't try to simulate the results of a region operation as we can't 80 // guarantee that folding will be out-of-place. We don't allow in-place 81 // folds as the desire here is for simulated execution, and not general 82 // folding. 83 if (op->getNumRegions()) 84 return markAllPessimisticFixpoint(op->getResults()); 85 86 SmallVector<Attribute> constantOperands( 87 llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) { 88 return value->getValue().constant; 89 })); 90 91 // Save the original operands and attributes just in case the operation 92 // folds in-place. The constant passed in may not correspond to the real 93 // runtime value, so in-place updates are not allowed. 94 SmallVector<Value, 8> originalOperands(op->getOperands()); 95 DictionaryAttr originalAttrs = op->getAttrDictionary(); 96 97 // Simulate the result of folding this operation to a constant. If folding 98 // fails or was an in-place fold, mark the results as overdefined. 99 SmallVector<OpFoldResult, 8> foldResults; 100 foldResults.reserve(op->getNumResults()); 101 if (failed(op->fold(constantOperands, foldResults))) 102 return markAllPessimisticFixpoint(op->getResults()); 103 104 // If the folding was in-place, mark the results as overdefined and reset 105 // the operation. We don't allow in-place folds as the desire here is for 106 // simulated execution, and not general folding. 107 if (foldResults.empty()) { 108 op->setOperands(originalOperands); 109 op->setAttrs(originalAttrs); 110 return markAllPessimisticFixpoint(op->getResults()); 111 } 112 113 // Merge the fold results into the lattice for this operation. 114 assert(foldResults.size() == op->getNumResults() && "invalid result size"); 115 Dialect *dialect = op->getDialect(); 116 ChangeResult result = ChangeResult::NoChange; 117 for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { 118 LatticeElement<SCCPLatticeValue> &lattice = 119 getLatticeElement(op->getResult(i)); 120 121 // Merge in the result of the fold, either a constant or a value. 122 OpFoldResult foldResult = foldResults[i]; 123 if (Attribute attr = foldResult.dyn_cast<Attribute>()) 124 result |= lattice.join(SCCPLatticeValue(attr, dialect)); 125 else 126 result |= lattice.join(getLatticeElement(foldResult.get<Value>())); 127 } 128 return result; 129 } 130 131 /// Implementation of `getSuccessorsForOperands` that uses constant operands 132 /// to potentially remove dead successors. 133 LogicalResult getSuccessorsForOperands( 134 BranchOpInterface branch, 135 ArrayRef<LatticeElement<SCCPLatticeValue> *> operands, 136 SmallVectorImpl<Block *> &successors) final { 137 SmallVector<Attribute> constantOperands( 138 llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) { 139 return value->getValue().constant; 140 })); 141 if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) { 142 successors.push_back(singleSucc); 143 return success(); 144 } 145 return failure(); 146 } 147 148 /// Implementation of `getSuccessorsForOperands` that uses constant operands 149 /// to potentially remove dead region successors. 150 void getSuccessorsForOperands( 151 RegionBranchOpInterface branch, Optional<unsigned> sourceIndex, 152 ArrayRef<LatticeElement<SCCPLatticeValue> *> operands, 153 SmallVectorImpl<RegionSuccessor> &successors) final { 154 SmallVector<Attribute> constantOperands( 155 llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) { 156 return value->getValue().constant; 157 })); 158 branch.getSuccessorRegions(sourceIndex, constantOperands, successors); 159 } 160 }; 161 } // namespace 162 163 //===----------------------------------------------------------------------===// 164 // SCCP Rewrites 165 //===----------------------------------------------------------------------===// 166 167 /// Replace the given value with a constant if the corresponding lattice 168 /// represents a constant. Returns success if the value was replaced, failure 169 /// otherwise. 170 static LogicalResult replaceWithConstant(SCCPAnalysis &analysis, 171 OpBuilder &builder, 172 OperationFolder &folder, Value value) { 173 LatticeElement<SCCPLatticeValue> *lattice = 174 analysis.lookupLatticeElement(value); 175 if (!lattice) 176 return failure(); 177 SCCPLatticeValue &latticeValue = lattice->getValue(); 178 if (!latticeValue.constant) 179 return failure(); 180 181 // Attempt to materialize a constant for the given value. 182 Dialect *dialect = latticeValue.constantDialect; 183 Value constant = folder.getOrCreateConstant( 184 builder, dialect, latticeValue.constant, value.getType(), value.getLoc()); 185 if (!constant) 186 return failure(); 187 188 value.replaceAllUsesWith(constant); 189 return success(); 190 } 191 192 /// Rewrite the given regions using the computing analysis. This replaces the 193 /// uses of all values that have been computed to be constant, and erases as 194 /// many newly dead operations. 195 static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, 196 MutableArrayRef<Region> initialRegions) { 197 SmallVector<Block *> worklist; 198 auto addToWorklist = [&](MutableArrayRef<Region> regions) { 199 for (Region ®ion : regions) 200 for (Block &block : llvm::reverse(region)) 201 worklist.push_back(&block); 202 }; 203 204 // An operation folder used to create and unique constants. 205 OperationFolder folder(context); 206 OpBuilder builder(context); 207 208 addToWorklist(initialRegions); 209 while (!worklist.empty()) { 210 Block *block = worklist.pop_back_val(); 211 212 for (Operation &op : llvm::make_early_inc_range(*block)) { 213 builder.setInsertionPoint(&op); 214 215 // Replace any result with constants. 216 bool replacedAll = op.getNumResults() != 0; 217 for (Value res : op.getResults()) 218 replacedAll &= 219 succeeded(replaceWithConstant(analysis, builder, folder, res)); 220 221 // If all of the results of the operation were replaced, try to erase 222 // the operation completely. 223 if (replacedAll && wouldOpBeTriviallyDead(&op)) { 224 assert(op.use_empty() && "expected all uses to be replaced"); 225 op.erase(); 226 continue; 227 } 228 229 // Add any the regions of this operation to the worklist. 230 addToWorklist(op.getRegions()); 231 } 232 233 // Replace any block arguments with constants. 234 builder.setInsertionPointToStart(block); 235 for (BlockArgument arg : block->getArguments()) 236 (void)replaceWithConstant(analysis, builder, folder, arg); 237 } 238 } 239 240 //===----------------------------------------------------------------------===// 241 // SCCP Pass 242 //===----------------------------------------------------------------------===// 243 244 namespace { 245 struct SCCP : public SCCPBase<SCCP> { 246 void runOnOperation() override; 247 }; 248 } // namespace 249 250 void SCCP::runOnOperation() { 251 Operation *op = getOperation(); 252 253 SCCPAnalysis analysis(op->getContext()); 254 analysis.run(op); 255 rewrite(analysis, op->getContext(), op->getRegions()); 256 } 257 258 std::unique_ptr<Pass> mlir::createSCCPPass() { 259 return std::make_unique<SCCP>(); 260 } 261