1c095afcbSMogball //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
2c095afcbSMogball //
3c095afcbSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c095afcbSMogball // See https://llvm.org/LICENSE.txt for license information.
5c095afcbSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c095afcbSMogball //
7c095afcbSMogball //===----------------------------------------------------------------------===//
8c095afcbSMogball 
9c095afcbSMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
109432fbfeSMogball #include "mlir/IR/OpDefinition.h"
119432fbfeSMogball #include "llvm/Support/Debug.h"
129432fbfeSMogball 
139432fbfeSMogball #define DEBUG_TYPE "constant-propagation"
14c095afcbSMogball 
15c095afcbSMogball using namespace mlir;
16c095afcbSMogball using namespace mlir::dataflow;
17c095afcbSMogball 
18c095afcbSMogball //===----------------------------------------------------------------------===//
19c095afcbSMogball // ConstantValue
20c095afcbSMogball //===----------------------------------------------------------------------===//
21c095afcbSMogball 
print(raw_ostream & os) const22c095afcbSMogball void ConstantValue::print(raw_ostream &os) const {
23c095afcbSMogball   if (constant)
24c095afcbSMogball     return constant.print(os);
25c095afcbSMogball   os << "<NO VALUE>";
26c095afcbSMogball }
279432fbfeSMogball 
289432fbfeSMogball //===----------------------------------------------------------------------===//
299432fbfeSMogball // SparseConstantPropagation
309432fbfeSMogball //===----------------------------------------------------------------------===//
319432fbfeSMogball 
visitOperation(Operation * op,ArrayRef<const Lattice<ConstantValue> * > operands,ArrayRef<Lattice<ConstantValue> * > results)329432fbfeSMogball void SparseConstantPropagation::visitOperation(
339432fbfeSMogball     Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
349432fbfeSMogball     ArrayRef<Lattice<ConstantValue> *> results) {
359432fbfeSMogball   LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
369432fbfeSMogball 
379432fbfeSMogball   // Don't try to simulate the results of a region operation as we can't
389432fbfeSMogball   // guarantee that folding will be out-of-place. We don't allow in-place
399432fbfeSMogball   // folds as the desire here is for simulated execution, and not general
409432fbfeSMogball   // folding.
419432fbfeSMogball   if (op->getNumRegions())
429432fbfeSMogball     return;
439432fbfeSMogball 
449432fbfeSMogball   SmallVector<Attribute, 8> constantOperands;
459432fbfeSMogball   constantOperands.reserve(op->getNumOperands());
469432fbfeSMogball   for (auto *operandLattice : operands)
479432fbfeSMogball     constantOperands.push_back(operandLattice->getValue().getConstantValue());
489432fbfeSMogball 
499432fbfeSMogball   // Save the original operands and attributes just in case the operation
509432fbfeSMogball   // folds in-place. The constant passed in may not correspond to the real
519432fbfeSMogball   // runtime value, so in-place updates are not allowed.
529432fbfeSMogball   SmallVector<Value, 8> originalOperands(op->getOperands());
539432fbfeSMogball   DictionaryAttr originalAttrs = op->getAttrDictionary();
549432fbfeSMogball 
559432fbfeSMogball   // Simulate the result of folding this operation to a constant. If folding
569432fbfeSMogball   // fails or was an in-place fold, mark the results as overdefined.
579432fbfeSMogball   SmallVector<OpFoldResult, 8> foldResults;
589432fbfeSMogball   foldResults.reserve(op->getNumResults());
599432fbfeSMogball   if (failed(op->fold(constantOperands, foldResults))) {
609432fbfeSMogball     markAllPessimisticFixpoint(results);
619432fbfeSMogball     return;
629432fbfeSMogball   }
639432fbfeSMogball 
649432fbfeSMogball   // If the folding was in-place, mark the results as overdefined and reset
659432fbfeSMogball   // the operation. We don't allow in-place folds as the desire here is for
669432fbfeSMogball   // simulated execution, and not general folding.
679432fbfeSMogball   if (foldResults.empty()) {
689432fbfeSMogball     op->setOperands(originalOperands);
699432fbfeSMogball     op->setAttrs(originalAttrs);
70*13bc82b5SJacques Pienaar     markAllPessimisticFixpoint(results);
719432fbfeSMogball     return;
729432fbfeSMogball   }
739432fbfeSMogball 
749432fbfeSMogball   // Merge the fold results into the lattice for this operation.
759432fbfeSMogball   assert(foldResults.size() == op->getNumResults() && "invalid result size");
769432fbfeSMogball   for (const auto it : llvm::zip(results, foldResults)) {
779432fbfeSMogball     Lattice<ConstantValue> *lattice = std::get<0>(it);
789432fbfeSMogball 
799432fbfeSMogball     // Merge in the result of the fold, either a constant or a value.
809432fbfeSMogball     OpFoldResult foldResult = std::get<1>(it);
819432fbfeSMogball     if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
829432fbfeSMogball       LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
839432fbfeSMogball       propagateIfChanged(lattice,
849432fbfeSMogball                          lattice->join(ConstantValue(attr, op->getDialect())));
859432fbfeSMogball     } else {
869432fbfeSMogball       LLVM_DEBUG(llvm::dbgs()
879432fbfeSMogball                  << "Folded to value: " << foldResult.get<Value>() << "\n");
889432fbfeSMogball       AbstractSparseDataFlowAnalysis::join(
899432fbfeSMogball           lattice, *getLatticeElement(foldResult.get<Value>()));
909432fbfeSMogball     }
919432fbfeSMogball   }
929432fbfeSMogball }
93