1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR Types. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Parser.h" 14 15 #include "AsmParserImpl.h" 16 #include "mlir/AsmParser/AsmParserState.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/IR/DialectImplementation.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/Support/Endian.h" 24 25 using namespace mlir; 26 using namespace mlir::detail; 27 28 /// Parse an arbitrary attribute. 29 /// 30 /// attribute-value ::= `unit` 31 /// | bool-literal 32 /// | integer-literal (`:` (index-type | integer-type))? 33 /// | float-literal (`:` float-type)? 34 /// | string-literal (`:` type)? 35 /// | type 36 /// | `[` `:` (integer-type | float-type) tensor-literal `]` 37 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 38 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 39 /// | symbol-ref-id (`::` symbol-ref-id)* 40 /// | `dense` `<` tensor-literal `>` `:` 41 /// (tensor-type | vector-type) 42 /// | `sparse` `<` attribute-value `,` attribute-value `>` 43 /// `:` (tensor-type | vector-type) 44 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 45 /// `>` `:` (tensor-type | vector-type) 46 /// | extended-attribute 47 /// 48 Attribute Parser::parseAttribute(Type type) { 49 switch (getToken().getKind()) { 50 // Parse an AffineMap or IntegerSet attribute. 51 case Token::kw_affine_map: { 52 consumeToken(Token::kw_affine_map); 53 54 AffineMap map; 55 if (parseToken(Token::less, "expected '<' in affine map") || 56 parseAffineMapReference(map) || 57 parseToken(Token::greater, "expected '>' in affine map")) 58 return Attribute(); 59 return AffineMapAttr::get(map); 60 } 61 case Token::kw_affine_set: { 62 consumeToken(Token::kw_affine_set); 63 64 IntegerSet set; 65 if (parseToken(Token::less, "expected '<' in integer set") || 66 parseIntegerSetReference(set) || 67 parseToken(Token::greater, "expected '>' in integer set")) 68 return Attribute(); 69 return IntegerSetAttr::get(set); 70 } 71 72 // Parse an array attribute. 73 case Token::l_square: { 74 consumeToken(Token::l_square); 75 if (consumeIf(Token::colon)) 76 return parseDenseArrayAttr(); 77 SmallVector<Attribute, 4> elements; 78 auto parseElt = [&]() -> ParseResult { 79 elements.push_back(parseAttribute()); 80 return elements.back() ? success() : failure(); 81 }; 82 83 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 84 return nullptr; 85 return builder.getArrayAttr(elements); 86 } 87 88 // Parse a boolean attribute. 89 case Token::kw_false: 90 consumeToken(Token::kw_false); 91 return builder.getBoolAttr(false); 92 case Token::kw_true: 93 consumeToken(Token::kw_true); 94 return builder.getBoolAttr(true); 95 96 // Parse a dense elements attribute. 97 case Token::kw_dense: 98 return parseDenseElementsAttr(type); 99 100 // Parse a dictionary attribute. 101 case Token::l_brace: { 102 NamedAttrList elements; 103 if (parseAttributeDict(elements)) 104 return nullptr; 105 return elements.getDictionary(getContext()); 106 } 107 108 // Parse an extended attribute, i.e. alias or dialect attribute. 109 case Token::hash_identifier: 110 return parseExtendedAttr(type); 111 112 // Parse floating point and integer attributes. 113 case Token::floatliteral: 114 return parseFloatAttr(type, /*isNegative=*/false); 115 case Token::integer: 116 return parseDecOrHexAttr(type, /*isNegative=*/false); 117 case Token::minus: { 118 consumeToken(Token::minus); 119 if (getToken().is(Token::integer)) 120 return parseDecOrHexAttr(type, /*isNegative=*/true); 121 if (getToken().is(Token::floatliteral)) 122 return parseFloatAttr(type, /*isNegative=*/true); 123 124 return (emitWrongTokenError( 125 "expected constant integer or floating point value"), 126 nullptr); 127 } 128 129 // Parse a location attribute. 130 case Token::kw_loc: { 131 consumeToken(Token::kw_loc); 132 133 LocationAttr locAttr; 134 if (parseToken(Token::l_paren, "expected '(' in inline location") || 135 parseLocationInstance(locAttr) || 136 parseToken(Token::r_paren, "expected ')' in inline location")) 137 return Attribute(); 138 return locAttr; 139 } 140 141 // Parse an opaque elements attribute. 142 case Token::kw_opaque: 143 return parseOpaqueElementsAttr(type); 144 145 // Parse a sparse elements attribute. 146 case Token::kw_sparse: 147 return parseSparseElementsAttr(type); 148 149 // Parse a string attribute. 150 case Token::string: { 151 auto val = getToken().getStringValue(); 152 consumeToken(Token::string); 153 // Parse the optional trailing colon type if one wasn't explicitly provided. 154 if (!type && consumeIf(Token::colon) && !(type = parseType())) 155 return Attribute(); 156 157 return type ? StringAttr::get(val, type) 158 : StringAttr::get(getContext(), val); 159 } 160 161 // Parse a symbol reference attribute. 162 case Token::at_identifier: { 163 // When populating the parser state, this is a list of locations for all of 164 // the nested references. 165 SmallVector<SMRange> referenceLocations; 166 if (state.asmState) 167 referenceLocations.push_back(getToken().getLocRange()); 168 169 // Parse the top-level reference. 170 std::string nameStr = getToken().getSymbolReference(); 171 consumeToken(Token::at_identifier); 172 173 // Parse any nested references. 174 std::vector<FlatSymbolRefAttr> nestedRefs; 175 while (getToken().is(Token::colon)) { 176 // Check for the '::' prefix. 177 const char *curPointer = getToken().getLoc().getPointer(); 178 consumeToken(Token::colon); 179 if (!consumeIf(Token::colon)) { 180 if (getToken().isNot(Token::eof, Token::error)) { 181 state.lex.resetPointer(curPointer); 182 consumeToken(); 183 } 184 break; 185 } 186 // Parse the reference itself. 187 auto curLoc = getToken().getLoc(); 188 if (getToken().isNot(Token::at_identifier)) { 189 emitError(curLoc, "expected nested symbol reference identifier"); 190 return Attribute(); 191 } 192 193 // If we are populating the assembly state, add the location for this 194 // reference. 195 if (state.asmState) 196 referenceLocations.push_back(getToken().getLocRange()); 197 198 std::string nameStr = getToken().getSymbolReference(); 199 consumeToken(Token::at_identifier); 200 nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); 201 } 202 SymbolRefAttr symbolRefAttr = 203 SymbolRefAttr::get(getContext(), nameStr, nestedRefs); 204 205 // If we are populating the assembly state, record this symbol reference. 206 if (state.asmState) 207 state.asmState->addUses(symbolRefAttr, referenceLocations); 208 return symbolRefAttr; 209 } 210 211 // Parse a 'unit' attribute. 212 case Token::kw_unit: 213 consumeToken(Token::kw_unit); 214 return builder.getUnitAttr(); 215 216 // Handle completion of an attribute. 217 case Token::code_complete: 218 if (getToken().isCodeCompletionFor(Token::hash_identifier)) 219 return parseExtendedAttr(type); 220 return codeCompleteAttribute(); 221 222 default: 223 // Parse a type attribute. We parse `Optional` here to allow for providing a 224 // better error message. 225 Type type; 226 OptionalParseResult result = parseOptionalType(type); 227 if (!result.hasValue()) 228 return emitWrongTokenError("expected attribute value"), Attribute(); 229 return failed(*result) ? Attribute() : TypeAttr::get(type); 230 } 231 } 232 233 /// Parse an optional attribute with the provided type. 234 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute, 235 Type type) { 236 switch (getToken().getKind()) { 237 case Token::at_identifier: 238 case Token::floatliteral: 239 case Token::integer: 240 case Token::hash_identifier: 241 case Token::kw_affine_map: 242 case Token::kw_affine_set: 243 case Token::kw_dense: 244 case Token::kw_false: 245 case Token::kw_loc: 246 case Token::kw_opaque: 247 case Token::kw_sparse: 248 case Token::kw_true: 249 case Token::kw_unit: 250 case Token::l_brace: 251 case Token::l_square: 252 case Token::minus: 253 case Token::string: 254 attribute = parseAttribute(type); 255 return success(attribute != nullptr); 256 257 default: 258 // Parse an optional type attribute. 259 Type type; 260 OptionalParseResult result = parseOptionalType(type); 261 if (result.hasValue() && succeeded(*result)) 262 attribute = TypeAttr::get(type); 263 return result; 264 } 265 } 266 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute, 267 Type type) { 268 return parseOptionalAttributeWithToken(Token::l_square, attribute, type); 269 } 270 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute, 271 Type type) { 272 return parseOptionalAttributeWithToken(Token::string, attribute, type); 273 } 274 275 /// Attribute dictionary. 276 /// 277 /// attribute-dict ::= `{` `}` 278 /// | `{` attribute-entry (`,` attribute-entry)* `}` 279 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value 280 /// 281 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { 282 llvm::SmallDenseSet<StringAttr> seenKeys; 283 auto parseElt = [&]() -> ParseResult { 284 // The name of an attribute can either be a bare identifier, or a string. 285 Optional<StringAttr> nameId; 286 if (getToken().is(Token::string)) 287 nameId = builder.getStringAttr(getToken().getStringValue()); 288 else if (getToken().isAny(Token::bare_identifier, Token::inttype) || 289 getToken().isKeyword()) 290 nameId = builder.getStringAttr(getTokenSpelling()); 291 else 292 return emitWrongTokenError("expected attribute name"); 293 294 if (nameId->size() == 0) 295 return emitError("expected valid attribute name"); 296 297 if (!seenKeys.insert(*nameId).second) 298 return emitError("duplicate key '") 299 << nameId->getValue() << "' in dictionary attribute"; 300 consumeToken(); 301 302 // Lazy load a dialect in the context if there is a possible namespace. 303 auto splitName = nameId->strref().split('.'); 304 if (!splitName.second.empty()) 305 getContext()->getOrLoadDialect(splitName.first); 306 307 // Try to parse the '=' for the attribute value. 308 if (!consumeIf(Token::equal)) { 309 // If there is no '=', we treat this as a unit attribute. 310 attributes.push_back({*nameId, builder.getUnitAttr()}); 311 return success(); 312 } 313 314 auto attr = parseAttribute(); 315 if (!attr) 316 return failure(); 317 attributes.push_back({*nameId, attr}); 318 return success(); 319 }; 320 321 return parseCommaSeparatedList(Delimiter::Braces, parseElt, 322 " in attribute dictionary"); 323 } 324 325 /// Parse a float attribute. 326 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 327 auto val = getToken().getFloatingPointValue(); 328 if (!val) 329 return (emitError("floating point value too large for attribute"), nullptr); 330 consumeToken(Token::floatliteral); 331 if (!type) { 332 // Default to F64 when no type is specified. 333 if (!consumeIf(Token::colon)) 334 type = builder.getF64Type(); 335 else if (!(type = parseType())) 336 return nullptr; 337 } 338 if (!type.isa<FloatType>()) 339 return (emitError("floating point value not valid for specified type"), 340 nullptr); 341 return FloatAttr::get(type, isNegative ? -*val : *val); 342 } 343 344 /// Construct an APint from a parsed value, a known attribute type and 345 /// sign. 346 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative, 347 StringRef spelling) { 348 // Parse the integer value into an APInt that is big enough to hold the value. 349 APInt result; 350 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 351 if (spelling.getAsInteger(isHex ? 0 : 10, result)) 352 return llvm::None; 353 354 // Extend or truncate the bitwidth to the right size. 355 unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth 356 : type.getIntOrFloatBitWidth(); 357 358 if (width > result.getBitWidth()) { 359 result = result.zext(width); 360 } else if (width < result.getBitWidth()) { 361 // The parser can return an unnecessarily wide result with leading zeros. 362 // This isn't a problem, but truncating off bits is bad. 363 if (result.countLeadingZeros() < result.getBitWidth() - width) 364 return llvm::None; 365 366 result = result.trunc(width); 367 } 368 369 if (width == 0) { 370 // 0 bit integers cannot be negative and manipulation of their sign bit will 371 // assert, so short-cut validation here. 372 if (isNegative) 373 return llvm::None; 374 } else if (isNegative) { 375 // The value is negative, we have an overflow if the sign bit is not set 376 // in the negated apInt. 377 result.negate(); 378 if (!result.isSignBitSet()) 379 return llvm::None; 380 } else if ((type.isSignedInteger() || type.isIndex()) && 381 result.isSignBitSet()) { 382 // The value is a positive signed integer or index, 383 // we have an overflow if the sign bit is set. 384 return llvm::None; 385 } 386 387 return result; 388 } 389 390 /// Parse a decimal or a hexadecimal literal, which can be either an integer 391 /// or a float attribute. 392 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 393 Token tok = getToken(); 394 StringRef spelling = tok.getSpelling(); 395 SMLoc loc = tok.getLoc(); 396 397 consumeToken(Token::integer); 398 if (!type) { 399 // Default to i64 if not type is specified. 400 if (!consumeIf(Token::colon)) 401 type = builder.getIntegerType(64); 402 else if (!(type = parseType())) 403 return nullptr; 404 } 405 406 if (auto floatType = type.dyn_cast<FloatType>()) { 407 Optional<APFloat> result; 408 if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, 409 floatType.getFloatSemantics(), 410 floatType.getWidth()))) 411 return Attribute(); 412 return FloatAttr::get(floatType, *result); 413 } 414 415 if (!type.isa<IntegerType, IndexType>()) 416 return emitError(loc, "integer literal not valid for specified type"), 417 nullptr; 418 419 if (isNegative && type.isUnsignedInteger()) { 420 emitError(loc, 421 "negative integer literal not valid for unsigned integer type"); 422 return nullptr; 423 } 424 425 Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling); 426 if (!apInt) 427 return emitError(loc, "integer constant out of range for attribute"), 428 nullptr; 429 return builder.getIntegerAttr(type, *apInt); 430 } 431 432 //===----------------------------------------------------------------------===// 433 // TensorLiteralParser 434 //===----------------------------------------------------------------------===// 435 436 /// Parse elements values stored within a hex string. On success, the values are 437 /// stored into 'result'. 438 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, 439 std::string &result) { 440 if (Optional<std::string> value = tok.getHexStringValue()) { 441 result = std::move(*value); 442 return success(); 443 } 444 return parser.emitError( 445 tok.getLoc(), "expected string containing hex digits starting with `0x`"); 446 } 447 448 namespace { 449 /// This class implements a parser for TensorLiterals. A tensor literal is 450 /// either a single element (e.g, 5) or a multi-dimensional list of elements 451 /// (e.g., [[5, 5]]). 452 class TensorLiteralParser { 453 public: 454 TensorLiteralParser(Parser &p) : p(p) {} 455 456 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 457 /// may also parse a tensor literal that is store as a hex string. 458 ParseResult parse(bool allowHex); 459 460 /// Build a dense attribute instance with the parsed elements and the given 461 /// shaped type. 462 DenseElementsAttr getAttr(SMLoc loc, ShapedType type); 463 464 ArrayRef<int64_t> getShape() const { return shape; } 465 466 private: 467 /// Get the parsed elements for an integer attribute. 468 ParseResult getIntAttrElements(SMLoc loc, Type eltTy, 469 std::vector<APInt> &intValues); 470 471 /// Get the parsed elements for a float attribute. 472 ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy, 473 std::vector<APFloat> &floatValues); 474 475 /// Build a Dense String attribute for the given type. 476 DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); 477 478 /// Build a Dense attribute with hex data for the given type. 479 DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type); 480 481 /// Parse a single element, returning failure if it isn't a valid element 482 /// literal. For example: 483 /// parseElement(1) -> Success, 1 484 /// parseElement([1]) -> Failure 485 ParseResult parseElement(); 486 487 /// Parse a list of either lists or elements, returning the dimensions of the 488 /// parsed sub-tensors in dims. For example: 489 /// parseList([1, 2, 3]) -> Success, [3] 490 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 491 /// parseList([[1, 2], 3]) -> Failure 492 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 493 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 494 495 /// Parse a literal that was printed as a hex string. 496 ParseResult parseHexElements(); 497 498 Parser &p; 499 500 /// The shape inferred from the parsed elements. 501 SmallVector<int64_t, 4> shape; 502 503 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 504 std::vector<std::pair<bool, Token>> storage; 505 506 /// Storage used when parsing elements that were stored as hex values. 507 Optional<Token> hexStorage; 508 }; 509 } // namespace 510 511 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 512 /// may also parse a tensor literal that is store as a hex string. 513 ParseResult TensorLiteralParser::parse(bool allowHex) { 514 // If hex is allowed, check for a string literal. 515 if (allowHex && p.getToken().is(Token::string)) { 516 hexStorage = p.getToken(); 517 p.consumeToken(Token::string); 518 return success(); 519 } 520 // Otherwise, parse a list or an individual element. 521 if (p.getToken().is(Token::l_square)) 522 return parseList(shape); 523 return parseElement(); 524 } 525 526 /// Build a dense attribute instance with the parsed elements and the given 527 /// shaped type. 528 DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { 529 Type eltType = type.getElementType(); 530 531 // Check to see if we parse the literal from a hex string. 532 if (hexStorage && 533 (eltType.isIntOrIndexOrFloat() || eltType.isa<ComplexType>())) 534 return getHexAttr(loc, type); 535 536 // Check that the parsed storage size has the same number of elements to the 537 // type, or is a known splat. 538 if (!shape.empty() && getShape() != type.getShape()) { 539 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 540 << "]) does not match type ([" << type.getShape() << "])"; 541 return nullptr; 542 } 543 544 // Handle the case where no elements were parsed. 545 if (!hexStorage && storage.empty() && type.getNumElements()) { 546 p.emitError(loc) << "parsed zero elements, but type (" << type 547 << ") expected at least 1"; 548 return nullptr; 549 } 550 551 // Handle complex types in the specific element type cases below. 552 bool isComplex = false; 553 if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) { 554 eltType = complexTy.getElementType(); 555 isComplex = true; 556 } 557 558 // Handle integer and index types. 559 if (eltType.isIntOrIndex()) { 560 std::vector<APInt> intValues; 561 if (failed(getIntAttrElements(loc, eltType, intValues))) 562 return nullptr; 563 if (isComplex) { 564 // If this is a complex, treat the parsed values as complex values. 565 auto complexData = llvm::makeArrayRef( 566 reinterpret_cast<std::complex<APInt> *>(intValues.data()), 567 intValues.size() / 2); 568 return DenseElementsAttr::get(type, complexData); 569 } 570 return DenseElementsAttr::get(type, intValues); 571 } 572 // Handle floating point types. 573 if (FloatType floatTy = eltType.dyn_cast<FloatType>()) { 574 std::vector<APFloat> floatValues; 575 if (failed(getFloatAttrElements(loc, floatTy, floatValues))) 576 return nullptr; 577 if (isComplex) { 578 // If this is a complex, treat the parsed values as complex values. 579 auto complexData = llvm::makeArrayRef( 580 reinterpret_cast<std::complex<APFloat> *>(floatValues.data()), 581 floatValues.size() / 2); 582 return DenseElementsAttr::get(type, complexData); 583 } 584 return DenseElementsAttr::get(type, floatValues); 585 } 586 587 // Other types are assumed to be string representations. 588 return getStringAttr(loc, type, type.getElementType()); 589 } 590 591 /// Build a Dense Integer attribute for the given type. 592 ParseResult 593 TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy, 594 std::vector<APInt> &intValues) { 595 intValues.reserve(storage.size()); 596 bool isUintType = eltTy.isUnsignedInteger(); 597 for (const auto &signAndToken : storage) { 598 bool isNegative = signAndToken.first; 599 const Token &token = signAndToken.second; 600 auto tokenLoc = token.getLoc(); 601 602 if (isNegative && isUintType) { 603 return p.emitError(tokenLoc) 604 << "expected unsigned integer elements, but parsed negative value"; 605 } 606 607 // Check to see if floating point values were parsed. 608 if (token.is(Token::floatliteral)) { 609 return p.emitError(tokenLoc) 610 << "expected integer elements, but parsed floating-point"; 611 } 612 613 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 614 "unexpected token type"); 615 if (token.isAny(Token::kw_true, Token::kw_false)) { 616 if (!eltTy.isInteger(1)) { 617 return p.emitError(tokenLoc) 618 << "expected i1 type for 'true' or 'false' values"; 619 } 620 APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false); 621 intValues.push_back(apInt); 622 continue; 623 } 624 625 // Create APInt values for each element with the correct bitwidth. 626 Optional<APInt> apInt = 627 buildAttributeAPInt(eltTy, isNegative, token.getSpelling()); 628 if (!apInt) 629 return p.emitError(tokenLoc, "integer constant out of range for type"); 630 intValues.push_back(*apInt); 631 } 632 return success(); 633 } 634 635 /// Build a Dense Float attribute for the given type. 636 ParseResult 637 TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, 638 std::vector<APFloat> &floatValues) { 639 floatValues.reserve(storage.size()); 640 for (const auto &signAndToken : storage) { 641 bool isNegative = signAndToken.first; 642 const Token &token = signAndToken.second; 643 644 // Handle hexadecimal float literals. 645 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 646 Optional<APFloat> result; 647 if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative, 648 eltTy.getFloatSemantics(), 649 eltTy.getWidth()))) 650 return failure(); 651 652 floatValues.push_back(*result); 653 continue; 654 } 655 656 // Check to see if any decimal integers or booleans were parsed. 657 if (!token.is(Token::floatliteral)) 658 return p.emitError() 659 << "expected floating-point elements, but parsed integer"; 660 661 // Build the float values from tokens. 662 auto val = token.getFloatingPointValue(); 663 if (!val) 664 return p.emitError("floating point value too large for attribute"); 665 666 APFloat apVal(isNegative ? -*val : *val); 667 if (!eltTy.isF64()) { 668 bool unused; 669 apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, 670 &unused); 671 } 672 floatValues.push_back(apVal); 673 } 674 return success(); 675 } 676 677 /// Build a Dense String attribute for the given type. 678 DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, 679 Type eltTy) { 680 if (hexStorage.has_value()) { 681 auto stringValue = hexStorage.value().getStringValue(); 682 return DenseStringElementsAttr::get(type, {stringValue}); 683 } 684 685 std::vector<std::string> stringValues; 686 std::vector<StringRef> stringRefValues; 687 stringValues.reserve(storage.size()); 688 stringRefValues.reserve(storage.size()); 689 690 for (auto val : storage) { 691 stringValues.push_back(val.second.getStringValue()); 692 stringRefValues.emplace_back(stringValues.back()); 693 } 694 695 return DenseStringElementsAttr::get(type, stringRefValues); 696 } 697 698 /// Build a Dense attribute with hex data for the given type. 699 DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) { 700 Type elementType = type.getElementType(); 701 if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) { 702 p.emitError(loc) 703 << "expected floating-point, integer, or complex element type, got " 704 << elementType; 705 return nullptr; 706 } 707 708 std::string data; 709 if (parseElementAttrHexValues(p, *hexStorage, data)) 710 return nullptr; 711 712 ArrayRef<char> rawData(data.data(), data.size()); 713 bool detectedSplat = false; 714 if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { 715 p.emitError(loc) << "elements hex data size is invalid for provided type: " 716 << type; 717 return nullptr; 718 } 719 720 if (llvm::support::endian::system_endianness() == 721 llvm::support::endianness::big) { 722 // Convert endianess in big-endian(BE) machines. `rawData` is 723 // little-endian(LE) because HEX in raw data of dense element attribute 724 // is always LE format. It is converted into BE here to be used in BE 725 // machines. 726 SmallVector<char, 64> outDataVec(rawData.size()); 727 MutableArrayRef<char> convRawData(outDataVec); 728 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 729 rawData, convRawData, type); 730 return DenseElementsAttr::getFromRawBuffer(type, convRawData); 731 } 732 733 return DenseElementsAttr::getFromRawBuffer(type, rawData); 734 } 735 736 ParseResult TensorLiteralParser::parseElement() { 737 switch (p.getToken().getKind()) { 738 // Parse a boolean element. 739 case Token::kw_true: 740 case Token::kw_false: 741 case Token::floatliteral: 742 case Token::integer: 743 storage.emplace_back(/*isNegative=*/false, p.getToken()); 744 p.consumeToken(); 745 break; 746 747 // Parse a signed integer or a negative floating-point element. 748 case Token::minus: 749 p.consumeToken(Token::minus); 750 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 751 return p.emitError("expected integer or floating point literal"); 752 storage.emplace_back(/*isNegative=*/true, p.getToken()); 753 p.consumeToken(); 754 break; 755 756 case Token::string: 757 storage.emplace_back(/*isNegative=*/false, p.getToken()); 758 p.consumeToken(); 759 break; 760 761 // Parse a complex element of the form '(' element ',' element ')'. 762 case Token::l_paren: 763 p.consumeToken(Token::l_paren); 764 if (parseElement() || 765 p.parseToken(Token::comma, "expected ',' between complex elements") || 766 parseElement() || 767 p.parseToken(Token::r_paren, "expected ')' after complex elements")) 768 return failure(); 769 break; 770 771 default: 772 return p.emitError("expected element literal of primitive type"); 773 } 774 775 return success(); 776 } 777 778 /// Parse a list of either lists or elements, returning the dimensions of the 779 /// parsed sub-tensors in dims. For example: 780 /// parseList([1, 2, 3]) -> Success, [3] 781 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 782 /// parseList([[1, 2], 3]) -> Failure 783 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 784 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 785 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 786 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 787 if (prevDims == newDims) 788 return success(); 789 return p.emitError("tensor literal is invalid; ranks are not consistent " 790 "between elements"); 791 }; 792 793 bool first = true; 794 SmallVector<int64_t, 4> newDims; 795 unsigned size = 0; 796 auto parseOneElement = [&]() -> ParseResult { 797 SmallVector<int64_t, 4> thisDims; 798 if (p.getToken().getKind() == Token::l_square) { 799 if (parseList(thisDims)) 800 return failure(); 801 } else if (parseElement()) { 802 return failure(); 803 } 804 ++size; 805 if (!first) 806 return checkDims(newDims, thisDims); 807 newDims = thisDims; 808 first = false; 809 return success(); 810 }; 811 if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement)) 812 return failure(); 813 814 // Return the sublists' dimensions with 'size' prepended. 815 dims.clear(); 816 dims.push_back(size); 817 dims.append(newDims.begin(), newDims.end()); 818 return success(); 819 } 820 821 //===----------------------------------------------------------------------===// 822 // ElementsAttr Parser 823 //===----------------------------------------------------------------------===// 824 825 namespace { 826 /// This class provides an implementation of AsmParser, allowing to call back 827 /// into the libMLIRIR-provided APIs for invoking attribute parsing code defined 828 /// in libMLIRIR. 829 class CustomAsmParser : public AsmParserImpl<AsmParser> { 830 public: 831 CustomAsmParser(Parser &parser) 832 : AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {} 833 }; 834 } // namespace 835 836 /// Parse a dense array attribute. 837 Attribute Parser::parseDenseArrayAttr() { 838 auto typeLoc = getToken().getLoc(); 839 auto type = parseType(); 840 if (!type) 841 return {}; 842 CustomAsmParser parser(*this); 843 Attribute result; 844 // Check for empty list. 845 bool isEmptyList = getToken().is(Token::r_square); 846 847 if (auto intType = type.dyn_cast<IntegerType>()) { 848 switch (type.getIntOrFloatBitWidth()) { 849 case 8: 850 if (isEmptyList) 851 result = DenseI8ArrayAttr::get(parser.getContext(), {}); 852 else 853 result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{}); 854 break; 855 case 16: 856 if (isEmptyList) 857 result = DenseI16ArrayAttr::get(parser.getContext(), {}); 858 else 859 result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{}); 860 break; 861 case 32: 862 if (isEmptyList) 863 result = DenseI32ArrayAttr::get(parser.getContext(), {}); 864 else 865 result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{}); 866 break; 867 case 64: 868 if (isEmptyList) 869 result = DenseI64ArrayAttr::get(parser.getContext(), {}); 870 else 871 result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); 872 break; 873 default: 874 emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type; 875 return {}; 876 } 877 } else if (auto floatType = type.dyn_cast<FloatType>()) { 878 switch (type.getIntOrFloatBitWidth()) { 879 case 32: 880 if (isEmptyList) 881 result = DenseF32ArrayAttr::get(parser.getContext(), {}); 882 else 883 result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{}); 884 break; 885 case 64: 886 if (isEmptyList) 887 result = DenseF64ArrayAttr::get(parser.getContext(), {}); 888 else 889 result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{}); 890 break; 891 default: 892 emitError(typeLoc, "expected f32 or f64 but got: ") << type; 893 return {}; 894 } 895 } else { 896 emitError(typeLoc, "expected integer or float type, got: ") << type; 897 return {}; 898 } 899 if (!consumeIf(Token::r_square)) { 900 emitError("expected ']' to close an array attribute"); 901 return {}; 902 } 903 return result; 904 } 905 906 /// Parse a dense elements attribute. 907 Attribute Parser::parseDenseElementsAttr(Type attrType) { 908 auto attribLoc = getToken().getLoc(); 909 consumeToken(Token::kw_dense); 910 if (parseToken(Token::less, "expected '<' after 'dense'")) 911 return nullptr; 912 913 // Parse the literal data if necessary. 914 TensorLiteralParser literalParser(*this); 915 if (!consumeIf(Token::greater)) { 916 if (literalParser.parse(/*allowHex=*/true) || 917 parseToken(Token::greater, "expected '>'")) 918 return nullptr; 919 } 920 921 // If the type is specified `parseElementsLiteralType` will not parse a type. 922 // Use the attribute location as the location for error reporting in that 923 // case. 924 auto loc = attrType ? attribLoc : getToken().getLoc(); 925 auto type = parseElementsLiteralType(attrType); 926 if (!type) 927 return nullptr; 928 return literalParser.getAttr(loc, type); 929 } 930 931 /// Parse an opaque elements attribute. 932 Attribute Parser::parseOpaqueElementsAttr(Type attrType) { 933 SMLoc loc = getToken().getLoc(); 934 consumeToken(Token::kw_opaque); 935 if (parseToken(Token::less, "expected '<' after 'opaque'")) 936 return nullptr; 937 938 if (getToken().isNot(Token::string)) 939 return (emitError("expected dialect namespace"), nullptr); 940 941 std::string name = getToken().getStringValue(); 942 consumeToken(Token::string); 943 944 if (parseToken(Token::comma, "expected ','")) 945 return nullptr; 946 947 Token hexTok = getToken(); 948 if (parseToken(Token::string, "elements hex string should start with '0x'") || 949 parseToken(Token::greater, "expected '>'")) 950 return nullptr; 951 auto type = parseElementsLiteralType(attrType); 952 if (!type) 953 return nullptr; 954 955 std::string data; 956 if (parseElementAttrHexValues(*this, hexTok, data)) 957 return nullptr; 958 return getChecked<OpaqueElementsAttr>(loc, builder.getStringAttr(name), type, 959 data); 960 } 961 962 /// Shaped type for elements attribute. 963 /// 964 /// elements-literal-type ::= vector-type | ranked-tensor-type 965 /// 966 /// This method also checks the type has static shape. 967 ShapedType Parser::parseElementsLiteralType(Type type) { 968 // If the user didn't provide a type, parse the colon type for the literal. 969 if (!type) { 970 if (parseToken(Token::colon, "expected ':'")) 971 return nullptr; 972 if (!(type = parseType())) 973 return nullptr; 974 } 975 976 if (!type.isa<RankedTensorType, VectorType>()) { 977 emitError("elements literal must be a ranked tensor or vector type"); 978 return nullptr; 979 } 980 981 auto sType = type.cast<ShapedType>(); 982 if (!sType.hasStaticShape()) 983 return (emitError("elements literal type must have static shape"), nullptr); 984 985 return sType; 986 } 987 988 /// Parse a sparse elements attribute. 989 Attribute Parser::parseSparseElementsAttr(Type attrType) { 990 SMLoc loc = getToken().getLoc(); 991 consumeToken(Token::kw_sparse); 992 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 993 return nullptr; 994 995 // Check for the case where all elements are sparse. The indices are 996 // represented by a 2-dimensional shape where the second dimension is the rank 997 // of the type. 998 Type indiceEltType = builder.getIntegerType(64); 999 if (consumeIf(Token::greater)) { 1000 ShapedType type = parseElementsLiteralType(attrType); 1001 if (!type) 1002 return nullptr; 1003 1004 // Construct the sparse elements attr using zero element indice/value 1005 // attributes. 1006 ShapedType indicesType = 1007 RankedTensorType::get({0, type.getRank()}, indiceEltType); 1008 ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); 1009 return getChecked<SparseElementsAttr>( 1010 loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()), 1011 DenseElementsAttr::get(valuesType, ArrayRef<Attribute>())); 1012 } 1013 1014 /// Parse the indices. We don't allow hex values here as we may need to use 1015 /// the inferred shape. 1016 auto indicesLoc = getToken().getLoc(); 1017 TensorLiteralParser indiceParser(*this); 1018 if (indiceParser.parse(/*allowHex=*/false)) 1019 return nullptr; 1020 1021 if (parseToken(Token::comma, "expected ','")) 1022 return nullptr; 1023 1024 /// Parse the values. 1025 auto valuesLoc = getToken().getLoc(); 1026 TensorLiteralParser valuesParser(*this); 1027 if (valuesParser.parse(/*allowHex=*/true)) 1028 return nullptr; 1029 1030 if (parseToken(Token::greater, "expected '>'")) 1031 return nullptr; 1032 1033 auto type = parseElementsLiteralType(attrType); 1034 if (!type) 1035 return nullptr; 1036 1037 // If the indices are a splat, i.e. the literal parser parsed an element and 1038 // not a list, we set the shape explicitly. The indices are represented by a 1039 // 2-dimensional shape where the second dimension is the rank of the type. 1040 // Given that the parsed indices is a splat, we know that we only have one 1041 // indice and thus one for the first dimension. 1042 ShapedType indicesType; 1043 if (indiceParser.getShape().empty()) { 1044 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 1045 } else { 1046 // Otherwise, set the shape to the one parsed by the literal parser. 1047 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 1048 } 1049 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 1050 1051 // If the values are a splat, set the shape explicitly based on the number of 1052 // indices. The number of indices is encoded in the first dimension of the 1053 // indice shape type. 1054 auto valuesEltType = type.getElementType(); 1055 ShapedType valuesType = 1056 valuesParser.getShape().empty() 1057 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 1058 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 1059 auto values = valuesParser.getAttr(valuesLoc, valuesType); 1060 1061 // Build the sparse elements attribute by the indices and values. 1062 return getChecked<SparseElementsAttr>(loc, type, indices, values); 1063 } 1064