1*c60b897dSRiver Riddle //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2*c60b897dSRiver Riddle //
3*c60b897dSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*c60b897dSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*c60b897dSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*c60b897dSRiver Riddle //
7*c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
8*c60b897dSRiver Riddle //
9*c60b897dSRiver Riddle // This file implements the parser for the MLIR Types.
10*c60b897dSRiver Riddle //
11*c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
12*c60b897dSRiver Riddle 
13*c60b897dSRiver Riddle #include "Parser.h"
14*c60b897dSRiver Riddle 
15*c60b897dSRiver Riddle #include "AsmParserImpl.h"
16*c60b897dSRiver Riddle #include "mlir/AsmParser/AsmParserState.h"
17*c60b897dSRiver Riddle #include "mlir/IR/AffineMap.h"
18*c60b897dSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
19*c60b897dSRiver Riddle #include "mlir/IR/Dialect.h"
20*c60b897dSRiver Riddle #include "mlir/IR/DialectImplementation.h"
21*c60b897dSRiver Riddle #include "mlir/IR/IntegerSet.h"
22*c60b897dSRiver Riddle #include "llvm/ADT/StringExtras.h"
23*c60b897dSRiver Riddle #include "llvm/Support/Endian.h"
24*c60b897dSRiver Riddle 
25*c60b897dSRiver Riddle using namespace mlir;
26*c60b897dSRiver Riddle using namespace mlir::detail;
27*c60b897dSRiver Riddle 
28*c60b897dSRiver Riddle /// Parse an arbitrary attribute.
29*c60b897dSRiver Riddle ///
30*c60b897dSRiver Riddle ///  attribute-value ::= `unit`
31*c60b897dSRiver Riddle ///                    | bool-literal
32*c60b897dSRiver Riddle ///                    | integer-literal (`:` (index-type | integer-type))?
33*c60b897dSRiver Riddle ///                    | float-literal (`:` float-type)?
34*c60b897dSRiver Riddle ///                    | string-literal (`:` type)?
35*c60b897dSRiver Riddle ///                    | type
36*c60b897dSRiver Riddle ///                    | `[` `:` (integer-type | float-type) tensor-literal `]`
37*c60b897dSRiver Riddle ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
38*c60b897dSRiver Riddle ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
39*c60b897dSRiver Riddle ///                    | symbol-ref-id (`::` symbol-ref-id)*
40*c60b897dSRiver Riddle ///                    | `dense` `<` tensor-literal `>` `:`
41*c60b897dSRiver Riddle ///                      (tensor-type | vector-type)
42*c60b897dSRiver Riddle ///                    | `sparse` `<` attribute-value `,` attribute-value `>`
43*c60b897dSRiver Riddle ///                      `:` (tensor-type | vector-type)
44*c60b897dSRiver Riddle ///                    | `opaque` `<` dialect-namespace  `,` hex-string-literal
45*c60b897dSRiver Riddle ///                      `>` `:` (tensor-type | vector-type)
46*c60b897dSRiver Riddle ///                    | extended-attribute
47*c60b897dSRiver Riddle ///
parseAttribute(Type type)48*c60b897dSRiver Riddle Attribute Parser::parseAttribute(Type type) {
49*c60b897dSRiver Riddle   switch (getToken().getKind()) {
50*c60b897dSRiver Riddle   // Parse an AffineMap or IntegerSet attribute.
51*c60b897dSRiver Riddle   case Token::kw_affine_map: {
52*c60b897dSRiver Riddle     consumeToken(Token::kw_affine_map);
53*c60b897dSRiver Riddle 
54*c60b897dSRiver Riddle     AffineMap map;
55*c60b897dSRiver Riddle     if (parseToken(Token::less, "expected '<' in affine map") ||
56*c60b897dSRiver Riddle         parseAffineMapReference(map) ||
57*c60b897dSRiver Riddle         parseToken(Token::greater, "expected '>' in affine map"))
58*c60b897dSRiver Riddle       return Attribute();
59*c60b897dSRiver Riddle     return AffineMapAttr::get(map);
60*c60b897dSRiver Riddle   }
61*c60b897dSRiver Riddle   case Token::kw_affine_set: {
62*c60b897dSRiver Riddle     consumeToken(Token::kw_affine_set);
63*c60b897dSRiver Riddle 
64*c60b897dSRiver Riddle     IntegerSet set;
65*c60b897dSRiver Riddle     if (parseToken(Token::less, "expected '<' in integer set") ||
66*c60b897dSRiver Riddle         parseIntegerSetReference(set) ||
67*c60b897dSRiver Riddle         parseToken(Token::greater, "expected '>' in integer set"))
68*c60b897dSRiver Riddle       return Attribute();
69*c60b897dSRiver Riddle     return IntegerSetAttr::get(set);
70*c60b897dSRiver Riddle   }
71*c60b897dSRiver Riddle 
72*c60b897dSRiver Riddle   // Parse an array attribute.
73*c60b897dSRiver Riddle   case Token::l_square: {
74*c60b897dSRiver Riddle     consumeToken(Token::l_square);
75*c60b897dSRiver Riddle     if (consumeIf(Token::colon))
76*c60b897dSRiver Riddle       return parseDenseArrayAttr();
77*c60b897dSRiver Riddle     SmallVector<Attribute, 4> elements;
78*c60b897dSRiver Riddle     auto parseElt = [&]() -> ParseResult {
79*c60b897dSRiver Riddle       elements.push_back(parseAttribute());
80*c60b897dSRiver Riddle       return elements.back() ? success() : failure();
81*c60b897dSRiver Riddle     };
82*c60b897dSRiver Riddle 
83*c60b897dSRiver Riddle     if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
84*c60b897dSRiver Riddle       return nullptr;
85*c60b897dSRiver Riddle     return builder.getArrayAttr(elements);
86*c60b897dSRiver Riddle   }
87*c60b897dSRiver Riddle 
88*c60b897dSRiver Riddle   // Parse a boolean attribute.
89*c60b897dSRiver Riddle   case Token::kw_false:
90*c60b897dSRiver Riddle     consumeToken(Token::kw_false);
91*c60b897dSRiver Riddle     return builder.getBoolAttr(false);
92*c60b897dSRiver Riddle   case Token::kw_true:
93*c60b897dSRiver Riddle     consumeToken(Token::kw_true);
94*c60b897dSRiver Riddle     return builder.getBoolAttr(true);
95*c60b897dSRiver Riddle 
96*c60b897dSRiver Riddle   // Parse a dense elements attribute.
97*c60b897dSRiver Riddle   case Token::kw_dense:
98*c60b897dSRiver Riddle     return parseDenseElementsAttr(type);
99*c60b897dSRiver Riddle 
100*c60b897dSRiver Riddle   // Parse a dictionary attribute.
101*c60b897dSRiver Riddle   case Token::l_brace: {
102*c60b897dSRiver Riddle     NamedAttrList elements;
103*c60b897dSRiver Riddle     if (parseAttributeDict(elements))
104*c60b897dSRiver Riddle       return nullptr;
105*c60b897dSRiver Riddle     return elements.getDictionary(getContext());
106*c60b897dSRiver Riddle   }
107*c60b897dSRiver Riddle 
108*c60b897dSRiver Riddle   // Parse an extended attribute, i.e. alias or dialect attribute.
109*c60b897dSRiver Riddle   case Token::hash_identifier:
110*c60b897dSRiver Riddle     return parseExtendedAttr(type);
111*c60b897dSRiver Riddle 
112*c60b897dSRiver Riddle   // Parse floating point and integer attributes.
113*c60b897dSRiver Riddle   case Token::floatliteral:
114*c60b897dSRiver Riddle     return parseFloatAttr(type, /*isNegative=*/false);
115*c60b897dSRiver Riddle   case Token::integer:
116*c60b897dSRiver Riddle     return parseDecOrHexAttr(type, /*isNegative=*/false);
117*c60b897dSRiver Riddle   case Token::minus: {
118*c60b897dSRiver Riddle     consumeToken(Token::minus);
119*c60b897dSRiver Riddle     if (getToken().is(Token::integer))
120*c60b897dSRiver Riddle       return parseDecOrHexAttr(type, /*isNegative=*/true);
121*c60b897dSRiver Riddle     if (getToken().is(Token::floatliteral))
122*c60b897dSRiver Riddle       return parseFloatAttr(type, /*isNegative=*/true);
123*c60b897dSRiver Riddle 
124*c60b897dSRiver Riddle     return (emitWrongTokenError(
125*c60b897dSRiver Riddle                 "expected constant integer or floating point value"),
126*c60b897dSRiver Riddle             nullptr);
127*c60b897dSRiver Riddle   }
128*c60b897dSRiver Riddle 
129*c60b897dSRiver Riddle   // Parse a location attribute.
130*c60b897dSRiver Riddle   case Token::kw_loc: {
131*c60b897dSRiver Riddle     consumeToken(Token::kw_loc);
132*c60b897dSRiver Riddle 
133*c60b897dSRiver Riddle     LocationAttr locAttr;
134*c60b897dSRiver Riddle     if (parseToken(Token::l_paren, "expected '(' in inline location") ||
135*c60b897dSRiver Riddle         parseLocationInstance(locAttr) ||
136*c60b897dSRiver Riddle         parseToken(Token::r_paren, "expected ')' in inline location"))
137*c60b897dSRiver Riddle       return Attribute();
138*c60b897dSRiver Riddle     return locAttr;
139*c60b897dSRiver Riddle   }
140*c60b897dSRiver Riddle 
141*c60b897dSRiver Riddle   // Parse an opaque elements attribute.
142*c60b897dSRiver Riddle   case Token::kw_opaque:
143*c60b897dSRiver Riddle     return parseOpaqueElementsAttr(type);
144*c60b897dSRiver Riddle 
145*c60b897dSRiver Riddle   // Parse a sparse elements attribute.
146*c60b897dSRiver Riddle   case Token::kw_sparse:
147*c60b897dSRiver Riddle     return parseSparseElementsAttr(type);
148*c60b897dSRiver Riddle 
149*c60b897dSRiver Riddle   // Parse a string attribute.
150*c60b897dSRiver Riddle   case Token::string: {
151*c60b897dSRiver Riddle     auto val = getToken().getStringValue();
152*c60b897dSRiver Riddle     consumeToken(Token::string);
153*c60b897dSRiver Riddle     // Parse the optional trailing colon type if one wasn't explicitly provided.
154*c60b897dSRiver Riddle     if (!type && consumeIf(Token::colon) && !(type = parseType()))
155*c60b897dSRiver Riddle       return Attribute();
156*c60b897dSRiver Riddle 
157*c60b897dSRiver Riddle     return type ? StringAttr::get(val, type)
158*c60b897dSRiver Riddle                 : StringAttr::get(getContext(), val);
159*c60b897dSRiver Riddle   }
160*c60b897dSRiver Riddle 
161*c60b897dSRiver Riddle   // Parse a symbol reference attribute.
162*c60b897dSRiver Riddle   case Token::at_identifier: {
163*c60b897dSRiver Riddle     // When populating the parser state, this is a list of locations for all of
164*c60b897dSRiver Riddle     // the nested references.
165*c60b897dSRiver Riddle     SmallVector<SMRange> referenceLocations;
166*c60b897dSRiver Riddle     if (state.asmState)
167*c60b897dSRiver Riddle       referenceLocations.push_back(getToken().getLocRange());
168*c60b897dSRiver Riddle 
169*c60b897dSRiver Riddle     // Parse the top-level reference.
170*c60b897dSRiver Riddle     std::string nameStr = getToken().getSymbolReference();
171*c60b897dSRiver Riddle     consumeToken(Token::at_identifier);
172*c60b897dSRiver Riddle 
173*c60b897dSRiver Riddle     // Parse any nested references.
174*c60b897dSRiver Riddle     std::vector<FlatSymbolRefAttr> nestedRefs;
175*c60b897dSRiver Riddle     while (getToken().is(Token::colon)) {
176*c60b897dSRiver Riddle       // Check for the '::' prefix.
177*c60b897dSRiver Riddle       const char *curPointer = getToken().getLoc().getPointer();
178*c60b897dSRiver Riddle       consumeToken(Token::colon);
179*c60b897dSRiver Riddle       if (!consumeIf(Token::colon)) {
180*c60b897dSRiver Riddle         if (getToken().isNot(Token::eof, Token::error)) {
181*c60b897dSRiver Riddle           state.lex.resetPointer(curPointer);
182*c60b897dSRiver Riddle           consumeToken();
183*c60b897dSRiver Riddle         }
184*c60b897dSRiver Riddle         break;
185*c60b897dSRiver Riddle       }
186*c60b897dSRiver Riddle       // Parse the reference itself.
187*c60b897dSRiver Riddle       auto curLoc = getToken().getLoc();
188*c60b897dSRiver Riddle       if (getToken().isNot(Token::at_identifier)) {
189*c60b897dSRiver Riddle         emitError(curLoc, "expected nested symbol reference identifier");
190*c60b897dSRiver Riddle         return Attribute();
191*c60b897dSRiver Riddle       }
192*c60b897dSRiver Riddle 
193*c60b897dSRiver Riddle       // If we are populating the assembly state, add the location for this
194*c60b897dSRiver Riddle       // reference.
195*c60b897dSRiver Riddle       if (state.asmState)
196*c60b897dSRiver Riddle         referenceLocations.push_back(getToken().getLocRange());
197*c60b897dSRiver Riddle 
198*c60b897dSRiver Riddle       std::string nameStr = getToken().getSymbolReference();
199*c60b897dSRiver Riddle       consumeToken(Token::at_identifier);
200*c60b897dSRiver Riddle       nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
201*c60b897dSRiver Riddle     }
202*c60b897dSRiver Riddle     SymbolRefAttr symbolRefAttr =
203*c60b897dSRiver Riddle         SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
204*c60b897dSRiver Riddle 
205*c60b897dSRiver Riddle     // If we are populating the assembly state, record this symbol reference.
206*c60b897dSRiver Riddle     if (state.asmState)
207*c60b897dSRiver Riddle       state.asmState->addUses(symbolRefAttr, referenceLocations);
208*c60b897dSRiver Riddle     return symbolRefAttr;
209*c60b897dSRiver Riddle   }
210*c60b897dSRiver Riddle 
211*c60b897dSRiver Riddle   // Parse a 'unit' attribute.
212*c60b897dSRiver Riddle   case Token::kw_unit:
213*c60b897dSRiver Riddle     consumeToken(Token::kw_unit);
214*c60b897dSRiver Riddle     return builder.getUnitAttr();
215*c60b897dSRiver Riddle 
216*c60b897dSRiver Riddle     // Handle completion of an attribute.
217*c60b897dSRiver Riddle   case Token::code_complete:
218*c60b897dSRiver Riddle     if (getToken().isCodeCompletionFor(Token::hash_identifier))
219*c60b897dSRiver Riddle       return parseExtendedAttr(type);
220*c60b897dSRiver Riddle     return codeCompleteAttribute();
221*c60b897dSRiver Riddle 
222*c60b897dSRiver Riddle   default:
223*c60b897dSRiver Riddle     // Parse a type attribute. We parse `Optional` here to allow for providing a
224*c60b897dSRiver Riddle     // better error message.
225*c60b897dSRiver Riddle     Type type;
226*c60b897dSRiver Riddle     OptionalParseResult result = parseOptionalType(type);
227*c60b897dSRiver Riddle     if (!result.hasValue())
228*c60b897dSRiver Riddle       return emitWrongTokenError("expected attribute value"), Attribute();
229*c60b897dSRiver Riddle     return failed(*result) ? Attribute() : TypeAttr::get(type);
230*c60b897dSRiver Riddle   }
231*c60b897dSRiver Riddle }
232*c60b897dSRiver Riddle 
233*c60b897dSRiver Riddle /// Parse an optional attribute with the provided type.
parseOptionalAttribute(Attribute & attribute,Type type)234*c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
235*c60b897dSRiver Riddle                                                    Type type) {
236*c60b897dSRiver Riddle   switch (getToken().getKind()) {
237*c60b897dSRiver Riddle   case Token::at_identifier:
238*c60b897dSRiver Riddle   case Token::floatliteral:
239*c60b897dSRiver Riddle   case Token::integer:
240*c60b897dSRiver Riddle   case Token::hash_identifier:
241*c60b897dSRiver Riddle   case Token::kw_affine_map:
242*c60b897dSRiver Riddle   case Token::kw_affine_set:
243*c60b897dSRiver Riddle   case Token::kw_dense:
244*c60b897dSRiver Riddle   case Token::kw_false:
245*c60b897dSRiver Riddle   case Token::kw_loc:
246*c60b897dSRiver Riddle   case Token::kw_opaque:
247*c60b897dSRiver Riddle   case Token::kw_sparse:
248*c60b897dSRiver Riddle   case Token::kw_true:
249*c60b897dSRiver Riddle   case Token::kw_unit:
250*c60b897dSRiver Riddle   case Token::l_brace:
251*c60b897dSRiver Riddle   case Token::l_square:
252*c60b897dSRiver Riddle   case Token::minus:
253*c60b897dSRiver Riddle   case Token::string:
254*c60b897dSRiver Riddle     attribute = parseAttribute(type);
255*c60b897dSRiver Riddle     return success(attribute != nullptr);
256*c60b897dSRiver Riddle 
257*c60b897dSRiver Riddle   default:
258*c60b897dSRiver Riddle     // Parse an optional type attribute.
259*c60b897dSRiver Riddle     Type type;
260*c60b897dSRiver Riddle     OptionalParseResult result = parseOptionalType(type);
261*c60b897dSRiver Riddle     if (result.hasValue() && succeeded(*result))
262*c60b897dSRiver Riddle       attribute = TypeAttr::get(type);
263*c60b897dSRiver Riddle     return result;
264*c60b897dSRiver Riddle   }
265*c60b897dSRiver Riddle }
parseOptionalAttribute(ArrayAttr & attribute,Type type)266*c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
267*c60b897dSRiver Riddle                                                    Type type) {
268*c60b897dSRiver Riddle   return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
269*c60b897dSRiver Riddle }
parseOptionalAttribute(StringAttr & attribute,Type type)270*c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
271*c60b897dSRiver Riddle                                                    Type type) {
272*c60b897dSRiver Riddle   return parseOptionalAttributeWithToken(Token::string, attribute, type);
273*c60b897dSRiver Riddle }
274*c60b897dSRiver Riddle 
275*c60b897dSRiver Riddle /// Attribute dictionary.
276*c60b897dSRiver Riddle ///
277*c60b897dSRiver Riddle ///   attribute-dict ::= `{` `}`
278*c60b897dSRiver Riddle ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
279*c60b897dSRiver Riddle ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
280*c60b897dSRiver Riddle ///
parseAttributeDict(NamedAttrList & attributes)281*c60b897dSRiver Riddle ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
282*c60b897dSRiver Riddle   llvm::SmallDenseSet<StringAttr> seenKeys;
283*c60b897dSRiver Riddle   auto parseElt = [&]() -> ParseResult {
284*c60b897dSRiver Riddle     // The name of an attribute can either be a bare identifier, or a string.
285*c60b897dSRiver Riddle     Optional<StringAttr> nameId;
286*c60b897dSRiver Riddle     if (getToken().is(Token::string))
287*c60b897dSRiver Riddle       nameId = builder.getStringAttr(getToken().getStringValue());
288*c60b897dSRiver Riddle     else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
289*c60b897dSRiver Riddle              getToken().isKeyword())
290*c60b897dSRiver Riddle       nameId = builder.getStringAttr(getTokenSpelling());
291*c60b897dSRiver Riddle     else
292*c60b897dSRiver Riddle       return emitWrongTokenError("expected attribute name");
293*c60b897dSRiver Riddle 
294*c60b897dSRiver Riddle     if (nameId->size() == 0)
295*c60b897dSRiver Riddle       return emitError("expected valid attribute name");
296*c60b897dSRiver Riddle 
297*c60b897dSRiver Riddle     if (!seenKeys.insert(*nameId).second)
298*c60b897dSRiver Riddle       return emitError("duplicate key '")
299*c60b897dSRiver Riddle              << nameId->getValue() << "' in dictionary attribute";
300*c60b897dSRiver Riddle     consumeToken();
301*c60b897dSRiver Riddle 
302*c60b897dSRiver Riddle     // Lazy load a dialect in the context if there is a possible namespace.
303*c60b897dSRiver Riddle     auto splitName = nameId->strref().split('.');
304*c60b897dSRiver Riddle     if (!splitName.second.empty())
305*c60b897dSRiver Riddle       getContext()->getOrLoadDialect(splitName.first);
306*c60b897dSRiver Riddle 
307*c60b897dSRiver Riddle     // Try to parse the '=' for the attribute value.
308*c60b897dSRiver Riddle     if (!consumeIf(Token::equal)) {
309*c60b897dSRiver Riddle       // If there is no '=', we treat this as a unit attribute.
310*c60b897dSRiver Riddle       attributes.push_back({*nameId, builder.getUnitAttr()});
311*c60b897dSRiver Riddle       return success();
312*c60b897dSRiver Riddle     }
313*c60b897dSRiver Riddle 
314*c60b897dSRiver Riddle     auto attr = parseAttribute();
315*c60b897dSRiver Riddle     if (!attr)
316*c60b897dSRiver Riddle       return failure();
317*c60b897dSRiver Riddle     attributes.push_back({*nameId, attr});
318*c60b897dSRiver Riddle     return success();
319*c60b897dSRiver Riddle   };
320*c60b897dSRiver Riddle 
321*c60b897dSRiver Riddle   return parseCommaSeparatedList(Delimiter::Braces, parseElt,
322*c60b897dSRiver Riddle                                  " in attribute dictionary");
323*c60b897dSRiver Riddle }
324*c60b897dSRiver Riddle 
325*c60b897dSRiver Riddle /// Parse a float attribute.
parseFloatAttr(Type type,bool isNegative)326*c60b897dSRiver Riddle Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
327*c60b897dSRiver Riddle   auto val = getToken().getFloatingPointValue();
328*c60b897dSRiver Riddle   if (!val)
329*c60b897dSRiver Riddle     return (emitError("floating point value too large for attribute"), nullptr);
330*c60b897dSRiver Riddle   consumeToken(Token::floatliteral);
331*c60b897dSRiver Riddle   if (!type) {
332*c60b897dSRiver Riddle     // Default to F64 when no type is specified.
333*c60b897dSRiver Riddle     if (!consumeIf(Token::colon))
334*c60b897dSRiver Riddle       type = builder.getF64Type();
335*c60b897dSRiver Riddle     else if (!(type = parseType()))
336*c60b897dSRiver Riddle       return nullptr;
337*c60b897dSRiver Riddle   }
338*c60b897dSRiver Riddle   if (!type.isa<FloatType>())
339*c60b897dSRiver Riddle     return (emitError("floating point value not valid for specified type"),
340*c60b897dSRiver Riddle             nullptr);
341*c60b897dSRiver Riddle   return FloatAttr::get(type, isNegative ? -*val : *val);
342*c60b897dSRiver Riddle }
343*c60b897dSRiver Riddle 
344*c60b897dSRiver Riddle /// Construct an APint from a parsed value, a known attribute type and
345*c60b897dSRiver Riddle /// sign.
buildAttributeAPInt(Type type,bool isNegative,StringRef spelling)346*c60b897dSRiver Riddle static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
347*c60b897dSRiver Riddle                                            StringRef spelling) {
348*c60b897dSRiver Riddle   // Parse the integer value into an APInt that is big enough to hold the value.
349*c60b897dSRiver Riddle   APInt result;
350*c60b897dSRiver Riddle   bool isHex = spelling.size() > 1 && spelling[1] == 'x';
351*c60b897dSRiver Riddle   if (spelling.getAsInteger(isHex ? 0 : 10, result))
352*c60b897dSRiver Riddle     return llvm::None;
353*c60b897dSRiver Riddle 
354*c60b897dSRiver Riddle   // Extend or truncate the bitwidth to the right size.
355*c60b897dSRiver Riddle   unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
356*c60b897dSRiver Riddle                                   : type.getIntOrFloatBitWidth();
357*c60b897dSRiver Riddle 
358*c60b897dSRiver Riddle   if (width > result.getBitWidth()) {
359*c60b897dSRiver Riddle     result = result.zext(width);
360*c60b897dSRiver Riddle   } else if (width < result.getBitWidth()) {
361*c60b897dSRiver Riddle     // The parser can return an unnecessarily wide result with leading zeros.
362*c60b897dSRiver Riddle     // This isn't a problem, but truncating off bits is bad.
363*c60b897dSRiver Riddle     if (result.countLeadingZeros() < result.getBitWidth() - width)
364*c60b897dSRiver Riddle       return llvm::None;
365*c60b897dSRiver Riddle 
366*c60b897dSRiver Riddle     result = result.trunc(width);
367*c60b897dSRiver Riddle   }
368*c60b897dSRiver Riddle 
369*c60b897dSRiver Riddle   if (width == 0) {
370*c60b897dSRiver Riddle     // 0 bit integers cannot be negative and manipulation of their sign bit will
371*c60b897dSRiver Riddle     // assert, so short-cut validation here.
372*c60b897dSRiver Riddle     if (isNegative)
373*c60b897dSRiver Riddle       return llvm::None;
374*c60b897dSRiver Riddle   } else if (isNegative) {
375*c60b897dSRiver Riddle     // The value is negative, we have an overflow if the sign bit is not set
376*c60b897dSRiver Riddle     // in the negated apInt.
377*c60b897dSRiver Riddle     result.negate();
378*c60b897dSRiver Riddle     if (!result.isSignBitSet())
379*c60b897dSRiver Riddle       return llvm::None;
380*c60b897dSRiver Riddle   } else if ((type.isSignedInteger() || type.isIndex()) &&
381*c60b897dSRiver Riddle              result.isSignBitSet()) {
382*c60b897dSRiver Riddle     // The value is a positive signed integer or index,
383*c60b897dSRiver Riddle     // we have an overflow if the sign bit is set.
384*c60b897dSRiver Riddle     return llvm::None;
385*c60b897dSRiver Riddle   }
386*c60b897dSRiver Riddle 
387*c60b897dSRiver Riddle   return result;
388*c60b897dSRiver Riddle }
389*c60b897dSRiver Riddle 
390*c60b897dSRiver Riddle /// Parse a decimal or a hexadecimal literal, which can be either an integer
391*c60b897dSRiver Riddle /// or a float attribute.
parseDecOrHexAttr(Type type,bool isNegative)392*c60b897dSRiver Riddle Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
393*c60b897dSRiver Riddle   Token tok = getToken();
394*c60b897dSRiver Riddle   StringRef spelling = tok.getSpelling();
395*c60b897dSRiver Riddle   SMLoc loc = tok.getLoc();
396*c60b897dSRiver Riddle 
397*c60b897dSRiver Riddle   consumeToken(Token::integer);
398*c60b897dSRiver Riddle   if (!type) {
399*c60b897dSRiver Riddle     // Default to i64 if not type is specified.
400*c60b897dSRiver Riddle     if (!consumeIf(Token::colon))
401*c60b897dSRiver Riddle       type = builder.getIntegerType(64);
402*c60b897dSRiver Riddle     else if (!(type = parseType()))
403*c60b897dSRiver Riddle       return nullptr;
404*c60b897dSRiver Riddle   }
405*c60b897dSRiver Riddle 
406*c60b897dSRiver Riddle   if (auto floatType = type.dyn_cast<FloatType>()) {
407*c60b897dSRiver Riddle     Optional<APFloat> result;
408*c60b897dSRiver Riddle     if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
409*c60b897dSRiver Riddle                                             floatType.getFloatSemantics(),
410*c60b897dSRiver Riddle                                             floatType.getWidth())))
411*c60b897dSRiver Riddle       return Attribute();
412*c60b897dSRiver Riddle     return FloatAttr::get(floatType, *result);
413*c60b897dSRiver Riddle   }
414*c60b897dSRiver Riddle 
415*c60b897dSRiver Riddle   if (!type.isa<IntegerType, IndexType>())
416*c60b897dSRiver Riddle     return emitError(loc, "integer literal not valid for specified type"),
417*c60b897dSRiver Riddle            nullptr;
418*c60b897dSRiver Riddle 
419*c60b897dSRiver Riddle   if (isNegative && type.isUnsignedInteger()) {
420*c60b897dSRiver Riddle     emitError(loc,
421*c60b897dSRiver Riddle               "negative integer literal not valid for unsigned integer type");
422*c60b897dSRiver Riddle     return nullptr;
423*c60b897dSRiver Riddle   }
424*c60b897dSRiver Riddle 
425*c60b897dSRiver Riddle   Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
426*c60b897dSRiver Riddle   if (!apInt)
427*c60b897dSRiver Riddle     return emitError(loc, "integer constant out of range for attribute"),
428*c60b897dSRiver Riddle            nullptr;
429*c60b897dSRiver Riddle   return builder.getIntegerAttr(type, *apInt);
430*c60b897dSRiver Riddle }
431*c60b897dSRiver Riddle 
432*c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
433*c60b897dSRiver Riddle // TensorLiteralParser
434*c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
435*c60b897dSRiver Riddle 
436*c60b897dSRiver Riddle /// Parse elements values stored within a hex string. On success, the values are
437*c60b897dSRiver Riddle /// stored into 'result'.
parseElementAttrHexValues(Parser & parser,Token tok,std::string & result)438*c60b897dSRiver Riddle static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
439*c60b897dSRiver Riddle                                              std::string &result) {
440*c60b897dSRiver Riddle   if (Optional<std::string> value = tok.getHexStringValue()) {
441*c60b897dSRiver Riddle     result = std::move(*value);
442*c60b897dSRiver Riddle     return success();
443*c60b897dSRiver Riddle   }
444*c60b897dSRiver Riddle   return parser.emitError(
445*c60b897dSRiver Riddle       tok.getLoc(), "expected string containing hex digits starting with `0x`");
446*c60b897dSRiver Riddle }
447*c60b897dSRiver Riddle 
448*c60b897dSRiver Riddle namespace {
449*c60b897dSRiver Riddle /// This class implements a parser for TensorLiterals. A tensor literal is
450*c60b897dSRiver Riddle /// either a single element (e.g, 5) or a multi-dimensional list of elements
451*c60b897dSRiver Riddle /// (e.g., [[5, 5]]).
452*c60b897dSRiver Riddle class TensorLiteralParser {
453*c60b897dSRiver Riddle public:
TensorLiteralParser(Parser & p)454*c60b897dSRiver Riddle   TensorLiteralParser(Parser &p) : p(p) {}
455*c60b897dSRiver Riddle 
456*c60b897dSRiver Riddle   /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
457*c60b897dSRiver Riddle   /// may also parse a tensor literal that is store as a hex string.
458*c60b897dSRiver Riddle   ParseResult parse(bool allowHex);
459*c60b897dSRiver Riddle 
460*c60b897dSRiver Riddle   /// Build a dense attribute instance with the parsed elements and the given
461*c60b897dSRiver Riddle   /// shaped type.
462*c60b897dSRiver Riddle   DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
463*c60b897dSRiver Riddle 
getShape() const464*c60b897dSRiver Riddle   ArrayRef<int64_t> getShape() const { return shape; }
465*c60b897dSRiver Riddle 
466*c60b897dSRiver Riddle private:
467*c60b897dSRiver Riddle   /// Get the parsed elements for an integer attribute.
468*c60b897dSRiver Riddle   ParseResult getIntAttrElements(SMLoc loc, Type eltTy,
469*c60b897dSRiver Riddle                                  std::vector<APInt> &intValues);
470*c60b897dSRiver Riddle 
471*c60b897dSRiver Riddle   /// Get the parsed elements for a float attribute.
472*c60b897dSRiver Riddle   ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy,
473*c60b897dSRiver Riddle                                    std::vector<APFloat> &floatValues);
474*c60b897dSRiver Riddle 
475*c60b897dSRiver Riddle   /// Build a Dense String attribute for the given type.
476*c60b897dSRiver Riddle   DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
477*c60b897dSRiver Riddle 
478*c60b897dSRiver Riddle   /// Build a Dense attribute with hex data for the given type.
479*c60b897dSRiver Riddle   DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
480*c60b897dSRiver Riddle 
481*c60b897dSRiver Riddle   /// Parse a single element, returning failure if it isn't a valid element
482*c60b897dSRiver Riddle   /// literal. For example:
483*c60b897dSRiver Riddle   /// parseElement(1) -> Success, 1
484*c60b897dSRiver Riddle   /// parseElement([1]) -> Failure
485*c60b897dSRiver Riddle   ParseResult parseElement();
486*c60b897dSRiver Riddle 
487*c60b897dSRiver Riddle   /// Parse a list of either lists or elements, returning the dimensions of the
488*c60b897dSRiver Riddle   /// parsed sub-tensors in dims. For example:
489*c60b897dSRiver Riddle   ///   parseList([1, 2, 3]) -> Success, [3]
490*c60b897dSRiver Riddle   ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
491*c60b897dSRiver Riddle   ///   parseList([[1, 2], 3]) -> Failure
492*c60b897dSRiver Riddle   ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
493*c60b897dSRiver Riddle   ParseResult parseList(SmallVectorImpl<int64_t> &dims);
494*c60b897dSRiver Riddle 
495*c60b897dSRiver Riddle   /// Parse a literal that was printed as a hex string.
496*c60b897dSRiver Riddle   ParseResult parseHexElements();
497*c60b897dSRiver Riddle 
498*c60b897dSRiver Riddle   Parser &p;
499*c60b897dSRiver Riddle 
500*c60b897dSRiver Riddle   /// The shape inferred from the parsed elements.
501*c60b897dSRiver Riddle   SmallVector<int64_t, 4> shape;
502*c60b897dSRiver Riddle 
503*c60b897dSRiver Riddle   /// Storage used when parsing elements, this is a pair of <is_negated, token>.
504*c60b897dSRiver Riddle   std::vector<std::pair<bool, Token>> storage;
505*c60b897dSRiver Riddle 
506*c60b897dSRiver Riddle   /// Storage used when parsing elements that were stored as hex values.
507*c60b897dSRiver Riddle   Optional<Token> hexStorage;
508*c60b897dSRiver Riddle };
509*c60b897dSRiver Riddle } // namespace
510*c60b897dSRiver Riddle 
511*c60b897dSRiver Riddle /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
512*c60b897dSRiver Riddle /// may also parse a tensor literal that is store as a hex string.
parse(bool allowHex)513*c60b897dSRiver Riddle ParseResult TensorLiteralParser::parse(bool allowHex) {
514*c60b897dSRiver Riddle   // If hex is allowed, check for a string literal.
515*c60b897dSRiver Riddle   if (allowHex && p.getToken().is(Token::string)) {
516*c60b897dSRiver Riddle     hexStorage = p.getToken();
517*c60b897dSRiver Riddle     p.consumeToken(Token::string);
518*c60b897dSRiver Riddle     return success();
519*c60b897dSRiver Riddle   }
520*c60b897dSRiver Riddle   // Otherwise, parse a list or an individual element.
521*c60b897dSRiver Riddle   if (p.getToken().is(Token::l_square))
522*c60b897dSRiver Riddle     return parseList(shape);
523*c60b897dSRiver Riddle   return parseElement();
524*c60b897dSRiver Riddle }
525*c60b897dSRiver Riddle 
526*c60b897dSRiver Riddle /// Build a dense attribute instance with the parsed elements and the given
527*c60b897dSRiver Riddle /// shaped type.
getAttr(SMLoc loc,ShapedType type)528*c60b897dSRiver Riddle DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
529*c60b897dSRiver Riddle   Type eltType = type.getElementType();
530*c60b897dSRiver Riddle 
531*c60b897dSRiver Riddle   // Check to see if we parse the literal from a hex string.
532*c60b897dSRiver Riddle   if (hexStorage &&
533*c60b897dSRiver Riddle       (eltType.isIntOrIndexOrFloat() || eltType.isa<ComplexType>()))
534*c60b897dSRiver Riddle     return getHexAttr(loc, type);
535*c60b897dSRiver Riddle 
536*c60b897dSRiver Riddle   // Check that the parsed storage size has the same number of elements to the
537*c60b897dSRiver Riddle   // type, or is a known splat.
538*c60b897dSRiver Riddle   if (!shape.empty() && getShape() != type.getShape()) {
539*c60b897dSRiver Riddle     p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
540*c60b897dSRiver Riddle                      << "]) does not match type ([" << type.getShape() << "])";
541*c60b897dSRiver Riddle     return nullptr;
542*c60b897dSRiver Riddle   }
543*c60b897dSRiver Riddle 
544*c60b897dSRiver Riddle   // Handle the case where no elements were parsed.
545*c60b897dSRiver Riddle   if (!hexStorage && storage.empty() && type.getNumElements()) {
546*c60b897dSRiver Riddle     p.emitError(loc) << "parsed zero elements, but type (" << type
547*c60b897dSRiver Riddle                      << ") expected at least 1";
548*c60b897dSRiver Riddle     return nullptr;
549*c60b897dSRiver Riddle   }
550*c60b897dSRiver Riddle 
551*c60b897dSRiver Riddle   // Handle complex types in the specific element type cases below.
552*c60b897dSRiver Riddle   bool isComplex = false;
553*c60b897dSRiver Riddle   if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
554*c60b897dSRiver Riddle     eltType = complexTy.getElementType();
555*c60b897dSRiver Riddle     isComplex = true;
556*c60b897dSRiver Riddle   }
557*c60b897dSRiver Riddle 
558*c60b897dSRiver Riddle   // Handle integer and index types.
559*c60b897dSRiver Riddle   if (eltType.isIntOrIndex()) {
560*c60b897dSRiver Riddle     std::vector<APInt> intValues;
561*c60b897dSRiver Riddle     if (failed(getIntAttrElements(loc, eltType, intValues)))
562*c60b897dSRiver Riddle       return nullptr;
563*c60b897dSRiver Riddle     if (isComplex) {
564*c60b897dSRiver Riddle       // If this is a complex, treat the parsed values as complex values.
565*c60b897dSRiver Riddle       auto complexData = llvm::makeArrayRef(
566*c60b897dSRiver Riddle           reinterpret_cast<std::complex<APInt> *>(intValues.data()),
567*c60b897dSRiver Riddle           intValues.size() / 2);
568*c60b897dSRiver Riddle       return DenseElementsAttr::get(type, complexData);
569*c60b897dSRiver Riddle     }
570*c60b897dSRiver Riddle     return DenseElementsAttr::get(type, intValues);
571*c60b897dSRiver Riddle   }
572*c60b897dSRiver Riddle   // Handle floating point types.
573*c60b897dSRiver Riddle   if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
574*c60b897dSRiver Riddle     std::vector<APFloat> floatValues;
575*c60b897dSRiver Riddle     if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
576*c60b897dSRiver Riddle       return nullptr;
577*c60b897dSRiver Riddle     if (isComplex) {
578*c60b897dSRiver Riddle       // If this is a complex, treat the parsed values as complex values.
579*c60b897dSRiver Riddle       auto complexData = llvm::makeArrayRef(
580*c60b897dSRiver Riddle           reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
581*c60b897dSRiver Riddle           floatValues.size() / 2);
582*c60b897dSRiver Riddle       return DenseElementsAttr::get(type, complexData);
583*c60b897dSRiver Riddle     }
584*c60b897dSRiver Riddle     return DenseElementsAttr::get(type, floatValues);
585*c60b897dSRiver Riddle   }
586*c60b897dSRiver Riddle 
587*c60b897dSRiver Riddle   // Other types are assumed to be string representations.
588*c60b897dSRiver Riddle   return getStringAttr(loc, type, type.getElementType());
589*c60b897dSRiver Riddle }
590*c60b897dSRiver Riddle 
591*c60b897dSRiver Riddle /// Build a Dense Integer attribute for the given type.
592*c60b897dSRiver Riddle ParseResult
getIntAttrElements(SMLoc loc,Type eltTy,std::vector<APInt> & intValues)593*c60b897dSRiver Riddle TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
594*c60b897dSRiver Riddle                                         std::vector<APInt> &intValues) {
595*c60b897dSRiver Riddle   intValues.reserve(storage.size());
596*c60b897dSRiver Riddle   bool isUintType = eltTy.isUnsignedInteger();
597*c60b897dSRiver Riddle   for (const auto &signAndToken : storage) {
598*c60b897dSRiver Riddle     bool isNegative = signAndToken.first;
599*c60b897dSRiver Riddle     const Token &token = signAndToken.second;
600*c60b897dSRiver Riddle     auto tokenLoc = token.getLoc();
601*c60b897dSRiver Riddle 
602*c60b897dSRiver Riddle     if (isNegative && isUintType) {
603*c60b897dSRiver Riddle       return p.emitError(tokenLoc)
604*c60b897dSRiver Riddle              << "expected unsigned integer elements, but parsed negative value";
605*c60b897dSRiver Riddle     }
606*c60b897dSRiver Riddle 
607*c60b897dSRiver Riddle     // Check to see if floating point values were parsed.
608*c60b897dSRiver Riddle     if (token.is(Token::floatliteral)) {
609*c60b897dSRiver Riddle       return p.emitError(tokenLoc)
610*c60b897dSRiver Riddle              << "expected integer elements, but parsed floating-point";
611*c60b897dSRiver Riddle     }
612*c60b897dSRiver Riddle 
613*c60b897dSRiver Riddle     assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
614*c60b897dSRiver Riddle            "unexpected token type");
615*c60b897dSRiver Riddle     if (token.isAny(Token::kw_true, Token::kw_false)) {
616*c60b897dSRiver Riddle       if (!eltTy.isInteger(1)) {
617*c60b897dSRiver Riddle         return p.emitError(tokenLoc)
618*c60b897dSRiver Riddle                << "expected i1 type for 'true' or 'false' values";
619*c60b897dSRiver Riddle       }
620*c60b897dSRiver Riddle       APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
621*c60b897dSRiver Riddle       intValues.push_back(apInt);
622*c60b897dSRiver Riddle       continue;
623*c60b897dSRiver Riddle     }
624*c60b897dSRiver Riddle 
625*c60b897dSRiver Riddle     // Create APInt values for each element with the correct bitwidth.
626*c60b897dSRiver Riddle     Optional<APInt> apInt =
627*c60b897dSRiver Riddle         buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
628*c60b897dSRiver Riddle     if (!apInt)
629*c60b897dSRiver Riddle       return p.emitError(tokenLoc, "integer constant out of range for type");
630*c60b897dSRiver Riddle     intValues.push_back(*apInt);
631*c60b897dSRiver Riddle   }
632*c60b897dSRiver Riddle   return success();
633*c60b897dSRiver Riddle }
634*c60b897dSRiver Riddle 
635*c60b897dSRiver Riddle /// Build a Dense Float attribute for the given type.
636*c60b897dSRiver Riddle ParseResult
getFloatAttrElements(SMLoc loc,FloatType eltTy,std::vector<APFloat> & floatValues)637*c60b897dSRiver Riddle TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
638*c60b897dSRiver Riddle                                           std::vector<APFloat> &floatValues) {
639*c60b897dSRiver Riddle   floatValues.reserve(storage.size());
640*c60b897dSRiver Riddle   for (const auto &signAndToken : storage) {
641*c60b897dSRiver Riddle     bool isNegative = signAndToken.first;
642*c60b897dSRiver Riddle     const Token &token = signAndToken.second;
643*c60b897dSRiver Riddle 
644*c60b897dSRiver Riddle     // Handle hexadecimal float literals.
645*c60b897dSRiver Riddle     if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
646*c60b897dSRiver Riddle       Optional<APFloat> result;
647*c60b897dSRiver Riddle       if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
648*c60b897dSRiver Riddle                                                 eltTy.getFloatSemantics(),
649*c60b897dSRiver Riddle                                                 eltTy.getWidth())))
650*c60b897dSRiver Riddle         return failure();
651*c60b897dSRiver Riddle 
652*c60b897dSRiver Riddle       floatValues.push_back(*result);
653*c60b897dSRiver Riddle       continue;
654*c60b897dSRiver Riddle     }
655*c60b897dSRiver Riddle 
656*c60b897dSRiver Riddle     // Check to see if any decimal integers or booleans were parsed.
657*c60b897dSRiver Riddle     if (!token.is(Token::floatliteral))
658*c60b897dSRiver Riddle       return p.emitError()
659*c60b897dSRiver Riddle              << "expected floating-point elements, but parsed integer";
660*c60b897dSRiver Riddle 
661*c60b897dSRiver Riddle     // Build the float values from tokens.
662*c60b897dSRiver Riddle     auto val = token.getFloatingPointValue();
663*c60b897dSRiver Riddle     if (!val)
664*c60b897dSRiver Riddle       return p.emitError("floating point value too large for attribute");
665*c60b897dSRiver Riddle 
666*c60b897dSRiver Riddle     APFloat apVal(isNegative ? -*val : *val);
667*c60b897dSRiver Riddle     if (!eltTy.isF64()) {
668*c60b897dSRiver Riddle       bool unused;
669*c60b897dSRiver Riddle       apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
670*c60b897dSRiver Riddle                     &unused);
671*c60b897dSRiver Riddle     }
672*c60b897dSRiver Riddle     floatValues.push_back(apVal);
673*c60b897dSRiver Riddle   }
674*c60b897dSRiver Riddle   return success();
675*c60b897dSRiver Riddle }
676*c60b897dSRiver Riddle 
677*c60b897dSRiver Riddle /// Build a Dense String attribute for the given type.
getStringAttr(SMLoc loc,ShapedType type,Type eltTy)678*c60b897dSRiver Riddle DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
679*c60b897dSRiver Riddle                                                      Type eltTy) {
680*c60b897dSRiver Riddle   if (hexStorage.has_value()) {
681*c60b897dSRiver Riddle     auto stringValue = hexStorage.value().getStringValue();
682*c60b897dSRiver Riddle     return DenseStringElementsAttr::get(type, {stringValue});
683*c60b897dSRiver Riddle   }
684*c60b897dSRiver Riddle 
685*c60b897dSRiver Riddle   std::vector<std::string> stringValues;
686*c60b897dSRiver Riddle   std::vector<StringRef> stringRefValues;
687*c60b897dSRiver Riddle   stringValues.reserve(storage.size());
688*c60b897dSRiver Riddle   stringRefValues.reserve(storage.size());
689*c60b897dSRiver Riddle 
690*c60b897dSRiver Riddle   for (auto val : storage) {
691*c60b897dSRiver Riddle     stringValues.push_back(val.second.getStringValue());
692*c60b897dSRiver Riddle     stringRefValues.emplace_back(stringValues.back());
693*c60b897dSRiver Riddle   }
694*c60b897dSRiver Riddle 
695*c60b897dSRiver Riddle   return DenseStringElementsAttr::get(type, stringRefValues);
696*c60b897dSRiver Riddle }
697*c60b897dSRiver Riddle 
698*c60b897dSRiver Riddle /// Build a Dense attribute with hex data for the given type.
getHexAttr(SMLoc loc,ShapedType type)699*c60b897dSRiver Riddle DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
700*c60b897dSRiver Riddle   Type elementType = type.getElementType();
701*c60b897dSRiver Riddle   if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
702*c60b897dSRiver Riddle     p.emitError(loc)
703*c60b897dSRiver Riddle         << "expected floating-point, integer, or complex element type, got "
704*c60b897dSRiver Riddle         << elementType;
705*c60b897dSRiver Riddle     return nullptr;
706*c60b897dSRiver Riddle   }
707*c60b897dSRiver Riddle 
708*c60b897dSRiver Riddle   std::string data;
709*c60b897dSRiver Riddle   if (parseElementAttrHexValues(p, *hexStorage, data))
710*c60b897dSRiver Riddle     return nullptr;
711*c60b897dSRiver Riddle 
712*c60b897dSRiver Riddle   ArrayRef<char> rawData(data.data(), data.size());
713*c60b897dSRiver Riddle   bool detectedSplat = false;
714*c60b897dSRiver Riddle   if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
715*c60b897dSRiver Riddle     p.emitError(loc) << "elements hex data size is invalid for provided type: "
716*c60b897dSRiver Riddle                      << type;
717*c60b897dSRiver Riddle     return nullptr;
718*c60b897dSRiver Riddle   }
719*c60b897dSRiver Riddle 
720*c60b897dSRiver Riddle   if (llvm::support::endian::system_endianness() ==
721*c60b897dSRiver Riddle       llvm::support::endianness::big) {
722*c60b897dSRiver Riddle     // Convert endianess in big-endian(BE) machines. `rawData` is
723*c60b897dSRiver Riddle     // little-endian(LE) because HEX in raw data of dense element attribute
724*c60b897dSRiver Riddle     // is always LE format. It is converted into BE here to be used in BE
725*c60b897dSRiver Riddle     // machines.
726*c60b897dSRiver Riddle     SmallVector<char, 64> outDataVec(rawData.size());
727*c60b897dSRiver Riddle     MutableArrayRef<char> convRawData(outDataVec);
728*c60b897dSRiver Riddle     DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
729*c60b897dSRiver Riddle         rawData, convRawData, type);
730*c60b897dSRiver Riddle     return DenseElementsAttr::getFromRawBuffer(type, convRawData);
731*c60b897dSRiver Riddle   }
732*c60b897dSRiver Riddle 
733*c60b897dSRiver Riddle   return DenseElementsAttr::getFromRawBuffer(type, rawData);
734*c60b897dSRiver Riddle }
735*c60b897dSRiver Riddle 
parseElement()736*c60b897dSRiver Riddle ParseResult TensorLiteralParser::parseElement() {
737*c60b897dSRiver Riddle   switch (p.getToken().getKind()) {
738*c60b897dSRiver Riddle   // Parse a boolean element.
739*c60b897dSRiver Riddle   case Token::kw_true:
740*c60b897dSRiver Riddle   case Token::kw_false:
741*c60b897dSRiver Riddle   case Token::floatliteral:
742*c60b897dSRiver Riddle   case Token::integer:
743*c60b897dSRiver Riddle     storage.emplace_back(/*isNegative=*/false, p.getToken());
744*c60b897dSRiver Riddle     p.consumeToken();
745*c60b897dSRiver Riddle     break;
746*c60b897dSRiver Riddle 
747*c60b897dSRiver Riddle   // Parse a signed integer or a negative floating-point element.
748*c60b897dSRiver Riddle   case Token::minus:
749*c60b897dSRiver Riddle     p.consumeToken(Token::minus);
750*c60b897dSRiver Riddle     if (!p.getToken().isAny(Token::floatliteral, Token::integer))
751*c60b897dSRiver Riddle       return p.emitError("expected integer or floating point literal");
752*c60b897dSRiver Riddle     storage.emplace_back(/*isNegative=*/true, p.getToken());
753*c60b897dSRiver Riddle     p.consumeToken();
754*c60b897dSRiver Riddle     break;
755*c60b897dSRiver Riddle 
756*c60b897dSRiver Riddle   case Token::string:
757*c60b897dSRiver Riddle     storage.emplace_back(/*isNegative=*/false, p.getToken());
758*c60b897dSRiver Riddle     p.consumeToken();
759*c60b897dSRiver Riddle     break;
760*c60b897dSRiver Riddle 
761*c60b897dSRiver Riddle   // Parse a complex element of the form '(' element ',' element ')'.
762*c60b897dSRiver Riddle   case Token::l_paren:
763*c60b897dSRiver Riddle     p.consumeToken(Token::l_paren);
764*c60b897dSRiver Riddle     if (parseElement() ||
765*c60b897dSRiver Riddle         p.parseToken(Token::comma, "expected ',' between complex elements") ||
766*c60b897dSRiver Riddle         parseElement() ||
767*c60b897dSRiver Riddle         p.parseToken(Token::r_paren, "expected ')' after complex elements"))
768*c60b897dSRiver Riddle       return failure();
769*c60b897dSRiver Riddle     break;
770*c60b897dSRiver Riddle 
771*c60b897dSRiver Riddle   default:
772*c60b897dSRiver Riddle     return p.emitError("expected element literal of primitive type");
773*c60b897dSRiver Riddle   }
774*c60b897dSRiver Riddle 
775*c60b897dSRiver Riddle   return success();
776*c60b897dSRiver Riddle }
777*c60b897dSRiver Riddle 
778*c60b897dSRiver Riddle /// Parse a list of either lists or elements, returning the dimensions of the
779*c60b897dSRiver Riddle /// parsed sub-tensors in dims. For example:
780*c60b897dSRiver Riddle ///   parseList([1, 2, 3]) -> Success, [3]
781*c60b897dSRiver Riddle ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
782*c60b897dSRiver Riddle ///   parseList([[1, 2], 3]) -> Failure
783*c60b897dSRiver Riddle ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
parseList(SmallVectorImpl<int64_t> & dims)784*c60b897dSRiver Riddle ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
785*c60b897dSRiver Riddle   auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
786*c60b897dSRiver Riddle                        const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
787*c60b897dSRiver Riddle     if (prevDims == newDims)
788*c60b897dSRiver Riddle       return success();
789*c60b897dSRiver Riddle     return p.emitError("tensor literal is invalid; ranks are not consistent "
790*c60b897dSRiver Riddle                        "between elements");
791*c60b897dSRiver Riddle   };
792*c60b897dSRiver Riddle 
793*c60b897dSRiver Riddle   bool first = true;
794*c60b897dSRiver Riddle   SmallVector<int64_t, 4> newDims;
795*c60b897dSRiver Riddle   unsigned size = 0;
796*c60b897dSRiver Riddle   auto parseOneElement = [&]() -> ParseResult {
797*c60b897dSRiver Riddle     SmallVector<int64_t, 4> thisDims;
798*c60b897dSRiver Riddle     if (p.getToken().getKind() == Token::l_square) {
799*c60b897dSRiver Riddle       if (parseList(thisDims))
800*c60b897dSRiver Riddle         return failure();
801*c60b897dSRiver Riddle     } else if (parseElement()) {
802*c60b897dSRiver Riddle       return failure();
803*c60b897dSRiver Riddle     }
804*c60b897dSRiver Riddle     ++size;
805*c60b897dSRiver Riddle     if (!first)
806*c60b897dSRiver Riddle       return checkDims(newDims, thisDims);
807*c60b897dSRiver Riddle     newDims = thisDims;
808*c60b897dSRiver Riddle     first = false;
809*c60b897dSRiver Riddle     return success();
810*c60b897dSRiver Riddle   };
811*c60b897dSRiver Riddle   if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
812*c60b897dSRiver Riddle     return failure();
813*c60b897dSRiver Riddle 
814*c60b897dSRiver Riddle   // Return the sublists' dimensions with 'size' prepended.
815*c60b897dSRiver Riddle   dims.clear();
816*c60b897dSRiver Riddle   dims.push_back(size);
817*c60b897dSRiver Riddle   dims.append(newDims.begin(), newDims.end());
818*c60b897dSRiver Riddle   return success();
819*c60b897dSRiver Riddle }
820*c60b897dSRiver Riddle 
821*c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
822*c60b897dSRiver Riddle // ElementsAttr Parser
823*c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
824*c60b897dSRiver Riddle 
825*c60b897dSRiver Riddle namespace {
826*c60b897dSRiver Riddle /// This class provides an implementation of AsmParser, allowing to call back
827*c60b897dSRiver Riddle /// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
828*c60b897dSRiver Riddle /// in libMLIRIR.
829*c60b897dSRiver Riddle class CustomAsmParser : public AsmParserImpl<AsmParser> {
830*c60b897dSRiver Riddle public:
CustomAsmParser(Parser & parser)831*c60b897dSRiver Riddle   CustomAsmParser(Parser &parser)
832*c60b897dSRiver Riddle       : AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
833*c60b897dSRiver Riddle };
834*c60b897dSRiver Riddle } // namespace
835*c60b897dSRiver Riddle 
836*c60b897dSRiver Riddle /// Parse a dense array attribute.
parseDenseArrayAttr()837*c60b897dSRiver Riddle Attribute Parser::parseDenseArrayAttr() {
838*c60b897dSRiver Riddle   auto typeLoc = getToken().getLoc();
839*c60b897dSRiver Riddle   auto type = parseType();
840*c60b897dSRiver Riddle   if (!type)
841*c60b897dSRiver Riddle     return {};
842*c60b897dSRiver Riddle   CustomAsmParser parser(*this);
843*c60b897dSRiver Riddle   Attribute result;
844*c60b897dSRiver Riddle   // Check for empty list.
845*c60b897dSRiver Riddle   bool isEmptyList = getToken().is(Token::r_square);
846*c60b897dSRiver Riddle 
847*c60b897dSRiver Riddle   if (auto intType = type.dyn_cast<IntegerType>()) {
848*c60b897dSRiver Riddle     switch (type.getIntOrFloatBitWidth()) {
849*c60b897dSRiver Riddle     case 8:
850*c60b897dSRiver Riddle       if (isEmptyList)
851*c60b897dSRiver Riddle         result = DenseI8ArrayAttr::get(parser.getContext(), {});
852*c60b897dSRiver Riddle       else
853*c60b897dSRiver Riddle         result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
854*c60b897dSRiver Riddle       break;
855*c60b897dSRiver Riddle     case 16:
856*c60b897dSRiver Riddle       if (isEmptyList)
857*c60b897dSRiver Riddle         result = DenseI16ArrayAttr::get(parser.getContext(), {});
858*c60b897dSRiver Riddle       else
859*c60b897dSRiver Riddle         result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
860*c60b897dSRiver Riddle       break;
861*c60b897dSRiver Riddle     case 32:
862*c60b897dSRiver Riddle       if (isEmptyList)
863*c60b897dSRiver Riddle         result = DenseI32ArrayAttr::get(parser.getContext(), {});
864*c60b897dSRiver Riddle       else
865*c60b897dSRiver Riddle         result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
866*c60b897dSRiver Riddle       break;
867*c60b897dSRiver Riddle     case 64:
868*c60b897dSRiver Riddle       if (isEmptyList)
869*c60b897dSRiver Riddle         result = DenseI64ArrayAttr::get(parser.getContext(), {});
870*c60b897dSRiver Riddle       else
871*c60b897dSRiver Riddle         result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
872*c60b897dSRiver Riddle       break;
873*c60b897dSRiver Riddle     default:
874*c60b897dSRiver Riddle       emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
875*c60b897dSRiver Riddle       return {};
876*c60b897dSRiver Riddle     }
877*c60b897dSRiver Riddle   } else if (auto floatType = type.dyn_cast<FloatType>()) {
878*c60b897dSRiver Riddle     switch (type.getIntOrFloatBitWidth()) {
879*c60b897dSRiver Riddle     case 32:
880*c60b897dSRiver Riddle       if (isEmptyList)
881*c60b897dSRiver Riddle         result = DenseF32ArrayAttr::get(parser.getContext(), {});
882*c60b897dSRiver Riddle       else
883*c60b897dSRiver Riddle         result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
884*c60b897dSRiver Riddle       break;
885*c60b897dSRiver Riddle     case 64:
886*c60b897dSRiver Riddle       if (isEmptyList)
887*c60b897dSRiver Riddle         result = DenseF64ArrayAttr::get(parser.getContext(), {});
888*c60b897dSRiver Riddle       else
889*c60b897dSRiver Riddle         result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
890*c60b897dSRiver Riddle       break;
891*c60b897dSRiver Riddle     default:
892*c60b897dSRiver Riddle       emitError(typeLoc, "expected f32 or f64 but got: ") << type;
893*c60b897dSRiver Riddle       return {};
894*c60b897dSRiver Riddle     }
895*c60b897dSRiver Riddle   } else {
896*c60b897dSRiver Riddle     emitError(typeLoc, "expected integer or float type, got: ") << type;
897*c60b897dSRiver Riddle     return {};
898*c60b897dSRiver Riddle   }
899*c60b897dSRiver Riddle   if (!consumeIf(Token::r_square)) {
900*c60b897dSRiver Riddle     emitError("expected ']' to close an array attribute");
901*c60b897dSRiver Riddle     return {};
902*c60b897dSRiver Riddle   }
903*c60b897dSRiver Riddle   return result;
904*c60b897dSRiver Riddle }
905*c60b897dSRiver Riddle 
906*c60b897dSRiver Riddle /// Parse a dense elements attribute.
parseDenseElementsAttr(Type attrType)907*c60b897dSRiver Riddle Attribute Parser::parseDenseElementsAttr(Type attrType) {
908*c60b897dSRiver Riddle   auto attribLoc = getToken().getLoc();
909*c60b897dSRiver Riddle   consumeToken(Token::kw_dense);
910*c60b897dSRiver Riddle   if (parseToken(Token::less, "expected '<' after 'dense'"))
911*c60b897dSRiver Riddle     return nullptr;
912*c60b897dSRiver Riddle 
913*c60b897dSRiver Riddle   // Parse the literal data if necessary.
914*c60b897dSRiver Riddle   TensorLiteralParser literalParser(*this);
915*c60b897dSRiver Riddle   if (!consumeIf(Token::greater)) {
916*c60b897dSRiver Riddle     if (literalParser.parse(/*allowHex=*/true) ||
917*c60b897dSRiver Riddle         parseToken(Token::greater, "expected '>'"))
918*c60b897dSRiver Riddle       return nullptr;
919*c60b897dSRiver Riddle   }
920*c60b897dSRiver Riddle 
921*c60b897dSRiver Riddle   // If the type is specified `parseElementsLiteralType` will not parse a type.
922*c60b897dSRiver Riddle   // Use the attribute location as the location for error reporting in that
923*c60b897dSRiver Riddle   // case.
924*c60b897dSRiver Riddle   auto loc = attrType ? attribLoc : getToken().getLoc();
925*c60b897dSRiver Riddle   auto type = parseElementsLiteralType(attrType);
926*c60b897dSRiver Riddle   if (!type)
927*c60b897dSRiver Riddle     return nullptr;
928*c60b897dSRiver Riddle   return literalParser.getAttr(loc, type);
929*c60b897dSRiver Riddle }
930*c60b897dSRiver Riddle 
931*c60b897dSRiver Riddle /// Parse an opaque elements attribute.
parseOpaqueElementsAttr(Type attrType)932*c60b897dSRiver Riddle Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
933*c60b897dSRiver Riddle   SMLoc loc = getToken().getLoc();
934*c60b897dSRiver Riddle   consumeToken(Token::kw_opaque);
935*c60b897dSRiver Riddle   if (parseToken(Token::less, "expected '<' after 'opaque'"))
936*c60b897dSRiver Riddle     return nullptr;
937*c60b897dSRiver Riddle 
938*c60b897dSRiver Riddle   if (getToken().isNot(Token::string))
939*c60b897dSRiver Riddle     return (emitError("expected dialect namespace"), nullptr);
940*c60b897dSRiver Riddle 
941*c60b897dSRiver Riddle   std::string name = getToken().getStringValue();
942*c60b897dSRiver Riddle   consumeToken(Token::string);
943*c60b897dSRiver Riddle 
944*c60b897dSRiver Riddle   if (parseToken(Token::comma, "expected ','"))
945*c60b897dSRiver Riddle     return nullptr;
946*c60b897dSRiver Riddle 
947*c60b897dSRiver Riddle   Token hexTok = getToken();
948*c60b897dSRiver Riddle   if (parseToken(Token::string, "elements hex string should start with '0x'") ||
949*c60b897dSRiver Riddle       parseToken(Token::greater, "expected '>'"))
950*c60b897dSRiver Riddle     return nullptr;
951*c60b897dSRiver Riddle   auto type = parseElementsLiteralType(attrType);
952*c60b897dSRiver Riddle   if (!type)
953*c60b897dSRiver Riddle     return nullptr;
954*c60b897dSRiver Riddle 
955*c60b897dSRiver Riddle   std::string data;
956*c60b897dSRiver Riddle   if (parseElementAttrHexValues(*this, hexTok, data))
957*c60b897dSRiver Riddle     return nullptr;
958*c60b897dSRiver Riddle   return getChecked<OpaqueElementsAttr>(loc, builder.getStringAttr(name), type,
959*c60b897dSRiver Riddle                                         data);
960*c60b897dSRiver Riddle }
961*c60b897dSRiver Riddle 
962*c60b897dSRiver Riddle /// Shaped type for elements attribute.
963*c60b897dSRiver Riddle ///
964*c60b897dSRiver Riddle ///   elements-literal-type ::= vector-type | ranked-tensor-type
965*c60b897dSRiver Riddle ///
966*c60b897dSRiver Riddle /// This method also checks the type has static shape.
parseElementsLiteralType(Type type)967*c60b897dSRiver Riddle ShapedType Parser::parseElementsLiteralType(Type type) {
968*c60b897dSRiver Riddle   // If the user didn't provide a type, parse the colon type for the literal.
969*c60b897dSRiver Riddle   if (!type) {
970*c60b897dSRiver Riddle     if (parseToken(Token::colon, "expected ':'"))
971*c60b897dSRiver Riddle       return nullptr;
972*c60b897dSRiver Riddle     if (!(type = parseType()))
973*c60b897dSRiver Riddle       return nullptr;
974*c60b897dSRiver Riddle   }
975*c60b897dSRiver Riddle 
976*c60b897dSRiver Riddle   if (!type.isa<RankedTensorType, VectorType>()) {
977*c60b897dSRiver Riddle     emitError("elements literal must be a ranked tensor or vector type");
978*c60b897dSRiver Riddle     return nullptr;
979*c60b897dSRiver Riddle   }
980*c60b897dSRiver Riddle 
981*c60b897dSRiver Riddle   auto sType = type.cast<ShapedType>();
982*c60b897dSRiver Riddle   if (!sType.hasStaticShape())
983*c60b897dSRiver Riddle     return (emitError("elements literal type must have static shape"), nullptr);
984*c60b897dSRiver Riddle 
985*c60b897dSRiver Riddle   return sType;
986*c60b897dSRiver Riddle }
987*c60b897dSRiver Riddle 
988*c60b897dSRiver Riddle /// Parse a sparse elements attribute.
parseSparseElementsAttr(Type attrType)989*c60b897dSRiver Riddle Attribute Parser::parseSparseElementsAttr(Type attrType) {
990*c60b897dSRiver Riddle   SMLoc loc = getToken().getLoc();
991*c60b897dSRiver Riddle   consumeToken(Token::kw_sparse);
992*c60b897dSRiver Riddle   if (parseToken(Token::less, "Expected '<' after 'sparse'"))
993*c60b897dSRiver Riddle     return nullptr;
994*c60b897dSRiver Riddle 
995*c60b897dSRiver Riddle   // Check for the case where all elements are sparse. The indices are
996*c60b897dSRiver Riddle   // represented by a 2-dimensional shape where the second dimension is the rank
997*c60b897dSRiver Riddle   // of the type.
998*c60b897dSRiver Riddle   Type indiceEltType = builder.getIntegerType(64);
999*c60b897dSRiver Riddle   if (consumeIf(Token::greater)) {
1000*c60b897dSRiver Riddle     ShapedType type = parseElementsLiteralType(attrType);
1001*c60b897dSRiver Riddle     if (!type)
1002*c60b897dSRiver Riddle       return nullptr;
1003*c60b897dSRiver Riddle 
1004*c60b897dSRiver Riddle     // Construct the sparse elements attr using zero element indice/value
1005*c60b897dSRiver Riddle     // attributes.
1006*c60b897dSRiver Riddle     ShapedType indicesType =
1007*c60b897dSRiver Riddle         RankedTensorType::get({0, type.getRank()}, indiceEltType);
1008*c60b897dSRiver Riddle     ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
1009*c60b897dSRiver Riddle     return getChecked<SparseElementsAttr>(
1010*c60b897dSRiver Riddle         loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
1011*c60b897dSRiver Riddle         DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
1012*c60b897dSRiver Riddle   }
1013*c60b897dSRiver Riddle 
1014*c60b897dSRiver Riddle   /// Parse the indices. We don't allow hex values here as we may need to use
1015*c60b897dSRiver Riddle   /// the inferred shape.
1016*c60b897dSRiver Riddle   auto indicesLoc = getToken().getLoc();
1017*c60b897dSRiver Riddle   TensorLiteralParser indiceParser(*this);
1018*c60b897dSRiver Riddle   if (indiceParser.parse(/*allowHex=*/false))
1019*c60b897dSRiver Riddle     return nullptr;
1020*c60b897dSRiver Riddle 
1021*c60b897dSRiver Riddle   if (parseToken(Token::comma, "expected ','"))
1022*c60b897dSRiver Riddle     return nullptr;
1023*c60b897dSRiver Riddle 
1024*c60b897dSRiver Riddle   /// Parse the values.
1025*c60b897dSRiver Riddle   auto valuesLoc = getToken().getLoc();
1026*c60b897dSRiver Riddle   TensorLiteralParser valuesParser(*this);
1027*c60b897dSRiver Riddle   if (valuesParser.parse(/*allowHex=*/true))
1028*c60b897dSRiver Riddle     return nullptr;
1029*c60b897dSRiver Riddle 
1030*c60b897dSRiver Riddle   if (parseToken(Token::greater, "expected '>'"))
1031*c60b897dSRiver Riddle     return nullptr;
1032*c60b897dSRiver Riddle 
1033*c60b897dSRiver Riddle   auto type = parseElementsLiteralType(attrType);
1034*c60b897dSRiver Riddle   if (!type)
1035*c60b897dSRiver Riddle     return nullptr;
1036*c60b897dSRiver Riddle 
1037*c60b897dSRiver Riddle   // If the indices are a splat, i.e. the literal parser parsed an element and
1038*c60b897dSRiver Riddle   // not a list, we set the shape explicitly. The indices are represented by a
1039*c60b897dSRiver Riddle   // 2-dimensional shape where the second dimension is the rank of the type.
1040*c60b897dSRiver Riddle   // Given that the parsed indices is a splat, we know that we only have one
1041*c60b897dSRiver Riddle   // indice and thus one for the first dimension.
1042*c60b897dSRiver Riddle   ShapedType indicesType;
1043*c60b897dSRiver Riddle   if (indiceParser.getShape().empty()) {
1044*c60b897dSRiver Riddle     indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
1045*c60b897dSRiver Riddle   } else {
1046*c60b897dSRiver Riddle     // Otherwise, set the shape to the one parsed by the literal parser.
1047*c60b897dSRiver Riddle     indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
1048*c60b897dSRiver Riddle   }
1049*c60b897dSRiver Riddle   auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1050*c60b897dSRiver Riddle 
1051*c60b897dSRiver Riddle   // If the values are a splat, set the shape explicitly based on the number of
1052*c60b897dSRiver Riddle   // indices. The number of indices is encoded in the first dimension of the
1053*c60b897dSRiver Riddle   // indice shape type.
1054*c60b897dSRiver Riddle   auto valuesEltType = type.getElementType();
1055*c60b897dSRiver Riddle   ShapedType valuesType =
1056*c60b897dSRiver Riddle       valuesParser.getShape().empty()
1057*c60b897dSRiver Riddle           ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
1058*c60b897dSRiver Riddle           : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
1059*c60b897dSRiver Riddle   auto values = valuesParser.getAttr(valuesLoc, valuesType);
1060*c60b897dSRiver Riddle 
1061*c60b897dSRiver Riddle   // Build the sparse elements attribute by the indices and values.
1062*c60b897dSRiver Riddle   return getChecked<SparseElementsAttr>(loc, type, indices, values);
1063*c60b897dSRiver Riddle }
1064