1 //===-- FIRAttr.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Dialect/FIRAttr.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Support/KindMapping.h"
16 #include "mlir/IR/AttributeSupport.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "llvm/ADT/SmallString.h"
20 
21 using namespace fir;
22 
23 namespace fir::detail {
24 
25 struct RealAttributeStorage : public mlir::AttributeStorage {
26   using KeyTy = std::pair<int, llvm::APFloat>;
27 
RealAttributeStoragefir::detail::RealAttributeStorage28   RealAttributeStorage(int kind, const llvm::APFloat &value)
29       : kind(kind), value(value) {}
RealAttributeStoragefir::detail::RealAttributeStorage30   RealAttributeStorage(const KeyTy &key)
31       : RealAttributeStorage(key.first, key.second) {}
32 
hashKeyfir::detail::RealAttributeStorage33   static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); }
34 
operator ==fir::detail::RealAttributeStorage35   bool operator==(const KeyTy &key) const {
36     return key.first == kind &&
37            key.second.compare(value) == llvm::APFloatBase::cmpEqual;
38   }
39 
40   static RealAttributeStorage *
constructfir::detail::RealAttributeStorage41   construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
42     return new (allocator.allocate<RealAttributeStorage>())
43         RealAttributeStorage(key);
44   }
45 
getFKindfir::detail::RealAttributeStorage46   KindTy getFKind() const { return kind; }
getValuefir::detail::RealAttributeStorage47   llvm::APFloat getValue() const { return value; }
48 
49 private:
50   int kind;
51   llvm::APFloat value;
52 };
53 
54 /// An attribute representing a reference to a type.
55 struct TypeAttributeStorage : public mlir::AttributeStorage {
56   using KeyTy = mlir::Type;
57 
TypeAttributeStoragefir::detail::TypeAttributeStorage58   TypeAttributeStorage(mlir::Type value) : value(value) {
59     assert(value && "must not be of Type null");
60   }
61 
62   /// Key equality function.
operator ==fir::detail::TypeAttributeStorage63   bool operator==(const KeyTy &key) const { return key == value; }
64 
65   /// Construct a new storage instance.
66   static TypeAttributeStorage *
constructfir::detail::TypeAttributeStorage67   construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) {
68     return new (allocator.allocate<TypeAttributeStorage>())
69         TypeAttributeStorage(key);
70   }
71 
getTypefir::detail::TypeAttributeStorage72   mlir::Type getType() const { return value; }
73 
74 private:
75   mlir::Type value;
76 };
77 } // namespace fir::detail
78 
79 //===----------------------------------------------------------------------===//
80 // Attributes for SELECT TYPE
81 //===----------------------------------------------------------------------===//
82 
get(mlir::Type value)83 ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) {
84   return Base::get(value.getContext(), value);
85 }
86 
getType() const87 mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); }
88 
get(mlir::Type value)89 SubclassAttr fir::SubclassAttr::get(mlir::Type value) {
90   return Base::get(value.getContext(), value);
91 }
92 
getType() const93 mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); }
94 
95 //===----------------------------------------------------------------------===//
96 // Attributes for SELECT CASE
97 //===----------------------------------------------------------------------===//
98 
99 using AttributeUniquer = mlir::detail::AttributeUniquer;
100 
get(mlir::MLIRContext * ctxt)101 ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
102   return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
103 }
104 
get(mlir::MLIRContext * ctxt)105 UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
106   return AttributeUniquer::get<UpperBoundAttr>(ctxt);
107 }
108 
get(mlir::MLIRContext * ctxt)109 LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
110   return AttributeUniquer::get<LowerBoundAttr>(ctxt);
111 }
112 
get(mlir::MLIRContext * ctxt)113 PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
114   return AttributeUniquer::get<PointIntervalAttr>(ctxt);
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // RealAttr
119 //===----------------------------------------------------------------------===//
120 
get(mlir::MLIRContext * ctxt,const RealAttr::ValueType & key)121 RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt,
122                             const RealAttr::ValueType &key) {
123   return Base::get(ctxt, key);
124 }
125 
getFKind() const126 KindTy fir::RealAttr::getFKind() const { return getImpl()->getFKind(); }
127 
getValue() const128 llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); }
129 
130 //===----------------------------------------------------------------------===//
131 // FIR attribute parsing
132 //===----------------------------------------------------------------------===//
133 
parseFirRealAttr(FIROpsDialect * dialect,mlir::DialectAsmParser & parser,mlir::Type type)134 static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect,
135                                         mlir::DialectAsmParser &parser,
136                                         mlir::Type type) {
137   int kind = 0;
138   if (parser.parseLess() || parser.parseInteger(kind) || parser.parseComma()) {
139     parser.emitError(parser.getNameLoc(), "expected '<' kind ','");
140     return {};
141   }
142   KindMapping kindMap(dialect->getContext());
143   llvm::APFloat value(0.);
144   if (parser.parseOptionalKeyword("i")) {
145     // `i` not present, so literal float must be present
146     double dontCare;
147     if (parser.parseFloat(dontCare) || parser.parseGreater()) {
148       parser.emitError(parser.getNameLoc(), "expected real constant '>'");
149       return {};
150     }
151     auto fltStr = parser.getFullSymbolSpec()
152                       .drop_until([](char c) { return c == ','; })
153                       .drop_front()
154                       .drop_while([](char c) { return c == ' ' || c == '\t'; })
155                       .take_until([](char c) {
156                         return c == '>' || c == ' ' || c == '\t';
157                       });
158     value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr);
159   } else {
160     // `i` is present, so literal bitstring (hex) must be present
161     llvm::StringRef hex;
162     if (parser.parseKeyword(&hex) || parser.parseGreater()) {
163       parser.emitError(parser.getNameLoc(), "expected real constant '>'");
164       return {};
165     }
166     const llvm::fltSemantics &sem = kindMap.getFloatSemantics(kind);
167     unsigned int numBits = llvm::APFloat::semanticsSizeInBits(sem);
168     auto bits = llvm::APInt(numBits, hex.drop_front(), 16);
169     value = llvm::APFloat(sem, bits);
170   }
171   return RealAttr::get(dialect->getContext(), {kind, value});
172 }
173 
parseFirAttribute(FIROpsDialect * dialect,mlir::DialectAsmParser & parser,mlir::Type type)174 mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect,
175                                        mlir::DialectAsmParser &parser,
176                                        mlir::Type type) {
177   auto loc = parser.getNameLoc();
178   llvm::StringRef attrName;
179   if (parser.parseKeyword(&attrName)) {
180     parser.emitError(loc, "expected an attribute name");
181     return {};
182   }
183 
184   if (attrName == ExactTypeAttr::getAttrName()) {
185     mlir::Type type;
186     if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
187       parser.emitError(loc, "expected a type");
188       return {};
189     }
190     return ExactTypeAttr::get(type);
191   }
192   if (attrName == SubclassAttr::getAttrName()) {
193     mlir::Type type;
194     if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
195       parser.emitError(loc, "expected a subtype");
196       return {};
197     }
198     return SubclassAttr::get(type);
199   }
200   if (attrName == PointIntervalAttr::getAttrName())
201     return PointIntervalAttr::get(dialect->getContext());
202   if (attrName == LowerBoundAttr::getAttrName())
203     return LowerBoundAttr::get(dialect->getContext());
204   if (attrName == UpperBoundAttr::getAttrName())
205     return UpperBoundAttr::get(dialect->getContext());
206   if (attrName == ClosedIntervalAttr::getAttrName())
207     return ClosedIntervalAttr::get(dialect->getContext());
208   if (attrName == RealAttr::getAttrName())
209     return parseFirRealAttr(dialect, parser, type);
210 
211   parser.emitError(loc, "unknown FIR attribute: ") << attrName;
212   return {};
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // FIR attribute pretty printer
217 //===----------------------------------------------------------------------===//
218 
printFirAttribute(FIROpsDialect * dialect,mlir::Attribute attr,mlir::DialectAsmPrinter & p)219 void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
220                             mlir::DialectAsmPrinter &p) {
221   auto &os = p.getStream();
222   if (auto exact = attr.dyn_cast<fir::ExactTypeAttr>()) {
223     os << fir::ExactTypeAttr::getAttrName() << '<';
224     p.printType(exact.getType());
225     os << '>';
226   } else if (auto sub = attr.dyn_cast<fir::SubclassAttr>()) {
227     os << fir::SubclassAttr::getAttrName() << '<';
228     p.printType(sub.getType());
229     os << '>';
230   } else if (attr.dyn_cast_or_null<fir::PointIntervalAttr>()) {
231     os << fir::PointIntervalAttr::getAttrName();
232   } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) {
233     os << fir::ClosedIntervalAttr::getAttrName();
234   } else if (attr.dyn_cast_or_null<fir::LowerBoundAttr>()) {
235     os << fir::LowerBoundAttr::getAttrName();
236   } else if (attr.dyn_cast_or_null<fir::UpperBoundAttr>()) {
237     os << fir::UpperBoundAttr::getAttrName();
238   } else if (auto a = attr.dyn_cast_or_null<fir::RealAttr>()) {
239     os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x";
240     llvm::SmallString<40> ss;
241     a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16);
242     os << ss << '>';
243   } else {
244     // don't know how to print the attribute, so use a default
245     os << "<(unknown attribute)>";
246   }
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // FIROpsDialect
251 //===----------------------------------------------------------------------===//
252 
registerAttributes()253 void FIROpsDialect::registerAttributes() {
254   addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
255                 PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
256 }
257