1 //===-- AnnotateConstant.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIROps.h"
11 #include "flang/Optimizer/Transforms/Passes.h"
12 #include "mlir/IR/BuiltinAttributes.h"
13 
14 #define DEBUG_TYPE "flang-annotate-constant"
15 
16 using namespace fir;
17 
18 namespace {
19 struct AnnotateConstantOperands
20     : AnnotateConstantOperandsBase<AnnotateConstantOperands> {
runOnOperation__anonbbd3f5210111::AnnotateConstantOperands21   void runOnOperation() override {
22     auto *context = &getContext();
23     mlir::Dialect *firDialect = context->getLoadedDialect("fir");
24     getOperation()->walk([&](mlir::Operation *op) {
25       // We filter out other dialects even though they may undergo merging of
26       // non-equal constant values by the canonicalizer as well.
27       if (op->getDialect() == firDialect) {
28         llvm::SmallVector<mlir::Attribute> attrs;
29         bool hasOneOrMoreConstOpnd = false;
30         for (mlir::Value opnd : op->getOperands()) {
31           if (auto constOp = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
32                   opnd.getDefiningOp())) {
33             attrs.push_back(constOp.getValue());
34             hasOneOrMoreConstOpnd = true;
35           } else if (auto addrOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
36                          opnd.getDefiningOp())) {
37             attrs.push_back(addrOp.getSymbol());
38             hasOneOrMoreConstOpnd = true;
39           } else {
40             attrs.push_back(mlir::UnitAttr::get(context));
41           }
42         }
43         if (hasOneOrMoreConstOpnd)
44           op->setAttr("canonicalize_constant_operands",
45                       mlir::ArrayAttr::get(context, attrs));
46       }
47     });
48   }
49 };
50 
51 } // namespace
52 
createAnnotateConstantOperandsPass()53 std::unique_ptr<mlir::Pass> fir::createAnnotateConstantOperandsPass() {
54   return std::make_unique<AnnotateConstantOperands>();
55 }
56