1 //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/TensorEncoding.h"
18 
19 using namespace mlir;
20 using namespace mlir::detail;
21 
22 /// Optionally parse a type.
parseOptionalType(Type & type)23 OptionalParseResult Parser::parseOptionalType(Type &type) {
24   // There are many different starting tokens for a type, check them here.
25   switch (getToken().getKind()) {
26   case Token::l_paren:
27   case Token::kw_memref:
28   case Token::kw_tensor:
29   case Token::kw_complex:
30   case Token::kw_tuple:
31   case Token::kw_vector:
32   case Token::inttype:
33   case Token::kw_bf16:
34   case Token::kw_f16:
35   case Token::kw_f32:
36   case Token::kw_f64:
37   case Token::kw_f80:
38   case Token::kw_f128:
39   case Token::kw_index:
40   case Token::kw_none:
41   case Token::exclamation_identifier:
42     return failure(!(type = parseType()));
43 
44   default:
45     return llvm::None;
46   }
47 }
48 
49 /// Parse an arbitrary type.
50 ///
51 ///   type ::= function-type
52 ///          | non-function-type
53 ///
parseType()54 Type Parser::parseType() {
55   if (getToken().is(Token::l_paren))
56     return parseFunctionType();
57   return parseNonFunctionType();
58 }
59 
60 /// Parse a function result type.
61 ///
62 ///   function-result-type ::= type-list-parens
63 ///                          | non-function-type
64 ///
parseFunctionResultTypes(SmallVectorImpl<Type> & elements)65 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
66   if (getToken().is(Token::l_paren))
67     return parseTypeListParens(elements);
68 
69   Type t = parseNonFunctionType();
70   if (!t)
71     return failure();
72   elements.push_back(t);
73   return success();
74 }
75 
76 /// Parse a list of types without an enclosing parenthesis.  The list must have
77 /// at least one member.
78 ///
79 ///   type-list-no-parens ::=  type (`,` type)*
80 ///
parseTypeListNoParens(SmallVectorImpl<Type> & elements)81 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
82   auto parseElt = [&]() -> ParseResult {
83     auto elt = parseType();
84     elements.push_back(elt);
85     return elt ? success() : failure();
86   };
87 
88   return parseCommaSeparatedList(parseElt);
89 }
90 
91 /// Parse a parenthesized list of types.
92 ///
93 ///   type-list-parens ::= `(` `)`
94 ///                      | `(` type-list-no-parens `)`
95 ///
parseTypeListParens(SmallVectorImpl<Type> & elements)96 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
97   if (parseToken(Token::l_paren, "expected '('"))
98     return failure();
99 
100   // Handle empty lists.
101   if (getToken().is(Token::r_paren))
102     return consumeToken(), success();
103 
104   if (parseTypeListNoParens(elements) ||
105       parseToken(Token::r_paren, "expected ')'"))
106     return failure();
107   return success();
108 }
109 
110 /// Parse a complex type.
111 ///
112 ///   complex-type ::= `complex` `<` type `>`
113 ///
parseComplexType()114 Type Parser::parseComplexType() {
115   consumeToken(Token::kw_complex);
116 
117   // Parse the '<'.
118   if (parseToken(Token::less, "expected '<' in complex type"))
119     return nullptr;
120 
121   SMLoc elementTypeLoc = getToken().getLoc();
122   auto elementType = parseType();
123   if (!elementType ||
124       parseToken(Token::greater, "expected '>' in complex type"))
125     return nullptr;
126   if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
127     return emitError(elementTypeLoc, "invalid element type for complex"),
128            nullptr;
129 
130   return ComplexType::get(elementType);
131 }
132 
133 /// Parse a function type.
134 ///
135 ///   function-type ::= type-list-parens `->` function-result-type
136 ///
parseFunctionType()137 Type Parser::parseFunctionType() {
138   assert(getToken().is(Token::l_paren));
139 
140   SmallVector<Type, 4> arguments, results;
141   if (parseTypeListParens(arguments) ||
142       parseToken(Token::arrow, "expected '->' in function type") ||
143       parseFunctionResultTypes(results))
144     return nullptr;
145 
146   return builder.getFunctionType(arguments, results);
147 }
148 
149 /// Parse the offset and strides from a strided layout specification.
150 ///
151 ///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
152 ///
parseStridedLayout(int64_t & offset,SmallVectorImpl<int64_t> & strides)153 ParseResult Parser::parseStridedLayout(int64_t &offset,
154                                        SmallVectorImpl<int64_t> &strides) {
155   // Parse offset.
156   consumeToken(Token::kw_offset);
157   if (parseToken(Token::colon, "expected colon after `offset` keyword"))
158     return failure();
159 
160   auto maybeOffset = getToken().getUnsignedIntegerValue();
161   bool question = getToken().is(Token::question);
162   if (!maybeOffset && !question)
163     return emitWrongTokenError("invalid offset");
164   offset = maybeOffset ? static_cast<int64_t>(*maybeOffset)
165                        : MemRefType::getDynamicStrideOrOffset();
166   consumeToken();
167 
168   // Parse stride list.
169   if (parseToken(Token::comma, "expected comma after offset value") ||
170       parseToken(Token::kw_strides,
171                  "expected `strides` keyword after offset specification") ||
172       parseToken(Token::colon, "expected colon after `strides` keyword") ||
173       parseStrideList(strides))
174     return failure();
175   return success();
176 }
177 
178 /// Parse a memref type.
179 ///
180 ///   memref-type ::= ranked-memref-type | unranked-memref-type
181 ///
182 ///   ranked-memref-type ::= `memref` `<` dimension-list-ranked type
183 ///                          (`,` layout-specification)? (`,` memory-space)? `>`
184 ///
185 ///   unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
186 ///
187 ///   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
188 ///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
189 ///   layout-specification ::= semi-affine-map | strided-layout | attribute
190 ///   memory-space ::= integer-literal | attribute
191 ///
parseMemRefType()192 Type Parser::parseMemRefType() {
193   SMLoc loc = getToken().getLoc();
194   consumeToken(Token::kw_memref);
195 
196   if (parseToken(Token::less, "expected '<' in memref type"))
197     return nullptr;
198 
199   bool isUnranked;
200   SmallVector<int64_t, 4> dimensions;
201 
202   if (consumeIf(Token::star)) {
203     // This is an unranked memref type.
204     isUnranked = true;
205     if (parseXInDimensionList())
206       return nullptr;
207 
208   } else {
209     isUnranked = false;
210     if (parseDimensionListRanked(dimensions))
211       return nullptr;
212   }
213 
214   // Parse the element type.
215   auto typeLoc = getToken().getLoc();
216   auto elementType = parseType();
217   if (!elementType)
218     return nullptr;
219 
220   // Check that memref is formed from allowed types.
221   if (!BaseMemRefType::isValidElementType(elementType))
222     return emitError(typeLoc, "invalid memref element type"), nullptr;
223 
224   MemRefLayoutAttrInterface layout;
225   Attribute memorySpace;
226 
227   auto parseElt = [&]() -> ParseResult {
228     // Check for AffineMap as offset/strides.
229     if (getToken().is(Token::kw_offset)) {
230       int64_t offset;
231       SmallVector<int64_t, 4> strides;
232       if (failed(parseStridedLayout(offset, strides)))
233         return failure();
234       // Construct strided affine map.
235       AffineMap map = makeStridedLinearLayoutMap(strides, offset, getContext());
236       layout = AffineMapAttr::get(map);
237     } else {
238       // Either it is MemRefLayoutAttrInterface or memory space attribute.
239       Attribute attr = parseAttribute();
240       if (!attr)
241         return failure();
242 
243       if (attr.isa<MemRefLayoutAttrInterface>()) {
244         layout = attr.cast<MemRefLayoutAttrInterface>();
245       } else if (memorySpace) {
246         return emitError("multiple memory spaces specified in memref type");
247       } else {
248         memorySpace = attr;
249         return success();
250       }
251     }
252 
253     if (isUnranked)
254       return emitError("cannot have affine map for unranked memref type");
255     if (memorySpace)
256       return emitError("expected memory space to be last in memref type");
257 
258     return success();
259   };
260 
261   // Parse a list of mappings and address space if present.
262   if (!consumeIf(Token::greater)) {
263     // Parse comma separated list of affine maps, followed by memory space.
264     if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
265         parseCommaSeparatedListUntil(Token::greater, parseElt,
266                                      /*allowEmptyList=*/false)) {
267       return nullptr;
268     }
269   }
270 
271   if (isUnranked)
272     return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
273 
274   return getChecked<MemRefType>(loc, dimensions, elementType, layout,
275                                 memorySpace);
276 }
277 
278 /// Parse any type except the function type.
279 ///
280 ///   non-function-type ::= integer-type
281 ///                       | index-type
282 ///                       | float-type
283 ///                       | extended-type
284 ///                       | vector-type
285 ///                       | tensor-type
286 ///                       | memref-type
287 ///                       | complex-type
288 ///                       | tuple-type
289 ///                       | none-type
290 ///
291 ///   index-type ::= `index`
292 ///   float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
293 ///   none-type ::= `none`
294 ///
parseNonFunctionType()295 Type Parser::parseNonFunctionType() {
296   switch (getToken().getKind()) {
297   default:
298     return (emitWrongTokenError("expected non-function type"), nullptr);
299   case Token::kw_memref:
300     return parseMemRefType();
301   case Token::kw_tensor:
302     return parseTensorType();
303   case Token::kw_complex:
304     return parseComplexType();
305   case Token::kw_tuple:
306     return parseTupleType();
307   case Token::kw_vector:
308     return parseVectorType();
309   // integer-type
310   case Token::inttype: {
311     auto width = getToken().getIntTypeBitwidth();
312     if (!width.has_value())
313       return (emitError("invalid integer width"), nullptr);
314     if (width.value() > IntegerType::kMaxWidth) {
315       emitError(getToken().getLoc(), "integer bitwidth is limited to ")
316           << IntegerType::kMaxWidth << " bits";
317       return nullptr;
318     }
319 
320     IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
321     if (Optional<bool> signedness = getToken().getIntTypeSignedness())
322       signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
323 
324     consumeToken(Token::inttype);
325     return IntegerType::get(getContext(), *width, signSemantics);
326   }
327 
328   // float-type
329   case Token::kw_bf16:
330     consumeToken(Token::kw_bf16);
331     return builder.getBF16Type();
332   case Token::kw_f16:
333     consumeToken(Token::kw_f16);
334     return builder.getF16Type();
335   case Token::kw_f32:
336     consumeToken(Token::kw_f32);
337     return builder.getF32Type();
338   case Token::kw_f64:
339     consumeToken(Token::kw_f64);
340     return builder.getF64Type();
341   case Token::kw_f80:
342     consumeToken(Token::kw_f80);
343     return builder.getF80Type();
344   case Token::kw_f128:
345     consumeToken(Token::kw_f128);
346     return builder.getF128Type();
347 
348   // index-type
349   case Token::kw_index:
350     consumeToken(Token::kw_index);
351     return builder.getIndexType();
352 
353   // none-type
354   case Token::kw_none:
355     consumeToken(Token::kw_none);
356     return builder.getNoneType();
357 
358   // extended type
359   case Token::exclamation_identifier:
360     return parseExtendedType();
361 
362   // Handle completion of a dialect type.
363   case Token::code_complete:
364     if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
365       return parseExtendedType();
366     return codeCompleteType();
367   }
368 }
369 
370 /// Parse a tensor type.
371 ///
372 ///   tensor-type ::= `tensor` `<` dimension-list type `>`
373 ///   dimension-list ::= dimension-list-ranked | `*x`
374 ///
parseTensorType()375 Type Parser::parseTensorType() {
376   consumeToken(Token::kw_tensor);
377 
378   if (parseToken(Token::less, "expected '<' in tensor type"))
379     return nullptr;
380 
381   bool isUnranked;
382   SmallVector<int64_t, 4> dimensions;
383 
384   if (consumeIf(Token::star)) {
385     // This is an unranked tensor type.
386     isUnranked = true;
387 
388     if (parseXInDimensionList())
389       return nullptr;
390 
391   } else {
392     isUnranked = false;
393     if (parseDimensionListRanked(dimensions))
394       return nullptr;
395   }
396 
397   // Parse the element type.
398   auto elementTypeLoc = getToken().getLoc();
399   auto elementType = parseType();
400 
401   // Parse an optional encoding attribute.
402   Attribute encoding;
403   if (consumeIf(Token::comma)) {
404     encoding = parseAttribute();
405     if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
406       if (failed(v.verifyEncoding(dimensions, elementType,
407                                   [&] { return emitError(); })))
408         return nullptr;
409     }
410   }
411 
412   if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
413     return nullptr;
414   if (!TensorType::isValidElementType(elementType))
415     return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
416 
417   if (isUnranked) {
418     if (encoding)
419       return emitError("cannot apply encoding to unranked tensor"), nullptr;
420     return UnrankedTensorType::get(elementType);
421   }
422   return RankedTensorType::get(dimensions, elementType, encoding);
423 }
424 
425 /// Parse a tuple type.
426 ///
427 ///   tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
428 ///
parseTupleType()429 Type Parser::parseTupleType() {
430   consumeToken(Token::kw_tuple);
431 
432   // Parse the '<'.
433   if (parseToken(Token::less, "expected '<' in tuple type"))
434     return nullptr;
435 
436   // Check for an empty tuple by directly parsing '>'.
437   if (consumeIf(Token::greater))
438     return TupleType::get(getContext());
439 
440   // Parse the element types and the '>'.
441   SmallVector<Type, 4> types;
442   if (parseTypeListNoParens(types) ||
443       parseToken(Token::greater, "expected '>' in tuple type"))
444     return nullptr;
445 
446   return TupleType::get(getContext(), types);
447 }
448 
449 /// Parse a vector type.
450 ///
451 /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
452 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
453 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
454 ///
parseVectorType()455 VectorType Parser::parseVectorType() {
456   consumeToken(Token::kw_vector);
457 
458   if (parseToken(Token::less, "expected '<' in vector type"))
459     return nullptr;
460 
461   SmallVector<int64_t, 4> dimensions;
462   unsigned numScalableDims;
463   if (parseVectorDimensionList(dimensions, numScalableDims))
464     return nullptr;
465   if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
466     return emitError(getToken().getLoc(),
467                      "vector types must have positive constant sizes"),
468            nullptr;
469 
470   // Parse the element type.
471   auto typeLoc = getToken().getLoc();
472   auto elementType = parseType();
473   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
474     return nullptr;
475 
476   if (!VectorType::isValidElementType(elementType))
477     return emitError(typeLoc, "vector elements must be int/index/float type"),
478            nullptr;
479 
480   return VectorType::get(dimensions, elementType, numScalableDims);
481 }
482 
483 /// Parse a dimension list in a vector type. This populates the dimension list,
484 /// and returns the number of scalable dimensions in `numScalableDims`.
485 ///
486 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
487 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
488 ///
489 ParseResult
parseVectorDimensionList(SmallVectorImpl<int64_t> & dimensions,unsigned & numScalableDims)490 Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
491                                  unsigned &numScalableDims) {
492   numScalableDims = 0;
493   // If there is a set of fixed-length dimensions, consume it
494   while (getToken().is(Token::integer)) {
495     int64_t value;
496     if (parseIntegerInDimensionList(value))
497       return failure();
498     dimensions.push_back(value);
499     // Make sure we have an 'x' or something like 'xbf32'.
500     if (parseXInDimensionList())
501       return failure();
502   }
503   // If there is a set of scalable dimensions, consume it
504   if (consumeIf(Token::l_square)) {
505     while (getToken().is(Token::integer)) {
506       int64_t value;
507       if (parseIntegerInDimensionList(value))
508         return failure();
509       dimensions.push_back(value);
510       numScalableDims++;
511       // Check if we have reached the end of the scalable dimension list
512       if (consumeIf(Token::r_square)) {
513         // Make sure we have something like 'xbf32'.
514         return parseXInDimensionList();
515       }
516       // Make sure we have an 'x'
517       if (parseXInDimensionList())
518         return failure();
519     }
520     // If we make it here, we've finished parsing the dimension list
521     // without finding ']' closing the set of scalable dimensions
522     return emitWrongTokenError(
523         "missing ']' closing set of scalable dimensions");
524   }
525 
526   return success();
527 }
528 
529 /// Parse a dimension list of a tensor or memref type.  This populates the
530 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
531 /// errors out on `?` otherwise. Parsing the trailing `x` is configurable.
532 ///
533 ///   dimension-list ::= eps | dimension (`x` dimension)*
534 ///   dimension-list-with-trailing-x ::= (dimension `x`)*
535 ///   dimension ::= `?` | decimal-literal
536 ///
537 /// When `allowDynamic` is not set, this is used to parse:
538 ///
539 ///   static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
540 ///   static-dimension-list-with-trailing-x ::= (dimension `x`)*
541 ParseResult
parseDimensionListRanked(SmallVectorImpl<int64_t> & dimensions,bool allowDynamic,bool withTrailingX)542 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
543                                  bool allowDynamic, bool withTrailingX) {
544   auto parseDim = [&]() -> LogicalResult {
545     auto loc = getToken().getLoc();
546     if (consumeIf(Token::question)) {
547       if (!allowDynamic)
548         return emitError(loc, "expected static shape");
549       dimensions.push_back(-1);
550     } else {
551       int64_t value;
552       if (failed(parseIntegerInDimensionList(value)))
553         return failure();
554       dimensions.push_back(value);
555     }
556     return success();
557   };
558 
559   if (withTrailingX) {
560     while (getToken().isAny(Token::integer, Token::question)) {
561       if (failed(parseDim()) || failed(parseXInDimensionList()))
562         return failure();
563     }
564     return success();
565   }
566 
567   if (getToken().isAny(Token::integer, Token::question)) {
568     if (failed(parseDim()))
569       return failure();
570     while (getToken().is(Token::bare_identifier) &&
571            getTokenSpelling()[0] == 'x') {
572       if (failed(parseXInDimensionList()) || failed(parseDim()))
573         return failure();
574     }
575   }
576   return success();
577 }
578 
parseIntegerInDimensionList(int64_t & value)579 ParseResult Parser::parseIntegerInDimensionList(int64_t &value) {
580   // Hexadecimal integer literals (starting with `0x`) are not allowed in
581   // aggregate type declarations.  Therefore, `0xf32` should be processed as
582   // a sequence of separate elements `0`, `x`, `f32`.
583   if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
584     // We can get here only if the token is an integer literal.  Hexadecimal
585     // integer literals can only start with `0x` (`1x` wouldn't lex as a
586     // literal, just `1` would, at which point we don't get into this
587     // branch).
588     assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
589     value = 0;
590     state.lex.resetPointer(getTokenSpelling().data() + 1);
591     consumeToken();
592   } else {
593     // Make sure this integer value is in bound and valid.
594     Optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
595     if (!dimension ||
596         *dimension > (uint64_t)std::numeric_limits<int64_t>::max())
597       return emitError("invalid dimension");
598     value = (int64_t)*dimension;
599     consumeToken(Token::integer);
600   }
601   return success();
602 }
603 
604 /// Parse an 'x' token in a dimension list, handling the case where the x is
605 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
606 /// token.
parseXInDimensionList()607 ParseResult Parser::parseXInDimensionList() {
608   if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
609     return emitWrongTokenError("expected 'x' in dimension list");
610 
611   // If we had a prefix of 'x', lex the next token immediately after the 'x'.
612   if (getTokenSpelling().size() != 1)
613     state.lex.resetPointer(getTokenSpelling().data() + 1);
614 
615   // Consume the 'x'.
616   consumeToken(Token::bare_identifier);
617 
618   return success();
619 }
620 
621 // Parse a comma-separated list of dimensions, possibly empty:
622 //   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
parseStrideList(SmallVectorImpl<int64_t> & dimensions)623 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
624   return parseCommaSeparatedList(
625       Delimiter::Square,
626       [&]() -> ParseResult {
627         if (consumeIf(Token::question)) {
628           dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
629         } else {
630           // This must be an integer value.
631           int64_t val;
632           if (getToken().getSpelling().getAsInteger(10, val))
633             return emitError("invalid integer value: ")
634                    << getToken().getSpelling();
635           // Make sure it is not the one value for `?`.
636           if (ShapedType::isDynamic(val))
637             return emitError("invalid integer value: ")
638                    << getToken().getSpelling()
639                    << ", use `?` to specify a dynamic dimension";
640 
641           if (val == 0)
642             return emitError("invalid memref stride");
643 
644           dimensions.push_back(val);
645           consumeToken(Token::integer);
646         }
647         return success();
648       },
649       " in stride list");
650 }
651