1 //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 10 #include "mlir/IR/OpDefinition.h" 11 #include "llvm/Support/Debug.h" 12 13 #define DEBUG_TYPE "constant-propagation" 14 15 using namespace mlir; 16 using namespace mlir::dataflow; 17 18 //===----------------------------------------------------------------------===// 19 // ConstantValue 20 //===----------------------------------------------------------------------===// 21 22 void ConstantValue::print(raw_ostream &os) const { 23 if (constant) 24 return constant.print(os); 25 os << "<NO VALUE>"; 26 } 27 28 //===----------------------------------------------------------------------===// 29 // SparseConstantPropagation 30 //===----------------------------------------------------------------------===// 31 32 void SparseConstantPropagation::visitOperation( 33 Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands, 34 ArrayRef<Lattice<ConstantValue> *> results) { 35 LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n"); 36 37 // Don't try to simulate the results of a region operation as we can't 38 // guarantee that folding will be out-of-place. We don't allow in-place 39 // folds as the desire here is for simulated execution, and not general 40 // folding. 41 if (op->getNumRegions()) 42 return; 43 44 SmallVector<Attribute, 8> constantOperands; 45 constantOperands.reserve(op->getNumOperands()); 46 for (auto *operandLattice : operands) 47 constantOperands.push_back(operandLattice->getValue().getConstantValue()); 48 49 // Save the original operands and attributes just in case the operation 50 // folds in-place. The constant passed in may not correspond to the real 51 // runtime value, so in-place updates are not allowed. 52 SmallVector<Value, 8> originalOperands(op->getOperands()); 53 DictionaryAttr originalAttrs = op->getAttrDictionary(); 54 55 // Simulate the result of folding this operation to a constant. If folding 56 // fails or was an in-place fold, mark the results as overdefined. 57 SmallVector<OpFoldResult, 8> foldResults; 58 foldResults.reserve(op->getNumResults()); 59 if (failed(op->fold(constantOperands, foldResults))) { 60 markAllPessimisticFixpoint(results); 61 return; 62 } 63 64 // If the folding was in-place, mark the results as overdefined and reset 65 // the operation. We don't allow in-place folds as the desire here is for 66 // simulated execution, and not general folding. 67 if (foldResults.empty()) { 68 op->setOperands(originalOperands); 69 op->setAttrs(originalAttrs); 70 return; 71 } 72 73 // Merge the fold results into the lattice for this operation. 74 assert(foldResults.size() == op->getNumResults() && "invalid result size"); 75 for (const auto it : llvm::zip(results, foldResults)) { 76 Lattice<ConstantValue> *lattice = std::get<0>(it); 77 78 // Merge in the result of the fold, either a constant or a value. 79 OpFoldResult foldResult = std::get<1>(it); 80 if (Attribute attr = foldResult.dyn_cast<Attribute>()) { 81 LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); 82 propagateIfChanged(lattice, 83 lattice->join(ConstantValue(attr, op->getDialect()))); 84 } else { 85 LLVM_DEBUG(llvm::dbgs() 86 << "Folded to value: " << foldResult.get<Value>() << "\n"); 87 AbstractSparseDataFlowAnalysis::join( 88 lattice, *getLatticeElement(foldResult.get<Value>())); 89 } 90 } 91 } 92