1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/ScopeExit.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/ADT/TypeSwitch.h"
15
16 using namespace mlir;
17 using namespace mlir::LLVM;
18
19 //===----------------------------------------------------------------------===//
20 // Printing.
21 //===----------------------------------------------------------------------===//
22
23 /// If the given type is compatible with the LLVM dialect, prints it using
24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
25 /// prints it as usual.
dispatchPrint(AsmPrinter & printer,Type type)26 static void dispatchPrint(AsmPrinter &printer, Type type) {
27 if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>())
28 return mlir::LLVM::detail::printType(type, printer);
29 printer.printType(type);
30 }
31
32 /// Returns the keyword to use for the given type.
getTypeKeyword(Type type)33 static StringRef getTypeKeyword(Type type) {
34 return TypeSwitch<Type, StringRef>(type)
35 .Case<LLVMVoidType>([&](Type) { return "void"; })
36 .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
37 .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
38 .Case<LLVMTokenType>([&](Type) { return "token"; })
39 .Case<LLVMLabelType>([&](Type) { return "label"; })
40 .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
41 .Case<LLVMFunctionType>([&](Type) { return "func"; })
42 .Case<LLVMPointerType>([&](Type) { return "ptr"; })
43 .Case<LLVMFixedVectorType, LLVMScalableVectorType>(
44 [&](Type) { return "vec"; })
45 .Case<LLVMArrayType>([&](Type) { return "array"; })
46 .Case<LLVMStructType>([&](Type) { return "struct"; })
47 .Default([](Type) -> StringRef {
48 llvm_unreachable("unexpected 'llvm' type kind");
49 });
50 }
51
52 /// Prints a structure type. Keeps track of known struct names to handle self-
53 /// or mutually-referring structs without falling into infinite recursion.
printStructType(AsmPrinter & printer,LLVMStructType type)54 static void printStructType(AsmPrinter &printer, LLVMStructType type) {
55 // This keeps track of the names of identified structure types that are
56 // currently being printed. Since such types can refer themselves, this
57 // tracking is necessary to stop the recursion: the current function may be
58 // called recursively from AsmPrinter::printType after the appropriate
59 // dispatch. We maintain the invariant of this storage being modified
60 // exclusively in this function, and at most one name being added per call.
61 // TODO: consider having such functionality inside AsmPrinter.
62 thread_local SetVector<StringRef> knownStructNames;
63 unsigned stackSize = knownStructNames.size();
64 (void)stackSize;
65 auto guard = llvm::make_scope_exit([&]() {
66 assert(knownStructNames.size() == stackSize &&
67 "malformed identified stack when printing recursive structs");
68 });
69
70 printer << "<";
71 if (type.isIdentified()) {
72 printer << '"' << type.getName() << '"';
73 // If we are printing a reference to one of the enclosing structs, just
74 // print the name and stop to avoid infinitely long output.
75 if (knownStructNames.count(type.getName())) {
76 printer << '>';
77 return;
78 }
79 printer << ", ";
80 }
81
82 if (type.isIdentified() && type.isOpaque()) {
83 printer << "opaque>";
84 return;
85 }
86
87 if (type.isPacked())
88 printer << "packed ";
89
90 // Put the current type on stack to avoid infinite recursion.
91 printer << '(';
92 if (type.isIdentified())
93 knownStructNames.insert(type.getName());
94 llvm::interleaveComma(type.getBody(), printer.getStream(),
95 [&](Type subtype) { dispatchPrint(printer, subtype); });
96 if (type.isIdentified())
97 knownStructNames.pop_back();
98 printer << ')';
99 printer << '>';
100 }
101
102 /// Prints a type containing a fixed number of elements.
103 template <typename TypeTy>
printArrayOrVectorType(AsmPrinter & printer,TypeTy type)104 static void printArrayOrVectorType(AsmPrinter &printer, TypeTy type) {
105 printer << '<' << type.getNumElements() << " x ";
106 dispatchPrint(printer, type.getElementType());
107 printer << '>';
108 }
109
110 /// Prints a function type.
printFunctionType(AsmPrinter & printer,LLVMFunctionType funcType)111 static void printFunctionType(AsmPrinter &printer, LLVMFunctionType funcType) {
112 printer << '<';
113 dispatchPrint(printer, funcType.getReturnType());
114 printer << " (";
115 llvm::interleaveComma(
116 funcType.getParams(), printer.getStream(),
117 [&printer](Type subtype) { dispatchPrint(printer, subtype); });
118 if (funcType.isVarArg()) {
119 if (funcType.getNumParams() != 0)
120 printer << ", ";
121 printer << "...";
122 }
123 printer << ")>";
124 }
125
126 /// Prints the given LLVM dialect type recursively. This leverages closedness of
127 /// the LLVM dialect type system to avoid printing the dialect prefix
128 /// repeatedly. For recursive structures, only prints the name of the structure
129 /// when printing a self-reference. Note that this does not apply to sibling
130 /// references. For example,
131 /// struct<"a", (ptr<struct<"a">>)>
132 /// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
133 /// ptr<struct<"b", (ptr<struct<"c">>)>>)>
134 /// note that "b" is printed twice.
printType(Type type,AsmPrinter & printer)135 void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
136 if (!type) {
137 printer << "<<NULL-TYPE>>";
138 return;
139 }
140
141 printer << getTypeKeyword(type);
142
143 if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
144 if (ptrType.isOpaque()) {
145 if (ptrType.getAddressSpace() != 0)
146 printer << '<' << ptrType.getAddressSpace() << '>';
147 return;
148 }
149
150 printer << '<';
151 dispatchPrint(printer, ptrType.getElementType());
152 if (ptrType.getAddressSpace() != 0)
153 printer << ", " << ptrType.getAddressSpace();
154 printer << '>';
155 return;
156 }
157
158 if (auto arrayType = type.dyn_cast<LLVMArrayType>())
159 return printArrayOrVectorType(printer, arrayType);
160 if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
161 return printArrayOrVectorType(printer, vectorType);
162
163 if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
164 printer << "<? x " << vectorType.getMinNumElements() << " x ";
165 dispatchPrint(printer, vectorType.getElementType());
166 printer << '>';
167 return;
168 }
169
170 if (auto structType = type.dyn_cast<LLVMStructType>())
171 return printStructType(printer, structType);
172
173 if (auto funcType = type.dyn_cast<LLVMFunctionType>())
174 return printFunctionType(printer, funcType);
175 }
176
177 //===----------------------------------------------------------------------===//
178 // Parsing.
179 //===----------------------------------------------------------------------===//
180
181 static ParseResult dispatchParse(AsmParser &parser, Type &type);
182
183 /// Parses an LLVM dialect function type.
184 /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
parseFunctionType(AsmParser & parser)185 static LLVMFunctionType parseFunctionType(AsmParser &parser) {
186 SMLoc loc = parser.getCurrentLocation();
187 Type returnType;
188 if (parser.parseLess() || dispatchParse(parser, returnType) ||
189 parser.parseLParen())
190 return LLVMFunctionType();
191
192 // Function type without arguments.
193 if (succeeded(parser.parseOptionalRParen())) {
194 if (succeeded(parser.parseGreater()))
195 return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
196 /*isVarArg=*/false);
197 return LLVMFunctionType();
198 }
199
200 // Parse arguments.
201 SmallVector<Type, 8> argTypes;
202 do {
203 if (succeeded(parser.parseOptionalEllipsis())) {
204 if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
205 return LLVMFunctionType();
206 return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
207 /*isVarArg=*/true);
208 }
209
210 Type arg;
211 if (dispatchParse(parser, arg))
212 return LLVMFunctionType();
213 argTypes.push_back(arg);
214 } while (succeeded(parser.parseOptionalComma()));
215
216 if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
217 return LLVMFunctionType();
218 return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
219 /*isVarArg=*/false);
220 }
221
222 /// Parses an LLVM dialect pointer type.
223 /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
224 /// | `ptr` (`<` integer `>`)?
parsePointerType(AsmParser & parser)225 static LLVMPointerType parsePointerType(AsmParser &parser) {
226 SMLoc loc = parser.getCurrentLocation();
227 Type elementType;
228 if (parser.parseOptionalLess()) {
229 return parser.getChecked<LLVMPointerType>(loc, parser.getContext(),
230 /*addressSpace=*/0);
231 }
232
233 unsigned addressSpace = 0;
234 OptionalParseResult opr = parser.parseOptionalInteger(addressSpace);
235 if (opr.hasValue()) {
236 if (failed(*opr) || parser.parseGreater())
237 return LLVMPointerType();
238 return parser.getChecked<LLVMPointerType>(loc, parser.getContext(),
239 addressSpace);
240 }
241
242 if (dispatchParse(parser, elementType))
243 return LLVMPointerType();
244
245 if (succeeded(parser.parseOptionalComma()) &&
246 failed(parser.parseInteger(addressSpace)))
247 return LLVMPointerType();
248 if (failed(parser.parseGreater()))
249 return LLVMPointerType();
250 return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
251 }
252
253 /// Parses an LLVM dialect vector type.
254 /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
255 /// Supports both fixed and scalable vectors.
parseVectorType(AsmParser & parser)256 static Type parseVectorType(AsmParser &parser) {
257 SmallVector<int64_t, 2> dims;
258 SMLoc dimPos, typePos;
259 Type elementType;
260 SMLoc loc = parser.getCurrentLocation();
261 if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
262 parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
263 parser.getCurrentLocation(&typePos) ||
264 dispatchParse(parser, elementType) || parser.parseGreater())
265 return Type();
266
267 // We parsed a generic dimension list, but vectors only support two forms:
268 // - single non-dynamic entry in the list (fixed vector);
269 // - two elements, the first dynamic (indicated by -1) and the second
270 // non-dynamic (scalable vector).
271 if (dims.empty() || dims.size() > 2 ||
272 ((dims.size() == 2) ^ (dims[0] == -1)) ||
273 (dims.size() == 2 && dims[1] == -1)) {
274 parser.emitError(dimPos)
275 << "expected '? x <integer> x <type>' or '<integer> x <type>'";
276 return Type();
277 }
278
279 bool isScalable = dims.size() == 2;
280 if (isScalable)
281 return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
282 if (elementType.isSignlessIntOrFloat()) {
283 parser.emitError(typePos)
284 << "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
285 return Type();
286 }
287 return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
288 }
289
290 /// Parses an LLVM dialect array type.
291 /// llvm-type ::= `array<` integer `x` llvm-type `>`
parseArrayType(AsmParser & parser)292 static LLVMArrayType parseArrayType(AsmParser &parser) {
293 SmallVector<int64_t, 1> dims;
294 SMLoc sizePos;
295 Type elementType;
296 SMLoc loc = parser.getCurrentLocation();
297 if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
298 parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
299 dispatchParse(parser, elementType) || parser.parseGreater())
300 return LLVMArrayType();
301
302 if (dims.size() != 1) {
303 parser.emitError(sizePos) << "expected ? x <type>";
304 return LLVMArrayType();
305 }
306
307 return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]);
308 }
309
310 /// Attempts to set the body of an identified structure type. Reports a parsing
311 /// error at `subtypesLoc` in case of failure.
trySetStructBody(LLVMStructType type,ArrayRef<Type> subtypes,bool isPacked,AsmParser & parser,SMLoc subtypesLoc)312 static LLVMStructType trySetStructBody(LLVMStructType type,
313 ArrayRef<Type> subtypes, bool isPacked,
314 AsmParser &parser, SMLoc subtypesLoc) {
315 for (Type t : subtypes) {
316 if (!LLVMStructType::isValidElementType(t)) {
317 parser.emitError(subtypesLoc)
318 << "invalid LLVM structure element type: " << t;
319 return LLVMStructType();
320 }
321 }
322
323 if (succeeded(type.setBody(subtypes, isPacked)))
324 return type;
325
326 parser.emitError(subtypesLoc)
327 << "identified type already used with a different body";
328 return LLVMStructType();
329 }
330
331 /// Parses an LLVM dialect structure type.
332 /// llvm-type ::= `struct<` (string-literal `,`)? `packed`?
333 /// `(` llvm-type-list `)` `>`
334 /// | `struct<` string-literal `>`
335 /// | `struct<` string-literal `, opaque>`
parseStructType(AsmParser & parser)336 static LLVMStructType parseStructType(AsmParser &parser) {
337 // This keeps track of the names of identified structure types that are
338 // currently being parsed. Since such types can refer themselves, this
339 // tracking is necessary to stop the recursion: the current function may be
340 // called recursively from AsmParser::parseType after the appropriate
341 // dispatch. We maintain the invariant of this storage being modified
342 // exclusively in this function, and at most one name being added per call.
343 // TODO: consider having such functionality inside AsmParser.
344 thread_local SetVector<StringRef> knownStructNames;
345 unsigned stackSize = knownStructNames.size();
346 (void)stackSize;
347 auto guard = llvm::make_scope_exit([&]() {
348 assert(knownStructNames.size() == stackSize &&
349 "malformed identified stack when parsing recursive structs");
350 });
351
352 Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
353
354 if (failed(parser.parseLess()))
355 return LLVMStructType();
356
357 // If we are parsing a self-reference to a recursive struct, i.e. the parsing
358 // stack already contains a struct with the same identifier, bail out after
359 // the name.
360 std::string name;
361 bool isIdentified = succeeded(parser.parseOptionalString(&name));
362 if (isIdentified) {
363 if (knownStructNames.count(name)) {
364 if (failed(parser.parseGreater()))
365 return LLVMStructType();
366 return LLVMStructType::getIdentifiedChecked(
367 [loc] { return emitError(loc); }, loc.getContext(), name);
368 }
369 if (failed(parser.parseComma()))
370 return LLVMStructType();
371 }
372
373 // Handle intentionally opaque structs.
374 SMLoc kwLoc = parser.getCurrentLocation();
375 if (succeeded(parser.parseOptionalKeyword("opaque"))) {
376 if (!isIdentified)
377 return parser.emitError(kwLoc, "only identified structs can be opaque"),
378 LLVMStructType();
379 if (failed(parser.parseGreater()))
380 return LLVMStructType();
381 auto type = LLVMStructType::getOpaqueChecked(
382 [loc] { return emitError(loc); }, loc.getContext(), name);
383 if (!type.isOpaque()) {
384 parser.emitError(kwLoc, "redeclaring defined struct as opaque");
385 return LLVMStructType();
386 }
387 return type;
388 }
389
390 // Check for packedness.
391 bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
392 if (failed(parser.parseLParen()))
393 return LLVMStructType();
394
395 // Fast pass for structs with zero subtypes.
396 if (succeeded(parser.parseOptionalRParen())) {
397 if (failed(parser.parseGreater()))
398 return LLVMStructType();
399 if (!isIdentified)
400 return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); },
401 loc.getContext(), {}, isPacked);
402 auto type = LLVMStructType::getIdentifiedChecked(
403 [loc] { return emitError(loc); }, loc.getContext(), name);
404 return trySetStructBody(type, {}, isPacked, parser, kwLoc);
405 }
406
407 // Parse subtypes. For identified structs, put the identifier of the struct on
408 // the stack to support self-references in the recursive calls.
409 SmallVector<Type, 4> subtypes;
410 SMLoc subtypesLoc = parser.getCurrentLocation();
411 do {
412 if (isIdentified)
413 knownStructNames.insert(name);
414 Type type;
415 if (dispatchParse(parser, type))
416 return LLVMStructType();
417 subtypes.push_back(type);
418 if (isIdentified)
419 knownStructNames.pop_back();
420 } while (succeeded(parser.parseOptionalComma()));
421
422 if (parser.parseRParen() || parser.parseGreater())
423 return LLVMStructType();
424
425 // Construct the struct with body.
426 if (!isIdentified)
427 return LLVMStructType::getLiteralChecked(
428 [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked);
429 auto type = LLVMStructType::getIdentifiedChecked(
430 [loc] { return emitError(loc); }, loc.getContext(), name);
431 return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
432 }
433
434 /// Parses a type appearing inside another LLVM dialect-compatible type. This
435 /// will try to parse any type in full form (including types with the `!llvm`
436 /// prefix), and on failure fall back to parsing the short-hand version of the
437 /// LLVM dialect types without the `!llvm` prefix.
dispatchParse(AsmParser & parser,bool allowAny=true)438 static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
439 SMLoc keyLoc = parser.getCurrentLocation();
440
441 // Try parsing any MLIR type.
442 Type type;
443 OptionalParseResult result = parser.parseOptionalType(type);
444 if (result.hasValue()) {
445 if (failed(result.getValue()))
446 return nullptr;
447 if (!allowAny) {
448 parser.emitError(keyLoc) << "unexpected type, expected keyword";
449 return nullptr;
450 }
451 return type;
452 }
453
454 // If no type found, fallback to the shorthand form.
455 StringRef key;
456 if (failed(parser.parseKeyword(&key)))
457 return Type();
458
459 MLIRContext *ctx = parser.getContext();
460 return StringSwitch<function_ref<Type()>>(key)
461 .Case("void", [&] { return LLVMVoidType::get(ctx); })
462 .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
463 .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); })
464 .Case("token", [&] { return LLVMTokenType::get(ctx); })
465 .Case("label", [&] { return LLVMLabelType::get(ctx); })
466 .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
467 .Case("func", [&] { return parseFunctionType(parser); })
468 .Case("ptr", [&] { return parsePointerType(parser); })
469 .Case("vec", [&] { return parseVectorType(parser); })
470 .Case("array", [&] { return parseArrayType(parser); })
471 .Case("struct", [&] { return parseStructType(parser); })
472 .Default([&] {
473 parser.emitError(keyLoc) << "unknown LLVM type: " << key;
474 return Type();
475 })();
476 }
477
478 /// Helper to use in parse lists.
dispatchParse(AsmParser & parser,Type & type)479 static ParseResult dispatchParse(AsmParser &parser, Type &type) {
480 type = dispatchParse(parser);
481 return success(type != nullptr);
482 }
483
484 /// Parses one of the LLVM dialect types.
parseType(DialectAsmParser & parser)485 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
486 SMLoc loc = parser.getCurrentLocation();
487 Type type = dispatchParse(parser, /*allowAny=*/false);
488 if (!type)
489 return type;
490 if (!isCompatibleOuterType(type)) {
491 parser.emitError(loc) << "unexpected type, expected keyword";
492 return nullptr;
493 }
494 return type;
495 }
496