1 //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10 #include "mlir/IR/OpDefinition.h"
11 #include "llvm/Support/Debug.h"
12
13 #define DEBUG_TYPE "constant-propagation"
14
15 using namespace mlir;
16 using namespace mlir::dataflow;
17
18 //===----------------------------------------------------------------------===//
19 // ConstantValue
20 //===----------------------------------------------------------------------===//
21
print(raw_ostream & os) const22 void ConstantValue::print(raw_ostream &os) const {
23 if (constant)
24 return constant.print(os);
25 os << "<NO VALUE>";
26 }
27
28 //===----------------------------------------------------------------------===//
29 // SparseConstantPropagation
30 //===----------------------------------------------------------------------===//
31
visitOperation(Operation * op,ArrayRef<const Lattice<ConstantValue> * > operands,ArrayRef<Lattice<ConstantValue> * > results)32 void SparseConstantPropagation::visitOperation(
33 Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
34 ArrayRef<Lattice<ConstantValue> *> results) {
35 LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
36
37 // Don't try to simulate the results of a region operation as we can't
38 // guarantee that folding will be out-of-place. We don't allow in-place
39 // folds as the desire here is for simulated execution, and not general
40 // folding.
41 if (op->getNumRegions())
42 return;
43
44 SmallVector<Attribute, 8> constantOperands;
45 constantOperands.reserve(op->getNumOperands());
46 for (auto *operandLattice : operands)
47 constantOperands.push_back(operandLattice->getValue().getConstantValue());
48
49 // Save the original operands and attributes just in case the operation
50 // folds in-place. The constant passed in may not correspond to the real
51 // runtime value, so in-place updates are not allowed.
52 SmallVector<Value, 8> originalOperands(op->getOperands());
53 DictionaryAttr originalAttrs = op->getAttrDictionary();
54
55 // Simulate the result of folding this operation to a constant. If folding
56 // fails or was an in-place fold, mark the results as overdefined.
57 SmallVector<OpFoldResult, 8> foldResults;
58 foldResults.reserve(op->getNumResults());
59 if (failed(op->fold(constantOperands, foldResults))) {
60 markAllPessimisticFixpoint(results);
61 return;
62 }
63
64 // If the folding was in-place, mark the results as overdefined and reset
65 // the operation. We don't allow in-place folds as the desire here is for
66 // simulated execution, and not general folding.
67 if (foldResults.empty()) {
68 op->setOperands(originalOperands);
69 op->setAttrs(originalAttrs);
70 markAllPessimisticFixpoint(results);
71 return;
72 }
73
74 // Merge the fold results into the lattice for this operation.
75 assert(foldResults.size() == op->getNumResults() && "invalid result size");
76 for (const auto it : llvm::zip(results, foldResults)) {
77 Lattice<ConstantValue> *lattice = std::get<0>(it);
78
79 // Merge in the result of the fold, either a constant or a value.
80 OpFoldResult foldResult = std::get<1>(it);
81 if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
82 LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
83 propagateIfChanged(lattice,
84 lattice->join(ConstantValue(attr, op->getDialect())));
85 } else {
86 LLVM_DEBUG(llvm::dbgs()
87 << "Folded to value: " << foldResult.get<Value>() << "\n");
88 AbstractSparseDataFlowAnalysis::join(
89 lattice, *getLatticeElement(foldResult.get<Value>()));
90 }
91 }
92 }
93