1 //===-- FIRType.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Support/KindMapping.h"
16 #include "flang/Tools/PointerModels.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/ErrorHandling.h"
25 
26 #define GET_TYPEDEF_CLASSES
27 #include "flang/Optimizer/Dialect/FIROpsTypes.cpp.inc"
28 
29 using namespace fir;
30 
31 namespace {
32 
33 template <typename TYPE>
parseIntSingleton(mlir::AsmParser & parser)34 TYPE parseIntSingleton(mlir::AsmParser &parser) {
35   int kind = 0;
36   if (parser.parseLess() || parser.parseInteger(kind) || parser.parseGreater())
37     return {};
38   return TYPE::get(parser.getContext(), kind);
39 }
40 
41 template <typename TYPE>
parseKindSingleton(mlir::AsmParser & parser)42 TYPE parseKindSingleton(mlir::AsmParser &parser) {
43   return parseIntSingleton<TYPE>(parser);
44 }
45 
46 template <typename TYPE>
parseRankSingleton(mlir::AsmParser & parser)47 TYPE parseRankSingleton(mlir::AsmParser &parser) {
48   return parseIntSingleton<TYPE>(parser);
49 }
50 
51 template <typename TYPE>
parseTypeSingleton(mlir::AsmParser & parser)52 TYPE parseTypeSingleton(mlir::AsmParser &parser) {
53   mlir::Type ty;
54   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater())
55     return {};
56   return TYPE::get(ty);
57 }
58 
59 /// Is `ty` a standard or FIR integer type?
isaIntegerType(mlir::Type ty)60 static bool isaIntegerType(mlir::Type ty) {
61   // TODO: why aren't we using isa_integer? investigatation required.
62   return ty.isa<mlir::IntegerType>() || ty.isa<fir::IntegerType>();
63 }
64 
verifyRecordMemberType(mlir::Type ty)65 bool verifyRecordMemberType(mlir::Type ty) {
66   return !(ty.isa<BoxCharType>() || ty.isa<BoxProcType>() ||
67            ty.isa<ShapeType>() || ty.isa<ShapeShiftType>() ||
68            ty.isa<ShiftType>() || ty.isa<SliceType>() || ty.isa<FieldType>() ||
69            ty.isa<LenType>() || ty.isa<ReferenceType>() ||
70            ty.isa<TypeDescType>());
71 }
72 
verifySameLists(llvm::ArrayRef<RecordType::TypePair> a1,llvm::ArrayRef<RecordType::TypePair> a2)73 bool verifySameLists(llvm::ArrayRef<RecordType::TypePair> a1,
74                      llvm::ArrayRef<RecordType::TypePair> a2) {
75   // FIXME: do we need to allow for any variance here?
76   return a1 == a2;
77 }
78 
verifyDerived(mlir::AsmParser & parser,RecordType derivedTy,llvm::ArrayRef<RecordType::TypePair> lenPList,llvm::ArrayRef<RecordType::TypePair> typeList)79 RecordType verifyDerived(mlir::AsmParser &parser, RecordType derivedTy,
80                          llvm::ArrayRef<RecordType::TypePair> lenPList,
81                          llvm::ArrayRef<RecordType::TypePair> typeList) {
82   auto loc = parser.getNameLoc();
83   if (!verifySameLists(derivedTy.getLenParamList(), lenPList) ||
84       !verifySameLists(derivedTy.getTypeList(), typeList)) {
85     parser.emitError(loc, "cannot redefine record type members");
86     return {};
87   }
88   for (auto &p : lenPList)
89     if (!isaIntegerType(p.second)) {
90       parser.emitError(loc, "LEN parameter must be integral type");
91       return {};
92     }
93   for (auto &p : typeList)
94     if (!verifyRecordMemberType(p.second)) {
95       parser.emitError(loc, "field parameter has invalid type");
96       return {};
97     }
98   llvm::StringSet<> uniq;
99   for (auto &p : lenPList)
100     if (!uniq.insert(p.first).second) {
101       parser.emitError(loc, "LEN parameter cannot have duplicate name");
102       return {};
103     }
104   for (auto &p : typeList)
105     if (!uniq.insert(p.first).second) {
106       parser.emitError(loc, "field cannot have duplicate name");
107       return {};
108     }
109   return derivedTy;
110 }
111 
112 } // namespace
113 
114 // Implementation of the thin interface from dialect to type parser
115 
parseFirType(FIROpsDialect * dialect,mlir::DialectAsmParser & parser)116 mlir::Type fir::parseFirType(FIROpsDialect *dialect,
117                              mlir::DialectAsmParser &parser) {
118   mlir::StringRef typeTag;
119   mlir::Type genType;
120   auto parseResult = generatedTypeParser(parser, &typeTag, genType);
121   if (parseResult.hasValue())
122     return genType;
123   parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag;
124   return {};
125 }
126 
127 namespace fir {
128 namespace detail {
129 
130 // Type storage classes
131 
132 /// Derived type storage
133 struct RecordTypeStorage : public mlir::TypeStorage {
134   using KeyTy = llvm::StringRef;
135 
hashKeyfir::detail::RecordTypeStorage136   static unsigned hashKey(const KeyTy &key) {
137     return llvm::hash_combine(key.str());
138   }
139 
operator ==fir::detail::RecordTypeStorage140   bool operator==(const KeyTy &key) const { return key == getName(); }
141 
constructfir::detail::RecordTypeStorage142   static RecordTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
143                                       const KeyTy &key) {
144     auto *storage = allocator.allocate<RecordTypeStorage>();
145     return new (storage) RecordTypeStorage{key};
146   }
147 
getNamefir::detail::RecordTypeStorage148   llvm::StringRef getName() const { return name; }
149 
setLenParamListfir::detail::RecordTypeStorage150   void setLenParamList(llvm::ArrayRef<RecordType::TypePair> list) {
151     lens = list;
152   }
getLenParamListfir::detail::RecordTypeStorage153   llvm::ArrayRef<RecordType::TypePair> getLenParamList() const { return lens; }
154 
setTypeListfir::detail::RecordTypeStorage155   void setTypeList(llvm::ArrayRef<RecordType::TypePair> list) { types = list; }
getTypeListfir::detail::RecordTypeStorage156   llvm::ArrayRef<RecordType::TypePair> getTypeList() const { return types; }
157 
finalizefir::detail::RecordTypeStorage158   void finalize(llvm::ArrayRef<RecordType::TypePair> lenParamList,
159                 llvm::ArrayRef<RecordType::TypePair> typeList) {
160     if (finalized)
161       return;
162     finalized = true;
163     setLenParamList(lenParamList);
164     setTypeList(typeList);
165   }
166 
167 protected:
168   std::string name;
169   bool finalized;
170   std::vector<RecordType::TypePair> lens;
171   std::vector<RecordType::TypePair> types;
172 
173 private:
174   RecordTypeStorage() = delete;
RecordTypeStoragefir::detail::RecordTypeStorage175   explicit RecordTypeStorage(llvm::StringRef name)
176       : name{name}, finalized{false} {}
177 };
178 
179 } // namespace detail
180 
181 template <typename A, typename B>
inbounds(A v,B lb,B ub)182 bool inbounds(A v, B lb, B ub) {
183   return v >= lb && v < ub;
184 }
185 
isa_fir_type(mlir::Type t)186 bool isa_fir_type(mlir::Type t) {
187   return llvm::isa<FIROpsDialect>(t.getDialect());
188 }
189 
isa_std_type(mlir::Type t)190 bool isa_std_type(mlir::Type t) {
191   return llvm::isa<mlir::BuiltinDialect>(t.getDialect());
192 }
193 
isa_fir_or_std_type(mlir::Type t)194 bool isa_fir_or_std_type(mlir::Type t) {
195   if (auto funcType = t.dyn_cast<mlir::FunctionType>())
196     return llvm::all_of(funcType.getInputs(), isa_fir_or_std_type) &&
197            llvm::all_of(funcType.getResults(), isa_fir_or_std_type);
198   return isa_fir_type(t) || isa_std_type(t);
199 }
200 
dyn_cast_ptrEleTy(mlir::Type t)201 mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
202   return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
203       .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
204             fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
205       .Default([](mlir::Type) { return mlir::Type{}; });
206 }
207 
dyn_cast_ptrOrBoxEleTy(mlir::Type t)208 mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
209   return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
210       .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
211             fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
212       .Case<fir::BoxType>([](auto p) {
213         auto eleTy = p.getEleTy();
214         if (auto ty = fir::dyn_cast_ptrEleTy(eleTy))
215           return ty;
216         return eleTy;
217       })
218       .Default([](mlir::Type) { return mlir::Type{}; });
219 }
220 
hasDynamicSize(fir::RecordType recTy)221 static bool hasDynamicSize(fir::RecordType recTy) {
222   for (auto field : recTy.getTypeList()) {
223     if (auto arr = field.second.dyn_cast<fir::SequenceType>()) {
224       if (sequenceWithNonConstantShape(arr))
225         return true;
226     } else if (characterWithDynamicLen(field.second)) {
227       return true;
228     } else if (auto rec = field.second.dyn_cast<fir::RecordType>()) {
229       if (hasDynamicSize(rec))
230         return true;
231     }
232   }
233   return false;
234 }
235 
hasDynamicSize(mlir::Type t)236 bool hasDynamicSize(mlir::Type t) {
237   if (auto arr = t.dyn_cast<fir::SequenceType>()) {
238     if (sequenceWithNonConstantShape(arr))
239       return true;
240     t = arr.getEleTy();
241   }
242   if (characterWithDynamicLen(t))
243     return true;
244   if (auto rec = t.dyn_cast<fir::RecordType>())
245     return hasDynamicSize(rec);
246   return false;
247 }
248 
isPointerType(mlir::Type ty)249 bool isPointerType(mlir::Type ty) {
250   if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
251     ty = refTy;
252   if (auto boxTy = ty.dyn_cast<fir::BoxType>())
253     return boxTy.getEleTy().isa<fir::PointerType>();
254   return false;
255 }
256 
isAllocatableType(mlir::Type ty)257 bool isAllocatableType(mlir::Type ty) {
258   if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
259     ty = refTy;
260   if (auto boxTy = ty.dyn_cast<fir::BoxType>())
261     return boxTy.getEleTy().isa<fir::HeapType>();
262   return false;
263 }
264 
isUnlimitedPolymorphicType(mlir::Type ty)265 bool isUnlimitedPolymorphicType(mlir::Type ty) {
266   if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
267     ty = refTy;
268   if (auto boxTy = ty.dyn_cast<fir::BoxType>())
269     return boxTy.getEleTy().isa<mlir::NoneType>();
270   return false;
271 }
272 
isRecordWithAllocatableMember(mlir::Type ty)273 bool isRecordWithAllocatableMember(mlir::Type ty) {
274   if (auto recTy = ty.dyn_cast<fir::RecordType>())
275     for (auto [field, memTy] : recTy.getTypeList()) {
276       if (fir::isAllocatableType(memTy))
277         return true;
278       // A record type cannot recursively include itself as a direct member.
279       // There must be an intervening `ptr` type, so recursion is safe here.
280       if (memTy.isa<fir::RecordType>() && isRecordWithAllocatableMember(memTy))
281         return true;
282     }
283   return false;
284 }
285 
unwrapAllRefAndSeqType(mlir::Type ty)286 mlir::Type unwrapAllRefAndSeqType(mlir::Type ty) {
287   while (true) {
288     mlir::Type nt = unwrapSequenceType(unwrapRefType(ty));
289     if (auto vecTy = nt.dyn_cast<fir::VectorType>())
290       nt = vecTy.getEleTy();
291     if (nt == ty)
292       return ty;
293     ty = nt;
294   }
295 }
296 
unwrapSeqOrBoxedSeqType(mlir::Type ty)297 mlir::Type unwrapSeqOrBoxedSeqType(mlir::Type ty) {
298   if (auto seqTy = ty.dyn_cast<fir::SequenceType>())
299     return seqTy.getEleTy();
300   if (auto boxTy = ty.dyn_cast<fir::BoxType>()) {
301     auto eleTy = unwrapRefType(boxTy.getEleTy());
302     if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>())
303       return seqTy.getEleTy();
304   }
305   return ty;
306 }
307 
308 } // namespace fir
309 
310 namespace {
311 
312 static llvm::SmallPtrSet<detail::RecordTypeStorage const *, 4>
313     recordTypeVisited;
314 
315 } // namespace
316 
verifyIntegralType(mlir::Type type)317 void fir::verifyIntegralType(mlir::Type type) {
318   if (isaIntegerType(type) || type.isa<mlir::IndexType>())
319     return;
320   llvm::report_fatal_error("expected integral type");
321 }
322 
printFirType(FIROpsDialect *,mlir::Type ty,mlir::DialectAsmPrinter & p)323 void fir::printFirType(FIROpsDialect *, mlir::Type ty,
324                        mlir::DialectAsmPrinter &p) {
325   if (mlir::failed(generatedTypePrinter(ty, p)))
326     llvm::report_fatal_error("unknown type to print");
327 }
328 
isa_unknown_size_box(mlir::Type t)329 bool fir::isa_unknown_size_box(mlir::Type t) {
330   if (auto boxTy = t.dyn_cast<fir::BoxType>()) {
331     auto eleTy = boxTy.getEleTy();
332     if (auto actualEleTy = fir::dyn_cast_ptrEleTy(eleTy))
333       eleTy = actualEleTy;
334     if (eleTy.isa<mlir::NoneType>())
335       return true;
336     if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>())
337       if (seqTy.hasUnknownShape())
338         return true;
339   }
340   return false;
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // BoxProcType
345 //===----------------------------------------------------------------------===//
346 
347 // `boxproc` `<` return-type `>`
parse(mlir::AsmParser & parser)348 mlir::Type BoxProcType::parse(mlir::AsmParser &parser) {
349   mlir::Type ty;
350   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater())
351     return {};
352   return get(parser.getContext(), ty);
353 }
354 
print(mlir::AsmPrinter & printer) const355 void fir::BoxProcType::print(mlir::AsmPrinter &printer) const {
356   printer << "<" << getEleTy() << '>';
357 }
358 
359 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy)360 BoxProcType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
361                     mlir::Type eleTy) {
362   if (eleTy.isa<mlir::FunctionType>())
363     return mlir::success();
364   if (auto refTy = eleTy.dyn_cast<ReferenceType>())
365     if (refTy.isa<mlir::FunctionType>())
366       return mlir::success();
367   return emitError() << "invalid type for boxproc" << eleTy << '\n';
368 }
369 
cannotBePointerOrHeapElementType(mlir::Type eleTy)370 static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) {
371   return eleTy.isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
372                    SliceType, FieldType, LenType, HeapType, PointerType,
373                    ReferenceType, TypeDescType>();
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // BoxType
378 //===----------------------------------------------------------------------===//
379 
380 // `box` `<` type (',' affine-map)? `>`
parse(mlir::AsmParser & parser)381 mlir::Type fir::BoxType::parse(mlir::AsmParser &parser) {
382   mlir::Type ofTy;
383   if (parser.parseLess() || parser.parseType(ofTy))
384     return {};
385 
386   mlir::AffineMapAttr map;
387   if (!parser.parseOptionalComma()) {
388     if (parser.parseAttribute(map)) {
389       parser.emitError(parser.getCurrentLocation(), "expected affine map");
390       return {};
391     }
392   }
393   if (parser.parseGreater())
394     return {};
395   return get(ofTy, map);
396 }
397 
print(mlir::AsmPrinter & printer) const398 void fir::BoxType::print(mlir::AsmPrinter &printer) const {
399   printer << "<" << getEleTy();
400   if (auto map = getLayoutMap()) {
401     printer << ", " << map;
402   }
403   printer << '>';
404 }
405 
406 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy,mlir::AffineMapAttr map)407 fir::BoxType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
408                      mlir::Type eleTy, mlir::AffineMapAttr map) {
409   // TODO
410   return mlir::success();
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // BoxCharType
415 //===----------------------------------------------------------------------===//
416 
parse(mlir::AsmParser & parser)417 mlir::Type fir::BoxCharType::parse(mlir::AsmParser &parser) {
418   return parseKindSingleton<fir::BoxCharType>(parser);
419 }
420 
print(mlir::AsmPrinter & printer) const421 void fir::BoxCharType::print(mlir::AsmPrinter &printer) const {
422   printer << "<" << getKind() << ">";
423 }
424 
425 CharacterType
getElementType(mlir::MLIRContext * context) const426 fir::BoxCharType::getElementType(mlir::MLIRContext *context) const {
427   return CharacterType::getUnknownLen(context, getKind());
428 }
429 
getEleTy() const430 CharacterType fir::BoxCharType::getEleTy() const {
431   return getElementType(getContext());
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // CharacterType
436 //===----------------------------------------------------------------------===//
437 
438 // `char` `<` kind [`,` `len`] `>`
parse(mlir::AsmParser & parser)439 mlir::Type fir::CharacterType::parse(mlir::AsmParser &parser) {
440   int kind = 0;
441   if (parser.parseLess() || parser.parseInteger(kind))
442     return {};
443   CharacterType::LenType len = 1;
444   if (mlir::succeeded(parser.parseOptionalComma())) {
445     if (mlir::succeeded(parser.parseOptionalQuestion())) {
446       len = fir::CharacterType::unknownLen();
447     } else if (!mlir::succeeded(parser.parseInteger(len))) {
448       return {};
449     }
450   }
451   if (parser.parseGreater())
452     return {};
453   return get(parser.getContext(), kind, len);
454 }
455 
print(mlir::AsmPrinter & printer) const456 void fir::CharacterType::print(mlir::AsmPrinter &printer) const {
457   printer << "<" << getFKind();
458   auto len = getLen();
459   if (len != fir::CharacterType::singleton()) {
460     printer << ',';
461     if (len == fir::CharacterType::unknownLen())
462       printer << '?';
463     else
464       printer << len;
465   }
466   printer << '>';
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // ComplexType
471 //===----------------------------------------------------------------------===//
472 
parse(mlir::AsmParser & parser)473 mlir::Type fir::ComplexType::parse(mlir::AsmParser &parser) {
474   return parseKindSingleton<fir::ComplexType>(parser);
475 }
476 
print(mlir::AsmPrinter & printer) const477 void fir::ComplexType::print(mlir::AsmPrinter &printer) const {
478   printer << "<" << getFKind() << '>';
479 }
480 
getElementType() const481 mlir::Type fir::ComplexType::getElementType() const {
482   return fir::RealType::get(getContext(), getFKind());
483 }
484 
485 // Return the MLIR float type of the complex element type.
getEleType(const fir::KindMapping & kindMap) const486 mlir::Type fir::ComplexType::getEleType(const fir::KindMapping &kindMap) const {
487   auto fkind = getFKind();
488   auto realTypeID = kindMap.getRealTypeID(fkind);
489   return fir::fromRealTypeID(getContext(), realTypeID, fkind);
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // HeapType
494 //===----------------------------------------------------------------------===//
495 
496 // `heap` `<` type `>`
parse(mlir::AsmParser & parser)497 mlir::Type fir::HeapType::parse(mlir::AsmParser &parser) {
498   return parseTypeSingleton<HeapType>(parser);
499 }
500 
print(mlir::AsmPrinter & printer) const501 void fir::HeapType::print(mlir::AsmPrinter &printer) const {
502   printer << "<" << getEleTy() << '>';
503 }
504 
505 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy)506 fir::HeapType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
507                       mlir::Type eleTy) {
508   if (cannotBePointerOrHeapElementType(eleTy))
509     return emitError() << "cannot build a heap pointer to type: " << eleTy
510                        << '\n';
511   return mlir::success();
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // IntegerType
516 //===----------------------------------------------------------------------===//
517 
518 // `int` `<` kind `>`
parse(mlir::AsmParser & parser)519 mlir::Type fir::IntegerType::parse(mlir::AsmParser &parser) {
520   return parseKindSingleton<fir::IntegerType>(parser);
521 }
522 
print(mlir::AsmPrinter & printer) const523 void fir::IntegerType::print(mlir::AsmPrinter &printer) const {
524   printer << "<" << getFKind() << '>';
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // LogicalType
529 //===----------------------------------------------------------------------===//
530 
531 // `logical` `<` kind `>`
parse(mlir::AsmParser & parser)532 mlir::Type fir::LogicalType::parse(mlir::AsmParser &parser) {
533   return parseKindSingleton<fir::LogicalType>(parser);
534 }
535 
print(mlir::AsmPrinter & printer) const536 void fir::LogicalType::print(mlir::AsmPrinter &printer) const {
537   printer << "<" << getFKind() << '>';
538 }
539 
540 //===----------------------------------------------------------------------===//
541 // PointerType
542 //===----------------------------------------------------------------------===//
543 
544 // `ptr` `<` type `>`
parse(mlir::AsmParser & parser)545 mlir::Type fir::PointerType::parse(mlir::AsmParser &parser) {
546   return parseTypeSingleton<fir::PointerType>(parser);
547 }
548 
print(mlir::AsmPrinter & printer) const549 void fir::PointerType::print(mlir::AsmPrinter &printer) const {
550   printer << "<" << getEleTy() << '>';
551 }
552 
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy)553 mlir::LogicalResult fir::PointerType::verify(
554     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
555     mlir::Type eleTy) {
556   if (cannotBePointerOrHeapElementType(eleTy))
557     return emitError() << "cannot build a pointer to type: " << eleTy << '\n';
558   return mlir::success();
559 }
560 
561 //===----------------------------------------------------------------------===//
562 // RealType
563 //===----------------------------------------------------------------------===//
564 
565 // `real` `<` kind `>`
parse(mlir::AsmParser & parser)566 mlir::Type fir::RealType::parse(mlir::AsmParser &parser) {
567   return parseKindSingleton<fir::RealType>(parser);
568 }
569 
print(mlir::AsmPrinter & printer) const570 void fir::RealType::print(mlir::AsmPrinter &printer) const {
571   printer << "<" << getFKind() << '>';
572 }
573 
574 mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,KindTy fKind)575 fir::RealType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
576                       KindTy fKind) {
577   // TODO
578   return mlir::success();
579 }
580 
581 //===----------------------------------------------------------------------===//
582 // RecordType
583 //===----------------------------------------------------------------------===//
584 
585 // Fortran derived type
586 // `type` `<` name
587 //           (`(` id `:` type (`,` id `:` type)* `)`)?
588 //           (`{` id `:` type (`,` id `:` type)* `}`)? '>'
parse(mlir::AsmParser & parser)589 mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
590   llvm::StringRef name;
591   if (parser.parseLess() || parser.parseKeyword(&name))
592     return {};
593   RecordType result = RecordType::get(parser.getContext(), name);
594 
595   RecordType::TypeList lenParamList;
596   if (!parser.parseOptionalLParen()) {
597     while (true) {
598       llvm::StringRef lenparam;
599       mlir::Type intTy;
600       if (parser.parseKeyword(&lenparam) || parser.parseColon() ||
601           parser.parseType(intTy)) {
602         parser.emitError(parser.getNameLoc(), "expected LEN parameter list");
603         return {};
604       }
605       lenParamList.emplace_back(lenparam, intTy);
606       if (parser.parseOptionalComma())
607         break;
608     }
609     if (parser.parseRParen())
610       return {};
611   }
612 
613   RecordType::TypeList typeList;
614   if (!parser.parseOptionalLBrace()) {
615     while (true) {
616       llvm::StringRef field;
617       mlir::Type fldTy;
618       if (parser.parseKeyword(&field) || parser.parseColon() ||
619           parser.parseType(fldTy)) {
620         parser.emitError(parser.getNameLoc(), "expected field type list");
621         return {};
622       }
623       typeList.emplace_back(field, fldTy);
624       if (parser.parseOptionalComma())
625         break;
626     }
627     if (parser.parseRBrace())
628       return {};
629   }
630 
631   if (parser.parseGreater())
632     return {};
633 
634   if (lenParamList.empty() && typeList.empty())
635     return result;
636 
637   result.finalize(lenParamList, typeList);
638   return verifyDerived(parser, result, lenParamList, typeList);
639 }
640 
print(mlir::AsmPrinter & printer) const641 void fir::RecordType::print(mlir::AsmPrinter &printer) const {
642   printer << "<" << getName();
643   if (!recordTypeVisited.count(uniqueKey())) {
644     recordTypeVisited.insert(uniqueKey());
645     if (getLenParamList().size()) {
646       char ch = '(';
647       for (auto p : getLenParamList()) {
648         printer << ch << p.first << ':';
649         p.second.print(printer.getStream());
650         ch = ',';
651       }
652       printer << ')';
653     }
654     if (getTypeList().size()) {
655       char ch = '{';
656       for (auto p : getTypeList()) {
657         printer << ch << p.first << ':';
658         p.second.print(printer.getStream());
659         ch = ',';
660       }
661       printer << '}';
662     }
663     recordTypeVisited.erase(uniqueKey());
664   }
665   printer << '>';
666 }
667 
finalize(llvm::ArrayRef<TypePair> lenPList,llvm::ArrayRef<TypePair> typeList)668 void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
669                                llvm::ArrayRef<TypePair> typeList) {
670   getImpl()->finalize(lenPList, typeList);
671 }
672 
getName() const673 llvm::StringRef fir::RecordType::getName() const {
674   return getImpl()->getName();
675 }
676 
getTypeList() const677 RecordType::TypeList fir::RecordType::getTypeList() const {
678   return getImpl()->getTypeList();
679 }
680 
getLenParamList() const681 RecordType::TypeList fir::RecordType::getLenParamList() const {
682   return getImpl()->getLenParamList();
683 }
684 
uniqueKey() const685 detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
686   return getImpl();
687 }
688 
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,llvm::StringRef name)689 mlir::LogicalResult fir::RecordType::verify(
690     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
691     llvm::StringRef name) {
692   if (name.size() == 0)
693     return emitError() << "record types must have a name";
694   return mlir::success();
695 }
696 
getType(llvm::StringRef ident)697 mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
698   for (auto f : getTypeList())
699     if (ident == f.first)
700       return f.second;
701   return {};
702 }
703 
getFieldIndex(llvm::StringRef ident)704 unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
705   for (auto f : llvm::enumerate(getTypeList()))
706     if (ident == f.value().first)
707       return f.index();
708   return std::numeric_limits<unsigned>::max();
709 }
710 
711 //===----------------------------------------------------------------------===//
712 // ReferenceType
713 //===----------------------------------------------------------------------===//
714 
715 // `ref` `<` type `>`
parse(mlir::AsmParser & parser)716 mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
717   return parseTypeSingleton<fir::ReferenceType>(parser);
718 }
719 
print(mlir::AsmPrinter & printer) const720 void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
721   printer << "<" << getEleTy() << '>';
722 }
723 
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy)724 mlir::LogicalResult fir::ReferenceType::verify(
725     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
726     mlir::Type eleTy) {
727   if (eleTy.isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
728                 ReferenceType, TypeDescType>())
729     return emitError() << "cannot build a reference to type: " << eleTy << '\n';
730   return mlir::success();
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // SequenceType
735 //===----------------------------------------------------------------------===//
736 
737 // `array` `<` `*` | bounds (`x` bounds)* `:` type (',' affine-map)? `>`
738 // bounds ::= `?` | int-lit
parse(mlir::AsmParser & parser)739 mlir::Type fir::SequenceType::parse(mlir::AsmParser &parser) {
740   if (parser.parseLess())
741     return {};
742   SequenceType::Shape shape;
743   if (parser.parseOptionalStar()) {
744     if (parser.parseDimensionList(shape, /*allowDynamic=*/true))
745       return {};
746   } else if (parser.parseColon()) {
747     return {};
748   }
749   mlir::Type eleTy;
750   if (parser.parseType(eleTy))
751     return {};
752   mlir::AffineMapAttr map;
753   if (!parser.parseOptionalComma()) {
754     if (parser.parseAttribute(map)) {
755       parser.emitError(parser.getNameLoc(), "expecting affine map");
756       return {};
757     }
758   }
759   if (parser.parseGreater())
760     return {};
761   return SequenceType::get(parser.getContext(), shape, eleTy, map);
762 }
763 
print(mlir::AsmPrinter & printer) const764 void fir::SequenceType::print(mlir::AsmPrinter &printer) const {
765   auto shape = getShape();
766   if (shape.size()) {
767     printer << '<';
768     for (const auto &b : shape) {
769       if (b >= 0)
770         printer << b << 'x';
771       else
772         printer << "?x";
773     }
774   } else {
775     printer << "<*:";
776   }
777   printer << getEleTy();
778   if (auto map = getLayoutMap()) {
779     printer << ", ";
780     map.print(printer.getStream());
781   }
782   printer << '>';
783 }
784 
getConstantRows() const785 unsigned fir::SequenceType::getConstantRows() const {
786   auto shape = getShape();
787   unsigned count = 0;
788   for (auto d : shape) {
789     if (d < 0)
790       break;
791     ++count;
792   }
793   return count;
794 }
795 
796 // This test helps us determine if we can degenerate an array to a
797 // pointer to some interior section (possibly a single element) of the
798 // sequence. This is used to determine if we can lower to the LLVM IR.
hasConstantInterior() const799 bool fir::SequenceType::hasConstantInterior() const {
800   if (hasUnknownShape())
801     return true;
802   auto rows = getConstantRows();
803   auto dim = getDimension();
804   if (rows == dim)
805     return true;
806   auto shape = getShape();
807   for (unsigned i = rows, size = dim; i < size; ++i)
808     if (shape[i] != getUnknownExtent())
809       return false;
810   return true;
811 }
812 
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,llvm::ArrayRef<int64_t> shape,mlir::Type eleTy,mlir::AffineMapAttr layoutMap)813 mlir::LogicalResult fir::SequenceType::verify(
814     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
815     llvm::ArrayRef<int64_t> shape, mlir::Type eleTy,
816     mlir::AffineMapAttr layoutMap) {
817   // DIMENSION attribute can only be applied to an intrinsic or record type
818   if (eleTy.isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
819                 ShiftType, SliceType, FieldType, LenType, HeapType, PointerType,
820                 ReferenceType, TypeDescType, fir::VectorType, SequenceType>())
821     return emitError() << "cannot build an array of this element type: "
822                        << eleTy << '\n';
823   return mlir::success();
824 }
825 
826 //===----------------------------------------------------------------------===//
827 // ShapeType
828 //===----------------------------------------------------------------------===//
829 
parse(mlir::AsmParser & parser)830 mlir::Type fir::ShapeType::parse(mlir::AsmParser &parser) {
831   return parseRankSingleton<fir::ShapeType>(parser);
832 }
833 
print(mlir::AsmPrinter & printer) const834 void fir::ShapeType::print(mlir::AsmPrinter &printer) const {
835   printer << "<" << getImpl()->rank << ">";
836 }
837 
838 //===----------------------------------------------------------------------===//
839 // ShapeShiftType
840 //===----------------------------------------------------------------------===//
841 
parse(mlir::AsmParser & parser)842 mlir::Type fir::ShapeShiftType::parse(mlir::AsmParser &parser) {
843   return parseRankSingleton<fir::ShapeShiftType>(parser);
844 }
845 
print(mlir::AsmPrinter & printer) const846 void fir::ShapeShiftType::print(mlir::AsmPrinter &printer) const {
847   printer << "<" << getRank() << ">";
848 }
849 
850 //===----------------------------------------------------------------------===//
851 // ShiftType
852 //===----------------------------------------------------------------------===//
853 
parse(mlir::AsmParser & parser)854 mlir::Type fir::ShiftType::parse(mlir::AsmParser &parser) {
855   return parseRankSingleton<fir::ShiftType>(parser);
856 }
857 
print(mlir::AsmPrinter & printer) const858 void fir::ShiftType::print(mlir::AsmPrinter &printer) const {
859   printer << "<" << getRank() << ">";
860 }
861 
862 //===----------------------------------------------------------------------===//
863 // SliceType
864 //===----------------------------------------------------------------------===//
865 
866 // `slice` `<` rank `>`
parse(mlir::AsmParser & parser)867 mlir::Type fir::SliceType::parse(mlir::AsmParser &parser) {
868   return parseRankSingleton<fir::SliceType>(parser);
869 }
870 
print(mlir::AsmPrinter & printer) const871 void fir::SliceType::print(mlir::AsmPrinter &printer) const {
872   printer << "<" << getRank() << '>';
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // TypeDescType
877 //===----------------------------------------------------------------------===//
878 
879 // `tdesc` `<` type `>`
parse(mlir::AsmParser & parser)880 mlir::Type fir::TypeDescType::parse(mlir::AsmParser &parser) {
881   return parseTypeSingleton<fir::TypeDescType>(parser);
882 }
883 
print(mlir::AsmPrinter & printer) const884 void fir::TypeDescType::print(mlir::AsmPrinter &printer) const {
885   printer << "<" << getOfTy() << '>';
886 }
887 
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,mlir::Type eleTy)888 mlir::LogicalResult fir::TypeDescType::verify(
889     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
890     mlir::Type eleTy) {
891   if (eleTy.isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
892                 ShiftType, SliceType, FieldType, LenType, ReferenceType,
893                 TypeDescType>())
894     return emitError() << "cannot build a type descriptor of type: " << eleTy
895                        << '\n';
896   return mlir::success();
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // VectorType
901 //===----------------------------------------------------------------------===//
902 
903 // `vector` `<` len `:` type `>`
parse(mlir::AsmParser & parser)904 mlir::Type fir::VectorType::parse(mlir::AsmParser &parser) {
905   int64_t len = 0;
906   mlir::Type eleTy;
907   if (parser.parseLess() || parser.parseInteger(len) || parser.parseColon() ||
908       parser.parseType(eleTy) || parser.parseGreater())
909     return {};
910   return fir::VectorType::get(len, eleTy);
911 }
912 
print(mlir::AsmPrinter & printer) const913 void fir::VectorType::print(mlir::AsmPrinter &printer) const {
914   printer << "<" << getLen() << ':' << getEleTy() << '>';
915 }
916 
verify(llvm::function_ref<mlir::InFlightDiagnostic ()> emitError,uint64_t len,mlir::Type eleTy)917 mlir::LogicalResult fir::VectorType::verify(
918     llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len,
919     mlir::Type eleTy) {
920   if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
921     return emitError() << "cannot build a vector of type " << eleTy << '\n';
922   return mlir::success();
923 }
924 
isValidElementType(mlir::Type t)925 bool fir::VectorType::isValidElementType(mlir::Type t) {
926   return isa_real(t) || isa_integer(t);
927 }
928 
isCharacterProcedureTuple(mlir::Type ty,bool acceptRawFunc)929 bool fir::isCharacterProcedureTuple(mlir::Type ty, bool acceptRawFunc) {
930   mlir::TupleType tuple = ty.dyn_cast<mlir::TupleType>();
931   return tuple && tuple.size() == 2 &&
932          (tuple.getType(0).isa<fir::BoxProcType>() ||
933           (acceptRawFunc && tuple.getType(0).isa<mlir::FunctionType>())) &&
934          fir::isa_integer(tuple.getType(1));
935 }
936 
hasAbstractResult(mlir::FunctionType ty)937 bool fir::hasAbstractResult(mlir::FunctionType ty) {
938   if (ty.getNumResults() == 0)
939     return false;
940   auto resultType = ty.getResult(0);
941   return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>();
942 }
943 
944 /// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error
945 /// messages only.
fromRealTypeID(mlir::MLIRContext * context,llvm::Type::TypeID typeID,fir::KindTy kind)946 mlir::Type fir::fromRealTypeID(mlir::MLIRContext *context,
947                                llvm::Type::TypeID typeID, fir::KindTy kind) {
948   switch (typeID) {
949   case llvm::Type::TypeID::HalfTyID:
950     return mlir::FloatType::getF16(context);
951   case llvm::Type::TypeID::BFloatTyID:
952     return mlir::FloatType::getBF16(context);
953   case llvm::Type::TypeID::FloatTyID:
954     return mlir::FloatType::getF32(context);
955   case llvm::Type::TypeID::DoubleTyID:
956     return mlir::FloatType::getF64(context);
957   case llvm::Type::TypeID::X86_FP80TyID:
958     return mlir::FloatType::getF80(context);
959   case llvm::Type::TypeID::FP128TyID:
960     return mlir::FloatType::getF128(context);
961   default:
962     mlir::emitError(mlir::UnknownLoc::get(context))
963         << "unsupported type: !fir.real<" << kind << ">";
964     return {};
965   }
966 }
967 
968 //===----------------------------------------------------------------------===//
969 // FIROpsDialect
970 //===----------------------------------------------------------------------===//
971 
registerTypes()972 void FIROpsDialect::registerTypes() {
973   addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
974            FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
975            LLVMPointerType, PointerType, RealType, RecordType, ReferenceType,
976            SequenceType, ShapeType, ShapeShiftType, ShiftType, SliceType,
977            TypeDescType, fir::VectorType>();
978   fir::ReferenceType::attachInterface<PointerLikeModel<fir::ReferenceType>>(
979       *getContext());
980 
981   fir::PointerType::attachInterface<PointerLikeModel<fir::PointerType>>(
982       *getContext());
983 
984   fir::HeapType::attachInterface<AlternativePointerLikeModel<fir::HeapType>>(
985       *getContext());
986 
987   fir::LLVMPointerType::attachInterface<
988       AlternativePointerLikeModel<fir::LLVMPointerType>>(*getContext());
989 }
990