1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Parser.h"
14
15 #include "AsmParserImpl.h"
16 #include "mlir/AsmParser/AsmParserState.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/Endian.h"
24
25 using namespace mlir;
26 using namespace mlir::detail;
27
28 /// Parse an arbitrary attribute.
29 ///
30 /// attribute-value ::= `unit`
31 /// | bool-literal
32 /// | integer-literal (`:` (index-type | integer-type))?
33 /// | float-literal (`:` float-type)?
34 /// | string-literal (`:` type)?
35 /// | type
36 /// | `[` `:` (integer-type | float-type) tensor-literal `]`
37 /// | `[` (attribute-value (`,` attribute-value)*)? `]`
38 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
39 /// | symbol-ref-id (`::` symbol-ref-id)*
40 /// | `dense` `<` tensor-literal `>` `:`
41 /// (tensor-type | vector-type)
42 /// | `sparse` `<` attribute-value `,` attribute-value `>`
43 /// `:` (tensor-type | vector-type)
44 /// | `opaque` `<` dialect-namespace `,` hex-string-literal
45 /// `>` `:` (tensor-type | vector-type)
46 /// | extended-attribute
47 ///
parseAttribute(Type type)48 Attribute Parser::parseAttribute(Type type) {
49 switch (getToken().getKind()) {
50 // Parse an AffineMap or IntegerSet attribute.
51 case Token::kw_affine_map: {
52 consumeToken(Token::kw_affine_map);
53
54 AffineMap map;
55 if (parseToken(Token::less, "expected '<' in affine map") ||
56 parseAffineMapReference(map) ||
57 parseToken(Token::greater, "expected '>' in affine map"))
58 return Attribute();
59 return AffineMapAttr::get(map);
60 }
61 case Token::kw_affine_set: {
62 consumeToken(Token::kw_affine_set);
63
64 IntegerSet set;
65 if (parseToken(Token::less, "expected '<' in integer set") ||
66 parseIntegerSetReference(set) ||
67 parseToken(Token::greater, "expected '>' in integer set"))
68 return Attribute();
69 return IntegerSetAttr::get(set);
70 }
71
72 // Parse an array attribute.
73 case Token::l_square: {
74 consumeToken(Token::l_square);
75 if (consumeIf(Token::colon))
76 return parseDenseArrayAttr();
77 SmallVector<Attribute, 4> elements;
78 auto parseElt = [&]() -> ParseResult {
79 elements.push_back(parseAttribute());
80 return elements.back() ? success() : failure();
81 };
82
83 if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
84 return nullptr;
85 return builder.getArrayAttr(elements);
86 }
87
88 // Parse a boolean attribute.
89 case Token::kw_false:
90 consumeToken(Token::kw_false);
91 return builder.getBoolAttr(false);
92 case Token::kw_true:
93 consumeToken(Token::kw_true);
94 return builder.getBoolAttr(true);
95
96 // Parse a dense elements attribute.
97 case Token::kw_dense:
98 return parseDenseElementsAttr(type);
99
100 // Parse a dictionary attribute.
101 case Token::l_brace: {
102 NamedAttrList elements;
103 if (parseAttributeDict(elements))
104 return nullptr;
105 return elements.getDictionary(getContext());
106 }
107
108 // Parse an extended attribute, i.e. alias or dialect attribute.
109 case Token::hash_identifier:
110 return parseExtendedAttr(type);
111
112 // Parse floating point and integer attributes.
113 case Token::floatliteral:
114 return parseFloatAttr(type, /*isNegative=*/false);
115 case Token::integer:
116 return parseDecOrHexAttr(type, /*isNegative=*/false);
117 case Token::minus: {
118 consumeToken(Token::minus);
119 if (getToken().is(Token::integer))
120 return parseDecOrHexAttr(type, /*isNegative=*/true);
121 if (getToken().is(Token::floatliteral))
122 return parseFloatAttr(type, /*isNegative=*/true);
123
124 return (emitWrongTokenError(
125 "expected constant integer or floating point value"),
126 nullptr);
127 }
128
129 // Parse a location attribute.
130 case Token::kw_loc: {
131 consumeToken(Token::kw_loc);
132
133 LocationAttr locAttr;
134 if (parseToken(Token::l_paren, "expected '(' in inline location") ||
135 parseLocationInstance(locAttr) ||
136 parseToken(Token::r_paren, "expected ')' in inline location"))
137 return Attribute();
138 return locAttr;
139 }
140
141 // Parse an opaque elements attribute.
142 case Token::kw_opaque:
143 return parseOpaqueElementsAttr(type);
144
145 // Parse a sparse elements attribute.
146 case Token::kw_sparse:
147 return parseSparseElementsAttr(type);
148
149 // Parse a string attribute.
150 case Token::string: {
151 auto val = getToken().getStringValue();
152 consumeToken(Token::string);
153 // Parse the optional trailing colon type if one wasn't explicitly provided.
154 if (!type && consumeIf(Token::colon) && !(type = parseType()))
155 return Attribute();
156
157 return type ? StringAttr::get(val, type)
158 : StringAttr::get(getContext(), val);
159 }
160
161 // Parse a symbol reference attribute.
162 case Token::at_identifier: {
163 // When populating the parser state, this is a list of locations for all of
164 // the nested references.
165 SmallVector<SMRange> referenceLocations;
166 if (state.asmState)
167 referenceLocations.push_back(getToken().getLocRange());
168
169 // Parse the top-level reference.
170 std::string nameStr = getToken().getSymbolReference();
171 consumeToken(Token::at_identifier);
172
173 // Parse any nested references.
174 std::vector<FlatSymbolRefAttr> nestedRefs;
175 while (getToken().is(Token::colon)) {
176 // Check for the '::' prefix.
177 const char *curPointer = getToken().getLoc().getPointer();
178 consumeToken(Token::colon);
179 if (!consumeIf(Token::colon)) {
180 if (getToken().isNot(Token::eof, Token::error)) {
181 state.lex.resetPointer(curPointer);
182 consumeToken();
183 }
184 break;
185 }
186 // Parse the reference itself.
187 auto curLoc = getToken().getLoc();
188 if (getToken().isNot(Token::at_identifier)) {
189 emitError(curLoc, "expected nested symbol reference identifier");
190 return Attribute();
191 }
192
193 // If we are populating the assembly state, add the location for this
194 // reference.
195 if (state.asmState)
196 referenceLocations.push_back(getToken().getLocRange());
197
198 std::string nameStr = getToken().getSymbolReference();
199 consumeToken(Token::at_identifier);
200 nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
201 }
202 SymbolRefAttr symbolRefAttr =
203 SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
204
205 // If we are populating the assembly state, record this symbol reference.
206 if (state.asmState)
207 state.asmState->addUses(symbolRefAttr, referenceLocations);
208 return symbolRefAttr;
209 }
210
211 // Parse a 'unit' attribute.
212 case Token::kw_unit:
213 consumeToken(Token::kw_unit);
214 return builder.getUnitAttr();
215
216 // Handle completion of an attribute.
217 case Token::code_complete:
218 if (getToken().isCodeCompletionFor(Token::hash_identifier))
219 return parseExtendedAttr(type);
220 return codeCompleteAttribute();
221
222 default:
223 // Parse a type attribute. We parse `Optional` here to allow for providing a
224 // better error message.
225 Type type;
226 OptionalParseResult result = parseOptionalType(type);
227 if (!result.hasValue())
228 return emitWrongTokenError("expected attribute value"), Attribute();
229 return failed(*result) ? Attribute() : TypeAttr::get(type);
230 }
231 }
232
233 /// Parse an optional attribute with the provided type.
parseOptionalAttribute(Attribute & attribute,Type type)234 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
235 Type type) {
236 switch (getToken().getKind()) {
237 case Token::at_identifier:
238 case Token::floatliteral:
239 case Token::integer:
240 case Token::hash_identifier:
241 case Token::kw_affine_map:
242 case Token::kw_affine_set:
243 case Token::kw_dense:
244 case Token::kw_false:
245 case Token::kw_loc:
246 case Token::kw_opaque:
247 case Token::kw_sparse:
248 case Token::kw_true:
249 case Token::kw_unit:
250 case Token::l_brace:
251 case Token::l_square:
252 case Token::minus:
253 case Token::string:
254 attribute = parseAttribute(type);
255 return success(attribute != nullptr);
256
257 default:
258 // Parse an optional type attribute.
259 Type type;
260 OptionalParseResult result = parseOptionalType(type);
261 if (result.hasValue() && succeeded(*result))
262 attribute = TypeAttr::get(type);
263 return result;
264 }
265 }
parseOptionalAttribute(ArrayAttr & attribute,Type type)266 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
267 Type type) {
268 return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
269 }
parseOptionalAttribute(StringAttr & attribute,Type type)270 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
271 Type type) {
272 return parseOptionalAttributeWithToken(Token::string, attribute, type);
273 }
274
275 /// Attribute dictionary.
276 ///
277 /// attribute-dict ::= `{` `}`
278 /// | `{` attribute-entry (`,` attribute-entry)* `}`
279 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
280 ///
parseAttributeDict(NamedAttrList & attributes)281 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
282 llvm::SmallDenseSet<StringAttr> seenKeys;
283 auto parseElt = [&]() -> ParseResult {
284 // The name of an attribute can either be a bare identifier, or a string.
285 Optional<StringAttr> nameId;
286 if (getToken().is(Token::string))
287 nameId = builder.getStringAttr(getToken().getStringValue());
288 else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
289 getToken().isKeyword())
290 nameId = builder.getStringAttr(getTokenSpelling());
291 else
292 return emitWrongTokenError("expected attribute name");
293
294 if (nameId->size() == 0)
295 return emitError("expected valid attribute name");
296
297 if (!seenKeys.insert(*nameId).second)
298 return emitError("duplicate key '")
299 << nameId->getValue() << "' in dictionary attribute";
300 consumeToken();
301
302 // Lazy load a dialect in the context if there is a possible namespace.
303 auto splitName = nameId->strref().split('.');
304 if (!splitName.second.empty())
305 getContext()->getOrLoadDialect(splitName.first);
306
307 // Try to parse the '=' for the attribute value.
308 if (!consumeIf(Token::equal)) {
309 // If there is no '=', we treat this as a unit attribute.
310 attributes.push_back({*nameId, builder.getUnitAttr()});
311 return success();
312 }
313
314 auto attr = parseAttribute();
315 if (!attr)
316 return failure();
317 attributes.push_back({*nameId, attr});
318 return success();
319 };
320
321 return parseCommaSeparatedList(Delimiter::Braces, parseElt,
322 " in attribute dictionary");
323 }
324
325 /// Parse a float attribute.
parseFloatAttr(Type type,bool isNegative)326 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
327 auto val = getToken().getFloatingPointValue();
328 if (!val)
329 return (emitError("floating point value too large for attribute"), nullptr);
330 consumeToken(Token::floatliteral);
331 if (!type) {
332 // Default to F64 when no type is specified.
333 if (!consumeIf(Token::colon))
334 type = builder.getF64Type();
335 else if (!(type = parseType()))
336 return nullptr;
337 }
338 if (!type.isa<FloatType>())
339 return (emitError("floating point value not valid for specified type"),
340 nullptr);
341 return FloatAttr::get(type, isNegative ? -*val : *val);
342 }
343
344 /// Construct an APint from a parsed value, a known attribute type and
345 /// sign.
buildAttributeAPInt(Type type,bool isNegative,StringRef spelling)346 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
347 StringRef spelling) {
348 // Parse the integer value into an APInt that is big enough to hold the value.
349 APInt result;
350 bool isHex = spelling.size() > 1 && spelling[1] == 'x';
351 if (spelling.getAsInteger(isHex ? 0 : 10, result))
352 return llvm::None;
353
354 // Extend or truncate the bitwidth to the right size.
355 unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
356 : type.getIntOrFloatBitWidth();
357
358 if (width > result.getBitWidth()) {
359 result = result.zext(width);
360 } else if (width < result.getBitWidth()) {
361 // The parser can return an unnecessarily wide result with leading zeros.
362 // This isn't a problem, but truncating off bits is bad.
363 if (result.countLeadingZeros() < result.getBitWidth() - width)
364 return llvm::None;
365
366 result = result.trunc(width);
367 }
368
369 if (width == 0) {
370 // 0 bit integers cannot be negative and manipulation of their sign bit will
371 // assert, so short-cut validation here.
372 if (isNegative)
373 return llvm::None;
374 } else if (isNegative) {
375 // The value is negative, we have an overflow if the sign bit is not set
376 // in the negated apInt.
377 result.negate();
378 if (!result.isSignBitSet())
379 return llvm::None;
380 } else if ((type.isSignedInteger() || type.isIndex()) &&
381 result.isSignBitSet()) {
382 // The value is a positive signed integer or index,
383 // we have an overflow if the sign bit is set.
384 return llvm::None;
385 }
386
387 return result;
388 }
389
390 /// Parse a decimal or a hexadecimal literal, which can be either an integer
391 /// or a float attribute.
parseDecOrHexAttr(Type type,bool isNegative)392 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
393 Token tok = getToken();
394 StringRef spelling = tok.getSpelling();
395 SMLoc loc = tok.getLoc();
396
397 consumeToken(Token::integer);
398 if (!type) {
399 // Default to i64 if not type is specified.
400 if (!consumeIf(Token::colon))
401 type = builder.getIntegerType(64);
402 else if (!(type = parseType()))
403 return nullptr;
404 }
405
406 if (auto floatType = type.dyn_cast<FloatType>()) {
407 Optional<APFloat> result;
408 if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
409 floatType.getFloatSemantics(),
410 floatType.getWidth())))
411 return Attribute();
412 return FloatAttr::get(floatType, *result);
413 }
414
415 if (!type.isa<IntegerType, IndexType>())
416 return emitError(loc, "integer literal not valid for specified type"),
417 nullptr;
418
419 if (isNegative && type.isUnsignedInteger()) {
420 emitError(loc,
421 "negative integer literal not valid for unsigned integer type");
422 return nullptr;
423 }
424
425 Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
426 if (!apInt)
427 return emitError(loc, "integer constant out of range for attribute"),
428 nullptr;
429 return builder.getIntegerAttr(type, *apInt);
430 }
431
432 //===----------------------------------------------------------------------===//
433 // TensorLiteralParser
434 //===----------------------------------------------------------------------===//
435
436 /// Parse elements values stored within a hex string. On success, the values are
437 /// stored into 'result'.
parseElementAttrHexValues(Parser & parser,Token tok,std::string & result)438 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
439 std::string &result) {
440 if (Optional<std::string> value = tok.getHexStringValue()) {
441 result = std::move(*value);
442 return success();
443 }
444 return parser.emitError(
445 tok.getLoc(), "expected string containing hex digits starting with `0x`");
446 }
447
448 namespace {
449 /// This class implements a parser for TensorLiterals. A tensor literal is
450 /// either a single element (e.g, 5) or a multi-dimensional list of elements
451 /// (e.g., [[5, 5]]).
452 class TensorLiteralParser {
453 public:
TensorLiteralParser(Parser & p)454 TensorLiteralParser(Parser &p) : p(p) {}
455
456 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
457 /// may also parse a tensor literal that is store as a hex string.
458 ParseResult parse(bool allowHex);
459
460 /// Build a dense attribute instance with the parsed elements and the given
461 /// shaped type.
462 DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
463
getShape() const464 ArrayRef<int64_t> getShape() const { return shape; }
465
466 private:
467 /// Get the parsed elements for an integer attribute.
468 ParseResult getIntAttrElements(SMLoc loc, Type eltTy,
469 std::vector<APInt> &intValues);
470
471 /// Get the parsed elements for a float attribute.
472 ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy,
473 std::vector<APFloat> &floatValues);
474
475 /// Build a Dense String attribute for the given type.
476 DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
477
478 /// Build a Dense attribute with hex data for the given type.
479 DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
480
481 /// Parse a single element, returning failure if it isn't a valid element
482 /// literal. For example:
483 /// parseElement(1) -> Success, 1
484 /// parseElement([1]) -> Failure
485 ParseResult parseElement();
486
487 /// Parse a list of either lists or elements, returning the dimensions of the
488 /// parsed sub-tensors in dims. For example:
489 /// parseList([1, 2, 3]) -> Success, [3]
490 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
491 /// parseList([[1, 2], 3]) -> Failure
492 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
493 ParseResult parseList(SmallVectorImpl<int64_t> &dims);
494
495 /// Parse a literal that was printed as a hex string.
496 ParseResult parseHexElements();
497
498 Parser &p;
499
500 /// The shape inferred from the parsed elements.
501 SmallVector<int64_t, 4> shape;
502
503 /// Storage used when parsing elements, this is a pair of <is_negated, token>.
504 std::vector<std::pair<bool, Token>> storage;
505
506 /// Storage used when parsing elements that were stored as hex values.
507 Optional<Token> hexStorage;
508 };
509 } // namespace
510
511 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
512 /// may also parse a tensor literal that is store as a hex string.
parse(bool allowHex)513 ParseResult TensorLiteralParser::parse(bool allowHex) {
514 // If hex is allowed, check for a string literal.
515 if (allowHex && p.getToken().is(Token::string)) {
516 hexStorage = p.getToken();
517 p.consumeToken(Token::string);
518 return success();
519 }
520 // Otherwise, parse a list or an individual element.
521 if (p.getToken().is(Token::l_square))
522 return parseList(shape);
523 return parseElement();
524 }
525
526 /// Build a dense attribute instance with the parsed elements and the given
527 /// shaped type.
getAttr(SMLoc loc,ShapedType type)528 DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
529 Type eltType = type.getElementType();
530
531 // Check to see if we parse the literal from a hex string.
532 if (hexStorage &&
533 (eltType.isIntOrIndexOrFloat() || eltType.isa<ComplexType>()))
534 return getHexAttr(loc, type);
535
536 // Check that the parsed storage size has the same number of elements to the
537 // type, or is a known splat.
538 if (!shape.empty() && getShape() != type.getShape()) {
539 p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
540 << "]) does not match type ([" << type.getShape() << "])";
541 return nullptr;
542 }
543
544 // Handle the case where no elements were parsed.
545 if (!hexStorage && storage.empty() && type.getNumElements()) {
546 p.emitError(loc) << "parsed zero elements, but type (" << type
547 << ") expected at least 1";
548 return nullptr;
549 }
550
551 // Handle complex types in the specific element type cases below.
552 bool isComplex = false;
553 if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
554 eltType = complexTy.getElementType();
555 isComplex = true;
556 }
557
558 // Handle integer and index types.
559 if (eltType.isIntOrIndex()) {
560 std::vector<APInt> intValues;
561 if (failed(getIntAttrElements(loc, eltType, intValues)))
562 return nullptr;
563 if (isComplex) {
564 // If this is a complex, treat the parsed values as complex values.
565 auto complexData = llvm::makeArrayRef(
566 reinterpret_cast<std::complex<APInt> *>(intValues.data()),
567 intValues.size() / 2);
568 return DenseElementsAttr::get(type, complexData);
569 }
570 return DenseElementsAttr::get(type, intValues);
571 }
572 // Handle floating point types.
573 if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
574 std::vector<APFloat> floatValues;
575 if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
576 return nullptr;
577 if (isComplex) {
578 // If this is a complex, treat the parsed values as complex values.
579 auto complexData = llvm::makeArrayRef(
580 reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
581 floatValues.size() / 2);
582 return DenseElementsAttr::get(type, complexData);
583 }
584 return DenseElementsAttr::get(type, floatValues);
585 }
586
587 // Other types are assumed to be string representations.
588 return getStringAttr(loc, type, type.getElementType());
589 }
590
591 /// Build a Dense Integer attribute for the given type.
592 ParseResult
getIntAttrElements(SMLoc loc,Type eltTy,std::vector<APInt> & intValues)593 TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
594 std::vector<APInt> &intValues) {
595 intValues.reserve(storage.size());
596 bool isUintType = eltTy.isUnsignedInteger();
597 for (const auto &signAndToken : storage) {
598 bool isNegative = signAndToken.first;
599 const Token &token = signAndToken.second;
600 auto tokenLoc = token.getLoc();
601
602 if (isNegative && isUintType) {
603 return p.emitError(tokenLoc)
604 << "expected unsigned integer elements, but parsed negative value";
605 }
606
607 // Check to see if floating point values were parsed.
608 if (token.is(Token::floatliteral)) {
609 return p.emitError(tokenLoc)
610 << "expected integer elements, but parsed floating-point";
611 }
612
613 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
614 "unexpected token type");
615 if (token.isAny(Token::kw_true, Token::kw_false)) {
616 if (!eltTy.isInteger(1)) {
617 return p.emitError(tokenLoc)
618 << "expected i1 type for 'true' or 'false' values";
619 }
620 APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
621 intValues.push_back(apInt);
622 continue;
623 }
624
625 // Create APInt values for each element with the correct bitwidth.
626 Optional<APInt> apInt =
627 buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
628 if (!apInt)
629 return p.emitError(tokenLoc, "integer constant out of range for type");
630 intValues.push_back(*apInt);
631 }
632 return success();
633 }
634
635 /// Build a Dense Float attribute for the given type.
636 ParseResult
getFloatAttrElements(SMLoc loc,FloatType eltTy,std::vector<APFloat> & floatValues)637 TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
638 std::vector<APFloat> &floatValues) {
639 floatValues.reserve(storage.size());
640 for (const auto &signAndToken : storage) {
641 bool isNegative = signAndToken.first;
642 const Token &token = signAndToken.second;
643
644 // Handle hexadecimal float literals.
645 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
646 Optional<APFloat> result;
647 if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
648 eltTy.getFloatSemantics(),
649 eltTy.getWidth())))
650 return failure();
651
652 floatValues.push_back(*result);
653 continue;
654 }
655
656 // Check to see if any decimal integers or booleans were parsed.
657 if (!token.is(Token::floatliteral))
658 return p.emitError()
659 << "expected floating-point elements, but parsed integer";
660
661 // Build the float values from tokens.
662 auto val = token.getFloatingPointValue();
663 if (!val)
664 return p.emitError("floating point value too large for attribute");
665
666 APFloat apVal(isNegative ? -*val : *val);
667 if (!eltTy.isF64()) {
668 bool unused;
669 apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
670 &unused);
671 }
672 floatValues.push_back(apVal);
673 }
674 return success();
675 }
676
677 /// Build a Dense String attribute for the given type.
getStringAttr(SMLoc loc,ShapedType type,Type eltTy)678 DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
679 Type eltTy) {
680 if (hexStorage.has_value()) {
681 auto stringValue = hexStorage.value().getStringValue();
682 return DenseStringElementsAttr::get(type, {stringValue});
683 }
684
685 std::vector<std::string> stringValues;
686 std::vector<StringRef> stringRefValues;
687 stringValues.reserve(storage.size());
688 stringRefValues.reserve(storage.size());
689
690 for (auto val : storage) {
691 stringValues.push_back(val.second.getStringValue());
692 stringRefValues.emplace_back(stringValues.back());
693 }
694
695 return DenseStringElementsAttr::get(type, stringRefValues);
696 }
697
698 /// Build a Dense attribute with hex data for the given type.
getHexAttr(SMLoc loc,ShapedType type)699 DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
700 Type elementType = type.getElementType();
701 if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
702 p.emitError(loc)
703 << "expected floating-point, integer, or complex element type, got "
704 << elementType;
705 return nullptr;
706 }
707
708 std::string data;
709 if (parseElementAttrHexValues(p, *hexStorage, data))
710 return nullptr;
711
712 ArrayRef<char> rawData(data.data(), data.size());
713 bool detectedSplat = false;
714 if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
715 p.emitError(loc) << "elements hex data size is invalid for provided type: "
716 << type;
717 return nullptr;
718 }
719
720 if (llvm::support::endian::system_endianness() ==
721 llvm::support::endianness::big) {
722 // Convert endianess in big-endian(BE) machines. `rawData` is
723 // little-endian(LE) because HEX in raw data of dense element attribute
724 // is always LE format. It is converted into BE here to be used in BE
725 // machines.
726 SmallVector<char, 64> outDataVec(rawData.size());
727 MutableArrayRef<char> convRawData(outDataVec);
728 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
729 rawData, convRawData, type);
730 return DenseElementsAttr::getFromRawBuffer(type, convRawData);
731 }
732
733 return DenseElementsAttr::getFromRawBuffer(type, rawData);
734 }
735
parseElement()736 ParseResult TensorLiteralParser::parseElement() {
737 switch (p.getToken().getKind()) {
738 // Parse a boolean element.
739 case Token::kw_true:
740 case Token::kw_false:
741 case Token::floatliteral:
742 case Token::integer:
743 storage.emplace_back(/*isNegative=*/false, p.getToken());
744 p.consumeToken();
745 break;
746
747 // Parse a signed integer or a negative floating-point element.
748 case Token::minus:
749 p.consumeToken(Token::minus);
750 if (!p.getToken().isAny(Token::floatliteral, Token::integer))
751 return p.emitError("expected integer or floating point literal");
752 storage.emplace_back(/*isNegative=*/true, p.getToken());
753 p.consumeToken();
754 break;
755
756 case Token::string:
757 storage.emplace_back(/*isNegative=*/false, p.getToken());
758 p.consumeToken();
759 break;
760
761 // Parse a complex element of the form '(' element ',' element ')'.
762 case Token::l_paren:
763 p.consumeToken(Token::l_paren);
764 if (parseElement() ||
765 p.parseToken(Token::comma, "expected ',' between complex elements") ||
766 parseElement() ||
767 p.parseToken(Token::r_paren, "expected ')' after complex elements"))
768 return failure();
769 break;
770
771 default:
772 return p.emitError("expected element literal of primitive type");
773 }
774
775 return success();
776 }
777
778 /// Parse a list of either lists or elements, returning the dimensions of the
779 /// parsed sub-tensors in dims. For example:
780 /// parseList([1, 2, 3]) -> Success, [3]
781 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
782 /// parseList([[1, 2], 3]) -> Failure
783 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
parseList(SmallVectorImpl<int64_t> & dims)784 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
785 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
786 const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
787 if (prevDims == newDims)
788 return success();
789 return p.emitError("tensor literal is invalid; ranks are not consistent "
790 "between elements");
791 };
792
793 bool first = true;
794 SmallVector<int64_t, 4> newDims;
795 unsigned size = 0;
796 auto parseOneElement = [&]() -> ParseResult {
797 SmallVector<int64_t, 4> thisDims;
798 if (p.getToken().getKind() == Token::l_square) {
799 if (parseList(thisDims))
800 return failure();
801 } else if (parseElement()) {
802 return failure();
803 }
804 ++size;
805 if (!first)
806 return checkDims(newDims, thisDims);
807 newDims = thisDims;
808 first = false;
809 return success();
810 };
811 if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
812 return failure();
813
814 // Return the sublists' dimensions with 'size' prepended.
815 dims.clear();
816 dims.push_back(size);
817 dims.append(newDims.begin(), newDims.end());
818 return success();
819 }
820
821 //===----------------------------------------------------------------------===//
822 // ElementsAttr Parser
823 //===----------------------------------------------------------------------===//
824
825 namespace {
826 /// This class provides an implementation of AsmParser, allowing to call back
827 /// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
828 /// in libMLIRIR.
829 class CustomAsmParser : public AsmParserImpl<AsmParser> {
830 public:
CustomAsmParser(Parser & parser)831 CustomAsmParser(Parser &parser)
832 : AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
833 };
834 } // namespace
835
836 /// Parse a dense array attribute.
parseDenseArrayAttr()837 Attribute Parser::parseDenseArrayAttr() {
838 auto typeLoc = getToken().getLoc();
839 auto type = parseType();
840 if (!type)
841 return {};
842 CustomAsmParser parser(*this);
843 Attribute result;
844 // Check for empty list.
845 bool isEmptyList = getToken().is(Token::r_square);
846
847 if (auto intType = type.dyn_cast<IntegerType>()) {
848 switch (type.getIntOrFloatBitWidth()) {
849 case 8:
850 if (isEmptyList)
851 result = DenseI8ArrayAttr::get(parser.getContext(), {});
852 else
853 result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
854 break;
855 case 16:
856 if (isEmptyList)
857 result = DenseI16ArrayAttr::get(parser.getContext(), {});
858 else
859 result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
860 break;
861 case 32:
862 if (isEmptyList)
863 result = DenseI32ArrayAttr::get(parser.getContext(), {});
864 else
865 result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
866 break;
867 case 64:
868 if (isEmptyList)
869 result = DenseI64ArrayAttr::get(parser.getContext(), {});
870 else
871 result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
872 break;
873 default:
874 emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
875 return {};
876 }
877 } else if (auto floatType = type.dyn_cast<FloatType>()) {
878 switch (type.getIntOrFloatBitWidth()) {
879 case 32:
880 if (isEmptyList)
881 result = DenseF32ArrayAttr::get(parser.getContext(), {});
882 else
883 result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
884 break;
885 case 64:
886 if (isEmptyList)
887 result = DenseF64ArrayAttr::get(parser.getContext(), {});
888 else
889 result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
890 break;
891 default:
892 emitError(typeLoc, "expected f32 or f64 but got: ") << type;
893 return {};
894 }
895 } else {
896 emitError(typeLoc, "expected integer or float type, got: ") << type;
897 return {};
898 }
899 if (!consumeIf(Token::r_square)) {
900 emitError("expected ']' to close an array attribute");
901 return {};
902 }
903 return result;
904 }
905
906 /// Parse a dense elements attribute.
parseDenseElementsAttr(Type attrType)907 Attribute Parser::parseDenseElementsAttr(Type attrType) {
908 auto attribLoc = getToken().getLoc();
909 consumeToken(Token::kw_dense);
910 if (parseToken(Token::less, "expected '<' after 'dense'"))
911 return nullptr;
912
913 // Parse the literal data if necessary.
914 TensorLiteralParser literalParser(*this);
915 if (!consumeIf(Token::greater)) {
916 if (literalParser.parse(/*allowHex=*/true) ||
917 parseToken(Token::greater, "expected '>'"))
918 return nullptr;
919 }
920
921 // If the type is specified `parseElementsLiteralType` will not parse a type.
922 // Use the attribute location as the location for error reporting in that
923 // case.
924 auto loc = attrType ? attribLoc : getToken().getLoc();
925 auto type = parseElementsLiteralType(attrType);
926 if (!type)
927 return nullptr;
928 return literalParser.getAttr(loc, type);
929 }
930
931 /// Parse an opaque elements attribute.
parseOpaqueElementsAttr(Type attrType)932 Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
933 SMLoc loc = getToken().getLoc();
934 consumeToken(Token::kw_opaque);
935 if (parseToken(Token::less, "expected '<' after 'opaque'"))
936 return nullptr;
937
938 if (getToken().isNot(Token::string))
939 return (emitError("expected dialect namespace"), nullptr);
940
941 std::string name = getToken().getStringValue();
942 consumeToken(Token::string);
943
944 if (parseToken(Token::comma, "expected ','"))
945 return nullptr;
946
947 Token hexTok = getToken();
948 if (parseToken(Token::string, "elements hex string should start with '0x'") ||
949 parseToken(Token::greater, "expected '>'"))
950 return nullptr;
951 auto type = parseElementsLiteralType(attrType);
952 if (!type)
953 return nullptr;
954
955 std::string data;
956 if (parseElementAttrHexValues(*this, hexTok, data))
957 return nullptr;
958 return getChecked<OpaqueElementsAttr>(loc, builder.getStringAttr(name), type,
959 data);
960 }
961
962 /// Shaped type for elements attribute.
963 ///
964 /// elements-literal-type ::= vector-type | ranked-tensor-type
965 ///
966 /// This method also checks the type has static shape.
parseElementsLiteralType(Type type)967 ShapedType Parser::parseElementsLiteralType(Type type) {
968 // If the user didn't provide a type, parse the colon type for the literal.
969 if (!type) {
970 if (parseToken(Token::colon, "expected ':'"))
971 return nullptr;
972 if (!(type = parseType()))
973 return nullptr;
974 }
975
976 if (!type.isa<RankedTensorType, VectorType>()) {
977 emitError("elements literal must be a ranked tensor or vector type");
978 return nullptr;
979 }
980
981 auto sType = type.cast<ShapedType>();
982 if (!sType.hasStaticShape())
983 return (emitError("elements literal type must have static shape"), nullptr);
984
985 return sType;
986 }
987
988 /// Parse a sparse elements attribute.
parseSparseElementsAttr(Type attrType)989 Attribute Parser::parseSparseElementsAttr(Type attrType) {
990 SMLoc loc = getToken().getLoc();
991 consumeToken(Token::kw_sparse);
992 if (parseToken(Token::less, "Expected '<' after 'sparse'"))
993 return nullptr;
994
995 // Check for the case where all elements are sparse. The indices are
996 // represented by a 2-dimensional shape where the second dimension is the rank
997 // of the type.
998 Type indiceEltType = builder.getIntegerType(64);
999 if (consumeIf(Token::greater)) {
1000 ShapedType type = parseElementsLiteralType(attrType);
1001 if (!type)
1002 return nullptr;
1003
1004 // Construct the sparse elements attr using zero element indice/value
1005 // attributes.
1006 ShapedType indicesType =
1007 RankedTensorType::get({0, type.getRank()}, indiceEltType);
1008 ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
1009 return getChecked<SparseElementsAttr>(
1010 loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
1011 DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
1012 }
1013
1014 /// Parse the indices. We don't allow hex values here as we may need to use
1015 /// the inferred shape.
1016 auto indicesLoc = getToken().getLoc();
1017 TensorLiteralParser indiceParser(*this);
1018 if (indiceParser.parse(/*allowHex=*/false))
1019 return nullptr;
1020
1021 if (parseToken(Token::comma, "expected ','"))
1022 return nullptr;
1023
1024 /// Parse the values.
1025 auto valuesLoc = getToken().getLoc();
1026 TensorLiteralParser valuesParser(*this);
1027 if (valuesParser.parse(/*allowHex=*/true))
1028 return nullptr;
1029
1030 if (parseToken(Token::greater, "expected '>'"))
1031 return nullptr;
1032
1033 auto type = parseElementsLiteralType(attrType);
1034 if (!type)
1035 return nullptr;
1036
1037 // If the indices are a splat, i.e. the literal parser parsed an element and
1038 // not a list, we set the shape explicitly. The indices are represented by a
1039 // 2-dimensional shape where the second dimension is the rank of the type.
1040 // Given that the parsed indices is a splat, we know that we only have one
1041 // indice and thus one for the first dimension.
1042 ShapedType indicesType;
1043 if (indiceParser.getShape().empty()) {
1044 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
1045 } else {
1046 // Otherwise, set the shape to the one parsed by the literal parser.
1047 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
1048 }
1049 auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1050
1051 // If the values are a splat, set the shape explicitly based on the number of
1052 // indices. The number of indices is encoded in the first dimension of the
1053 // indice shape type.
1054 auto valuesEltType = type.getElementType();
1055 ShapedType valuesType =
1056 valuesParser.getShape().empty()
1057 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
1058 : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
1059 auto values = valuesParser.getAttr(valuesLoc, valuesType);
1060
1061 // Build the sparse elements attribute by the indices and values.
1062 return getChecked<SparseElementsAttr>(loc, type, indices, values);
1063 }
1064