1c095afcbSMogball //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2c095afcbSMogball //
3c095afcbSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c095afcbSMogball // See https://llvm.org/LICENSE.txt for license information.
5c095afcbSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c095afcbSMogball //
7c095afcbSMogball //===----------------------------------------------------------------------===//
8c095afcbSMogball
9c095afcbSMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
109432fbfeSMogball #include "mlir/IR/OpDefinition.h"
119432fbfeSMogball #include "llvm/Support/Debug.h"
129432fbfeSMogball
139432fbfeSMogball #define DEBUG_TYPE "constant-propagation"
14c095afcbSMogball
15c095afcbSMogball using namespace mlir;
16c095afcbSMogball using namespace mlir::dataflow;
17c095afcbSMogball
18c095afcbSMogball //===----------------------------------------------------------------------===//
19c095afcbSMogball // ConstantValue
20c095afcbSMogball //===----------------------------------------------------------------------===//
21c095afcbSMogball
print(raw_ostream & os) const22c095afcbSMogball void ConstantValue::print(raw_ostream &os) const {
23c095afcbSMogball if (constant)
24c095afcbSMogball return constant.print(os);
25c095afcbSMogball os << "<NO VALUE>";
26c095afcbSMogball }
279432fbfeSMogball
289432fbfeSMogball //===----------------------------------------------------------------------===//
299432fbfeSMogball // SparseConstantPropagation
309432fbfeSMogball //===----------------------------------------------------------------------===//
319432fbfeSMogball
visitOperation(Operation * op,ArrayRef<const Lattice<ConstantValue> * > operands,ArrayRef<Lattice<ConstantValue> * > results)329432fbfeSMogball void SparseConstantPropagation::visitOperation(
339432fbfeSMogball Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
349432fbfeSMogball ArrayRef<Lattice<ConstantValue> *> results) {
359432fbfeSMogball LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
369432fbfeSMogball
379432fbfeSMogball // Don't try to simulate the results of a region operation as we can't
389432fbfeSMogball // guarantee that folding will be out-of-place. We don't allow in-place
399432fbfeSMogball // folds as the desire here is for simulated execution, and not general
409432fbfeSMogball // folding.
419432fbfeSMogball if (op->getNumRegions())
429432fbfeSMogball return;
439432fbfeSMogball
449432fbfeSMogball SmallVector<Attribute, 8> constantOperands;
459432fbfeSMogball constantOperands.reserve(op->getNumOperands());
469432fbfeSMogball for (auto *operandLattice : operands)
479432fbfeSMogball constantOperands.push_back(operandLattice->getValue().getConstantValue());
489432fbfeSMogball
499432fbfeSMogball // Save the original operands and attributes just in case the operation
509432fbfeSMogball // folds in-place. The constant passed in may not correspond to the real
519432fbfeSMogball // runtime value, so in-place updates are not allowed.
529432fbfeSMogball SmallVector<Value, 8> originalOperands(op->getOperands());
539432fbfeSMogball DictionaryAttr originalAttrs = op->getAttrDictionary();
549432fbfeSMogball
559432fbfeSMogball // Simulate the result of folding this operation to a constant. If folding
569432fbfeSMogball // fails or was an in-place fold, mark the results as overdefined.
579432fbfeSMogball SmallVector<OpFoldResult, 8> foldResults;
589432fbfeSMogball foldResults.reserve(op->getNumResults());
599432fbfeSMogball if (failed(op->fold(constantOperands, foldResults))) {
609432fbfeSMogball markAllPessimisticFixpoint(results);
619432fbfeSMogball return;
629432fbfeSMogball }
639432fbfeSMogball
649432fbfeSMogball // If the folding was in-place, mark the results as overdefined and reset
659432fbfeSMogball // the operation. We don't allow in-place folds as the desire here is for
669432fbfeSMogball // simulated execution, and not general folding.
679432fbfeSMogball if (foldResults.empty()) {
689432fbfeSMogball op->setOperands(originalOperands);
699432fbfeSMogball op->setAttrs(originalAttrs);
70*13bc82b5SJacques Pienaar markAllPessimisticFixpoint(results);
719432fbfeSMogball return;
729432fbfeSMogball }
739432fbfeSMogball
749432fbfeSMogball // Merge the fold results into the lattice for this operation.
759432fbfeSMogball assert(foldResults.size() == op->getNumResults() && "invalid result size");
769432fbfeSMogball for (const auto it : llvm::zip(results, foldResults)) {
779432fbfeSMogball Lattice<ConstantValue> *lattice = std::get<0>(it);
789432fbfeSMogball
799432fbfeSMogball // Merge in the result of the fold, either a constant or a value.
809432fbfeSMogball OpFoldResult foldResult = std::get<1>(it);
819432fbfeSMogball if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
829432fbfeSMogball LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
839432fbfeSMogball propagateIfChanged(lattice,
849432fbfeSMogball lattice->join(ConstantValue(attr, op->getDialect())));
859432fbfeSMogball } else {
869432fbfeSMogball LLVM_DEBUG(llvm::dbgs()
879432fbfeSMogball << "Folded to value: " << foldResult.get<Value>() << "\n");
889432fbfeSMogball AbstractSparseDataFlowAnalysis::join(
899432fbfeSMogball lattice, *getLatticeElement(foldResult.get<Value>()));
909432fbfeSMogball }
919432fbfeSMogball }
929432fbfeSMogball }
93