1 //===-- FIRType.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Support/KindMapping.h"
16 #include "flang/Tools/PointerModels.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/ErrorHandling.h"
25
26 #define GET_TYPEDEF_CLASSES
27 #include "flang/Optimizer/Dialect/FIROpsTypes.cpp.inc"
28
29 using namespace fir;
30
31 namespace {
32
33 template <typename TYPE>
parseIntSingleton(mlir::AsmParser & parser)34 TYPE parseIntSingleton(mlir::AsmParser &parser) {
35 int kind = 0;
36 if (parser.parseLess() || parser.parseInteger(kind) || parser.parseGreater())
37 return {};
38 return TYPE::get(parser.getContext(), kind);
39 }
40
41 template <typename TYPE>
parseKindSingleton(mlir::AsmParser & parser)42 TYPE parseKindSingleton(mlir::AsmParser &parser) {
43 return parseIntSingleton<TYPE>(parser);
44 }
45
46 template <typename TYPE>
parseRankSingleton(mlir::AsmParser & parser)47 TYPE parseRankSingleton(mlir::AsmParser &parser) {
48 return parseIntSingleton<TYPE>(parser);
49 }
50
51 template <typename TYPE>
parseTypeSingleton(mlir::AsmParser & parser)52 TYPE parseTypeSingleton(mlir::AsmParser &parser) {
53 mlir::Type ty;
54 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater())
55 return {};
56 return TYPE::get(ty);
57 }
58
59 /// Is `ty` a standard or FIR integer type?
isaIntegerType(mlir::Type ty)60 static bool isaIntegerType(mlir::Type ty) {
61 // TODO: why aren't we using isa_integer? investigatation required.
62 return ty.isa<mlir::IntegerType>() || ty.isa<fir::IntegerType>();
63 }
64
verifyRecordMemberType(mlir::Type ty)65 bool verifyRecordMemberType(mlir::Type ty) {
66 return !(ty.isa<BoxCharType>() || ty.isa<BoxProcType>() ||
67 ty.isa<ShapeType>() || ty.isa<ShapeShiftType>() ||
68 ty.isa<ShiftType>() || ty.isa<SliceType>() || ty.isa<FieldType>() ||
69 ty.isa<LenType>() || ty.isa<ReferenceType>() ||
70 ty.isa<TypeDescType>());
71 }
72
verifySameLists(llvm::ArrayRef<RecordType::TypePair> a1,llvm::ArrayRef<RecordType::TypePair> a2)73 bool verifySameLists(llvm::ArrayRef<RecordType::TypePair> a1,
74 llvm::ArrayRef<RecordType::TypePair> a2) {
75 // FIXME: do we need to allow for any variance here?
76 return a1 == a2;
77 }
78
verifyDerived(mlir::AsmParser & parser,RecordType derivedTy,llvm::ArrayRef<RecordType::TypePair> lenPList,llvm::ArrayRef<RecordType::TypePair> typeList)79 RecordType verifyDerived(mlir::AsmParser &parser, RecordType derivedTy,
80 llvm::ArrayRef<RecordType::TypePair> lenPList,
81 llvm::ArrayRef<RecordType::TypePair> typeList) {
82 auto loc = parser.getNameLoc();
83 if (!verifySameLists(derivedTy.getLenParamList(), lenPList) ||
84 !verifySameLists(derivedTy.getTypeList(), typeList)) {
85 parser.emitError(loc, "cannot redefine record type members");
86 return {};
87 }
88 for (auto &p : lenPList)
89 if (!isaIntegerType(p.second)) {
90 parser.emitError(loc, "LEN parameter must be integral type");
91 return {};
92 }
93 for (auto &p : typeList)
94 if (!verifyRecordMemberType(p.second)) {
95 parser.emitError(loc, "field parameter has invalid type");
96 return {};
97 }
98 llvm::StringSet<> uniq;
99 for (auto &p : lenPList)
100 if (!uniq.insert(p.first).second) {
101 parser.emitError(loc, "LEN parameter cannot have duplicate name");
102 return {};
103 }
104 for (auto &p : typeList)
105 if (!uniq.insert(p.first).second) {
106 parser.emitError(loc, "field cannot have duplicate name");
107 return {};
108 }
109 return derivedTy;
110 }
111
112 } // namespace
113
114 // Implementation of the thin interface from dialect to type parser
115
parseFirType(FIROpsDialect * dialect,mlir::DialectAsmParser & parser)116 mlir::Type fir::parseFirType(FIROpsDialect *dialect,
117 mlir::DialectAsmParser &parser) {
118 mlir::StringRef typeTag;
119 mlir::Type genType;
120 auto parseResult = generatedTypeParser(parser, &typeTag, genType);
121 if (parseResult.hasValue())
122 return genType;
123 parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag;
124 return {};
125 }
126
127 namespace fir {
128 namespace detail {
129
130 // Type storage classes
131
132 /// Derived type storage
133 struct RecordTypeStorage : public mlir::TypeStorage {
134 using KeyTy = llvm::StringRef;
135
hashKeyfir::detail::RecordTypeStorage136 static unsigned hashKey(const KeyTy &key) {
137 return llvm::hash_combine(key.str());
138 }
139
operator ==fir::detail::RecordTypeStorage140 bool operator==(const KeyTy &key) const { return key == getName(); }
141
constructfir::detail::RecordTypeStorage142 static RecordTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
143 const KeyTy &key) {
144 auto *storage = allocator.allocate<RecordTypeStorage>();
145 return new (storage) RecordTypeStorage{key};
146 }
147
getNamefir::detail::RecordTypeStorage148 llvm::StringRef getName() const { return name; }
149
setLenParamListfir::detail::RecordTypeStorage150 void setLenParamList(llvm::ArrayRef<RecordType::TypePair> list) {
151 lens = list;
152 }
getLenParamListfir::detail::RecordTypeStorage153 llvm::ArrayRef<RecordType::TypePair> getLenParamList() const { return lens; }
154
setTypeListfir::detail::RecordTypeStorage155 void setTypeList(llvm::ArrayRef<RecordType::TypePair> list) { types = list; }
getTypeListfir::detail::RecordTypeStorage156 llvm::ArrayRef<RecordType::TypePair> getTypeList() const { return types; }
157
finalizefir::detail::RecordTypeStorage158 void finalize(llvm::ArrayRef<RecordType::TypePair> lenParamList,
159 llvm::ArrayRef<RecordType::TypePair> typeList) {
160 if (finalized)
161 return;
162 finalized = true;
163 setLenParamList(lenParamList);
164 setTypeList(typeList);
165 }
166
167 protected:
168 std::string name;
169 bool finalized;
170 std::vector<RecordType::TypePair> lens;
171 std::vector<RecordType::TypePair> types;
172
173 private:
174 RecordTypeStorage() = delete;
RecordTypeStoragefir::detail::RecordTypeStorage175 explicit RecordTypeStorage(llvm::StringRef name)
176 : name{name}, finalized{false} {}
177 };
178
179 } // namespace detail
180
181 template <typename A, typename B>
inbounds(A v,B lb,B ub)182 bool inbounds(A v, B lb, B ub) {
183 return v >= lb && v < ub;
184 }
185
isa_fir_type(mlir::Type t)186 bool isa_fir_type(mlir::Type t) {
187 return llvm::isa<FIROpsDialect>(t.getDialect());
188 }
189
isa_std_type(mlir::Type t)190 bool isa_std_type(mlir::Type t) {
191 return llvm::isa<mlir::BuiltinDialect>(t.getDialect());
192 }
193
isa_fir_or_std_type(mlir::Type t)194 bool isa_fir_or_std_type(mlir::Type t) {
195 if (auto funcType = t.dyn_cast<mlir::FunctionType>())
196 return llvm::all_of(funcType.getInputs(), isa_fir_or_std_type) &&
197 llvm::all_of(funcType.getResults(), isa_fir_or_std_type);
198 return isa_fir_type(t) || isa_std_type(t);
199 }
200
dyn_cast_ptrEleTy(mlir::Type t)201 mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
202 return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
203 .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
204 fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
205 .Default([](mlir::Type) { return mlir::Type{}; });
206 }
207
dyn_cast_ptrOrBoxEleTy(mlir::Type t)208 mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
209 return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
210 .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
211 fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
212 .Case<fir::BoxType>([](auto p) {
213 auto eleTy = p.getEleTy();
214 if (auto ty = fir::dyn_cast_ptrEleTy(eleTy))
215 return ty;
216 return eleTy;
217 })
218 .Default([](mlir::Type) { return mlir::Type{}; });
219 }
220
hasDynamicSize(fir::RecordType recTy)221 static bool hasDynamicSize(fir::RecordType recTy) {
222 for (auto field : recTy.getTypeList()) {
223 if (auto arr = field.second.dyn_cast<fir::SequenceType>()) {
224 if (sequenceWithNonConstantShape(arr))
225 return true;
226 } else if (characterWithDynamicLen(field.second)) {
227 return true;
228 } else if (auto rec = field.second.dyn_cast<fir::RecordType>()) {
229 if (hasDynamicSize(rec))
230 return true;
231 }
232 }
233 return false;
234 }
235
hasDynamicSize(mlir::Type t)236 bool hasDynamicSize(mlir::Type t) {
237 if (auto arr = t.dyn_cast<fir::SequenceType>()) {
238 if (sequenceWithNonConstantShape(arr))
239 return true;
240 t = arr.getEleTy();
241 }
242 if (characterWithDynamicLen(t))
243 return true;
244 if (auto rec = t.dyn_cast<fir::RecordType>())
245 return hasDynamicSize(rec);
246 return false;
247 }
248
isPointerType(mlir::Type ty)249 bool isPointerType(mlir::Type ty) {
250 if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
251 ty = refTy;
252 if (auto boxTy = ty.dyn_cast<fir::BoxType>())
253 return boxTy.getEleTy().isa<fir::PointerType>();
254 return false;
255 }
256
isAllocatableType(mlir::Type ty)257 bool isAllocatableType(mlir::Type ty) {
258 if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
259 ty = refTy;
260 if (auto boxTy = ty.dyn_cast<fir::BoxType>())
261 return boxTy.getEleTy().isa<fir::HeapType>();
262 return false;
263 }
264
isUnlimitedPolymorphicType(mlir::Type ty)265 bool isUnlimitedPolymorphicType(mlir::Type ty) {
266 if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
267 ty = refTy;
268 if (auto boxTy = ty.dyn_cast<fir::BoxType>())
269 return boxTy.getEleTy().isa<mlir::NoneType>();
270 return false;
271 }
272
isRecordWithAllocatableMember(mlir::Type ty)273 bool isRecordWithAllocatableMember(mlir::Type ty) {
274 if (auto recTy = ty.dyn_cast<fir::RecordType>())
275 for (auto [field, memTy] : recTy.getTypeList()) {
276 if (fir::isAllocatableType(memTy))
277 return true;
278 // A record type cannot recursively include itself as a direct member.
279 // There must be an intervening `ptr` type, so recursion is safe here.
280 if (memTy.isa<fir::RecordType>() && isRecordWithAllocatableMember(memTy))
281 return true;
282 }
283 return false;
284 }
285
unwrapAllRefAndSeqType(mlir::Type ty)286 mlir::Type unwrapAllRefAndSeqType(mlir::Type ty) {
287 while (true) {
288 mlir::Type nt = unwrapSequenceType(unwrapRefType(ty));
289 if (auto vecTy = nt.dyn_cast<fir::VectorType>())
290 nt = vecTy.getEleTy();
291 if (nt == ty)
292 return ty;
293 ty = nt;
294 }
295 }
296
unwrapSeqOrBoxedSeqType(mlir::Type ty)297 mlir::Type unwrapSeqOrBoxedSeqType(mlir::Type ty) {
298 if (auto seqTy = ty.dyn_cast<fir::SequenceType>())
299 return seqTy.getEleTy();
300 if (auto boxTy = ty.dyn_cast<fir::BoxType>()) {
301 auto eleTy = unwrapRefType(boxTy.getEleTy());
302 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>())
303 return seqTy.getEleTy();
304 }
305 return ty;
306 }
307
308 } // namespace fir
309
310 namespace {
311
312 static llvm::SmallPtrSet<detail::RecordTypeStorage const *, 4>
313 recordTypeVisited;
314
315 } // namespace
316
verifyIntegralType(mlir::Type type)317 void fir::verifyIntegralType(mlir::Type type) {
318 if (isaIntegerType(type) || type.isa<mlir::IndexType>())
319 return;
320 llvm::report_fatal_error("expected integral type");
321 }
322
printFirType(FIROpsDialect *,mlir::Type ty,mlir::DialectAsmPrinter & p)323 void fir::printFirType(FIROpsDialect *, mlir::Type ty,
324 mlir::DialectAsmPrinter &p) {
325 if (mlir::failed(generatedTypePrinter(ty, p)))
326 llvm::report_fatal_error("unknown type to print");
327 }
328
isa_unknown_size_box(mlir::Type t)329 bool fir::isa_unknown_size_box(mlir::Type t) {
330 if (auto boxTy = t.dyn_cast<fir::BoxType>()) {
331 auto eleTy = boxTy.getEleTy();
332 if (auto actualEleTy = fir::dyn_cast_ptrEleTy(eleTy))
333 eleTy = actualEleTy;
334 if (eleTy.isa<mlir::NoneType>())
335 return true;
336 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>())
337 if (seqTy.hasUnknownShape())
338 return true;
339 }
340 return false;
341 }
342
343 //===----------------------------------------------------------------------===//
344 // BoxProcType
345 //===----------------------------------------------------------------------===//
346
347 // `boxproc` `<` return-type `>`
parse(mlir::AsmParser & parser)348 mlir::Type BoxProcType::parse(mlir::AsmParser &parser) {
349 mlir::Type ty;
350 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater())
351 return {};
352 return get(parser.getContext(), ty);
353 }
354
print(mlir::AsmPrinter & printer) const355 void fir::BoxProcType::print(mlir::AsmPrinter &printer) const {
356 printer << "<" << getEleTy() << '>';
357 }
358
359 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy)360 BoxProcType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
361 mlir::Type eleTy) {
362 if (eleTy.isa<mlir::FunctionType>())
363 return mlir::success();
364 if (auto refTy = eleTy.dyn_cast<ReferenceType>())
365 if (refTy.isa<mlir::FunctionType>())
366 return mlir::success();
367 return emitError() << "invalid type for boxproc" << eleTy << '\n';
368 }
369
cannotBePointerOrHeapElementType(mlir::Type eleTy)370 static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) {
371 return eleTy.isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
372 SliceType, FieldType, LenType, HeapType, PointerType,
373 ReferenceType, TypeDescType>();
374 }
375
376 //===----------------------------------------------------------------------===//
377 // BoxType
378 //===----------------------------------------------------------------------===//
379
380 // `box` `<` type (',' affine-map)? `>`
parse(mlir::AsmParser & parser)381 mlir::Type fir::BoxType::parse(mlir::AsmParser &parser) {
382 mlir::Type ofTy;
383 if (parser.parseLess() || parser.parseType(ofTy))
384 return {};
385
386 mlir::AffineMapAttr map;
387 if (!parser.parseOptionalComma()) {
388 if (parser.parseAttribute(map)) {
389 parser.emitError(parser.getCurrentLocation(), "expected affine map");
390 return {};
391 }
392 }
393 if (parser.parseGreater())
394 return {};
395 return get(ofTy, map);
396 }
397
print(mlir::AsmPrinter & printer) const398 void fir::BoxType::print(mlir::AsmPrinter &printer) const {
399 printer << "<" << getEleTy();
400 if (auto map = getLayoutMap()) {
401 printer << ", " << map;
402 }
403 printer << '>';
404 }
405
406 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy,mlir::AffineMapAttr map)407 fir::BoxType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
408 mlir::Type eleTy, mlir::AffineMapAttr map) {
409 // TODO
410 return mlir::success();
411 }
412
413 //===----------------------------------------------------------------------===//
414 // BoxCharType
415 //===----------------------------------------------------------------------===//
416
parse(mlir::AsmParser & parser)417 mlir::Type fir::BoxCharType::parse(mlir::AsmParser &parser) {
418 return parseKindSingleton<fir::BoxCharType>(parser);
419 }
420
print(mlir::AsmPrinter & printer) const421 void fir::BoxCharType::print(mlir::AsmPrinter &printer) const {
422 printer << "<" << getKind() << ">";
423 }
424
425 CharacterType
getElementType(mlir::MLIRContext * context) const426 fir::BoxCharType::getElementType(mlir::MLIRContext *context) const {
427 return CharacterType::getUnknownLen(context, getKind());
428 }
429
getEleTy() const430 CharacterType fir::BoxCharType::getEleTy() const {
431 return getElementType(getContext());
432 }
433
434 //===----------------------------------------------------------------------===//
435 // CharacterType
436 //===----------------------------------------------------------------------===//
437
438 // `char` `<` kind [`,` `len`] `>`
parse(mlir::AsmParser & parser)439 mlir::Type fir::CharacterType::parse(mlir::AsmParser &parser) {
440 int kind = 0;
441 if (parser.parseLess() || parser.parseInteger(kind))
442 return {};
443 CharacterType::LenType len = 1;
444 if (mlir::succeeded(parser.parseOptionalComma())) {
445 if (mlir::succeeded(parser.parseOptionalQuestion())) {
446 len = fir::CharacterType::unknownLen();
447 } else if (!mlir::succeeded(parser.parseInteger(len))) {
448 return {};
449 }
450 }
451 if (parser.parseGreater())
452 return {};
453 return get(parser.getContext(), kind, len);
454 }
455
print(mlir::AsmPrinter & printer) const456 void fir::CharacterType::print(mlir::AsmPrinter &printer) const {
457 printer << "<" << getFKind();
458 auto len = getLen();
459 if (len != fir::CharacterType::singleton()) {
460 printer << ',';
461 if (len == fir::CharacterType::unknownLen())
462 printer << '?';
463 else
464 printer << len;
465 }
466 printer << '>';
467 }
468
469 //===----------------------------------------------------------------------===//
470 // ComplexType
471 //===----------------------------------------------------------------------===//
472
parse(mlir::AsmParser & parser)473 mlir::Type fir::ComplexType::parse(mlir::AsmParser &parser) {
474 return parseKindSingleton<fir::ComplexType>(parser);
475 }
476
print(mlir::AsmPrinter & printer) const477 void fir::ComplexType::print(mlir::AsmPrinter &printer) const {
478 printer << "<" << getFKind() << '>';
479 }
480
getElementType() const481 mlir::Type fir::ComplexType::getElementType() const {
482 return fir::RealType::get(getContext(), getFKind());
483 }
484
485 // Return the MLIR float type of the complex element type.
getEleType(const fir::KindMapping & kindMap) const486 mlir::Type fir::ComplexType::getEleType(const fir::KindMapping &kindMap) const {
487 auto fkind = getFKind();
488 auto realTypeID = kindMap.getRealTypeID(fkind);
489 return fir::fromRealTypeID(getContext(), realTypeID, fkind);
490 }
491
492 //===----------------------------------------------------------------------===//
493 // HeapType
494 //===----------------------------------------------------------------------===//
495
496 // `heap` `<` type `>`
parse(mlir::AsmParser & parser)497 mlir::Type fir::HeapType::parse(mlir::AsmParser &parser) {
498 return parseTypeSingleton<HeapType>(parser);
499 }
500
print(mlir::AsmPrinter & printer) const501 void fir::HeapType::print(mlir::AsmPrinter &printer) const {
502 printer << "<" << getEleTy() << '>';
503 }
504
505 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy)506 fir::HeapType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
507 mlir::Type eleTy) {
508 if (cannotBePointerOrHeapElementType(eleTy))
509 return emitError() << "cannot build a heap pointer to type: " << eleTy
510 << '\n';
511 return mlir::success();
512 }
513
514 //===----------------------------------------------------------------------===//
515 // IntegerType
516 //===----------------------------------------------------------------------===//
517
518 // `int` `<` kind `>`
parse(mlir::AsmParser & parser)519 mlir::Type fir::IntegerType::parse(mlir::AsmParser &parser) {
520 return parseKindSingleton<fir::IntegerType>(parser);
521 }
522
print(mlir::AsmPrinter & printer) const523 void fir::IntegerType::print(mlir::AsmPrinter &printer) const {
524 printer << "<" << getFKind() << '>';
525 }
526
527 //===----------------------------------------------------------------------===//
528 // LogicalType
529 //===----------------------------------------------------------------------===//
530
531 // `logical` `<` kind `>`
parse(mlir::AsmParser & parser)532 mlir::Type fir::LogicalType::parse(mlir::AsmParser &parser) {
533 return parseKindSingleton<fir::LogicalType>(parser);
534 }
535
print(mlir::AsmPrinter & printer) const536 void fir::LogicalType::print(mlir::AsmPrinter &printer) const {
537 printer << "<" << getFKind() << '>';
538 }
539
540 //===----------------------------------------------------------------------===//
541 // PointerType
542 //===----------------------------------------------------------------------===//
543
544 // `ptr` `<` type `>`
parse(mlir::AsmParser & parser)545 mlir::Type fir::PointerType::parse(mlir::AsmParser &parser) {
546 return parseTypeSingleton<fir::PointerType>(parser);
547 }
548
print(mlir::AsmPrinter & printer) const549 void fir::PointerType::print(mlir::AsmPrinter &printer) const {
550 printer << "<" << getEleTy() << '>';
551 }
552
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy)553 mlir::LogicalResult fir::PointerType::verify(
554 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
555 mlir::Type eleTy) {
556 if (cannotBePointerOrHeapElementType(eleTy))
557 return emitError() << "cannot build a pointer to type: " << eleTy << '\n';
558 return mlir::success();
559 }
560
561 //===----------------------------------------------------------------------===//
562 // RealType
563 //===----------------------------------------------------------------------===//
564
565 // `real` `<` kind `>`
parse(mlir::AsmParser & parser)566 mlir::Type fir::RealType::parse(mlir::AsmParser &parser) {
567 return parseKindSingleton<fir::RealType>(parser);
568 }
569
print(mlir::AsmPrinter & printer) const570 void fir::RealType::print(mlir::AsmPrinter &printer) const {
571 printer << "<" << getFKind() << '>';
572 }
573
574 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,KindTy fKind)575 fir::RealType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
576 KindTy fKind) {
577 // TODO
578 return mlir::success();
579 }
580
581 //===----------------------------------------------------------------------===//
582 // RecordType
583 //===----------------------------------------------------------------------===//
584
585 // Fortran derived type
586 // `type` `<` name
587 // (`(` id `:` type (`,` id `:` type)* `)`)?
588 // (`{` id `:` type (`,` id `:` type)* `}`)? '>'
parse(mlir::AsmParser & parser)589 mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
590 llvm::StringRef name;
591 if (parser.parseLess() || parser.parseKeyword(&name))
592 return {};
593 RecordType result = RecordType::get(parser.getContext(), name);
594
595 RecordType::TypeList lenParamList;
596 if (!parser.parseOptionalLParen()) {
597 while (true) {
598 llvm::StringRef lenparam;
599 mlir::Type intTy;
600 if (parser.parseKeyword(&lenparam) || parser.parseColon() ||
601 parser.parseType(intTy)) {
602 parser.emitError(parser.getNameLoc(), "expected LEN parameter list");
603 return {};
604 }
605 lenParamList.emplace_back(lenparam, intTy);
606 if (parser.parseOptionalComma())
607 break;
608 }
609 if (parser.parseRParen())
610 return {};
611 }
612
613 RecordType::TypeList typeList;
614 if (!parser.parseOptionalLBrace()) {
615 while (true) {
616 llvm::StringRef field;
617 mlir::Type fldTy;
618 if (parser.parseKeyword(&field) || parser.parseColon() ||
619 parser.parseType(fldTy)) {
620 parser.emitError(parser.getNameLoc(), "expected field type list");
621 return {};
622 }
623 typeList.emplace_back(field, fldTy);
624 if (parser.parseOptionalComma())
625 break;
626 }
627 if (parser.parseRBrace())
628 return {};
629 }
630
631 if (parser.parseGreater())
632 return {};
633
634 if (lenParamList.empty() && typeList.empty())
635 return result;
636
637 result.finalize(lenParamList, typeList);
638 return verifyDerived(parser, result, lenParamList, typeList);
639 }
640
print(mlir::AsmPrinter & printer) const641 void fir::RecordType::print(mlir::AsmPrinter &printer) const {
642 printer << "<" << getName();
643 if (!recordTypeVisited.count(uniqueKey())) {
644 recordTypeVisited.insert(uniqueKey());
645 if (getLenParamList().size()) {
646 char ch = '(';
647 for (auto p : getLenParamList()) {
648 printer << ch << p.first << ':';
649 p.second.print(printer.getStream());
650 ch = ',';
651 }
652 printer << ')';
653 }
654 if (getTypeList().size()) {
655 char ch = '{';
656 for (auto p : getTypeList()) {
657 printer << ch << p.first << ':';
658 p.second.print(printer.getStream());
659 ch = ',';
660 }
661 printer << '}';
662 }
663 recordTypeVisited.erase(uniqueKey());
664 }
665 printer << '>';
666 }
667
finalize(llvm::ArrayRef<TypePair> lenPList,llvm::ArrayRef<TypePair> typeList)668 void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
669 llvm::ArrayRef<TypePair> typeList) {
670 getImpl()->finalize(lenPList, typeList);
671 }
672
getName() const673 llvm::StringRef fir::RecordType::getName() const {
674 return getImpl()->getName();
675 }
676
getTypeList() const677 RecordType::TypeList fir::RecordType::getTypeList() const {
678 return getImpl()->getTypeList();
679 }
680
getLenParamList() const681 RecordType::TypeList fir::RecordType::getLenParamList() const {
682 return getImpl()->getLenParamList();
683 }
684
uniqueKey() const685 detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
686 return getImpl();
687 }
688
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,llvm::StringRef name)689 mlir::LogicalResult fir::RecordType::verify(
690 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
691 llvm::StringRef name) {
692 if (name.size() == 0)
693 return emitError() << "record types must have a name";
694 return mlir::success();
695 }
696
getType(llvm::StringRef ident)697 mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
698 for (auto f : getTypeList())
699 if (ident == f.first)
700 return f.second;
701 return {};
702 }
703
getFieldIndex(llvm::StringRef ident)704 unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
705 for (auto f : llvm::enumerate(getTypeList()))
706 if (ident == f.value().first)
707 return f.index();
708 return std::numeric_limits<unsigned>::max();
709 }
710
711 //===----------------------------------------------------------------------===//
712 // ReferenceType
713 //===----------------------------------------------------------------------===//
714
715 // `ref` `<` type `>`
parse(mlir::AsmParser & parser)716 mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
717 return parseTypeSingleton<fir::ReferenceType>(parser);
718 }
719
print(mlir::AsmPrinter & printer) const720 void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
721 printer << "<" << getEleTy() << '>';
722 }
723
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy)724 mlir::LogicalResult fir::ReferenceType::verify(
725 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
726 mlir::Type eleTy) {
727 if (eleTy.isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
728 ReferenceType, TypeDescType>())
729 return emitError() << "cannot build a reference to type: " << eleTy << '\n';
730 return mlir::success();
731 }
732
733 //===----------------------------------------------------------------------===//
734 // SequenceType
735 //===----------------------------------------------------------------------===//
736
737 // `array` `<` `*` | bounds (`x` bounds)* `:` type (',' affine-map)? `>`
738 // bounds ::= `?` | int-lit
parse(mlir::AsmParser & parser)739 mlir::Type fir::SequenceType::parse(mlir::AsmParser &parser) {
740 if (parser.parseLess())
741 return {};
742 SequenceType::Shape shape;
743 if (parser.parseOptionalStar()) {
744 if (parser.parseDimensionList(shape, /*allowDynamic=*/true))
745 return {};
746 } else if (parser.parseColon()) {
747 return {};
748 }
749 mlir::Type eleTy;
750 if (parser.parseType(eleTy))
751 return {};
752 mlir::AffineMapAttr map;
753 if (!parser.parseOptionalComma()) {
754 if (parser.parseAttribute(map)) {
755 parser.emitError(parser.getNameLoc(), "expecting affine map");
756 return {};
757 }
758 }
759 if (parser.parseGreater())
760 return {};
761 return SequenceType::get(parser.getContext(), shape, eleTy, map);
762 }
763
print(mlir::AsmPrinter & printer) const764 void fir::SequenceType::print(mlir::AsmPrinter &printer) const {
765 auto shape = getShape();
766 if (shape.size()) {
767 printer << '<';
768 for (const auto &b : shape) {
769 if (b >= 0)
770 printer << b << 'x';
771 else
772 printer << "?x";
773 }
774 } else {
775 printer << "<*:";
776 }
777 printer << getEleTy();
778 if (auto map = getLayoutMap()) {
779 printer << ", ";
780 map.print(printer.getStream());
781 }
782 printer << '>';
783 }
784
getConstantRows() const785 unsigned fir::SequenceType::getConstantRows() const {
786 auto shape = getShape();
787 unsigned count = 0;
788 for (auto d : shape) {
789 if (d < 0)
790 break;
791 ++count;
792 }
793 return count;
794 }
795
796 // This test helps us determine if we can degenerate an array to a
797 // pointer to some interior section (possibly a single element) of the
798 // sequence. This is used to determine if we can lower to the LLVM IR.
hasConstantInterior() const799 bool fir::SequenceType::hasConstantInterior() const {
800 if (hasUnknownShape())
801 return true;
802 auto rows = getConstantRows();
803 auto dim = getDimension();
804 if (rows == dim)
805 return true;
806 auto shape = getShape();
807 for (unsigned i = rows, size = dim; i < size; ++i)
808 if (shape[i] != getUnknownExtent())
809 return false;
810 return true;
811 }
812
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,llvm::ArrayRef<int64_t> shape,mlir::Type eleTy,mlir::AffineMapAttr layoutMap)813 mlir::LogicalResult fir::SequenceType::verify(
814 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
815 llvm::ArrayRef<int64_t> shape, mlir::Type eleTy,
816 mlir::AffineMapAttr layoutMap) {
817 // DIMENSION attribute can only be applied to an intrinsic or record type
818 if (eleTy.isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
819 ShiftType, SliceType, FieldType, LenType, HeapType, PointerType,
820 ReferenceType, TypeDescType, fir::VectorType, SequenceType>())
821 return emitError() << "cannot build an array of this element type: "
822 << eleTy << '\n';
823 return mlir::success();
824 }
825
826 //===----------------------------------------------------------------------===//
827 // ShapeType
828 //===----------------------------------------------------------------------===//
829
parse(mlir::AsmParser & parser)830 mlir::Type fir::ShapeType::parse(mlir::AsmParser &parser) {
831 return parseRankSingleton<fir::ShapeType>(parser);
832 }
833
print(mlir::AsmPrinter & printer) const834 void fir::ShapeType::print(mlir::AsmPrinter &printer) const {
835 printer << "<" << getImpl()->rank << ">";
836 }
837
838 //===----------------------------------------------------------------------===//
839 // ShapeShiftType
840 //===----------------------------------------------------------------------===//
841
parse(mlir::AsmParser & parser)842 mlir::Type fir::ShapeShiftType::parse(mlir::AsmParser &parser) {
843 return parseRankSingleton<fir::ShapeShiftType>(parser);
844 }
845
print(mlir::AsmPrinter & printer) const846 void fir::ShapeShiftType::print(mlir::AsmPrinter &printer) const {
847 printer << "<" << getRank() << ">";
848 }
849
850 //===----------------------------------------------------------------------===//
851 // ShiftType
852 //===----------------------------------------------------------------------===//
853
parse(mlir::AsmParser & parser)854 mlir::Type fir::ShiftType::parse(mlir::AsmParser &parser) {
855 return parseRankSingleton<fir::ShiftType>(parser);
856 }
857
print(mlir::AsmPrinter & printer) const858 void fir::ShiftType::print(mlir::AsmPrinter &printer) const {
859 printer << "<" << getRank() << ">";
860 }
861
862 //===----------------------------------------------------------------------===//
863 // SliceType
864 //===----------------------------------------------------------------------===//
865
866 // `slice` `<` rank `>`
parse(mlir::AsmParser & parser)867 mlir::Type fir::SliceType::parse(mlir::AsmParser &parser) {
868 return parseRankSingleton<fir::SliceType>(parser);
869 }
870
print(mlir::AsmPrinter & printer) const871 void fir::SliceType::print(mlir::AsmPrinter &printer) const {
872 printer << "<" << getRank() << '>';
873 }
874
875 //===----------------------------------------------------------------------===//
876 // TypeDescType
877 //===----------------------------------------------------------------------===//
878
879 // `tdesc` `<` type `>`
parse(mlir::AsmParser & parser)880 mlir::Type fir::TypeDescType::parse(mlir::AsmParser &parser) {
881 return parseTypeSingleton<fir::TypeDescType>(parser);
882 }
883
print(mlir::AsmPrinter & printer) const884 void fir::TypeDescType::print(mlir::AsmPrinter &printer) const {
885 printer << "<" << getOfTy() << '>';
886 }
887
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy)888 mlir::LogicalResult fir::TypeDescType::verify(
889 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
890 mlir::Type eleTy) {
891 if (eleTy.isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
892 ShiftType, SliceType, FieldType, LenType, ReferenceType,
893 TypeDescType>())
894 return emitError() << "cannot build a type descriptor of type: " << eleTy
895 << '\n';
896 return mlir::success();
897 }
898
899 //===----------------------------------------------------------------------===//
900 // VectorType
901 //===----------------------------------------------------------------------===//
902
903 // `vector` `<` len `:` type `>`
parse(mlir::AsmParser & parser)904 mlir::Type fir::VectorType::parse(mlir::AsmParser &parser) {
905 int64_t len = 0;
906 mlir::Type eleTy;
907 if (parser.parseLess() || parser.parseInteger(len) || parser.parseColon() ||
908 parser.parseType(eleTy) || parser.parseGreater())
909 return {};
910 return fir::VectorType::get(len, eleTy);
911 }
912
print(mlir::AsmPrinter & printer) const913 void fir::VectorType::print(mlir::AsmPrinter &printer) const {
914 printer << "<" << getLen() << ':' << getEleTy() << '>';
915 }
916
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,uint64_t len,mlir::Type eleTy)917 mlir::LogicalResult fir::VectorType::verify(
918 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len,
919 mlir::Type eleTy) {
920 if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
921 return emitError() << "cannot build a vector of type " << eleTy << '\n';
922 return mlir::success();
923 }
924
isValidElementType(mlir::Type t)925 bool fir::VectorType::isValidElementType(mlir::Type t) {
926 return isa_real(t) || isa_integer(t);
927 }
928
isCharacterProcedureTuple(mlir::Type ty,bool acceptRawFunc)929 bool fir::isCharacterProcedureTuple(mlir::Type ty, bool acceptRawFunc) {
930 mlir::TupleType tuple = ty.dyn_cast<mlir::TupleType>();
931 return tuple && tuple.size() == 2 &&
932 (tuple.getType(0).isa<fir::BoxProcType>() ||
933 (acceptRawFunc && tuple.getType(0).isa<mlir::FunctionType>())) &&
934 fir::isa_integer(tuple.getType(1));
935 }
936
hasAbstractResult(mlir::FunctionType ty)937 bool fir::hasAbstractResult(mlir::FunctionType ty) {
938 if (ty.getNumResults() == 0)
939 return false;
940 auto resultType = ty.getResult(0);
941 return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>();
942 }
943
944 /// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error
945 /// messages only.
fromRealTypeID(mlir::MLIRContext * context,llvm::Type::TypeID typeID,fir::KindTy kind)946 mlir::Type fir::fromRealTypeID(mlir::MLIRContext *context,
947 llvm::Type::TypeID typeID, fir::KindTy kind) {
948 switch (typeID) {
949 case llvm::Type::TypeID::HalfTyID:
950 return mlir::FloatType::getF16(context);
951 case llvm::Type::TypeID::BFloatTyID:
952 return mlir::FloatType::getBF16(context);
953 case llvm::Type::TypeID::FloatTyID:
954 return mlir::FloatType::getF32(context);
955 case llvm::Type::TypeID::DoubleTyID:
956 return mlir::FloatType::getF64(context);
957 case llvm::Type::TypeID::X86_FP80TyID:
958 return mlir::FloatType::getF80(context);
959 case llvm::Type::TypeID::FP128TyID:
960 return mlir::FloatType::getF128(context);
961 default:
962 mlir::emitError(mlir::UnknownLoc::get(context))
963 << "unsupported type: !fir.real<" << kind << ">";
964 return {};
965 }
966 }
967
968 //===----------------------------------------------------------------------===//
969 // FIROpsDialect
970 //===----------------------------------------------------------------------===//
971
registerTypes()972 void FIROpsDialect::registerTypes() {
973 addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
974 FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
975 LLVMPointerType, PointerType, RealType, RecordType, ReferenceType,
976 SequenceType, ShapeType, ShapeShiftType, ShiftType, SliceType,
977 TypeDescType, fir::VectorType>();
978 fir::ReferenceType::attachInterface<PointerLikeModel<fir::ReferenceType>>(
979 *getContext());
980
981 fir::PointerType::attachInterface<PointerLikeModel<fir::PointerType>>(
982 *getContext());
983
984 fir::HeapType::attachInterface<AlternativePointerLikeModel<fir::HeapType>>(
985 *getContext());
986
987 fir::LLVMPointerType::attachInterface<
988 AlternativePointerLikeModel<fir::LLVMPointerType>>(*getContext());
989 }
990