1 //===-- FIRAttr.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "flang/Optimizer/Dialect/FIRAttr.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Support/KindMapping.h"
16 #include "mlir/IR/AttributeSupport.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "llvm/ADT/SmallString.h"
20
21 using namespace fir;
22
23 namespace fir::detail {
24
25 struct RealAttributeStorage : public mlir::AttributeStorage {
26 using KeyTy = std::pair<int, llvm::APFloat>;
27
RealAttributeStoragefir::detail::RealAttributeStorage28 RealAttributeStorage(int kind, const llvm::APFloat &value)
29 : kind(kind), value(value) {}
RealAttributeStoragefir::detail::RealAttributeStorage30 RealAttributeStorage(const KeyTy &key)
31 : RealAttributeStorage(key.first, key.second) {}
32
hashKeyfir::detail::RealAttributeStorage33 static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); }
34
operator ==fir::detail::RealAttributeStorage35 bool operator==(const KeyTy &key) const {
36 return key.first == kind &&
37 key.second.compare(value) == llvm::APFloatBase::cmpEqual;
38 }
39
40 static RealAttributeStorage *
constructfir::detail::RealAttributeStorage41 construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
42 return new (allocator.allocate<RealAttributeStorage>())
43 RealAttributeStorage(key);
44 }
45
getFKindfir::detail::RealAttributeStorage46 KindTy getFKind() const { return kind; }
getValuefir::detail::RealAttributeStorage47 llvm::APFloat getValue() const { return value; }
48
49 private:
50 int kind;
51 llvm::APFloat value;
52 };
53
54 /// An attribute representing a reference to a type.
55 struct TypeAttributeStorage : public mlir::AttributeStorage {
56 using KeyTy = mlir::Type;
57
TypeAttributeStoragefir::detail::TypeAttributeStorage58 TypeAttributeStorage(mlir::Type value) : value(value) {
59 assert(value && "must not be of Type null");
60 }
61
62 /// Key equality function.
operator ==fir::detail::TypeAttributeStorage63 bool operator==(const KeyTy &key) const { return key == value; }
64
65 /// Construct a new storage instance.
66 static TypeAttributeStorage *
constructfir::detail::TypeAttributeStorage67 construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) {
68 return new (allocator.allocate<TypeAttributeStorage>())
69 TypeAttributeStorage(key);
70 }
71
getTypefir::detail::TypeAttributeStorage72 mlir::Type getType() const { return value; }
73
74 private:
75 mlir::Type value;
76 };
77 } // namespace fir::detail
78
79 //===----------------------------------------------------------------------===//
80 // Attributes for SELECT TYPE
81 //===----------------------------------------------------------------------===//
82
get(mlir::Type value)83 ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) {
84 return Base::get(value.getContext(), value);
85 }
86
getType() const87 mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); }
88
get(mlir::Type value)89 SubclassAttr fir::SubclassAttr::get(mlir::Type value) {
90 return Base::get(value.getContext(), value);
91 }
92
getType() const93 mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); }
94
95 //===----------------------------------------------------------------------===//
96 // Attributes for SELECT CASE
97 //===----------------------------------------------------------------------===//
98
99 using AttributeUniquer = mlir::detail::AttributeUniquer;
100
get(mlir::MLIRContext * ctxt)101 ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
102 return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
103 }
104
get(mlir::MLIRContext * ctxt)105 UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
106 return AttributeUniquer::get<UpperBoundAttr>(ctxt);
107 }
108
get(mlir::MLIRContext * ctxt)109 LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
110 return AttributeUniquer::get<LowerBoundAttr>(ctxt);
111 }
112
get(mlir::MLIRContext * ctxt)113 PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
114 return AttributeUniquer::get<PointIntervalAttr>(ctxt);
115 }
116
117 //===----------------------------------------------------------------------===//
118 // RealAttr
119 //===----------------------------------------------------------------------===//
120
get(mlir::MLIRContext * ctxt,const RealAttr::ValueType & key)121 RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt,
122 const RealAttr::ValueType &key) {
123 return Base::get(ctxt, key);
124 }
125
getFKind() const126 KindTy fir::RealAttr::getFKind() const { return getImpl()->getFKind(); }
127
getValue() const128 llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); }
129
130 //===----------------------------------------------------------------------===//
131 // FIR attribute parsing
132 //===----------------------------------------------------------------------===//
133
parseFirRealAttr(FIROpsDialect * dialect,mlir::DialectAsmParser & parser,mlir::Type type)134 static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect,
135 mlir::DialectAsmParser &parser,
136 mlir::Type type) {
137 int kind = 0;
138 if (parser.parseLess() || parser.parseInteger(kind) || parser.parseComma()) {
139 parser.emitError(parser.getNameLoc(), "expected '<' kind ','");
140 return {};
141 }
142 KindMapping kindMap(dialect->getContext());
143 llvm::APFloat value(0.);
144 if (parser.parseOptionalKeyword("i")) {
145 // `i` not present, so literal float must be present
146 double dontCare;
147 if (parser.parseFloat(dontCare) || parser.parseGreater()) {
148 parser.emitError(parser.getNameLoc(), "expected real constant '>'");
149 return {};
150 }
151 auto fltStr = parser.getFullSymbolSpec()
152 .drop_until([](char c) { return c == ','; })
153 .drop_front()
154 .drop_while([](char c) { return c == ' ' || c == '\t'; })
155 .take_until([](char c) {
156 return c == '>' || c == ' ' || c == '\t';
157 });
158 value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr);
159 } else {
160 // `i` is present, so literal bitstring (hex) must be present
161 llvm::StringRef hex;
162 if (parser.parseKeyword(&hex) || parser.parseGreater()) {
163 parser.emitError(parser.getNameLoc(), "expected real constant '>'");
164 return {};
165 }
166 const llvm::fltSemantics &sem = kindMap.getFloatSemantics(kind);
167 unsigned int numBits = llvm::APFloat::semanticsSizeInBits(sem);
168 auto bits = llvm::APInt(numBits, hex.drop_front(), 16);
169 value = llvm::APFloat(sem, bits);
170 }
171 return RealAttr::get(dialect->getContext(), {kind, value});
172 }
173
parseFirAttribute(FIROpsDialect * dialect,mlir::DialectAsmParser & parser,mlir::Type type)174 mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect,
175 mlir::DialectAsmParser &parser,
176 mlir::Type type) {
177 auto loc = parser.getNameLoc();
178 llvm::StringRef attrName;
179 if (parser.parseKeyword(&attrName)) {
180 parser.emitError(loc, "expected an attribute name");
181 return {};
182 }
183
184 if (attrName == ExactTypeAttr::getAttrName()) {
185 mlir::Type type;
186 if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
187 parser.emitError(loc, "expected a type");
188 return {};
189 }
190 return ExactTypeAttr::get(type);
191 }
192 if (attrName == SubclassAttr::getAttrName()) {
193 mlir::Type type;
194 if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
195 parser.emitError(loc, "expected a subtype");
196 return {};
197 }
198 return SubclassAttr::get(type);
199 }
200 if (attrName == PointIntervalAttr::getAttrName())
201 return PointIntervalAttr::get(dialect->getContext());
202 if (attrName == LowerBoundAttr::getAttrName())
203 return LowerBoundAttr::get(dialect->getContext());
204 if (attrName == UpperBoundAttr::getAttrName())
205 return UpperBoundAttr::get(dialect->getContext());
206 if (attrName == ClosedIntervalAttr::getAttrName())
207 return ClosedIntervalAttr::get(dialect->getContext());
208 if (attrName == RealAttr::getAttrName())
209 return parseFirRealAttr(dialect, parser, type);
210
211 parser.emitError(loc, "unknown FIR attribute: ") << attrName;
212 return {};
213 }
214
215 //===----------------------------------------------------------------------===//
216 // FIR attribute pretty printer
217 //===----------------------------------------------------------------------===//
218
printFirAttribute(FIROpsDialect * dialect,mlir::Attribute attr,mlir::DialectAsmPrinter & p)219 void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
220 mlir::DialectAsmPrinter &p) {
221 auto &os = p.getStream();
222 if (auto exact = attr.dyn_cast<fir::ExactTypeAttr>()) {
223 os << fir::ExactTypeAttr::getAttrName() << '<';
224 p.printType(exact.getType());
225 os << '>';
226 } else if (auto sub = attr.dyn_cast<fir::SubclassAttr>()) {
227 os << fir::SubclassAttr::getAttrName() << '<';
228 p.printType(sub.getType());
229 os << '>';
230 } else if (attr.dyn_cast_or_null<fir::PointIntervalAttr>()) {
231 os << fir::PointIntervalAttr::getAttrName();
232 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) {
233 os << fir::ClosedIntervalAttr::getAttrName();
234 } else if (attr.dyn_cast_or_null<fir::LowerBoundAttr>()) {
235 os << fir::LowerBoundAttr::getAttrName();
236 } else if (attr.dyn_cast_or_null<fir::UpperBoundAttr>()) {
237 os << fir::UpperBoundAttr::getAttrName();
238 } else if (auto a = attr.dyn_cast_or_null<fir::RealAttr>()) {
239 os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x";
240 llvm::SmallString<40> ss;
241 a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16);
242 os << ss << '>';
243 } else {
244 // don't know how to print the attribute, so use a default
245 os << "<(unknown attribute)>";
246 }
247 }
248
249 //===----------------------------------------------------------------------===//
250 // FIROpsDialect
251 //===----------------------------------------------------------------------===//
252
registerAttributes()253 void FIROpsDialect::registerAttributes() {
254 addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
255 PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
256 }
257