1*c60b897dSRiver Riddle //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2*c60b897dSRiver Riddle //
3*c60b897dSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*c60b897dSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*c60b897dSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*c60b897dSRiver Riddle //
7*c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
8*c60b897dSRiver Riddle //
9*c60b897dSRiver Riddle // This file implements the parser for the MLIR Types.
10*c60b897dSRiver Riddle //
11*c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
12*c60b897dSRiver Riddle
13*c60b897dSRiver Riddle #include "Parser.h"
14*c60b897dSRiver Riddle #include "mlir/IR/AffineMap.h"
15*c60b897dSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
16*c60b897dSRiver Riddle #include "mlir/IR/OpDefinition.h"
17*c60b897dSRiver Riddle #include "mlir/IR/TensorEncoding.h"
18*c60b897dSRiver Riddle
19*c60b897dSRiver Riddle using namespace mlir;
20*c60b897dSRiver Riddle using namespace mlir::detail;
21*c60b897dSRiver Riddle
22*c60b897dSRiver Riddle /// Optionally parse a type.
parseOptionalType(Type & type)23*c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalType(Type &type) {
24*c60b897dSRiver Riddle // There are many different starting tokens for a type, check them here.
25*c60b897dSRiver Riddle switch (getToken().getKind()) {
26*c60b897dSRiver Riddle case Token::l_paren:
27*c60b897dSRiver Riddle case Token::kw_memref:
28*c60b897dSRiver Riddle case Token::kw_tensor:
29*c60b897dSRiver Riddle case Token::kw_complex:
30*c60b897dSRiver Riddle case Token::kw_tuple:
31*c60b897dSRiver Riddle case Token::kw_vector:
32*c60b897dSRiver Riddle case Token::inttype:
33*c60b897dSRiver Riddle case Token::kw_bf16:
34*c60b897dSRiver Riddle case Token::kw_f16:
35*c60b897dSRiver Riddle case Token::kw_f32:
36*c60b897dSRiver Riddle case Token::kw_f64:
37*c60b897dSRiver Riddle case Token::kw_f80:
38*c60b897dSRiver Riddle case Token::kw_f128:
39*c60b897dSRiver Riddle case Token::kw_index:
40*c60b897dSRiver Riddle case Token::kw_none:
41*c60b897dSRiver Riddle case Token::exclamation_identifier:
42*c60b897dSRiver Riddle return failure(!(type = parseType()));
43*c60b897dSRiver Riddle
44*c60b897dSRiver Riddle default:
45*c60b897dSRiver Riddle return llvm::None;
46*c60b897dSRiver Riddle }
47*c60b897dSRiver Riddle }
48*c60b897dSRiver Riddle
49*c60b897dSRiver Riddle /// Parse an arbitrary type.
50*c60b897dSRiver Riddle ///
51*c60b897dSRiver Riddle /// type ::= function-type
52*c60b897dSRiver Riddle /// | non-function-type
53*c60b897dSRiver Riddle ///
parseType()54*c60b897dSRiver Riddle Type Parser::parseType() {
55*c60b897dSRiver Riddle if (getToken().is(Token::l_paren))
56*c60b897dSRiver Riddle return parseFunctionType();
57*c60b897dSRiver Riddle return parseNonFunctionType();
58*c60b897dSRiver Riddle }
59*c60b897dSRiver Riddle
60*c60b897dSRiver Riddle /// Parse a function result type.
61*c60b897dSRiver Riddle ///
62*c60b897dSRiver Riddle /// function-result-type ::= type-list-parens
63*c60b897dSRiver Riddle /// | non-function-type
64*c60b897dSRiver Riddle ///
parseFunctionResultTypes(SmallVectorImpl<Type> & elements)65*c60b897dSRiver Riddle ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
66*c60b897dSRiver Riddle if (getToken().is(Token::l_paren))
67*c60b897dSRiver Riddle return parseTypeListParens(elements);
68*c60b897dSRiver Riddle
69*c60b897dSRiver Riddle Type t = parseNonFunctionType();
70*c60b897dSRiver Riddle if (!t)
71*c60b897dSRiver Riddle return failure();
72*c60b897dSRiver Riddle elements.push_back(t);
73*c60b897dSRiver Riddle return success();
74*c60b897dSRiver Riddle }
75*c60b897dSRiver Riddle
76*c60b897dSRiver Riddle /// Parse a list of types without an enclosing parenthesis. The list must have
77*c60b897dSRiver Riddle /// at least one member.
78*c60b897dSRiver Riddle ///
79*c60b897dSRiver Riddle /// type-list-no-parens ::= type (`,` type)*
80*c60b897dSRiver Riddle ///
parseTypeListNoParens(SmallVectorImpl<Type> & elements)81*c60b897dSRiver Riddle ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
82*c60b897dSRiver Riddle auto parseElt = [&]() -> ParseResult {
83*c60b897dSRiver Riddle auto elt = parseType();
84*c60b897dSRiver Riddle elements.push_back(elt);
85*c60b897dSRiver Riddle return elt ? success() : failure();
86*c60b897dSRiver Riddle };
87*c60b897dSRiver Riddle
88*c60b897dSRiver Riddle return parseCommaSeparatedList(parseElt);
89*c60b897dSRiver Riddle }
90*c60b897dSRiver Riddle
91*c60b897dSRiver Riddle /// Parse a parenthesized list of types.
92*c60b897dSRiver Riddle ///
93*c60b897dSRiver Riddle /// type-list-parens ::= `(` `)`
94*c60b897dSRiver Riddle /// | `(` type-list-no-parens `)`
95*c60b897dSRiver Riddle ///
parseTypeListParens(SmallVectorImpl<Type> & elements)96*c60b897dSRiver Riddle ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
97*c60b897dSRiver Riddle if (parseToken(Token::l_paren, "expected '('"))
98*c60b897dSRiver Riddle return failure();
99*c60b897dSRiver Riddle
100*c60b897dSRiver Riddle // Handle empty lists.
101*c60b897dSRiver Riddle if (getToken().is(Token::r_paren))
102*c60b897dSRiver Riddle return consumeToken(), success();
103*c60b897dSRiver Riddle
104*c60b897dSRiver Riddle if (parseTypeListNoParens(elements) ||
105*c60b897dSRiver Riddle parseToken(Token::r_paren, "expected ')'"))
106*c60b897dSRiver Riddle return failure();
107*c60b897dSRiver Riddle return success();
108*c60b897dSRiver Riddle }
109*c60b897dSRiver Riddle
110*c60b897dSRiver Riddle /// Parse a complex type.
111*c60b897dSRiver Riddle ///
112*c60b897dSRiver Riddle /// complex-type ::= `complex` `<` type `>`
113*c60b897dSRiver Riddle ///
parseComplexType()114*c60b897dSRiver Riddle Type Parser::parseComplexType() {
115*c60b897dSRiver Riddle consumeToken(Token::kw_complex);
116*c60b897dSRiver Riddle
117*c60b897dSRiver Riddle // Parse the '<'.
118*c60b897dSRiver Riddle if (parseToken(Token::less, "expected '<' in complex type"))
119*c60b897dSRiver Riddle return nullptr;
120*c60b897dSRiver Riddle
121*c60b897dSRiver Riddle SMLoc elementTypeLoc = getToken().getLoc();
122*c60b897dSRiver Riddle auto elementType = parseType();
123*c60b897dSRiver Riddle if (!elementType ||
124*c60b897dSRiver Riddle parseToken(Token::greater, "expected '>' in complex type"))
125*c60b897dSRiver Riddle return nullptr;
126*c60b897dSRiver Riddle if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
127*c60b897dSRiver Riddle return emitError(elementTypeLoc, "invalid element type for complex"),
128*c60b897dSRiver Riddle nullptr;
129*c60b897dSRiver Riddle
130*c60b897dSRiver Riddle return ComplexType::get(elementType);
131*c60b897dSRiver Riddle }
132*c60b897dSRiver Riddle
133*c60b897dSRiver Riddle /// Parse a function type.
134*c60b897dSRiver Riddle ///
135*c60b897dSRiver Riddle /// function-type ::= type-list-parens `->` function-result-type
136*c60b897dSRiver Riddle ///
parseFunctionType()137*c60b897dSRiver Riddle Type Parser::parseFunctionType() {
138*c60b897dSRiver Riddle assert(getToken().is(Token::l_paren));
139*c60b897dSRiver Riddle
140*c60b897dSRiver Riddle SmallVector<Type, 4> arguments, results;
141*c60b897dSRiver Riddle if (parseTypeListParens(arguments) ||
142*c60b897dSRiver Riddle parseToken(Token::arrow, "expected '->' in function type") ||
143*c60b897dSRiver Riddle parseFunctionResultTypes(results))
144*c60b897dSRiver Riddle return nullptr;
145*c60b897dSRiver Riddle
146*c60b897dSRiver Riddle return builder.getFunctionType(arguments, results);
147*c60b897dSRiver Riddle }
148*c60b897dSRiver Riddle
149*c60b897dSRiver Riddle /// Parse the offset and strides from a strided layout specification.
150*c60b897dSRiver Riddle ///
151*c60b897dSRiver Riddle /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
152*c60b897dSRiver Riddle ///
parseStridedLayout(int64_t & offset,SmallVectorImpl<int64_t> & strides)153*c60b897dSRiver Riddle ParseResult Parser::parseStridedLayout(int64_t &offset,
154*c60b897dSRiver Riddle SmallVectorImpl<int64_t> &strides) {
155*c60b897dSRiver Riddle // Parse offset.
156*c60b897dSRiver Riddle consumeToken(Token::kw_offset);
157*c60b897dSRiver Riddle if (parseToken(Token::colon, "expected colon after `offset` keyword"))
158*c60b897dSRiver Riddle return failure();
159*c60b897dSRiver Riddle
160*c60b897dSRiver Riddle auto maybeOffset = getToken().getUnsignedIntegerValue();
161*c60b897dSRiver Riddle bool question = getToken().is(Token::question);
162*c60b897dSRiver Riddle if (!maybeOffset && !question)
163*c60b897dSRiver Riddle return emitWrongTokenError("invalid offset");
164*c60b897dSRiver Riddle offset = maybeOffset ? static_cast<int64_t>(*maybeOffset)
165*c60b897dSRiver Riddle : MemRefType::getDynamicStrideOrOffset();
166*c60b897dSRiver Riddle consumeToken();
167*c60b897dSRiver Riddle
168*c60b897dSRiver Riddle // Parse stride list.
169*c60b897dSRiver Riddle if (parseToken(Token::comma, "expected comma after offset value") ||
170*c60b897dSRiver Riddle parseToken(Token::kw_strides,
171*c60b897dSRiver Riddle "expected `strides` keyword after offset specification") ||
172*c60b897dSRiver Riddle parseToken(Token::colon, "expected colon after `strides` keyword") ||
173*c60b897dSRiver Riddle parseStrideList(strides))
174*c60b897dSRiver Riddle return failure();
175*c60b897dSRiver Riddle return success();
176*c60b897dSRiver Riddle }
177*c60b897dSRiver Riddle
178*c60b897dSRiver Riddle /// Parse a memref type.
179*c60b897dSRiver Riddle ///
180*c60b897dSRiver Riddle /// memref-type ::= ranked-memref-type | unranked-memref-type
181*c60b897dSRiver Riddle ///
182*c60b897dSRiver Riddle /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
183*c60b897dSRiver Riddle /// (`,` layout-specification)? (`,` memory-space)? `>`
184*c60b897dSRiver Riddle ///
185*c60b897dSRiver Riddle /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
186*c60b897dSRiver Riddle ///
187*c60b897dSRiver Riddle /// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
188*c60b897dSRiver Riddle /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
189*c60b897dSRiver Riddle /// layout-specification ::= semi-affine-map | strided-layout | attribute
190*c60b897dSRiver Riddle /// memory-space ::= integer-literal | attribute
191*c60b897dSRiver Riddle ///
parseMemRefType()192*c60b897dSRiver Riddle Type Parser::parseMemRefType() {
193*c60b897dSRiver Riddle SMLoc loc = getToken().getLoc();
194*c60b897dSRiver Riddle consumeToken(Token::kw_memref);
195*c60b897dSRiver Riddle
196*c60b897dSRiver Riddle if (parseToken(Token::less, "expected '<' in memref type"))
197*c60b897dSRiver Riddle return nullptr;
198*c60b897dSRiver Riddle
199*c60b897dSRiver Riddle bool isUnranked;
200*c60b897dSRiver Riddle SmallVector<int64_t, 4> dimensions;
201*c60b897dSRiver Riddle
202*c60b897dSRiver Riddle if (consumeIf(Token::star)) {
203*c60b897dSRiver Riddle // This is an unranked memref type.
204*c60b897dSRiver Riddle isUnranked = true;
205*c60b897dSRiver Riddle if (parseXInDimensionList())
206*c60b897dSRiver Riddle return nullptr;
207*c60b897dSRiver Riddle
208*c60b897dSRiver Riddle } else {
209*c60b897dSRiver Riddle isUnranked = false;
210*c60b897dSRiver Riddle if (parseDimensionListRanked(dimensions))
211*c60b897dSRiver Riddle return nullptr;
212*c60b897dSRiver Riddle }
213*c60b897dSRiver Riddle
214*c60b897dSRiver Riddle // Parse the element type.
215*c60b897dSRiver Riddle auto typeLoc = getToken().getLoc();
216*c60b897dSRiver Riddle auto elementType = parseType();
217*c60b897dSRiver Riddle if (!elementType)
218*c60b897dSRiver Riddle return nullptr;
219*c60b897dSRiver Riddle
220*c60b897dSRiver Riddle // Check that memref is formed from allowed types.
221*c60b897dSRiver Riddle if (!BaseMemRefType::isValidElementType(elementType))
222*c60b897dSRiver Riddle return emitError(typeLoc, "invalid memref element type"), nullptr;
223*c60b897dSRiver Riddle
224*c60b897dSRiver Riddle MemRefLayoutAttrInterface layout;
225*c60b897dSRiver Riddle Attribute memorySpace;
226*c60b897dSRiver Riddle
227*c60b897dSRiver Riddle auto parseElt = [&]() -> ParseResult {
228*c60b897dSRiver Riddle // Check for AffineMap as offset/strides.
229*c60b897dSRiver Riddle if (getToken().is(Token::kw_offset)) {
230*c60b897dSRiver Riddle int64_t offset;
231*c60b897dSRiver Riddle SmallVector<int64_t, 4> strides;
232*c60b897dSRiver Riddle if (failed(parseStridedLayout(offset, strides)))
233*c60b897dSRiver Riddle return failure();
234*c60b897dSRiver Riddle // Construct strided affine map.
235*c60b897dSRiver Riddle AffineMap map = makeStridedLinearLayoutMap(strides, offset, getContext());
236*c60b897dSRiver Riddle layout = AffineMapAttr::get(map);
237*c60b897dSRiver Riddle } else {
238*c60b897dSRiver Riddle // Either it is MemRefLayoutAttrInterface or memory space attribute.
239*c60b897dSRiver Riddle Attribute attr = parseAttribute();
240*c60b897dSRiver Riddle if (!attr)
241*c60b897dSRiver Riddle return failure();
242*c60b897dSRiver Riddle
243*c60b897dSRiver Riddle if (attr.isa<MemRefLayoutAttrInterface>()) {
244*c60b897dSRiver Riddle layout = attr.cast<MemRefLayoutAttrInterface>();
245*c60b897dSRiver Riddle } else if (memorySpace) {
246*c60b897dSRiver Riddle return emitError("multiple memory spaces specified in memref type");
247*c60b897dSRiver Riddle } else {
248*c60b897dSRiver Riddle memorySpace = attr;
249*c60b897dSRiver Riddle return success();
250*c60b897dSRiver Riddle }
251*c60b897dSRiver Riddle }
252*c60b897dSRiver Riddle
253*c60b897dSRiver Riddle if (isUnranked)
254*c60b897dSRiver Riddle return emitError("cannot have affine map for unranked memref type");
255*c60b897dSRiver Riddle if (memorySpace)
256*c60b897dSRiver Riddle return emitError("expected memory space to be last in memref type");
257*c60b897dSRiver Riddle
258*c60b897dSRiver Riddle return success();
259*c60b897dSRiver Riddle };
260*c60b897dSRiver Riddle
261*c60b897dSRiver Riddle // Parse a list of mappings and address space if present.
262*c60b897dSRiver Riddle if (!consumeIf(Token::greater)) {
263*c60b897dSRiver Riddle // Parse comma separated list of affine maps, followed by memory space.
264*c60b897dSRiver Riddle if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
265*c60b897dSRiver Riddle parseCommaSeparatedListUntil(Token::greater, parseElt,
266*c60b897dSRiver Riddle /*allowEmptyList=*/false)) {
267*c60b897dSRiver Riddle return nullptr;
268*c60b897dSRiver Riddle }
269*c60b897dSRiver Riddle }
270*c60b897dSRiver Riddle
271*c60b897dSRiver Riddle if (isUnranked)
272*c60b897dSRiver Riddle return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
273*c60b897dSRiver Riddle
274*c60b897dSRiver Riddle return getChecked<MemRefType>(loc, dimensions, elementType, layout,
275*c60b897dSRiver Riddle memorySpace);
276*c60b897dSRiver Riddle }
277*c60b897dSRiver Riddle
278*c60b897dSRiver Riddle /// Parse any type except the function type.
279*c60b897dSRiver Riddle ///
280*c60b897dSRiver Riddle /// non-function-type ::= integer-type
281*c60b897dSRiver Riddle /// | index-type
282*c60b897dSRiver Riddle /// | float-type
283*c60b897dSRiver Riddle /// | extended-type
284*c60b897dSRiver Riddle /// | vector-type
285*c60b897dSRiver Riddle /// | tensor-type
286*c60b897dSRiver Riddle /// | memref-type
287*c60b897dSRiver Riddle /// | complex-type
288*c60b897dSRiver Riddle /// | tuple-type
289*c60b897dSRiver Riddle /// | none-type
290*c60b897dSRiver Riddle ///
291*c60b897dSRiver Riddle /// index-type ::= `index`
292*c60b897dSRiver Riddle /// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
293*c60b897dSRiver Riddle /// none-type ::= `none`
294*c60b897dSRiver Riddle ///
parseNonFunctionType()295*c60b897dSRiver Riddle Type Parser::parseNonFunctionType() {
296*c60b897dSRiver Riddle switch (getToken().getKind()) {
297*c60b897dSRiver Riddle default:
298*c60b897dSRiver Riddle return (emitWrongTokenError("expected non-function type"), nullptr);
299*c60b897dSRiver Riddle case Token::kw_memref:
300*c60b897dSRiver Riddle return parseMemRefType();
301*c60b897dSRiver Riddle case Token::kw_tensor:
302*c60b897dSRiver Riddle return parseTensorType();
303*c60b897dSRiver Riddle case Token::kw_complex:
304*c60b897dSRiver Riddle return parseComplexType();
305*c60b897dSRiver Riddle case Token::kw_tuple:
306*c60b897dSRiver Riddle return parseTupleType();
307*c60b897dSRiver Riddle case Token::kw_vector:
308*c60b897dSRiver Riddle return parseVectorType();
309*c60b897dSRiver Riddle // integer-type
310*c60b897dSRiver Riddle case Token::inttype: {
311*c60b897dSRiver Riddle auto width = getToken().getIntTypeBitwidth();
312*c60b897dSRiver Riddle if (!width.has_value())
313*c60b897dSRiver Riddle return (emitError("invalid integer width"), nullptr);
314*c60b897dSRiver Riddle if (width.value() > IntegerType::kMaxWidth) {
315*c60b897dSRiver Riddle emitError(getToken().getLoc(), "integer bitwidth is limited to ")
316*c60b897dSRiver Riddle << IntegerType::kMaxWidth << " bits";
317*c60b897dSRiver Riddle return nullptr;
318*c60b897dSRiver Riddle }
319*c60b897dSRiver Riddle
320*c60b897dSRiver Riddle IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
321*c60b897dSRiver Riddle if (Optional<bool> signedness = getToken().getIntTypeSignedness())
322*c60b897dSRiver Riddle signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
323*c60b897dSRiver Riddle
324*c60b897dSRiver Riddle consumeToken(Token::inttype);
325*c60b897dSRiver Riddle return IntegerType::get(getContext(), *width, signSemantics);
326*c60b897dSRiver Riddle }
327*c60b897dSRiver Riddle
328*c60b897dSRiver Riddle // float-type
329*c60b897dSRiver Riddle case Token::kw_bf16:
330*c60b897dSRiver Riddle consumeToken(Token::kw_bf16);
331*c60b897dSRiver Riddle return builder.getBF16Type();
332*c60b897dSRiver Riddle case Token::kw_f16:
333*c60b897dSRiver Riddle consumeToken(Token::kw_f16);
334*c60b897dSRiver Riddle return builder.getF16Type();
335*c60b897dSRiver Riddle case Token::kw_f32:
336*c60b897dSRiver Riddle consumeToken(Token::kw_f32);
337*c60b897dSRiver Riddle return builder.getF32Type();
338*c60b897dSRiver Riddle case Token::kw_f64:
339*c60b897dSRiver Riddle consumeToken(Token::kw_f64);
340*c60b897dSRiver Riddle return builder.getF64Type();
341*c60b897dSRiver Riddle case Token::kw_f80:
342*c60b897dSRiver Riddle consumeToken(Token::kw_f80);
343*c60b897dSRiver Riddle return builder.getF80Type();
344*c60b897dSRiver Riddle case Token::kw_f128:
345*c60b897dSRiver Riddle consumeToken(Token::kw_f128);
346*c60b897dSRiver Riddle return builder.getF128Type();
347*c60b897dSRiver Riddle
348*c60b897dSRiver Riddle // index-type
349*c60b897dSRiver Riddle case Token::kw_index:
350*c60b897dSRiver Riddle consumeToken(Token::kw_index);
351*c60b897dSRiver Riddle return builder.getIndexType();
352*c60b897dSRiver Riddle
353*c60b897dSRiver Riddle // none-type
354*c60b897dSRiver Riddle case Token::kw_none:
355*c60b897dSRiver Riddle consumeToken(Token::kw_none);
356*c60b897dSRiver Riddle return builder.getNoneType();
357*c60b897dSRiver Riddle
358*c60b897dSRiver Riddle // extended type
359*c60b897dSRiver Riddle case Token::exclamation_identifier:
360*c60b897dSRiver Riddle return parseExtendedType();
361*c60b897dSRiver Riddle
362*c60b897dSRiver Riddle // Handle completion of a dialect type.
363*c60b897dSRiver Riddle case Token::code_complete:
364*c60b897dSRiver Riddle if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
365*c60b897dSRiver Riddle return parseExtendedType();
366*c60b897dSRiver Riddle return codeCompleteType();
367*c60b897dSRiver Riddle }
368*c60b897dSRiver Riddle }
369*c60b897dSRiver Riddle
370*c60b897dSRiver Riddle /// Parse a tensor type.
371*c60b897dSRiver Riddle ///
372*c60b897dSRiver Riddle /// tensor-type ::= `tensor` `<` dimension-list type `>`
373*c60b897dSRiver Riddle /// dimension-list ::= dimension-list-ranked | `*x`
374*c60b897dSRiver Riddle ///
parseTensorType()375*c60b897dSRiver Riddle Type Parser::parseTensorType() {
376*c60b897dSRiver Riddle consumeToken(Token::kw_tensor);
377*c60b897dSRiver Riddle
378*c60b897dSRiver Riddle if (parseToken(Token::less, "expected '<' in tensor type"))
379*c60b897dSRiver Riddle return nullptr;
380*c60b897dSRiver Riddle
381*c60b897dSRiver Riddle bool isUnranked;
382*c60b897dSRiver Riddle SmallVector<int64_t, 4> dimensions;
383*c60b897dSRiver Riddle
384*c60b897dSRiver Riddle if (consumeIf(Token::star)) {
385*c60b897dSRiver Riddle // This is an unranked tensor type.
386*c60b897dSRiver Riddle isUnranked = true;
387*c60b897dSRiver Riddle
388*c60b897dSRiver Riddle if (parseXInDimensionList())
389*c60b897dSRiver Riddle return nullptr;
390*c60b897dSRiver Riddle
391*c60b897dSRiver Riddle } else {
392*c60b897dSRiver Riddle isUnranked = false;
393*c60b897dSRiver Riddle if (parseDimensionListRanked(dimensions))
394*c60b897dSRiver Riddle return nullptr;
395*c60b897dSRiver Riddle }
396*c60b897dSRiver Riddle
397*c60b897dSRiver Riddle // Parse the element type.
398*c60b897dSRiver Riddle auto elementTypeLoc = getToken().getLoc();
399*c60b897dSRiver Riddle auto elementType = parseType();
400*c60b897dSRiver Riddle
401*c60b897dSRiver Riddle // Parse an optional encoding attribute.
402*c60b897dSRiver Riddle Attribute encoding;
403*c60b897dSRiver Riddle if (consumeIf(Token::comma)) {
404*c60b897dSRiver Riddle encoding = parseAttribute();
405*c60b897dSRiver Riddle if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
406*c60b897dSRiver Riddle if (failed(v.verifyEncoding(dimensions, elementType,
407*c60b897dSRiver Riddle [&] { return emitError(); })))
408*c60b897dSRiver Riddle return nullptr;
409*c60b897dSRiver Riddle }
410*c60b897dSRiver Riddle }
411*c60b897dSRiver Riddle
412*c60b897dSRiver Riddle if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
413*c60b897dSRiver Riddle return nullptr;
414*c60b897dSRiver Riddle if (!TensorType::isValidElementType(elementType))
415*c60b897dSRiver Riddle return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
416*c60b897dSRiver Riddle
417*c60b897dSRiver Riddle if (isUnranked) {
418*c60b897dSRiver Riddle if (encoding)
419*c60b897dSRiver Riddle return emitError("cannot apply encoding to unranked tensor"), nullptr;
420*c60b897dSRiver Riddle return UnrankedTensorType::get(elementType);
421*c60b897dSRiver Riddle }
422*c60b897dSRiver Riddle return RankedTensorType::get(dimensions, elementType, encoding);
423*c60b897dSRiver Riddle }
424*c60b897dSRiver Riddle
425*c60b897dSRiver Riddle /// Parse a tuple type.
426*c60b897dSRiver Riddle ///
427*c60b897dSRiver Riddle /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
428*c60b897dSRiver Riddle ///
parseTupleType()429*c60b897dSRiver Riddle Type Parser::parseTupleType() {
430*c60b897dSRiver Riddle consumeToken(Token::kw_tuple);
431*c60b897dSRiver Riddle
432*c60b897dSRiver Riddle // Parse the '<'.
433*c60b897dSRiver Riddle if (parseToken(Token::less, "expected '<' in tuple type"))
434*c60b897dSRiver Riddle return nullptr;
435*c60b897dSRiver Riddle
436*c60b897dSRiver Riddle // Check for an empty tuple by directly parsing '>'.
437*c60b897dSRiver Riddle if (consumeIf(Token::greater))
438*c60b897dSRiver Riddle return TupleType::get(getContext());
439*c60b897dSRiver Riddle
440*c60b897dSRiver Riddle // Parse the element types and the '>'.
441*c60b897dSRiver Riddle SmallVector<Type, 4> types;
442*c60b897dSRiver Riddle if (parseTypeListNoParens(types) ||
443*c60b897dSRiver Riddle parseToken(Token::greater, "expected '>' in tuple type"))
444*c60b897dSRiver Riddle return nullptr;
445*c60b897dSRiver Riddle
446*c60b897dSRiver Riddle return TupleType::get(getContext(), types);
447*c60b897dSRiver Riddle }
448*c60b897dSRiver Riddle
449*c60b897dSRiver Riddle /// Parse a vector type.
450*c60b897dSRiver Riddle ///
451*c60b897dSRiver Riddle /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
452*c60b897dSRiver Riddle /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
453*c60b897dSRiver Riddle /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
454*c60b897dSRiver Riddle ///
parseVectorType()455*c60b897dSRiver Riddle VectorType Parser::parseVectorType() {
456*c60b897dSRiver Riddle consumeToken(Token::kw_vector);
457*c60b897dSRiver Riddle
458*c60b897dSRiver Riddle if (parseToken(Token::less, "expected '<' in vector type"))
459*c60b897dSRiver Riddle return nullptr;
460*c60b897dSRiver Riddle
461*c60b897dSRiver Riddle SmallVector<int64_t, 4> dimensions;
462*c60b897dSRiver Riddle unsigned numScalableDims;
463*c60b897dSRiver Riddle if (parseVectorDimensionList(dimensions, numScalableDims))
464*c60b897dSRiver Riddle return nullptr;
465*c60b897dSRiver Riddle if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
466*c60b897dSRiver Riddle return emitError(getToken().getLoc(),
467*c60b897dSRiver Riddle "vector types must have positive constant sizes"),
468*c60b897dSRiver Riddle nullptr;
469*c60b897dSRiver Riddle
470*c60b897dSRiver Riddle // Parse the element type.
471*c60b897dSRiver Riddle auto typeLoc = getToken().getLoc();
472*c60b897dSRiver Riddle auto elementType = parseType();
473*c60b897dSRiver Riddle if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
474*c60b897dSRiver Riddle return nullptr;
475*c60b897dSRiver Riddle
476*c60b897dSRiver Riddle if (!VectorType::isValidElementType(elementType))
477*c60b897dSRiver Riddle return emitError(typeLoc, "vector elements must be int/index/float type"),
478*c60b897dSRiver Riddle nullptr;
479*c60b897dSRiver Riddle
480*c60b897dSRiver Riddle return VectorType::get(dimensions, elementType, numScalableDims);
481*c60b897dSRiver Riddle }
482*c60b897dSRiver Riddle
483*c60b897dSRiver Riddle /// Parse a dimension list in a vector type. This populates the dimension list,
484*c60b897dSRiver Riddle /// and returns the number of scalable dimensions in `numScalableDims`.
485*c60b897dSRiver Riddle ///
486*c60b897dSRiver Riddle /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
487*c60b897dSRiver Riddle /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
488*c60b897dSRiver Riddle ///
489*c60b897dSRiver Riddle ParseResult
parseVectorDimensionList(SmallVectorImpl<int64_t> & dimensions,unsigned & numScalableDims)490*c60b897dSRiver Riddle Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
491*c60b897dSRiver Riddle unsigned &numScalableDims) {
492*c60b897dSRiver Riddle numScalableDims = 0;
493*c60b897dSRiver Riddle // If there is a set of fixed-length dimensions, consume it
494*c60b897dSRiver Riddle while (getToken().is(Token::integer)) {
495*c60b897dSRiver Riddle int64_t value;
496*c60b897dSRiver Riddle if (parseIntegerInDimensionList(value))
497*c60b897dSRiver Riddle return failure();
498*c60b897dSRiver Riddle dimensions.push_back(value);
499*c60b897dSRiver Riddle // Make sure we have an 'x' or something like 'xbf32'.
500*c60b897dSRiver Riddle if (parseXInDimensionList())
501*c60b897dSRiver Riddle return failure();
502*c60b897dSRiver Riddle }
503*c60b897dSRiver Riddle // If there is a set of scalable dimensions, consume it
504*c60b897dSRiver Riddle if (consumeIf(Token::l_square)) {
505*c60b897dSRiver Riddle while (getToken().is(Token::integer)) {
506*c60b897dSRiver Riddle int64_t value;
507*c60b897dSRiver Riddle if (parseIntegerInDimensionList(value))
508*c60b897dSRiver Riddle return failure();
509*c60b897dSRiver Riddle dimensions.push_back(value);
510*c60b897dSRiver Riddle numScalableDims++;
511*c60b897dSRiver Riddle // Check if we have reached the end of the scalable dimension list
512*c60b897dSRiver Riddle if (consumeIf(Token::r_square)) {
513*c60b897dSRiver Riddle // Make sure we have something like 'xbf32'.
514*c60b897dSRiver Riddle return parseXInDimensionList();
515*c60b897dSRiver Riddle }
516*c60b897dSRiver Riddle // Make sure we have an 'x'
517*c60b897dSRiver Riddle if (parseXInDimensionList())
518*c60b897dSRiver Riddle return failure();
519*c60b897dSRiver Riddle }
520*c60b897dSRiver Riddle // If we make it here, we've finished parsing the dimension list
521*c60b897dSRiver Riddle // without finding ']' closing the set of scalable dimensions
522*c60b897dSRiver Riddle return emitWrongTokenError(
523*c60b897dSRiver Riddle "missing ']' closing set of scalable dimensions");
524*c60b897dSRiver Riddle }
525*c60b897dSRiver Riddle
526*c60b897dSRiver Riddle return success();
527*c60b897dSRiver Riddle }
528*c60b897dSRiver Riddle
529*c60b897dSRiver Riddle /// Parse a dimension list of a tensor or memref type. This populates the
530*c60b897dSRiver Riddle /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
531*c60b897dSRiver Riddle /// errors out on `?` otherwise. Parsing the trailing `x` is configurable.
532*c60b897dSRiver Riddle ///
533*c60b897dSRiver Riddle /// dimension-list ::= eps | dimension (`x` dimension)*
534*c60b897dSRiver Riddle /// dimension-list-with-trailing-x ::= (dimension `x`)*
535*c60b897dSRiver Riddle /// dimension ::= `?` | decimal-literal
536*c60b897dSRiver Riddle ///
537*c60b897dSRiver Riddle /// When `allowDynamic` is not set, this is used to parse:
538*c60b897dSRiver Riddle ///
539*c60b897dSRiver Riddle /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
540*c60b897dSRiver Riddle /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
541*c60b897dSRiver Riddle ParseResult
parseDimensionListRanked(SmallVectorImpl<int64_t> & dimensions,bool allowDynamic,bool withTrailingX)542*c60b897dSRiver Riddle Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
543*c60b897dSRiver Riddle bool allowDynamic, bool withTrailingX) {
544*c60b897dSRiver Riddle auto parseDim = [&]() -> LogicalResult {
545*c60b897dSRiver Riddle auto loc = getToken().getLoc();
546*c60b897dSRiver Riddle if (consumeIf(Token::question)) {
547*c60b897dSRiver Riddle if (!allowDynamic)
548*c60b897dSRiver Riddle return emitError(loc, "expected static shape");
549*c60b897dSRiver Riddle dimensions.push_back(-1);
550*c60b897dSRiver Riddle } else {
551*c60b897dSRiver Riddle int64_t value;
552*c60b897dSRiver Riddle if (failed(parseIntegerInDimensionList(value)))
553*c60b897dSRiver Riddle return failure();
554*c60b897dSRiver Riddle dimensions.push_back(value);
555*c60b897dSRiver Riddle }
556*c60b897dSRiver Riddle return success();
557*c60b897dSRiver Riddle };
558*c60b897dSRiver Riddle
559*c60b897dSRiver Riddle if (withTrailingX) {
560*c60b897dSRiver Riddle while (getToken().isAny(Token::integer, Token::question)) {
561*c60b897dSRiver Riddle if (failed(parseDim()) || failed(parseXInDimensionList()))
562*c60b897dSRiver Riddle return failure();
563*c60b897dSRiver Riddle }
564*c60b897dSRiver Riddle return success();
565*c60b897dSRiver Riddle }
566*c60b897dSRiver Riddle
567*c60b897dSRiver Riddle if (getToken().isAny(Token::integer, Token::question)) {
568*c60b897dSRiver Riddle if (failed(parseDim()))
569*c60b897dSRiver Riddle return failure();
570*c60b897dSRiver Riddle while (getToken().is(Token::bare_identifier) &&
571*c60b897dSRiver Riddle getTokenSpelling()[0] == 'x') {
572*c60b897dSRiver Riddle if (failed(parseXInDimensionList()) || failed(parseDim()))
573*c60b897dSRiver Riddle return failure();
574*c60b897dSRiver Riddle }
575*c60b897dSRiver Riddle }
576*c60b897dSRiver Riddle return success();
577*c60b897dSRiver Riddle }
578*c60b897dSRiver Riddle
parseIntegerInDimensionList(int64_t & value)579*c60b897dSRiver Riddle ParseResult Parser::parseIntegerInDimensionList(int64_t &value) {
580*c60b897dSRiver Riddle // Hexadecimal integer literals (starting with `0x`) are not allowed in
581*c60b897dSRiver Riddle // aggregate type declarations. Therefore, `0xf32` should be processed as
582*c60b897dSRiver Riddle // a sequence of separate elements `0`, `x`, `f32`.
583*c60b897dSRiver Riddle if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
584*c60b897dSRiver Riddle // We can get here only if the token is an integer literal. Hexadecimal
585*c60b897dSRiver Riddle // integer literals can only start with `0x` (`1x` wouldn't lex as a
586*c60b897dSRiver Riddle // literal, just `1` would, at which point we don't get into this
587*c60b897dSRiver Riddle // branch).
588*c60b897dSRiver Riddle assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
589*c60b897dSRiver Riddle value = 0;
590*c60b897dSRiver Riddle state.lex.resetPointer(getTokenSpelling().data() + 1);
591*c60b897dSRiver Riddle consumeToken();
592*c60b897dSRiver Riddle } else {
593*c60b897dSRiver Riddle // Make sure this integer value is in bound and valid.
594*c60b897dSRiver Riddle Optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
595*c60b897dSRiver Riddle if (!dimension ||
596*c60b897dSRiver Riddle *dimension > (uint64_t)std::numeric_limits<int64_t>::max())
597*c60b897dSRiver Riddle return emitError("invalid dimension");
598*c60b897dSRiver Riddle value = (int64_t)*dimension;
599*c60b897dSRiver Riddle consumeToken(Token::integer);
600*c60b897dSRiver Riddle }
601*c60b897dSRiver Riddle return success();
602*c60b897dSRiver Riddle }
603*c60b897dSRiver Riddle
604*c60b897dSRiver Riddle /// Parse an 'x' token in a dimension list, handling the case where the x is
605*c60b897dSRiver Riddle /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
606*c60b897dSRiver Riddle /// token.
parseXInDimensionList()607*c60b897dSRiver Riddle ParseResult Parser::parseXInDimensionList() {
608*c60b897dSRiver Riddle if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
609*c60b897dSRiver Riddle return emitWrongTokenError("expected 'x' in dimension list");
610*c60b897dSRiver Riddle
611*c60b897dSRiver Riddle // If we had a prefix of 'x', lex the next token immediately after the 'x'.
612*c60b897dSRiver Riddle if (getTokenSpelling().size() != 1)
613*c60b897dSRiver Riddle state.lex.resetPointer(getTokenSpelling().data() + 1);
614*c60b897dSRiver Riddle
615*c60b897dSRiver Riddle // Consume the 'x'.
616*c60b897dSRiver Riddle consumeToken(Token::bare_identifier);
617*c60b897dSRiver Riddle
618*c60b897dSRiver Riddle return success();
619*c60b897dSRiver Riddle }
620*c60b897dSRiver Riddle
621*c60b897dSRiver Riddle // Parse a comma-separated list of dimensions, possibly empty:
622*c60b897dSRiver Riddle // stride-list ::= `[` (dimension (`,` dimension)*)? `]`
parseStrideList(SmallVectorImpl<int64_t> & dimensions)623*c60b897dSRiver Riddle ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
624*c60b897dSRiver Riddle return parseCommaSeparatedList(
625*c60b897dSRiver Riddle Delimiter::Square,
626*c60b897dSRiver Riddle [&]() -> ParseResult {
627*c60b897dSRiver Riddle if (consumeIf(Token::question)) {
628*c60b897dSRiver Riddle dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
629*c60b897dSRiver Riddle } else {
630*c60b897dSRiver Riddle // This must be an integer value.
631*c60b897dSRiver Riddle int64_t val;
632*c60b897dSRiver Riddle if (getToken().getSpelling().getAsInteger(10, val))
633*c60b897dSRiver Riddle return emitError("invalid integer value: ")
634*c60b897dSRiver Riddle << getToken().getSpelling();
635*c60b897dSRiver Riddle // Make sure it is not the one value for `?`.
636*c60b897dSRiver Riddle if (ShapedType::isDynamic(val))
637*c60b897dSRiver Riddle return emitError("invalid integer value: ")
638*c60b897dSRiver Riddle << getToken().getSpelling()
639*c60b897dSRiver Riddle << ", use `?` to specify a dynamic dimension";
640*c60b897dSRiver Riddle
641*c60b897dSRiver Riddle if (val == 0)
642*c60b897dSRiver Riddle return emitError("invalid memref stride");
643*c60b897dSRiver Riddle
644*c60b897dSRiver Riddle dimensions.push_back(val);
645*c60b897dSRiver Riddle consumeToken(Token::integer);
646*c60b897dSRiver Riddle }
647*c60b897dSRiver Riddle return success();
648*c60b897dSRiver Riddle },
649*c60b897dSRiver Riddle " in stride list");
650*c60b897dSRiver Riddle }
651