1*c2e7e759SEric Schweitz //===-- AnnotateConstant.cpp ----------------------------------------------===// 2*c2e7e759SEric Schweitz // 3*c2e7e759SEric Schweitz // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*c2e7e759SEric Schweitz // See https://llvm.org/LICENSE.txt for license information. 5*c2e7e759SEric Schweitz // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*c2e7e759SEric Schweitz // 7*c2e7e759SEric Schweitz //===----------------------------------------------------------------------===// 8*c2e7e759SEric Schweitz 9*c2e7e759SEric Schweitz #include "PassDetail.h" 10*c2e7e759SEric Schweitz #include "flang/Optimizer/Dialect/FIROps.h" 11*c2e7e759SEric Schweitz #include "flang/Optimizer/Transforms/Passes.h" 12*c2e7e759SEric Schweitz #include "mlir/IR/BuiltinAttributes.h" 13*c2e7e759SEric Schweitz 14*c2e7e759SEric Schweitz #define DEBUG_TYPE "flang-annotate-constant" 15*c2e7e759SEric Schweitz 16*c2e7e759SEric Schweitz using namespace fir; 17*c2e7e759SEric Schweitz 18*c2e7e759SEric Schweitz namespace { 19*c2e7e759SEric Schweitz struct AnnotateConstantOperands 20*c2e7e759SEric Schweitz : AnnotateConstantOperandsBase<AnnotateConstantOperands> { runOnOperation__anonbbd3f5210111::AnnotateConstantOperands21*c2e7e759SEric Schweitz void runOnOperation() override { 22*c2e7e759SEric Schweitz auto *context = &getContext(); 23*c2e7e759SEric Schweitz mlir::Dialect *firDialect = context->getLoadedDialect("fir"); 24*c2e7e759SEric Schweitz getOperation()->walk([&](mlir::Operation *op) { 25*c2e7e759SEric Schweitz // We filter out other dialects even though they may undergo merging of 26*c2e7e759SEric Schweitz // non-equal constant values by the canonicalizer as well. 27*c2e7e759SEric Schweitz if (op->getDialect() == firDialect) { 28*c2e7e759SEric Schweitz llvm::SmallVector<mlir::Attribute> attrs; 29*c2e7e759SEric Schweitz bool hasOneOrMoreConstOpnd = false; 30*c2e7e759SEric Schweitz for (mlir::Value opnd : op->getOperands()) { 31*c2e7e759SEric Schweitz if (auto constOp = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>( 32*c2e7e759SEric Schweitz opnd.getDefiningOp())) { 33*c2e7e759SEric Schweitz attrs.push_back(constOp.getValue()); 34*c2e7e759SEric Schweitz hasOneOrMoreConstOpnd = true; 35*c2e7e759SEric Schweitz } else if (auto addrOp = mlir::dyn_cast_or_null<fir::AddrOfOp>( 36*c2e7e759SEric Schweitz opnd.getDefiningOp())) { 37*c2e7e759SEric Schweitz attrs.push_back(addrOp.getSymbol()); 38*c2e7e759SEric Schweitz hasOneOrMoreConstOpnd = true; 39*c2e7e759SEric Schweitz } else { 40*c2e7e759SEric Schweitz attrs.push_back(mlir::UnitAttr::get(context)); 41*c2e7e759SEric Schweitz } 42*c2e7e759SEric Schweitz } 43*c2e7e759SEric Schweitz if (hasOneOrMoreConstOpnd) 44*c2e7e759SEric Schweitz op->setAttr("canonicalize_constant_operands", 45*c2e7e759SEric Schweitz mlir::ArrayAttr::get(context, attrs)); 46*c2e7e759SEric Schweitz } 47*c2e7e759SEric Schweitz }); 48*c2e7e759SEric Schweitz } 49*c2e7e759SEric Schweitz }; 50*c2e7e759SEric Schweitz 51*c2e7e759SEric Schweitz } // namespace 52*c2e7e759SEric Schweitz createAnnotateConstantOperandsPass()53*c2e7e759SEric Schweitzstd::unique_ptr<mlir::Pass> fir::createAnnotateConstantOperandsPass() { 54*c2e7e759SEric Schweitz return std::make_unique<AnnotateConstantOperands>(); 55*c2e7e759SEric Schweitz } 56