1*c2e7e759SEric Schweitz //===-- AnnotateConstant.cpp ----------------------------------------------===//
2*c2e7e759SEric Schweitz //
3*c2e7e759SEric Schweitz // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*c2e7e759SEric Schweitz // See https://llvm.org/LICENSE.txt for license information.
5*c2e7e759SEric Schweitz // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*c2e7e759SEric Schweitz //
7*c2e7e759SEric Schweitz //===----------------------------------------------------------------------===//
8*c2e7e759SEric Schweitz 
9*c2e7e759SEric Schweitz #include "PassDetail.h"
10*c2e7e759SEric Schweitz #include "flang/Optimizer/Dialect/FIROps.h"
11*c2e7e759SEric Schweitz #include "flang/Optimizer/Transforms/Passes.h"
12*c2e7e759SEric Schweitz #include "mlir/IR/BuiltinAttributes.h"
13*c2e7e759SEric Schweitz 
14*c2e7e759SEric Schweitz #define DEBUG_TYPE "flang-annotate-constant"
15*c2e7e759SEric Schweitz 
16*c2e7e759SEric Schweitz using namespace fir;
17*c2e7e759SEric Schweitz 
18*c2e7e759SEric Schweitz namespace {
19*c2e7e759SEric Schweitz struct AnnotateConstantOperands
20*c2e7e759SEric Schweitz     : AnnotateConstantOperandsBase<AnnotateConstantOperands> {
runOnOperation__anonbbd3f5210111::AnnotateConstantOperands21*c2e7e759SEric Schweitz   void runOnOperation() override {
22*c2e7e759SEric Schweitz     auto *context = &getContext();
23*c2e7e759SEric Schweitz     mlir::Dialect *firDialect = context->getLoadedDialect("fir");
24*c2e7e759SEric Schweitz     getOperation()->walk([&](mlir::Operation *op) {
25*c2e7e759SEric Schweitz       // We filter out other dialects even though they may undergo merging of
26*c2e7e759SEric Schweitz       // non-equal constant values by the canonicalizer as well.
27*c2e7e759SEric Schweitz       if (op->getDialect() == firDialect) {
28*c2e7e759SEric Schweitz         llvm::SmallVector<mlir::Attribute> attrs;
29*c2e7e759SEric Schweitz         bool hasOneOrMoreConstOpnd = false;
30*c2e7e759SEric Schweitz         for (mlir::Value opnd : op->getOperands()) {
31*c2e7e759SEric Schweitz           if (auto constOp = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
32*c2e7e759SEric Schweitz                   opnd.getDefiningOp())) {
33*c2e7e759SEric Schweitz             attrs.push_back(constOp.getValue());
34*c2e7e759SEric Schweitz             hasOneOrMoreConstOpnd = true;
35*c2e7e759SEric Schweitz           } else if (auto addrOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
36*c2e7e759SEric Schweitz                          opnd.getDefiningOp())) {
37*c2e7e759SEric Schweitz             attrs.push_back(addrOp.getSymbol());
38*c2e7e759SEric Schweitz             hasOneOrMoreConstOpnd = true;
39*c2e7e759SEric Schweitz           } else {
40*c2e7e759SEric Schweitz             attrs.push_back(mlir::UnitAttr::get(context));
41*c2e7e759SEric Schweitz           }
42*c2e7e759SEric Schweitz         }
43*c2e7e759SEric Schweitz         if (hasOneOrMoreConstOpnd)
44*c2e7e759SEric Schweitz           op->setAttr("canonicalize_constant_operands",
45*c2e7e759SEric Schweitz                       mlir::ArrayAttr::get(context, attrs));
46*c2e7e759SEric Schweitz       }
47*c2e7e759SEric Schweitz     });
48*c2e7e759SEric Schweitz   }
49*c2e7e759SEric Schweitz };
50*c2e7e759SEric Schweitz 
51*c2e7e759SEric Schweitz } // namespace
52*c2e7e759SEric Schweitz 
createAnnotateConstantOperandsPass()53*c2e7e759SEric Schweitz std::unique_ptr<mlir::Pass> fir::createAnnotateConstantOperandsPass() {
54*c2e7e759SEric Schweitz   return std::make_unique<AnnotateConstantOperands>();
55*c2e7e759SEric Schweitz }
56