1*c60b897dSRiver Riddle //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2*c60b897dSRiver Riddle //
3*c60b897dSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*c60b897dSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*c60b897dSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*c60b897dSRiver Riddle //
7*c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
8*c60b897dSRiver Riddle //
9*c60b897dSRiver Riddle // This file implements the parser for the MLIR Types.
10*c60b897dSRiver Riddle //
11*c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
12*c60b897dSRiver Riddle 
13*c60b897dSRiver Riddle #include "Parser.h"
14*c60b897dSRiver Riddle #include "mlir/IR/AffineMap.h"
15*c60b897dSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
16*c60b897dSRiver Riddle #include "mlir/IR/OpDefinition.h"
17*c60b897dSRiver Riddle #include "mlir/IR/TensorEncoding.h"
18*c60b897dSRiver Riddle 
19*c60b897dSRiver Riddle using namespace mlir;
20*c60b897dSRiver Riddle using namespace mlir::detail;
21*c60b897dSRiver Riddle 
22*c60b897dSRiver Riddle /// Optionally parse a type.
parseOptionalType(Type & type)23*c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalType(Type &type) {
24*c60b897dSRiver Riddle   // There are many different starting tokens for a type, check them here.
25*c60b897dSRiver Riddle   switch (getToken().getKind()) {
26*c60b897dSRiver Riddle   case Token::l_paren:
27*c60b897dSRiver Riddle   case Token::kw_memref:
28*c60b897dSRiver Riddle   case Token::kw_tensor:
29*c60b897dSRiver Riddle   case Token::kw_complex:
30*c60b897dSRiver Riddle   case Token::kw_tuple:
31*c60b897dSRiver Riddle   case Token::kw_vector:
32*c60b897dSRiver Riddle   case Token::inttype:
33*c60b897dSRiver Riddle   case Token::kw_bf16:
34*c60b897dSRiver Riddle   case Token::kw_f16:
35*c60b897dSRiver Riddle   case Token::kw_f32:
36*c60b897dSRiver Riddle   case Token::kw_f64:
37*c60b897dSRiver Riddle   case Token::kw_f80:
38*c60b897dSRiver Riddle   case Token::kw_f128:
39*c60b897dSRiver Riddle   case Token::kw_index:
40*c60b897dSRiver Riddle   case Token::kw_none:
41*c60b897dSRiver Riddle   case Token::exclamation_identifier:
42*c60b897dSRiver Riddle     return failure(!(type = parseType()));
43*c60b897dSRiver Riddle 
44*c60b897dSRiver Riddle   default:
45*c60b897dSRiver Riddle     return llvm::None;
46*c60b897dSRiver Riddle   }
47*c60b897dSRiver Riddle }
48*c60b897dSRiver Riddle 
49*c60b897dSRiver Riddle /// Parse an arbitrary type.
50*c60b897dSRiver Riddle ///
51*c60b897dSRiver Riddle ///   type ::= function-type
52*c60b897dSRiver Riddle ///          | non-function-type
53*c60b897dSRiver Riddle ///
parseType()54*c60b897dSRiver Riddle Type Parser::parseType() {
55*c60b897dSRiver Riddle   if (getToken().is(Token::l_paren))
56*c60b897dSRiver Riddle     return parseFunctionType();
57*c60b897dSRiver Riddle   return parseNonFunctionType();
58*c60b897dSRiver Riddle }
59*c60b897dSRiver Riddle 
60*c60b897dSRiver Riddle /// Parse a function result type.
61*c60b897dSRiver Riddle ///
62*c60b897dSRiver Riddle ///   function-result-type ::= type-list-parens
63*c60b897dSRiver Riddle ///                          | non-function-type
64*c60b897dSRiver Riddle ///
parseFunctionResultTypes(SmallVectorImpl<Type> & elements)65*c60b897dSRiver Riddle ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
66*c60b897dSRiver Riddle   if (getToken().is(Token::l_paren))
67*c60b897dSRiver Riddle     return parseTypeListParens(elements);
68*c60b897dSRiver Riddle 
69*c60b897dSRiver Riddle   Type t = parseNonFunctionType();
70*c60b897dSRiver Riddle   if (!t)
71*c60b897dSRiver Riddle     return failure();
72*c60b897dSRiver Riddle   elements.push_back(t);
73*c60b897dSRiver Riddle   return success();
74*c60b897dSRiver Riddle }
75*c60b897dSRiver Riddle 
76*c60b897dSRiver Riddle /// Parse a list of types without an enclosing parenthesis.  The list must have
77*c60b897dSRiver Riddle /// at least one member.
78*c60b897dSRiver Riddle ///
79*c60b897dSRiver Riddle ///   type-list-no-parens ::=  type (`,` type)*
80*c60b897dSRiver Riddle ///
parseTypeListNoParens(SmallVectorImpl<Type> & elements)81*c60b897dSRiver Riddle ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
82*c60b897dSRiver Riddle   auto parseElt = [&]() -> ParseResult {
83*c60b897dSRiver Riddle     auto elt = parseType();
84*c60b897dSRiver Riddle     elements.push_back(elt);
85*c60b897dSRiver Riddle     return elt ? success() : failure();
86*c60b897dSRiver Riddle   };
87*c60b897dSRiver Riddle 
88*c60b897dSRiver Riddle   return parseCommaSeparatedList(parseElt);
89*c60b897dSRiver Riddle }
90*c60b897dSRiver Riddle 
91*c60b897dSRiver Riddle /// Parse a parenthesized list of types.
92*c60b897dSRiver Riddle ///
93*c60b897dSRiver Riddle ///   type-list-parens ::= `(` `)`
94*c60b897dSRiver Riddle ///                      | `(` type-list-no-parens `)`
95*c60b897dSRiver Riddle ///
parseTypeListParens(SmallVectorImpl<Type> & elements)96*c60b897dSRiver Riddle ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
97*c60b897dSRiver Riddle   if (parseToken(Token::l_paren, "expected '('"))
98*c60b897dSRiver Riddle     return failure();
99*c60b897dSRiver Riddle 
100*c60b897dSRiver Riddle   // Handle empty lists.
101*c60b897dSRiver Riddle   if (getToken().is(Token::r_paren))
102*c60b897dSRiver Riddle     return consumeToken(), success();
103*c60b897dSRiver Riddle 
104*c60b897dSRiver Riddle   if (parseTypeListNoParens(elements) ||
105*c60b897dSRiver Riddle       parseToken(Token::r_paren, "expected ')'"))
106*c60b897dSRiver Riddle     return failure();
107*c60b897dSRiver Riddle   return success();
108*c60b897dSRiver Riddle }
109*c60b897dSRiver Riddle 
110*c60b897dSRiver Riddle /// Parse a complex type.
111*c60b897dSRiver Riddle ///
112*c60b897dSRiver Riddle ///   complex-type ::= `complex` `<` type `>`
113*c60b897dSRiver Riddle ///
parseComplexType()114*c60b897dSRiver Riddle Type Parser::parseComplexType() {
115*c60b897dSRiver Riddle   consumeToken(Token::kw_complex);
116*c60b897dSRiver Riddle 
117*c60b897dSRiver Riddle   // Parse the '<'.
118*c60b897dSRiver Riddle   if (parseToken(Token::less, "expected '<' in complex type"))
119*c60b897dSRiver Riddle     return nullptr;
120*c60b897dSRiver Riddle 
121*c60b897dSRiver Riddle   SMLoc elementTypeLoc = getToken().getLoc();
122*c60b897dSRiver Riddle   auto elementType = parseType();
123*c60b897dSRiver Riddle   if (!elementType ||
124*c60b897dSRiver Riddle       parseToken(Token::greater, "expected '>' in complex type"))
125*c60b897dSRiver Riddle     return nullptr;
126*c60b897dSRiver Riddle   if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
127*c60b897dSRiver Riddle     return emitError(elementTypeLoc, "invalid element type for complex"),
128*c60b897dSRiver Riddle            nullptr;
129*c60b897dSRiver Riddle 
130*c60b897dSRiver Riddle   return ComplexType::get(elementType);
131*c60b897dSRiver Riddle }
132*c60b897dSRiver Riddle 
133*c60b897dSRiver Riddle /// Parse a function type.
134*c60b897dSRiver Riddle ///
135*c60b897dSRiver Riddle ///   function-type ::= type-list-parens `->` function-result-type
136*c60b897dSRiver Riddle ///
parseFunctionType()137*c60b897dSRiver Riddle Type Parser::parseFunctionType() {
138*c60b897dSRiver Riddle   assert(getToken().is(Token::l_paren));
139*c60b897dSRiver Riddle 
140*c60b897dSRiver Riddle   SmallVector<Type, 4> arguments, results;
141*c60b897dSRiver Riddle   if (parseTypeListParens(arguments) ||
142*c60b897dSRiver Riddle       parseToken(Token::arrow, "expected '->' in function type") ||
143*c60b897dSRiver Riddle       parseFunctionResultTypes(results))
144*c60b897dSRiver Riddle     return nullptr;
145*c60b897dSRiver Riddle 
146*c60b897dSRiver Riddle   return builder.getFunctionType(arguments, results);
147*c60b897dSRiver Riddle }
148*c60b897dSRiver Riddle 
149*c60b897dSRiver Riddle /// Parse the offset and strides from a strided layout specification.
150*c60b897dSRiver Riddle ///
151*c60b897dSRiver Riddle ///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
152*c60b897dSRiver Riddle ///
parseStridedLayout(int64_t & offset,SmallVectorImpl<int64_t> & strides)153*c60b897dSRiver Riddle ParseResult Parser::parseStridedLayout(int64_t &offset,
154*c60b897dSRiver Riddle                                        SmallVectorImpl<int64_t> &strides) {
155*c60b897dSRiver Riddle   // Parse offset.
156*c60b897dSRiver Riddle   consumeToken(Token::kw_offset);
157*c60b897dSRiver Riddle   if (parseToken(Token::colon, "expected colon after `offset` keyword"))
158*c60b897dSRiver Riddle     return failure();
159*c60b897dSRiver Riddle 
160*c60b897dSRiver Riddle   auto maybeOffset = getToken().getUnsignedIntegerValue();
161*c60b897dSRiver Riddle   bool question = getToken().is(Token::question);
162*c60b897dSRiver Riddle   if (!maybeOffset && !question)
163*c60b897dSRiver Riddle     return emitWrongTokenError("invalid offset");
164*c60b897dSRiver Riddle   offset = maybeOffset ? static_cast<int64_t>(*maybeOffset)
165*c60b897dSRiver Riddle                        : MemRefType::getDynamicStrideOrOffset();
166*c60b897dSRiver Riddle   consumeToken();
167*c60b897dSRiver Riddle 
168*c60b897dSRiver Riddle   // Parse stride list.
169*c60b897dSRiver Riddle   if (parseToken(Token::comma, "expected comma after offset value") ||
170*c60b897dSRiver Riddle       parseToken(Token::kw_strides,
171*c60b897dSRiver Riddle                  "expected `strides` keyword after offset specification") ||
172*c60b897dSRiver Riddle       parseToken(Token::colon, "expected colon after `strides` keyword") ||
173*c60b897dSRiver Riddle       parseStrideList(strides))
174*c60b897dSRiver Riddle     return failure();
175*c60b897dSRiver Riddle   return success();
176*c60b897dSRiver Riddle }
177*c60b897dSRiver Riddle 
178*c60b897dSRiver Riddle /// Parse a memref type.
179*c60b897dSRiver Riddle ///
180*c60b897dSRiver Riddle ///   memref-type ::= ranked-memref-type | unranked-memref-type
181*c60b897dSRiver Riddle ///
182*c60b897dSRiver Riddle ///   ranked-memref-type ::= `memref` `<` dimension-list-ranked type
183*c60b897dSRiver Riddle ///                          (`,` layout-specification)? (`,` memory-space)? `>`
184*c60b897dSRiver Riddle ///
185*c60b897dSRiver Riddle ///   unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
186*c60b897dSRiver Riddle ///
187*c60b897dSRiver Riddle ///   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
188*c60b897dSRiver Riddle ///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
189*c60b897dSRiver Riddle ///   layout-specification ::= semi-affine-map | strided-layout | attribute
190*c60b897dSRiver Riddle ///   memory-space ::= integer-literal | attribute
191*c60b897dSRiver Riddle ///
parseMemRefType()192*c60b897dSRiver Riddle Type Parser::parseMemRefType() {
193*c60b897dSRiver Riddle   SMLoc loc = getToken().getLoc();
194*c60b897dSRiver Riddle   consumeToken(Token::kw_memref);
195*c60b897dSRiver Riddle 
196*c60b897dSRiver Riddle   if (parseToken(Token::less, "expected '<' in memref type"))
197*c60b897dSRiver Riddle     return nullptr;
198*c60b897dSRiver Riddle 
199*c60b897dSRiver Riddle   bool isUnranked;
200*c60b897dSRiver Riddle   SmallVector<int64_t, 4> dimensions;
201*c60b897dSRiver Riddle 
202*c60b897dSRiver Riddle   if (consumeIf(Token::star)) {
203*c60b897dSRiver Riddle     // This is an unranked memref type.
204*c60b897dSRiver Riddle     isUnranked = true;
205*c60b897dSRiver Riddle     if (parseXInDimensionList())
206*c60b897dSRiver Riddle       return nullptr;
207*c60b897dSRiver Riddle 
208*c60b897dSRiver Riddle   } else {
209*c60b897dSRiver Riddle     isUnranked = false;
210*c60b897dSRiver Riddle     if (parseDimensionListRanked(dimensions))
211*c60b897dSRiver Riddle       return nullptr;
212*c60b897dSRiver Riddle   }
213*c60b897dSRiver Riddle 
214*c60b897dSRiver Riddle   // Parse the element type.
215*c60b897dSRiver Riddle   auto typeLoc = getToken().getLoc();
216*c60b897dSRiver Riddle   auto elementType = parseType();
217*c60b897dSRiver Riddle   if (!elementType)
218*c60b897dSRiver Riddle     return nullptr;
219*c60b897dSRiver Riddle 
220*c60b897dSRiver Riddle   // Check that memref is formed from allowed types.
221*c60b897dSRiver Riddle   if (!BaseMemRefType::isValidElementType(elementType))
222*c60b897dSRiver Riddle     return emitError(typeLoc, "invalid memref element type"), nullptr;
223*c60b897dSRiver Riddle 
224*c60b897dSRiver Riddle   MemRefLayoutAttrInterface layout;
225*c60b897dSRiver Riddle   Attribute memorySpace;
226*c60b897dSRiver Riddle 
227*c60b897dSRiver Riddle   auto parseElt = [&]() -> ParseResult {
228*c60b897dSRiver Riddle     // Check for AffineMap as offset/strides.
229*c60b897dSRiver Riddle     if (getToken().is(Token::kw_offset)) {
230*c60b897dSRiver Riddle       int64_t offset;
231*c60b897dSRiver Riddle       SmallVector<int64_t, 4> strides;
232*c60b897dSRiver Riddle       if (failed(parseStridedLayout(offset, strides)))
233*c60b897dSRiver Riddle         return failure();
234*c60b897dSRiver Riddle       // Construct strided affine map.
235*c60b897dSRiver Riddle       AffineMap map = makeStridedLinearLayoutMap(strides, offset, getContext());
236*c60b897dSRiver Riddle       layout = AffineMapAttr::get(map);
237*c60b897dSRiver Riddle     } else {
238*c60b897dSRiver Riddle       // Either it is MemRefLayoutAttrInterface or memory space attribute.
239*c60b897dSRiver Riddle       Attribute attr = parseAttribute();
240*c60b897dSRiver Riddle       if (!attr)
241*c60b897dSRiver Riddle         return failure();
242*c60b897dSRiver Riddle 
243*c60b897dSRiver Riddle       if (attr.isa<MemRefLayoutAttrInterface>()) {
244*c60b897dSRiver Riddle         layout = attr.cast<MemRefLayoutAttrInterface>();
245*c60b897dSRiver Riddle       } else if (memorySpace) {
246*c60b897dSRiver Riddle         return emitError("multiple memory spaces specified in memref type");
247*c60b897dSRiver Riddle       } else {
248*c60b897dSRiver Riddle         memorySpace = attr;
249*c60b897dSRiver Riddle         return success();
250*c60b897dSRiver Riddle       }
251*c60b897dSRiver Riddle     }
252*c60b897dSRiver Riddle 
253*c60b897dSRiver Riddle     if (isUnranked)
254*c60b897dSRiver Riddle       return emitError("cannot have affine map for unranked memref type");
255*c60b897dSRiver Riddle     if (memorySpace)
256*c60b897dSRiver Riddle       return emitError("expected memory space to be last in memref type");
257*c60b897dSRiver Riddle 
258*c60b897dSRiver Riddle     return success();
259*c60b897dSRiver Riddle   };
260*c60b897dSRiver Riddle 
261*c60b897dSRiver Riddle   // Parse a list of mappings and address space if present.
262*c60b897dSRiver Riddle   if (!consumeIf(Token::greater)) {
263*c60b897dSRiver Riddle     // Parse comma separated list of affine maps, followed by memory space.
264*c60b897dSRiver Riddle     if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
265*c60b897dSRiver Riddle         parseCommaSeparatedListUntil(Token::greater, parseElt,
266*c60b897dSRiver Riddle                                      /*allowEmptyList=*/false)) {
267*c60b897dSRiver Riddle       return nullptr;
268*c60b897dSRiver Riddle     }
269*c60b897dSRiver Riddle   }
270*c60b897dSRiver Riddle 
271*c60b897dSRiver Riddle   if (isUnranked)
272*c60b897dSRiver Riddle     return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
273*c60b897dSRiver Riddle 
274*c60b897dSRiver Riddle   return getChecked<MemRefType>(loc, dimensions, elementType, layout,
275*c60b897dSRiver Riddle                                 memorySpace);
276*c60b897dSRiver Riddle }
277*c60b897dSRiver Riddle 
278*c60b897dSRiver Riddle /// Parse any type except the function type.
279*c60b897dSRiver Riddle ///
280*c60b897dSRiver Riddle ///   non-function-type ::= integer-type
281*c60b897dSRiver Riddle ///                       | index-type
282*c60b897dSRiver Riddle ///                       | float-type
283*c60b897dSRiver Riddle ///                       | extended-type
284*c60b897dSRiver Riddle ///                       | vector-type
285*c60b897dSRiver Riddle ///                       | tensor-type
286*c60b897dSRiver Riddle ///                       | memref-type
287*c60b897dSRiver Riddle ///                       | complex-type
288*c60b897dSRiver Riddle ///                       | tuple-type
289*c60b897dSRiver Riddle ///                       | none-type
290*c60b897dSRiver Riddle ///
291*c60b897dSRiver Riddle ///   index-type ::= `index`
292*c60b897dSRiver Riddle ///   float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
293*c60b897dSRiver Riddle ///   none-type ::= `none`
294*c60b897dSRiver Riddle ///
parseNonFunctionType()295*c60b897dSRiver Riddle Type Parser::parseNonFunctionType() {
296*c60b897dSRiver Riddle   switch (getToken().getKind()) {
297*c60b897dSRiver Riddle   default:
298*c60b897dSRiver Riddle     return (emitWrongTokenError("expected non-function type"), nullptr);
299*c60b897dSRiver Riddle   case Token::kw_memref:
300*c60b897dSRiver Riddle     return parseMemRefType();
301*c60b897dSRiver Riddle   case Token::kw_tensor:
302*c60b897dSRiver Riddle     return parseTensorType();
303*c60b897dSRiver Riddle   case Token::kw_complex:
304*c60b897dSRiver Riddle     return parseComplexType();
305*c60b897dSRiver Riddle   case Token::kw_tuple:
306*c60b897dSRiver Riddle     return parseTupleType();
307*c60b897dSRiver Riddle   case Token::kw_vector:
308*c60b897dSRiver Riddle     return parseVectorType();
309*c60b897dSRiver Riddle   // integer-type
310*c60b897dSRiver Riddle   case Token::inttype: {
311*c60b897dSRiver Riddle     auto width = getToken().getIntTypeBitwidth();
312*c60b897dSRiver Riddle     if (!width.has_value())
313*c60b897dSRiver Riddle       return (emitError("invalid integer width"), nullptr);
314*c60b897dSRiver Riddle     if (width.value() > IntegerType::kMaxWidth) {
315*c60b897dSRiver Riddle       emitError(getToken().getLoc(), "integer bitwidth is limited to ")
316*c60b897dSRiver Riddle           << IntegerType::kMaxWidth << " bits";
317*c60b897dSRiver Riddle       return nullptr;
318*c60b897dSRiver Riddle     }
319*c60b897dSRiver Riddle 
320*c60b897dSRiver Riddle     IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
321*c60b897dSRiver Riddle     if (Optional<bool> signedness = getToken().getIntTypeSignedness())
322*c60b897dSRiver Riddle       signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
323*c60b897dSRiver Riddle 
324*c60b897dSRiver Riddle     consumeToken(Token::inttype);
325*c60b897dSRiver Riddle     return IntegerType::get(getContext(), *width, signSemantics);
326*c60b897dSRiver Riddle   }
327*c60b897dSRiver Riddle 
328*c60b897dSRiver Riddle   // float-type
329*c60b897dSRiver Riddle   case Token::kw_bf16:
330*c60b897dSRiver Riddle     consumeToken(Token::kw_bf16);
331*c60b897dSRiver Riddle     return builder.getBF16Type();
332*c60b897dSRiver Riddle   case Token::kw_f16:
333*c60b897dSRiver Riddle     consumeToken(Token::kw_f16);
334*c60b897dSRiver Riddle     return builder.getF16Type();
335*c60b897dSRiver Riddle   case Token::kw_f32:
336*c60b897dSRiver Riddle     consumeToken(Token::kw_f32);
337*c60b897dSRiver Riddle     return builder.getF32Type();
338*c60b897dSRiver Riddle   case Token::kw_f64:
339*c60b897dSRiver Riddle     consumeToken(Token::kw_f64);
340*c60b897dSRiver Riddle     return builder.getF64Type();
341*c60b897dSRiver Riddle   case Token::kw_f80:
342*c60b897dSRiver Riddle     consumeToken(Token::kw_f80);
343*c60b897dSRiver Riddle     return builder.getF80Type();
344*c60b897dSRiver Riddle   case Token::kw_f128:
345*c60b897dSRiver Riddle     consumeToken(Token::kw_f128);
346*c60b897dSRiver Riddle     return builder.getF128Type();
347*c60b897dSRiver Riddle 
348*c60b897dSRiver Riddle   // index-type
349*c60b897dSRiver Riddle   case Token::kw_index:
350*c60b897dSRiver Riddle     consumeToken(Token::kw_index);
351*c60b897dSRiver Riddle     return builder.getIndexType();
352*c60b897dSRiver Riddle 
353*c60b897dSRiver Riddle   // none-type
354*c60b897dSRiver Riddle   case Token::kw_none:
355*c60b897dSRiver Riddle     consumeToken(Token::kw_none);
356*c60b897dSRiver Riddle     return builder.getNoneType();
357*c60b897dSRiver Riddle 
358*c60b897dSRiver Riddle   // extended type
359*c60b897dSRiver Riddle   case Token::exclamation_identifier:
360*c60b897dSRiver Riddle     return parseExtendedType();
361*c60b897dSRiver Riddle 
362*c60b897dSRiver Riddle   // Handle completion of a dialect type.
363*c60b897dSRiver Riddle   case Token::code_complete:
364*c60b897dSRiver Riddle     if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
365*c60b897dSRiver Riddle       return parseExtendedType();
366*c60b897dSRiver Riddle     return codeCompleteType();
367*c60b897dSRiver Riddle   }
368*c60b897dSRiver Riddle }
369*c60b897dSRiver Riddle 
370*c60b897dSRiver Riddle /// Parse a tensor type.
371*c60b897dSRiver Riddle ///
372*c60b897dSRiver Riddle ///   tensor-type ::= `tensor` `<` dimension-list type `>`
373*c60b897dSRiver Riddle ///   dimension-list ::= dimension-list-ranked | `*x`
374*c60b897dSRiver Riddle ///
parseTensorType()375*c60b897dSRiver Riddle Type Parser::parseTensorType() {
376*c60b897dSRiver Riddle   consumeToken(Token::kw_tensor);
377*c60b897dSRiver Riddle 
378*c60b897dSRiver Riddle   if (parseToken(Token::less, "expected '<' in tensor type"))
379*c60b897dSRiver Riddle     return nullptr;
380*c60b897dSRiver Riddle 
381*c60b897dSRiver Riddle   bool isUnranked;
382*c60b897dSRiver Riddle   SmallVector<int64_t, 4> dimensions;
383*c60b897dSRiver Riddle 
384*c60b897dSRiver Riddle   if (consumeIf(Token::star)) {
385*c60b897dSRiver Riddle     // This is an unranked tensor type.
386*c60b897dSRiver Riddle     isUnranked = true;
387*c60b897dSRiver Riddle 
388*c60b897dSRiver Riddle     if (parseXInDimensionList())
389*c60b897dSRiver Riddle       return nullptr;
390*c60b897dSRiver Riddle 
391*c60b897dSRiver Riddle   } else {
392*c60b897dSRiver Riddle     isUnranked = false;
393*c60b897dSRiver Riddle     if (parseDimensionListRanked(dimensions))
394*c60b897dSRiver Riddle       return nullptr;
395*c60b897dSRiver Riddle   }
396*c60b897dSRiver Riddle 
397*c60b897dSRiver Riddle   // Parse the element type.
398*c60b897dSRiver Riddle   auto elementTypeLoc = getToken().getLoc();
399*c60b897dSRiver Riddle   auto elementType = parseType();
400*c60b897dSRiver Riddle 
401*c60b897dSRiver Riddle   // Parse an optional encoding attribute.
402*c60b897dSRiver Riddle   Attribute encoding;
403*c60b897dSRiver Riddle   if (consumeIf(Token::comma)) {
404*c60b897dSRiver Riddle     encoding = parseAttribute();
405*c60b897dSRiver Riddle     if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
406*c60b897dSRiver Riddle       if (failed(v.verifyEncoding(dimensions, elementType,
407*c60b897dSRiver Riddle                                   [&] { return emitError(); })))
408*c60b897dSRiver Riddle         return nullptr;
409*c60b897dSRiver Riddle     }
410*c60b897dSRiver Riddle   }
411*c60b897dSRiver Riddle 
412*c60b897dSRiver Riddle   if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
413*c60b897dSRiver Riddle     return nullptr;
414*c60b897dSRiver Riddle   if (!TensorType::isValidElementType(elementType))
415*c60b897dSRiver Riddle     return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
416*c60b897dSRiver Riddle 
417*c60b897dSRiver Riddle   if (isUnranked) {
418*c60b897dSRiver Riddle     if (encoding)
419*c60b897dSRiver Riddle       return emitError("cannot apply encoding to unranked tensor"), nullptr;
420*c60b897dSRiver Riddle     return UnrankedTensorType::get(elementType);
421*c60b897dSRiver Riddle   }
422*c60b897dSRiver Riddle   return RankedTensorType::get(dimensions, elementType, encoding);
423*c60b897dSRiver Riddle }
424*c60b897dSRiver Riddle 
425*c60b897dSRiver Riddle /// Parse a tuple type.
426*c60b897dSRiver Riddle ///
427*c60b897dSRiver Riddle ///   tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
428*c60b897dSRiver Riddle ///
parseTupleType()429*c60b897dSRiver Riddle Type Parser::parseTupleType() {
430*c60b897dSRiver Riddle   consumeToken(Token::kw_tuple);
431*c60b897dSRiver Riddle 
432*c60b897dSRiver Riddle   // Parse the '<'.
433*c60b897dSRiver Riddle   if (parseToken(Token::less, "expected '<' in tuple type"))
434*c60b897dSRiver Riddle     return nullptr;
435*c60b897dSRiver Riddle 
436*c60b897dSRiver Riddle   // Check for an empty tuple by directly parsing '>'.
437*c60b897dSRiver Riddle   if (consumeIf(Token::greater))
438*c60b897dSRiver Riddle     return TupleType::get(getContext());
439*c60b897dSRiver Riddle 
440*c60b897dSRiver Riddle   // Parse the element types and the '>'.
441*c60b897dSRiver Riddle   SmallVector<Type, 4> types;
442*c60b897dSRiver Riddle   if (parseTypeListNoParens(types) ||
443*c60b897dSRiver Riddle       parseToken(Token::greater, "expected '>' in tuple type"))
444*c60b897dSRiver Riddle     return nullptr;
445*c60b897dSRiver Riddle 
446*c60b897dSRiver Riddle   return TupleType::get(getContext(), types);
447*c60b897dSRiver Riddle }
448*c60b897dSRiver Riddle 
449*c60b897dSRiver Riddle /// Parse a vector type.
450*c60b897dSRiver Riddle ///
451*c60b897dSRiver Riddle /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
452*c60b897dSRiver Riddle /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
453*c60b897dSRiver Riddle /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
454*c60b897dSRiver Riddle ///
parseVectorType()455*c60b897dSRiver Riddle VectorType Parser::parseVectorType() {
456*c60b897dSRiver Riddle   consumeToken(Token::kw_vector);
457*c60b897dSRiver Riddle 
458*c60b897dSRiver Riddle   if (parseToken(Token::less, "expected '<' in vector type"))
459*c60b897dSRiver Riddle     return nullptr;
460*c60b897dSRiver Riddle 
461*c60b897dSRiver Riddle   SmallVector<int64_t, 4> dimensions;
462*c60b897dSRiver Riddle   unsigned numScalableDims;
463*c60b897dSRiver Riddle   if (parseVectorDimensionList(dimensions, numScalableDims))
464*c60b897dSRiver Riddle     return nullptr;
465*c60b897dSRiver Riddle   if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
466*c60b897dSRiver Riddle     return emitError(getToken().getLoc(),
467*c60b897dSRiver Riddle                      "vector types must have positive constant sizes"),
468*c60b897dSRiver Riddle            nullptr;
469*c60b897dSRiver Riddle 
470*c60b897dSRiver Riddle   // Parse the element type.
471*c60b897dSRiver Riddle   auto typeLoc = getToken().getLoc();
472*c60b897dSRiver Riddle   auto elementType = parseType();
473*c60b897dSRiver Riddle   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
474*c60b897dSRiver Riddle     return nullptr;
475*c60b897dSRiver Riddle 
476*c60b897dSRiver Riddle   if (!VectorType::isValidElementType(elementType))
477*c60b897dSRiver Riddle     return emitError(typeLoc, "vector elements must be int/index/float type"),
478*c60b897dSRiver Riddle            nullptr;
479*c60b897dSRiver Riddle 
480*c60b897dSRiver Riddle   return VectorType::get(dimensions, elementType, numScalableDims);
481*c60b897dSRiver Riddle }
482*c60b897dSRiver Riddle 
483*c60b897dSRiver Riddle /// Parse a dimension list in a vector type. This populates the dimension list,
484*c60b897dSRiver Riddle /// and returns the number of scalable dimensions in `numScalableDims`.
485*c60b897dSRiver Riddle ///
486*c60b897dSRiver Riddle /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
487*c60b897dSRiver Riddle /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
488*c60b897dSRiver Riddle ///
489*c60b897dSRiver Riddle ParseResult
parseVectorDimensionList(SmallVectorImpl<int64_t> & dimensions,unsigned & numScalableDims)490*c60b897dSRiver Riddle Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
491*c60b897dSRiver Riddle                                  unsigned &numScalableDims) {
492*c60b897dSRiver Riddle   numScalableDims = 0;
493*c60b897dSRiver Riddle   // If there is a set of fixed-length dimensions, consume it
494*c60b897dSRiver Riddle   while (getToken().is(Token::integer)) {
495*c60b897dSRiver Riddle     int64_t value;
496*c60b897dSRiver Riddle     if (parseIntegerInDimensionList(value))
497*c60b897dSRiver Riddle       return failure();
498*c60b897dSRiver Riddle     dimensions.push_back(value);
499*c60b897dSRiver Riddle     // Make sure we have an 'x' or something like 'xbf32'.
500*c60b897dSRiver Riddle     if (parseXInDimensionList())
501*c60b897dSRiver Riddle       return failure();
502*c60b897dSRiver Riddle   }
503*c60b897dSRiver Riddle   // If there is a set of scalable dimensions, consume it
504*c60b897dSRiver Riddle   if (consumeIf(Token::l_square)) {
505*c60b897dSRiver Riddle     while (getToken().is(Token::integer)) {
506*c60b897dSRiver Riddle       int64_t value;
507*c60b897dSRiver Riddle       if (parseIntegerInDimensionList(value))
508*c60b897dSRiver Riddle         return failure();
509*c60b897dSRiver Riddle       dimensions.push_back(value);
510*c60b897dSRiver Riddle       numScalableDims++;
511*c60b897dSRiver Riddle       // Check if we have reached the end of the scalable dimension list
512*c60b897dSRiver Riddle       if (consumeIf(Token::r_square)) {
513*c60b897dSRiver Riddle         // Make sure we have something like 'xbf32'.
514*c60b897dSRiver Riddle         return parseXInDimensionList();
515*c60b897dSRiver Riddle       }
516*c60b897dSRiver Riddle       // Make sure we have an 'x'
517*c60b897dSRiver Riddle       if (parseXInDimensionList())
518*c60b897dSRiver Riddle         return failure();
519*c60b897dSRiver Riddle     }
520*c60b897dSRiver Riddle     // If we make it here, we've finished parsing the dimension list
521*c60b897dSRiver Riddle     // without finding ']' closing the set of scalable dimensions
522*c60b897dSRiver Riddle     return emitWrongTokenError(
523*c60b897dSRiver Riddle         "missing ']' closing set of scalable dimensions");
524*c60b897dSRiver Riddle   }
525*c60b897dSRiver Riddle 
526*c60b897dSRiver Riddle   return success();
527*c60b897dSRiver Riddle }
528*c60b897dSRiver Riddle 
529*c60b897dSRiver Riddle /// Parse a dimension list of a tensor or memref type.  This populates the
530*c60b897dSRiver Riddle /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
531*c60b897dSRiver Riddle /// errors out on `?` otherwise. Parsing the trailing `x` is configurable.
532*c60b897dSRiver Riddle ///
533*c60b897dSRiver Riddle ///   dimension-list ::= eps | dimension (`x` dimension)*
534*c60b897dSRiver Riddle ///   dimension-list-with-trailing-x ::= (dimension `x`)*
535*c60b897dSRiver Riddle ///   dimension ::= `?` | decimal-literal
536*c60b897dSRiver Riddle ///
537*c60b897dSRiver Riddle /// When `allowDynamic` is not set, this is used to parse:
538*c60b897dSRiver Riddle ///
539*c60b897dSRiver Riddle ///   static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
540*c60b897dSRiver Riddle ///   static-dimension-list-with-trailing-x ::= (dimension `x`)*
541*c60b897dSRiver Riddle ParseResult
parseDimensionListRanked(SmallVectorImpl<int64_t> & dimensions,bool allowDynamic,bool withTrailingX)542*c60b897dSRiver Riddle Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
543*c60b897dSRiver Riddle                                  bool allowDynamic, bool withTrailingX) {
544*c60b897dSRiver Riddle   auto parseDim = [&]() -> LogicalResult {
545*c60b897dSRiver Riddle     auto loc = getToken().getLoc();
546*c60b897dSRiver Riddle     if (consumeIf(Token::question)) {
547*c60b897dSRiver Riddle       if (!allowDynamic)
548*c60b897dSRiver Riddle         return emitError(loc, "expected static shape");
549*c60b897dSRiver Riddle       dimensions.push_back(-1);
550*c60b897dSRiver Riddle     } else {
551*c60b897dSRiver Riddle       int64_t value;
552*c60b897dSRiver Riddle       if (failed(parseIntegerInDimensionList(value)))
553*c60b897dSRiver Riddle         return failure();
554*c60b897dSRiver Riddle       dimensions.push_back(value);
555*c60b897dSRiver Riddle     }
556*c60b897dSRiver Riddle     return success();
557*c60b897dSRiver Riddle   };
558*c60b897dSRiver Riddle 
559*c60b897dSRiver Riddle   if (withTrailingX) {
560*c60b897dSRiver Riddle     while (getToken().isAny(Token::integer, Token::question)) {
561*c60b897dSRiver Riddle       if (failed(parseDim()) || failed(parseXInDimensionList()))
562*c60b897dSRiver Riddle         return failure();
563*c60b897dSRiver Riddle     }
564*c60b897dSRiver Riddle     return success();
565*c60b897dSRiver Riddle   }
566*c60b897dSRiver Riddle 
567*c60b897dSRiver Riddle   if (getToken().isAny(Token::integer, Token::question)) {
568*c60b897dSRiver Riddle     if (failed(parseDim()))
569*c60b897dSRiver Riddle       return failure();
570*c60b897dSRiver Riddle     while (getToken().is(Token::bare_identifier) &&
571*c60b897dSRiver Riddle            getTokenSpelling()[0] == 'x') {
572*c60b897dSRiver Riddle       if (failed(parseXInDimensionList()) || failed(parseDim()))
573*c60b897dSRiver Riddle         return failure();
574*c60b897dSRiver Riddle     }
575*c60b897dSRiver Riddle   }
576*c60b897dSRiver Riddle   return success();
577*c60b897dSRiver Riddle }
578*c60b897dSRiver Riddle 
parseIntegerInDimensionList(int64_t & value)579*c60b897dSRiver Riddle ParseResult Parser::parseIntegerInDimensionList(int64_t &value) {
580*c60b897dSRiver Riddle   // Hexadecimal integer literals (starting with `0x`) are not allowed in
581*c60b897dSRiver Riddle   // aggregate type declarations.  Therefore, `0xf32` should be processed as
582*c60b897dSRiver Riddle   // a sequence of separate elements `0`, `x`, `f32`.
583*c60b897dSRiver Riddle   if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
584*c60b897dSRiver Riddle     // We can get here only if the token is an integer literal.  Hexadecimal
585*c60b897dSRiver Riddle     // integer literals can only start with `0x` (`1x` wouldn't lex as a
586*c60b897dSRiver Riddle     // literal, just `1` would, at which point we don't get into this
587*c60b897dSRiver Riddle     // branch).
588*c60b897dSRiver Riddle     assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
589*c60b897dSRiver Riddle     value = 0;
590*c60b897dSRiver Riddle     state.lex.resetPointer(getTokenSpelling().data() + 1);
591*c60b897dSRiver Riddle     consumeToken();
592*c60b897dSRiver Riddle   } else {
593*c60b897dSRiver Riddle     // Make sure this integer value is in bound and valid.
594*c60b897dSRiver Riddle     Optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
595*c60b897dSRiver Riddle     if (!dimension ||
596*c60b897dSRiver Riddle         *dimension > (uint64_t)std::numeric_limits<int64_t>::max())
597*c60b897dSRiver Riddle       return emitError("invalid dimension");
598*c60b897dSRiver Riddle     value = (int64_t)*dimension;
599*c60b897dSRiver Riddle     consumeToken(Token::integer);
600*c60b897dSRiver Riddle   }
601*c60b897dSRiver Riddle   return success();
602*c60b897dSRiver Riddle }
603*c60b897dSRiver Riddle 
604*c60b897dSRiver Riddle /// Parse an 'x' token in a dimension list, handling the case where the x is
605*c60b897dSRiver Riddle /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
606*c60b897dSRiver Riddle /// token.
parseXInDimensionList()607*c60b897dSRiver Riddle ParseResult Parser::parseXInDimensionList() {
608*c60b897dSRiver Riddle   if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
609*c60b897dSRiver Riddle     return emitWrongTokenError("expected 'x' in dimension list");
610*c60b897dSRiver Riddle 
611*c60b897dSRiver Riddle   // If we had a prefix of 'x', lex the next token immediately after the 'x'.
612*c60b897dSRiver Riddle   if (getTokenSpelling().size() != 1)
613*c60b897dSRiver Riddle     state.lex.resetPointer(getTokenSpelling().data() + 1);
614*c60b897dSRiver Riddle 
615*c60b897dSRiver Riddle   // Consume the 'x'.
616*c60b897dSRiver Riddle   consumeToken(Token::bare_identifier);
617*c60b897dSRiver Riddle 
618*c60b897dSRiver Riddle   return success();
619*c60b897dSRiver Riddle }
620*c60b897dSRiver Riddle 
621*c60b897dSRiver Riddle // Parse a comma-separated list of dimensions, possibly empty:
622*c60b897dSRiver Riddle //   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
parseStrideList(SmallVectorImpl<int64_t> & dimensions)623*c60b897dSRiver Riddle ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
624*c60b897dSRiver Riddle   return parseCommaSeparatedList(
625*c60b897dSRiver Riddle       Delimiter::Square,
626*c60b897dSRiver Riddle       [&]() -> ParseResult {
627*c60b897dSRiver Riddle         if (consumeIf(Token::question)) {
628*c60b897dSRiver Riddle           dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
629*c60b897dSRiver Riddle         } else {
630*c60b897dSRiver Riddle           // This must be an integer value.
631*c60b897dSRiver Riddle           int64_t val;
632*c60b897dSRiver Riddle           if (getToken().getSpelling().getAsInteger(10, val))
633*c60b897dSRiver Riddle             return emitError("invalid integer value: ")
634*c60b897dSRiver Riddle                    << getToken().getSpelling();
635*c60b897dSRiver Riddle           // Make sure it is not the one value for `?`.
636*c60b897dSRiver Riddle           if (ShapedType::isDynamic(val))
637*c60b897dSRiver Riddle             return emitError("invalid integer value: ")
638*c60b897dSRiver Riddle                    << getToken().getSpelling()
639*c60b897dSRiver Riddle                    << ", use `?` to specify a dynamic dimension";
640*c60b897dSRiver Riddle 
641*c60b897dSRiver Riddle           if (val == 0)
642*c60b897dSRiver Riddle             return emitError("invalid memref stride");
643*c60b897dSRiver Riddle 
644*c60b897dSRiver Riddle           dimensions.push_back(val);
645*c60b897dSRiver Riddle           consumeToken(Token::integer);
646*c60b897dSRiver Riddle         }
647*c60b897dSRiver Riddle         return success();
648*c60b897dSRiver Riddle       },
649*c60b897dSRiver Riddle       " in stride list");
650*c60b897dSRiver Riddle }
651