1 //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10 #include "mlir/IR/OpDefinition.h"
11 #include "llvm/Support/Debug.h"
12 
13 #define DEBUG_TYPE "constant-propagation"
14 
15 using namespace mlir;
16 using namespace mlir::dataflow;
17 
18 //===----------------------------------------------------------------------===//
19 // ConstantValue
20 //===----------------------------------------------------------------------===//
21 
print(raw_ostream & os) const22 void ConstantValue::print(raw_ostream &os) const {
23   if (constant)
24     return constant.print(os);
25   os << "<NO VALUE>";
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // SparseConstantPropagation
30 //===----------------------------------------------------------------------===//
31 
visitOperation(Operation * op,ArrayRef<const Lattice<ConstantValue> * > operands,ArrayRef<Lattice<ConstantValue> * > results)32 void SparseConstantPropagation::visitOperation(
33     Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
34     ArrayRef<Lattice<ConstantValue> *> results) {
35   LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
36 
37   // Don't try to simulate the results of a region operation as we can't
38   // guarantee that folding will be out-of-place. We don't allow in-place
39   // folds as the desire here is for simulated execution, and not general
40   // folding.
41   if (op->getNumRegions())
42     return;
43 
44   SmallVector<Attribute, 8> constantOperands;
45   constantOperands.reserve(op->getNumOperands());
46   for (auto *operandLattice : operands)
47     constantOperands.push_back(operandLattice->getValue().getConstantValue());
48 
49   // Save the original operands and attributes just in case the operation
50   // folds in-place. The constant passed in may not correspond to the real
51   // runtime value, so in-place updates are not allowed.
52   SmallVector<Value, 8> originalOperands(op->getOperands());
53   DictionaryAttr originalAttrs = op->getAttrDictionary();
54 
55   // Simulate the result of folding this operation to a constant. If folding
56   // fails or was an in-place fold, mark the results as overdefined.
57   SmallVector<OpFoldResult, 8> foldResults;
58   foldResults.reserve(op->getNumResults());
59   if (failed(op->fold(constantOperands, foldResults))) {
60     markAllPessimisticFixpoint(results);
61     return;
62   }
63 
64   // If the folding was in-place, mark the results as overdefined and reset
65   // the operation. We don't allow in-place folds as the desire here is for
66   // simulated execution, and not general folding.
67   if (foldResults.empty()) {
68     op->setOperands(originalOperands);
69     op->setAttrs(originalAttrs);
70     markAllPessimisticFixpoint(results);
71     return;
72   }
73 
74   // Merge the fold results into the lattice for this operation.
75   assert(foldResults.size() == op->getNumResults() && "invalid result size");
76   for (const auto it : llvm::zip(results, foldResults)) {
77     Lattice<ConstantValue> *lattice = std::get<0>(it);
78 
79     // Merge in the result of the fold, either a constant or a value.
80     OpFoldResult foldResult = std::get<1>(it);
81     if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
82       LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
83       propagateIfChanged(lattice,
84                          lattice->join(ConstantValue(attr, op->getDialect())));
85     } else {
86       LLVM_DEBUG(llvm::dbgs()
87                  << "Folded to value: " << foldResult.get<Value>() << "\n");
88       AbstractSparseDataFlowAnalysis::join(
89           lattice, *getLatticeElement(foldResult.get<Value>()));
90     }
91   }
92 }
93