1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/ScopeExit.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::LLVM;
18 
19 //===----------------------------------------------------------------------===//
20 // Printing.
21 //===----------------------------------------------------------------------===//
22 
23 /// If the given type is compatible with the LLVM dialect, prints it using
24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
25 /// prints it as usual.
dispatchPrint(AsmPrinter & printer,Type type)26 static void dispatchPrint(AsmPrinter &printer, Type type) {
27   if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>())
28     return mlir::LLVM::detail::printType(type, printer);
29   printer.printType(type);
30 }
31 
32 /// Returns the keyword to use for the given type.
getTypeKeyword(Type type)33 static StringRef getTypeKeyword(Type type) {
34   return TypeSwitch<Type, StringRef>(type)
35       .Case<LLVMVoidType>([&](Type) { return "void"; })
36       .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
37       .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
38       .Case<LLVMTokenType>([&](Type) { return "token"; })
39       .Case<LLVMLabelType>([&](Type) { return "label"; })
40       .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
41       .Case<LLVMFunctionType>([&](Type) { return "func"; })
42       .Case<LLVMPointerType>([&](Type) { return "ptr"; })
43       .Case<LLVMFixedVectorType, LLVMScalableVectorType>(
44           [&](Type) { return "vec"; })
45       .Case<LLVMArrayType>([&](Type) { return "array"; })
46       .Case<LLVMStructType>([&](Type) { return "struct"; })
47       .Default([](Type) -> StringRef {
48         llvm_unreachable("unexpected 'llvm' type kind");
49       });
50 }
51 
52 /// Prints a structure type. Keeps track of known struct names to handle self-
53 /// or mutually-referring structs without falling into infinite recursion.
printStructType(AsmPrinter & printer,LLVMStructType type)54 static void printStructType(AsmPrinter &printer, LLVMStructType type) {
55   // This keeps track of the names of identified structure types that are
56   // currently being printed. Since such types can refer themselves, this
57   // tracking is necessary to stop the recursion: the current function may be
58   // called recursively from AsmPrinter::printType after the appropriate
59   // dispatch. We maintain the invariant of this storage being modified
60   // exclusively in this function, and at most one name being added per call.
61   // TODO: consider having such functionality inside AsmPrinter.
62   thread_local SetVector<StringRef> knownStructNames;
63   unsigned stackSize = knownStructNames.size();
64   (void)stackSize;
65   auto guard = llvm::make_scope_exit([&]() {
66     assert(knownStructNames.size() == stackSize &&
67            "malformed identified stack when printing recursive structs");
68   });
69 
70   printer << "<";
71   if (type.isIdentified()) {
72     printer << '"' << type.getName() << '"';
73     // If we are printing a reference to one of the enclosing structs, just
74     // print the name and stop to avoid infinitely long output.
75     if (knownStructNames.count(type.getName())) {
76       printer << '>';
77       return;
78     }
79     printer << ", ";
80   }
81 
82   if (type.isIdentified() && type.isOpaque()) {
83     printer << "opaque>";
84     return;
85   }
86 
87   if (type.isPacked())
88     printer << "packed ";
89 
90   // Put the current type on stack to avoid infinite recursion.
91   printer << '(';
92   if (type.isIdentified())
93     knownStructNames.insert(type.getName());
94   llvm::interleaveComma(type.getBody(), printer.getStream(),
95                         [&](Type subtype) { dispatchPrint(printer, subtype); });
96   if (type.isIdentified())
97     knownStructNames.pop_back();
98   printer << ')';
99   printer << '>';
100 }
101 
102 /// Prints a type containing a fixed number of elements.
103 template <typename TypeTy>
printArrayOrVectorType(AsmPrinter & printer,TypeTy type)104 static void printArrayOrVectorType(AsmPrinter &printer, TypeTy type) {
105   printer << '<' << type.getNumElements() << " x ";
106   dispatchPrint(printer, type.getElementType());
107   printer << '>';
108 }
109 
110 /// Prints a function type.
printFunctionType(AsmPrinter & printer,LLVMFunctionType funcType)111 static void printFunctionType(AsmPrinter &printer, LLVMFunctionType funcType) {
112   printer << '<';
113   dispatchPrint(printer, funcType.getReturnType());
114   printer << " (";
115   llvm::interleaveComma(
116       funcType.getParams(), printer.getStream(),
117       [&printer](Type subtype) { dispatchPrint(printer, subtype); });
118   if (funcType.isVarArg()) {
119     if (funcType.getNumParams() != 0)
120       printer << ", ";
121     printer << "...";
122   }
123   printer << ")>";
124 }
125 
126 /// Prints the given LLVM dialect type recursively. This leverages closedness of
127 /// the LLVM dialect type system to avoid printing the dialect prefix
128 /// repeatedly. For recursive structures, only prints the name of the structure
129 /// when printing a self-reference. Note that this does not apply to sibling
130 /// references. For example,
131 ///   struct<"a", (ptr<struct<"a">>)>
132 ///   struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
133 ///                ptr<struct<"b", (ptr<struct<"c">>)>>)>
134 /// note that "b" is printed twice.
printType(Type type,AsmPrinter & printer)135 void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
136   if (!type) {
137     printer << "<<NULL-TYPE>>";
138     return;
139   }
140 
141   printer << getTypeKeyword(type);
142 
143   if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
144     if (ptrType.isOpaque()) {
145       if (ptrType.getAddressSpace() != 0)
146         printer << '<' << ptrType.getAddressSpace() << '>';
147       return;
148     }
149 
150     printer << '<';
151     dispatchPrint(printer, ptrType.getElementType());
152     if (ptrType.getAddressSpace() != 0)
153       printer << ", " << ptrType.getAddressSpace();
154     printer << '>';
155     return;
156   }
157 
158   if (auto arrayType = type.dyn_cast<LLVMArrayType>())
159     return printArrayOrVectorType(printer, arrayType);
160   if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
161     return printArrayOrVectorType(printer, vectorType);
162 
163   if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
164     printer << "<? x " << vectorType.getMinNumElements() << " x ";
165     dispatchPrint(printer, vectorType.getElementType());
166     printer << '>';
167     return;
168   }
169 
170   if (auto structType = type.dyn_cast<LLVMStructType>())
171     return printStructType(printer, structType);
172 
173   if (auto funcType = type.dyn_cast<LLVMFunctionType>())
174     return printFunctionType(printer, funcType);
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // Parsing.
179 //===----------------------------------------------------------------------===//
180 
181 static ParseResult dispatchParse(AsmParser &parser, Type &type);
182 
183 /// Parses an LLVM dialect function type.
184 ///   llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
parseFunctionType(AsmParser & parser)185 static LLVMFunctionType parseFunctionType(AsmParser &parser) {
186   SMLoc loc = parser.getCurrentLocation();
187   Type returnType;
188   if (parser.parseLess() || dispatchParse(parser, returnType) ||
189       parser.parseLParen())
190     return LLVMFunctionType();
191 
192   // Function type without arguments.
193   if (succeeded(parser.parseOptionalRParen())) {
194     if (succeeded(parser.parseGreater()))
195       return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
196                                                  /*isVarArg=*/false);
197     return LLVMFunctionType();
198   }
199 
200   // Parse arguments.
201   SmallVector<Type, 8> argTypes;
202   do {
203     if (succeeded(parser.parseOptionalEllipsis())) {
204       if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
205         return LLVMFunctionType();
206       return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
207                                                  /*isVarArg=*/true);
208     }
209 
210     Type arg;
211     if (dispatchParse(parser, arg))
212       return LLVMFunctionType();
213     argTypes.push_back(arg);
214   } while (succeeded(parser.parseOptionalComma()));
215 
216   if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
217     return LLVMFunctionType();
218   return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
219                                              /*isVarArg=*/false);
220 }
221 
222 /// Parses an LLVM dialect pointer type.
223 ///   llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
224 ///               | `ptr` (`<` integer `>`)?
parsePointerType(AsmParser & parser)225 static LLVMPointerType parsePointerType(AsmParser &parser) {
226   SMLoc loc = parser.getCurrentLocation();
227   Type elementType;
228   if (parser.parseOptionalLess()) {
229     return parser.getChecked<LLVMPointerType>(loc, parser.getContext(),
230                                               /*addressSpace=*/0);
231   }
232 
233   unsigned addressSpace = 0;
234   OptionalParseResult opr = parser.parseOptionalInteger(addressSpace);
235   if (opr.hasValue()) {
236     if (failed(*opr) || parser.parseGreater())
237       return LLVMPointerType();
238     return parser.getChecked<LLVMPointerType>(loc, parser.getContext(),
239                                               addressSpace);
240   }
241 
242   if (dispatchParse(parser, elementType))
243     return LLVMPointerType();
244 
245   if (succeeded(parser.parseOptionalComma()) &&
246       failed(parser.parseInteger(addressSpace)))
247     return LLVMPointerType();
248   if (failed(parser.parseGreater()))
249     return LLVMPointerType();
250   return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
251 }
252 
253 /// Parses an LLVM dialect vector type.
254 ///   llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
255 /// Supports both fixed and scalable vectors.
parseVectorType(AsmParser & parser)256 static Type parseVectorType(AsmParser &parser) {
257   SmallVector<int64_t, 2> dims;
258   SMLoc dimPos, typePos;
259   Type elementType;
260   SMLoc loc = parser.getCurrentLocation();
261   if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
262       parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
263       parser.getCurrentLocation(&typePos) ||
264       dispatchParse(parser, elementType) || parser.parseGreater())
265     return Type();
266 
267   // We parsed a generic dimension list, but vectors only support two forms:
268   //  - single non-dynamic entry in the list (fixed vector);
269   //  - two elements, the first dynamic (indicated by -1) and the second
270   //    non-dynamic (scalable vector).
271   if (dims.empty() || dims.size() > 2 ||
272       ((dims.size() == 2) ^ (dims[0] == -1)) ||
273       (dims.size() == 2 && dims[1] == -1)) {
274     parser.emitError(dimPos)
275         << "expected '? x <integer> x <type>' or '<integer> x <type>'";
276     return Type();
277   }
278 
279   bool isScalable = dims.size() == 2;
280   if (isScalable)
281     return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
282   if (elementType.isSignlessIntOrFloat()) {
283     parser.emitError(typePos)
284         << "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
285     return Type();
286   }
287   return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
288 }
289 
290 /// Parses an LLVM dialect array type.
291 ///   llvm-type ::= `array<` integer `x` llvm-type `>`
parseArrayType(AsmParser & parser)292 static LLVMArrayType parseArrayType(AsmParser &parser) {
293   SmallVector<int64_t, 1> dims;
294   SMLoc sizePos;
295   Type elementType;
296   SMLoc loc = parser.getCurrentLocation();
297   if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
298       parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
299       dispatchParse(parser, elementType) || parser.parseGreater())
300     return LLVMArrayType();
301 
302   if (dims.size() != 1) {
303     parser.emitError(sizePos) << "expected ? x <type>";
304     return LLVMArrayType();
305   }
306 
307   return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]);
308 }
309 
310 /// Attempts to set the body of an identified structure type. Reports a parsing
311 /// error at `subtypesLoc` in case of failure.
trySetStructBody(LLVMStructType type,ArrayRef<Type> subtypes,bool isPacked,AsmParser & parser,SMLoc subtypesLoc)312 static LLVMStructType trySetStructBody(LLVMStructType type,
313                                        ArrayRef<Type> subtypes, bool isPacked,
314                                        AsmParser &parser, SMLoc subtypesLoc) {
315   for (Type t : subtypes) {
316     if (!LLVMStructType::isValidElementType(t)) {
317       parser.emitError(subtypesLoc)
318           << "invalid LLVM structure element type: " << t;
319       return LLVMStructType();
320     }
321   }
322 
323   if (succeeded(type.setBody(subtypes, isPacked)))
324     return type;
325 
326   parser.emitError(subtypesLoc)
327       << "identified type already used with a different body";
328   return LLVMStructType();
329 }
330 
331 /// Parses an LLVM dialect structure type.
332 ///   llvm-type ::= `struct<` (string-literal `,`)? `packed`?
333 ///                 `(` llvm-type-list `)` `>`
334 ///               | `struct<` string-literal `>`
335 ///               | `struct<` string-literal `, opaque>`
parseStructType(AsmParser & parser)336 static LLVMStructType parseStructType(AsmParser &parser) {
337   // This keeps track of the names of identified structure types that are
338   // currently being parsed. Since such types can refer themselves, this
339   // tracking is necessary to stop the recursion: the current function may be
340   // called recursively from AsmParser::parseType after the appropriate
341   // dispatch. We maintain the invariant of this storage being modified
342   // exclusively in this function, and at most one name being added per call.
343   // TODO: consider having such functionality inside AsmParser.
344   thread_local SetVector<StringRef> knownStructNames;
345   unsigned stackSize = knownStructNames.size();
346   (void)stackSize;
347   auto guard = llvm::make_scope_exit([&]() {
348     assert(knownStructNames.size() == stackSize &&
349            "malformed identified stack when parsing recursive structs");
350   });
351 
352   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
353 
354   if (failed(parser.parseLess()))
355     return LLVMStructType();
356 
357   // If we are parsing a self-reference to a recursive struct, i.e. the parsing
358   // stack already contains a struct with the same identifier, bail out after
359   // the name.
360   std::string name;
361   bool isIdentified = succeeded(parser.parseOptionalString(&name));
362   if (isIdentified) {
363     if (knownStructNames.count(name)) {
364       if (failed(parser.parseGreater()))
365         return LLVMStructType();
366       return LLVMStructType::getIdentifiedChecked(
367           [loc] { return emitError(loc); }, loc.getContext(), name);
368     }
369     if (failed(parser.parseComma()))
370       return LLVMStructType();
371   }
372 
373   // Handle intentionally opaque structs.
374   SMLoc kwLoc = parser.getCurrentLocation();
375   if (succeeded(parser.parseOptionalKeyword("opaque"))) {
376     if (!isIdentified)
377       return parser.emitError(kwLoc, "only identified structs can be opaque"),
378              LLVMStructType();
379     if (failed(parser.parseGreater()))
380       return LLVMStructType();
381     auto type = LLVMStructType::getOpaqueChecked(
382         [loc] { return emitError(loc); }, loc.getContext(), name);
383     if (!type.isOpaque()) {
384       parser.emitError(kwLoc, "redeclaring defined struct as opaque");
385       return LLVMStructType();
386     }
387     return type;
388   }
389 
390   // Check for packedness.
391   bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
392   if (failed(parser.parseLParen()))
393     return LLVMStructType();
394 
395   // Fast pass for structs with zero subtypes.
396   if (succeeded(parser.parseOptionalRParen())) {
397     if (failed(parser.parseGreater()))
398       return LLVMStructType();
399     if (!isIdentified)
400       return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); },
401                                                loc.getContext(), {}, isPacked);
402     auto type = LLVMStructType::getIdentifiedChecked(
403         [loc] { return emitError(loc); }, loc.getContext(), name);
404     return trySetStructBody(type, {}, isPacked, parser, kwLoc);
405   }
406 
407   // Parse subtypes. For identified structs, put the identifier of the struct on
408   // the stack to support self-references in the recursive calls.
409   SmallVector<Type, 4> subtypes;
410   SMLoc subtypesLoc = parser.getCurrentLocation();
411   do {
412     if (isIdentified)
413       knownStructNames.insert(name);
414     Type type;
415     if (dispatchParse(parser, type))
416       return LLVMStructType();
417     subtypes.push_back(type);
418     if (isIdentified)
419       knownStructNames.pop_back();
420   } while (succeeded(parser.parseOptionalComma()));
421 
422   if (parser.parseRParen() || parser.parseGreater())
423     return LLVMStructType();
424 
425   // Construct the struct with body.
426   if (!isIdentified)
427     return LLVMStructType::getLiteralChecked(
428         [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked);
429   auto type = LLVMStructType::getIdentifiedChecked(
430       [loc] { return emitError(loc); }, loc.getContext(), name);
431   return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
432 }
433 
434 /// Parses a type appearing inside another LLVM dialect-compatible type. This
435 /// will try to parse any type in full form (including types with the `!llvm`
436 /// prefix), and on failure fall back to parsing the short-hand version of the
437 /// LLVM dialect types without the `!llvm` prefix.
dispatchParse(AsmParser & parser,bool allowAny=true)438 static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
439   SMLoc keyLoc = parser.getCurrentLocation();
440 
441   // Try parsing any MLIR type.
442   Type type;
443   OptionalParseResult result = parser.parseOptionalType(type);
444   if (result.hasValue()) {
445     if (failed(result.getValue()))
446       return nullptr;
447     if (!allowAny) {
448       parser.emitError(keyLoc) << "unexpected type, expected keyword";
449       return nullptr;
450     }
451     return type;
452   }
453 
454   // If no type found, fallback to the shorthand form.
455   StringRef key;
456   if (failed(parser.parseKeyword(&key)))
457     return Type();
458 
459   MLIRContext *ctx = parser.getContext();
460   return StringSwitch<function_ref<Type()>>(key)
461       .Case("void", [&] { return LLVMVoidType::get(ctx); })
462       .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
463       .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); })
464       .Case("token", [&] { return LLVMTokenType::get(ctx); })
465       .Case("label", [&] { return LLVMLabelType::get(ctx); })
466       .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
467       .Case("func", [&] { return parseFunctionType(parser); })
468       .Case("ptr", [&] { return parsePointerType(parser); })
469       .Case("vec", [&] { return parseVectorType(parser); })
470       .Case("array", [&] { return parseArrayType(parser); })
471       .Case("struct", [&] { return parseStructType(parser); })
472       .Default([&] {
473         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
474         return Type();
475       })();
476 }
477 
478 /// Helper to use in parse lists.
dispatchParse(AsmParser & parser,Type & type)479 static ParseResult dispatchParse(AsmParser &parser, Type &type) {
480   type = dispatchParse(parser);
481   return success(type != nullptr);
482 }
483 
484 /// Parses one of the LLVM dialect types.
parseType(DialectAsmParser & parser)485 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
486   SMLoc loc = parser.getCurrentLocation();
487   Type type = dispatchParse(parser, /*allowAny=*/false);
488   if (!type)
489     return type;
490   if (!isCompatibleOuterType(type)) {
491     parser.emitError(loc) << "unexpected type, expected keyword";
492     return nullptr;
493   }
494   return type;
495 }
496