1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR textual form. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Parser.h" 14 #include "Lexer.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/IR/DialectImplementation.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/IR/MLIRContext.h" 24 #include "mlir/IR/Module.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/StandardTypes.h" 27 #include "mlir/IR/Verifier.h" 28 #include "llvm/ADT/APInt.h" 29 #include "llvm/ADT/DenseMap.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/ADT/bit.h" 33 #include "llvm/Support/PrettyStackTrace.h" 34 #include "llvm/Support/SMLoc.h" 35 #include "llvm/Support/SourceMgr.h" 36 #include <algorithm> 37 using namespace mlir; 38 using llvm::MemoryBuffer; 39 using llvm::SMLoc; 40 using llvm::SourceMgr; 41 42 namespace { 43 class Parser; 44 45 //===----------------------------------------------------------------------===// 46 // SymbolState 47 //===----------------------------------------------------------------------===// 48 49 /// This class contains record of any parsed top-level symbols. 50 struct SymbolState { 51 // A map from attribute alias identifier to Attribute. 52 llvm::StringMap<Attribute> attributeAliasDefinitions; 53 54 // A map from type alias identifier to Type. 55 llvm::StringMap<Type> typeAliasDefinitions; 56 57 /// A set of locations into the main parser memory buffer for each of the 58 /// active nested parsers. Given that some nested parsers, i.e. custom dialect 59 /// parsers, operate on a temporary memory buffer, this provides an anchor 60 /// point for emitting diagnostics. 61 SmallVector<llvm::SMLoc, 1> nestedParserLocs; 62 63 /// The top-level lexer that contains the original memory buffer provided by 64 /// the user. This is used by nested parsers to get a properly encoded source 65 /// location. 66 Lexer *topLevelLexer = nullptr; 67 }; 68 69 //===----------------------------------------------------------------------===// 70 // ParserState 71 //===----------------------------------------------------------------------===// 72 73 /// This class refers to all of the state maintained globally by the parser, 74 /// such as the current lexer position etc. 75 struct ParserState { 76 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, 77 SymbolState &symbols) 78 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), 79 symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { 80 // Set the top level lexer for the symbol state if one doesn't exist. 81 if (!symbols.topLevelLexer) 82 symbols.topLevelLexer = &lex; 83 } 84 ~ParserState() { 85 // Reset the top level lexer if it refers the lexer in our state. 86 if (symbols.topLevelLexer == &lex) 87 symbols.topLevelLexer = nullptr; 88 } 89 ParserState(const ParserState &) = delete; 90 void operator=(const ParserState &) = delete; 91 92 /// The context we're parsing into. 93 MLIRContext *const context; 94 95 /// The lexer for the source file we're parsing. 96 Lexer lex; 97 98 /// This is the next token that hasn't been consumed yet. 99 Token curToken; 100 101 /// The current state for symbol parsing. 102 SymbolState &symbols; 103 104 /// The depth of this parser in the nested parsing stack. 105 size_t parserDepth; 106 }; 107 108 //===----------------------------------------------------------------------===// 109 // Parser 110 //===----------------------------------------------------------------------===// 111 112 /// This class implement support for parsing global entities like types and 113 /// shared entities like SSA names. It is intended to be subclassed by 114 /// specialized subparsers that include state, e.g. when a local symbol table. 115 class Parser { 116 public: 117 Builder builder; 118 119 Parser(ParserState &state) : builder(state.context), state(state) {} 120 121 // Helper methods to get stuff from the parser-global state. 122 ParserState &getState() const { return state; } 123 MLIRContext *getContext() const { return state.context; } 124 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 125 126 /// Parse a comma-separated list of elements up until the specified end token. 127 ParseResult 128 parseCommaSeparatedListUntil(Token::Kind rightToken, 129 const std::function<ParseResult()> &parseElement, 130 bool allowEmptyList = true); 131 132 /// Parse a comma separated list of elements that must have at least one entry 133 /// in it. 134 ParseResult 135 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 136 137 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 138 139 // We have two forms of parsing methods - those that return a non-null 140 // pointer on success, and those that return a ParseResult to indicate whether 141 // they returned a failure. The second class fills in by-reference arguments 142 // as the results of their action. 143 144 //===--------------------------------------------------------------------===// 145 // Error Handling 146 //===--------------------------------------------------------------------===// 147 148 /// Emit an error and return failure. 149 InFlightDiagnostic emitError(const Twine &message = {}) { 150 return emitError(state.curToken.getLoc(), message); 151 } 152 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 153 154 /// Encode the specified source location information into an attribute for 155 /// attachment to the IR. 156 Location getEncodedSourceLocation(llvm::SMLoc loc) { 157 // If there are no active nested parsers, we can get the encoded source 158 // location directly. 159 if (state.parserDepth == 0) 160 return state.lex.getEncodedSourceLocation(loc); 161 // Otherwise, we need to re-encode it to point to the top level buffer. 162 return state.symbols.topLevelLexer->getEncodedSourceLocation( 163 remapLocationToTopLevelBuffer(loc)); 164 } 165 166 /// Remaps the given SMLoc to the top level lexer of the parser. This is used 167 /// to adjust locations of potentially nested parsers to ensure that they can 168 /// be emitted properly as diagnostics. 169 llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { 170 // If there are no active nested parsers, we can return location directly. 171 SymbolState &symbols = state.symbols; 172 if (state.parserDepth == 0) 173 return loc; 174 assert(symbols.topLevelLexer && "expected valid top-level lexer"); 175 176 // Otherwise, we need to remap the location to the main parser. This is 177 // simply offseting the location onto the location of the last nested 178 // parser. 179 size_t offset = loc.getPointer() - state.lex.getBufferBegin(); 180 auto *rawLoc = 181 symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; 182 return llvm::SMLoc::getFromPointer(rawLoc); 183 } 184 185 //===--------------------------------------------------------------------===// 186 // Token Parsing 187 //===--------------------------------------------------------------------===// 188 189 /// Return the current token the parser is inspecting. 190 const Token &getToken() const { return state.curToken; } 191 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 192 193 /// If the current token has the specified kind, consume it and return true. 194 /// If not, return false. 195 bool consumeIf(Token::Kind kind) { 196 if (state.curToken.isNot(kind)) 197 return false; 198 consumeToken(kind); 199 return true; 200 } 201 202 /// Advance the current lexer onto the next token. 203 void consumeToken() { 204 assert(state.curToken.isNot(Token::eof, Token::error) && 205 "shouldn't advance past EOF or errors"); 206 state.curToken = state.lex.lexToken(); 207 } 208 209 /// Advance the current lexer onto the next token, asserting what the expected 210 /// current token is. This is preferred to the above method because it leads 211 /// to more self-documenting code with better checking. 212 void consumeToken(Token::Kind kind) { 213 assert(state.curToken.is(kind) && "consumed an unexpected token"); 214 consumeToken(); 215 } 216 217 /// Consume the specified token if present and return success. On failure, 218 /// output a diagnostic and return failure. 219 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 220 221 //===--------------------------------------------------------------------===// 222 // Type Parsing 223 //===--------------------------------------------------------------------===// 224 225 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 226 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 227 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 228 229 /// Optionally parse a type. 230 OptionalParseResult parseOptionalType(Type &type); 231 232 /// Parse an arbitrary type. 233 Type parseType(); 234 235 /// Parse a complex type. 236 Type parseComplexType(); 237 238 /// Parse an extended type. 239 Type parseExtendedType(); 240 241 /// Parse a function type. 242 Type parseFunctionType(); 243 244 /// Parse a memref type. 245 Type parseMemRefType(); 246 247 /// Parse a non function type. 248 Type parseNonFunctionType(); 249 250 /// Parse a tensor type. 251 Type parseTensorType(); 252 253 /// Parse a tuple type. 254 Type parseTupleType(); 255 256 /// Parse a vector type. 257 VectorType parseVectorType(); 258 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 259 bool allowDynamic = true); 260 ParseResult parseXInDimensionList(); 261 262 /// Parse strided layout specification. 263 ParseResult parseStridedLayout(int64_t &offset, 264 SmallVectorImpl<int64_t> &strides); 265 266 // Parse a brace-delimiter list of comma-separated integers with `?` as an 267 // unknown marker. 268 ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions); 269 270 //===--------------------------------------------------------------------===// 271 // Attribute Parsing 272 //===--------------------------------------------------------------------===// 273 274 /// Parse an arbitrary attribute with an optional type. 275 Attribute parseAttribute(Type type = {}); 276 277 /// Parse an attribute dictionary. 278 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 279 280 /// Parse an extended attribute. 281 Attribute parseExtendedAttr(Type type); 282 283 /// Parse a float attribute. 284 Attribute parseFloatAttr(Type type, bool isNegative); 285 286 /// Parse a decimal or a hexadecimal literal, which can be either an integer 287 /// or a float attribute. 288 Attribute parseDecOrHexAttr(Type type, bool isNegative); 289 290 /// Parse an opaque elements attribute. 291 Attribute parseOpaqueElementsAttr(Type attrType); 292 293 /// Parse a dense elements attribute. 294 Attribute parseDenseElementsAttr(Type attrType); 295 ShapedType parseElementsLiteralType(Type type); 296 297 /// Parse a sparse elements attribute. 298 Attribute parseSparseElementsAttr(Type attrType); 299 300 //===--------------------------------------------------------------------===// 301 // Location Parsing 302 //===--------------------------------------------------------------------===// 303 304 /// Parse an inline location. 305 ParseResult parseLocation(LocationAttr &loc); 306 307 /// Parse a raw location instance. 308 ParseResult parseLocationInstance(LocationAttr &loc); 309 310 /// Parse a callsite location instance. 311 ParseResult parseCallSiteLocation(LocationAttr &loc); 312 313 /// Parse a fused location instance. 314 ParseResult parseFusedLocation(LocationAttr &loc); 315 316 /// Parse a name or FileLineCol location instance. 317 ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); 318 319 /// Parse an optional trailing location. 320 /// 321 /// trailing-location ::= (`loc` `(` location `)`)? 322 /// 323 ParseResult parseOptionalTrailingLocation(Location &loc) { 324 // If there is a 'loc' we parse a trailing location. 325 if (!getToken().is(Token::kw_loc)) 326 return success(); 327 328 // Parse the location. 329 LocationAttr directLoc; 330 if (parseLocation(directLoc)) 331 return failure(); 332 loc = directLoc; 333 return success(); 334 } 335 336 //===--------------------------------------------------------------------===// 337 // Affine Parsing 338 //===--------------------------------------------------------------------===// 339 340 /// Parse a reference to either an affine map, or an integer set. 341 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 342 IntegerSet &set); 343 ParseResult parseAffineMapReference(AffineMap &map); 344 ParseResult parseIntegerSetReference(IntegerSet &set); 345 346 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 347 ParseResult 348 parseAffineMapOfSSAIds(AffineMap &map, 349 function_ref<ParseResult(bool)> parseElement, 350 OpAsmParser::Delimiter delimiter); 351 352 private: 353 /// The Parser is subclassed and reinstantiated. Do not add additional 354 /// non-trivial state here, add it to the ParserState class. 355 ParserState &state; 356 }; 357 } // end anonymous namespace 358 359 //===----------------------------------------------------------------------===// 360 // Helper methods. 361 //===----------------------------------------------------------------------===// 362 363 /// Parse a comma separated list of elements that must have at least one entry 364 /// in it. 365 ParseResult Parser::parseCommaSeparatedList( 366 const std::function<ParseResult()> &parseElement) { 367 // Non-empty case starts with an element. 368 if (parseElement()) 369 return failure(); 370 371 // Otherwise we have a list of comma separated elements. 372 while (consumeIf(Token::comma)) { 373 if (parseElement()) 374 return failure(); 375 } 376 return success(); 377 } 378 379 /// Parse a comma-separated list of elements, terminated with an arbitrary 380 /// token. This allows empty lists if allowEmptyList is true. 381 /// 382 /// abstract-list ::= rightToken // if allowEmptyList == true 383 /// abstract-list ::= element (',' element)* rightToken 384 /// 385 ParseResult Parser::parseCommaSeparatedListUntil( 386 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 387 bool allowEmptyList) { 388 // Handle the empty case. 389 if (getToken().is(rightToken)) { 390 if (!allowEmptyList) 391 return emitError("expected list element"); 392 consumeToken(rightToken); 393 return success(); 394 } 395 396 if (parseCommaSeparatedList(parseElement) || 397 parseToken(rightToken, "expected ',' or '" + 398 Token::getTokenSpelling(rightToken) + "'")) 399 return failure(); 400 401 return success(); 402 } 403 404 //===----------------------------------------------------------------------===// 405 // DialectAsmParser 406 //===----------------------------------------------------------------------===// 407 408 namespace { 409 /// This class provides the main implementation of the DialectAsmParser that 410 /// allows for dialects to parse attributes and types. This allows for dialect 411 /// hooking into the main MLIR parsing logic. 412 class CustomDialectAsmParser : public DialectAsmParser { 413 public: 414 CustomDialectAsmParser(StringRef fullSpec, Parser &parser) 415 : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), 416 parser(parser) {} 417 ~CustomDialectAsmParser() override {} 418 419 /// Emit a diagnostic at the specified location and return failure. 420 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 421 return parser.emitError(loc, message); 422 } 423 424 /// Return a builder which provides useful access to MLIRContext, global 425 /// objects like types and attributes. 426 Builder &getBuilder() const override { return parser.builder; } 427 428 /// Get the location of the next token and store it into the argument. This 429 /// always succeeds. 430 llvm::SMLoc getCurrentLocation() override { 431 return parser.getToken().getLoc(); 432 } 433 434 /// Return the location of the original name token. 435 llvm::SMLoc getNameLoc() const override { return nameLoc; } 436 437 /// Re-encode the given source location as an MLIR location and return it. 438 Location getEncodedSourceLoc(llvm::SMLoc loc) override { 439 return parser.getEncodedSourceLocation(loc); 440 } 441 442 /// Returns the full specification of the symbol being parsed. This allows 443 /// for using a separate parser if necessary. 444 StringRef getFullSymbolSpec() const override { return fullSpec; } 445 446 /// Parse a floating point value from the stream. 447 ParseResult parseFloat(double &result) override { 448 bool negative = parser.consumeIf(Token::minus); 449 Token curTok = parser.getToken(); 450 451 // Check for a floating point value. 452 if (curTok.is(Token::floatliteral)) { 453 auto val = curTok.getFloatingPointValue(); 454 if (!val.hasValue()) 455 return emitError(curTok.getLoc(), "floating point value too large"); 456 parser.consumeToken(Token::floatliteral); 457 result = negative ? -*val : *val; 458 return success(); 459 } 460 461 // TODO(riverriddle) support hex floating point values. 462 return emitError(getCurrentLocation(), "expected floating point literal"); 463 } 464 465 /// Parse an optional integer value from the stream. 466 OptionalParseResult parseOptionalInteger(uint64_t &result) override { 467 Token curToken = parser.getToken(); 468 if (curToken.isNot(Token::integer, Token::minus)) 469 return llvm::None; 470 471 bool negative = parser.consumeIf(Token::minus); 472 Token curTok = parser.getToken(); 473 if (parser.parseToken(Token::integer, "expected integer value")) 474 return failure(); 475 476 auto val = curTok.getUInt64IntegerValue(); 477 if (!val) 478 return emitError(curTok.getLoc(), "integer value too large"); 479 result = negative ? -*val : *val; 480 return success(); 481 } 482 483 //===--------------------------------------------------------------------===// 484 // Token Parsing 485 //===--------------------------------------------------------------------===// 486 487 /// Parse a `->` token. 488 ParseResult parseArrow() override { 489 return parser.parseToken(Token::arrow, "expected '->'"); 490 } 491 492 /// Parses a `->` if present. 493 ParseResult parseOptionalArrow() override { 494 return success(parser.consumeIf(Token::arrow)); 495 } 496 497 /// Parse a '{' token. 498 ParseResult parseLBrace() override { 499 return parser.parseToken(Token::l_brace, "expected '{'"); 500 } 501 502 /// Parse a '{' token if present 503 ParseResult parseOptionalLBrace() override { 504 return success(parser.consumeIf(Token::l_brace)); 505 } 506 507 /// Parse a `}` token. 508 ParseResult parseRBrace() override { 509 return parser.parseToken(Token::r_brace, "expected '}'"); 510 } 511 512 /// Parse a `}` token if present 513 ParseResult parseOptionalRBrace() override { 514 return success(parser.consumeIf(Token::r_brace)); 515 } 516 517 /// Parse a `:` token. 518 ParseResult parseColon() override { 519 return parser.parseToken(Token::colon, "expected ':'"); 520 } 521 522 /// Parse a `:` token if present. 523 ParseResult parseOptionalColon() override { 524 return success(parser.consumeIf(Token::colon)); 525 } 526 527 /// Parse a `,` token. 528 ParseResult parseComma() override { 529 return parser.parseToken(Token::comma, "expected ','"); 530 } 531 532 /// Parse a `,` token if present. 533 ParseResult parseOptionalComma() override { 534 return success(parser.consumeIf(Token::comma)); 535 } 536 537 /// Parses a `...` if present. 538 ParseResult parseOptionalEllipsis() override { 539 return success(parser.consumeIf(Token::ellipsis)); 540 } 541 542 /// Parse a `=` token. 543 ParseResult parseEqual() override { 544 return parser.parseToken(Token::equal, "expected '='"); 545 } 546 547 /// Parse a '<' token. 548 ParseResult parseLess() override { 549 return parser.parseToken(Token::less, "expected '<'"); 550 } 551 552 /// Parse a `<` token if present. 553 ParseResult parseOptionalLess() override { 554 return success(parser.consumeIf(Token::less)); 555 } 556 557 /// Parse a '>' token. 558 ParseResult parseGreater() override { 559 return parser.parseToken(Token::greater, "expected '>'"); 560 } 561 562 /// Parse a `>` token if present. 563 ParseResult parseOptionalGreater() override { 564 return success(parser.consumeIf(Token::greater)); 565 } 566 567 /// Parse a `(` token. 568 ParseResult parseLParen() override { 569 return parser.parseToken(Token::l_paren, "expected '('"); 570 } 571 572 /// Parses a '(' if present. 573 ParseResult parseOptionalLParen() override { 574 return success(parser.consumeIf(Token::l_paren)); 575 } 576 577 /// Parse a `)` token. 578 ParseResult parseRParen() override { 579 return parser.parseToken(Token::r_paren, "expected ')'"); 580 } 581 582 /// Parses a ')' if present. 583 ParseResult parseOptionalRParen() override { 584 return success(parser.consumeIf(Token::r_paren)); 585 } 586 587 /// Parse a `[` token. 588 ParseResult parseLSquare() override { 589 return parser.parseToken(Token::l_square, "expected '['"); 590 } 591 592 /// Parses a '[' if present. 593 ParseResult parseOptionalLSquare() override { 594 return success(parser.consumeIf(Token::l_square)); 595 } 596 597 /// Parse a `]` token. 598 ParseResult parseRSquare() override { 599 return parser.parseToken(Token::r_square, "expected ']'"); 600 } 601 602 /// Parses a ']' if present. 603 ParseResult parseOptionalRSquare() override { 604 return success(parser.consumeIf(Token::r_square)); 605 } 606 607 /// Parses a '?' if present. 608 ParseResult parseOptionalQuestion() override { 609 return success(parser.consumeIf(Token::question)); 610 } 611 612 /// Parses a '*' if present. 613 ParseResult parseOptionalStar() override { 614 return success(parser.consumeIf(Token::star)); 615 } 616 617 /// Returns if the current token corresponds to a keyword. 618 bool isCurrentTokenAKeyword() const { 619 return parser.getToken().is(Token::bare_identifier) || 620 parser.getToken().isKeyword(); 621 } 622 623 /// Parse the given keyword if present. 624 ParseResult parseOptionalKeyword(StringRef keyword) override { 625 // Check that the current token has the same spelling. 626 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 627 return failure(); 628 parser.consumeToken(); 629 return success(); 630 } 631 632 /// Parse a keyword, if present, into 'keyword'. 633 ParseResult parseOptionalKeyword(StringRef *keyword) override { 634 // Check that the current token is a keyword. 635 if (!isCurrentTokenAKeyword()) 636 return failure(); 637 638 *keyword = parser.getTokenSpelling(); 639 parser.consumeToken(); 640 return success(); 641 } 642 643 //===--------------------------------------------------------------------===// 644 // Attribute Parsing 645 //===--------------------------------------------------------------------===// 646 647 /// Parse an arbitrary attribute and return it in result. 648 ParseResult parseAttribute(Attribute &result, Type type) override { 649 result = parser.parseAttribute(type); 650 return success(static_cast<bool>(result)); 651 } 652 653 /// Parse an affine map instance into 'map'. 654 ParseResult parseAffineMap(AffineMap &map) override { 655 return parser.parseAffineMapReference(map); 656 } 657 658 /// Parse an integer set instance into 'set'. 659 ParseResult printIntegerSet(IntegerSet &set) override { 660 return parser.parseIntegerSetReference(set); 661 } 662 663 //===--------------------------------------------------------------------===// 664 // Type Parsing 665 //===--------------------------------------------------------------------===// 666 667 ParseResult parseType(Type &result) override { 668 result = parser.parseType(); 669 return success(static_cast<bool>(result)); 670 } 671 672 ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, 673 bool allowDynamic) override { 674 return parser.parseDimensionListRanked(dimensions, allowDynamic); 675 } 676 677 private: 678 /// The full symbol specification. 679 StringRef fullSpec; 680 681 /// The source location of the dialect symbol. 682 SMLoc nameLoc; 683 684 /// The main parser. 685 Parser &parser; 686 }; 687 } // namespace 688 689 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 690 /// and may be recursive. Return with the 'prettyName' StringRef encompassing 691 /// the entire pretty name. 692 /// 693 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 694 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 695 /// | '(' pretty-dialect-sym-contents+ ')' 696 /// | '[' pretty-dialect-sym-contents+ ']' 697 /// | '{' pretty-dialect-sym-contents+ '}' 698 /// | '[^[<({>\])}\0]+' 699 /// 700 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 701 // Pretty symbol names are a relatively unstructured format that contains a 702 // series of properly nested punctuation, with anything else in the middle. 703 // Scan ahead to find it and consume it if successful, otherwise emit an 704 // error. 705 auto *curPtr = getTokenSpelling().data(); 706 707 SmallVector<char, 8> nestedPunctuation; 708 709 // Scan over the nested punctuation, bailing out on error and consuming until 710 // we find the end. We know that we're currently looking at the '<', so we 711 // can go until we find the matching '>' character. 712 assert(*curPtr == '<'); 713 do { 714 char c = *curPtr++; 715 switch (c) { 716 case '\0': 717 // This also handles the EOF case. 718 return emitError("unexpected nul or EOF in pretty dialect name"); 719 case '<': 720 case '[': 721 case '(': 722 case '{': 723 nestedPunctuation.push_back(c); 724 continue; 725 726 case '-': 727 // The sequence `->` is treated as special token. 728 if (*curPtr == '>') 729 ++curPtr; 730 continue; 731 732 case '>': 733 if (nestedPunctuation.pop_back_val() != '<') 734 return emitError("unbalanced '>' character in pretty dialect name"); 735 break; 736 case ']': 737 if (nestedPunctuation.pop_back_val() != '[') 738 return emitError("unbalanced ']' character in pretty dialect name"); 739 break; 740 case ')': 741 if (nestedPunctuation.pop_back_val() != '(') 742 return emitError("unbalanced ')' character in pretty dialect name"); 743 break; 744 case '}': 745 if (nestedPunctuation.pop_back_val() != '{') 746 return emitError("unbalanced '}' character in pretty dialect name"); 747 break; 748 749 default: 750 continue; 751 } 752 } while (!nestedPunctuation.empty()); 753 754 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 755 // consuming all this stuff, and return. 756 state.lex.resetPointer(curPtr); 757 758 unsigned length = curPtr - prettyName.begin(); 759 prettyName = StringRef(prettyName.begin(), length); 760 consumeToken(); 761 return success(); 762 } 763 764 /// Parse an extended dialect symbol. 765 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 766 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 767 SymbolAliasMap &aliases, 768 CreateFn &&createSymbol) { 769 // Parse the dialect namespace. 770 StringRef identifier = p.getTokenSpelling().drop_front(); 771 auto loc = p.getToken().getLoc(); 772 p.consumeToken(identifierTok); 773 774 // If there is no '<' token following this, and if the typename contains no 775 // dot, then we are parsing a symbol alias. 776 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 777 // Check for an alias for this type. 778 auto aliasIt = aliases.find(identifier); 779 if (aliasIt == aliases.end()) 780 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 781 nullptr); 782 return aliasIt->second; 783 } 784 785 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 786 // a dot, then this is the "pretty" form. If not, it is the verbose form that 787 // looks like <"...">. 788 std::string symbolData; 789 auto dialectName = identifier; 790 791 // Handle the verbose form, where "identifier" is a simple dialect name. 792 if (!identifier.contains('.')) { 793 // Consume the '<'. 794 if (p.parseToken(Token::less, "expected '<' in dialect type")) 795 return nullptr; 796 797 // Parse the symbol specific data. 798 if (p.getToken().isNot(Token::string)) 799 return (p.emitError("expected string literal data in dialect symbol"), 800 nullptr); 801 symbolData = p.getToken().getStringValue(); 802 loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); 803 p.consumeToken(Token::string); 804 805 // Consume the '>'. 806 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 807 return nullptr; 808 } else { 809 // Ok, the dialect name is the part of the identifier before the dot, the 810 // part after the dot is the dialect's symbol, or the start thereof. 811 auto dotHalves = identifier.split('.'); 812 dialectName = dotHalves.first; 813 auto prettyName = dotHalves.second; 814 loc = llvm::SMLoc::getFromPointer(prettyName.data()); 815 816 // If the dialect's symbol is followed immediately by a <, then lex the body 817 // of it into prettyName. 818 if (p.getToken().is(Token::less) && 819 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 820 if (p.parsePrettyDialectSymbolName(prettyName)) 821 return nullptr; 822 } 823 824 symbolData = prettyName.str(); 825 } 826 827 // Record the name location of the type remapped to the top level buffer. 828 llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); 829 p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); 830 831 // Call into the provided symbol construction function. 832 Symbol sym = createSymbol(dialectName, symbolData, loc); 833 834 // Pop the last parser location. 835 p.getState().symbols.nestedParserLocs.pop_back(); 836 return sym; 837 } 838 839 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 840 /// parsing failed, nullptr is returned. The number of bytes read from the input 841 /// string is returned in 'numRead'. 842 template <typename T, typename ParserFn> 843 static T parseSymbol(StringRef inputStr, MLIRContext *context, 844 SymbolState &symbolState, ParserFn &&parserFn, 845 size_t *numRead = nullptr) { 846 SourceMgr sourceMgr; 847 auto memBuffer = MemoryBuffer::getMemBuffer( 848 inputStr, /*BufferName=*/"<mlir_parser_buffer>", 849 /*RequiresNullTerminator=*/false); 850 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 851 ParserState state(sourceMgr, context, symbolState); 852 Parser parser(state); 853 854 Token startTok = parser.getToken(); 855 T symbol = parserFn(parser); 856 if (!symbol) 857 return T(); 858 859 // If 'numRead' is valid, then provide the number of bytes that were read. 860 Token endTok = parser.getToken(); 861 if (numRead) { 862 *numRead = static_cast<size_t>(endTok.getLoc().getPointer() - 863 startTok.getLoc().getPointer()); 864 865 // Otherwise, ensure that all of the tokens were parsed. 866 } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { 867 parser.emitError(endTok.getLoc(), "encountered unexpected token"); 868 return T(); 869 } 870 return symbol; 871 } 872 873 //===----------------------------------------------------------------------===// 874 // Error Handling 875 //===----------------------------------------------------------------------===// 876 877 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 878 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 879 880 // If we hit a parse error in response to a lexer error, then the lexer 881 // already reported the error. 882 if (getToken().is(Token::error)) 883 diag.abandon(); 884 return diag; 885 } 886 887 //===----------------------------------------------------------------------===// 888 // Token Parsing 889 //===----------------------------------------------------------------------===// 890 891 /// Consume the specified token if present and return success. On failure, 892 /// output a diagnostic and return failure. 893 ParseResult Parser::parseToken(Token::Kind expectedToken, 894 const Twine &message) { 895 if (consumeIf(expectedToken)) 896 return success(); 897 return emitError(message); 898 } 899 900 //===----------------------------------------------------------------------===// 901 // Type Parsing 902 //===----------------------------------------------------------------------===// 903 904 /// Optionally parse a type. 905 OptionalParseResult Parser::parseOptionalType(Type &type) { 906 // There are many different starting tokens for a type, check them here. 907 switch (getToken().getKind()) { 908 case Token::l_paren: 909 case Token::kw_memref: 910 case Token::kw_tensor: 911 case Token::kw_complex: 912 case Token::kw_tuple: 913 case Token::kw_vector: 914 case Token::inttype: 915 case Token::kw_bf16: 916 case Token::kw_f16: 917 case Token::kw_f32: 918 case Token::kw_f64: 919 case Token::kw_index: 920 case Token::kw_none: 921 case Token::exclamation_identifier: 922 return failure(!(type = parseType())); 923 924 default: 925 return llvm::None; 926 } 927 } 928 929 /// Parse an arbitrary type. 930 /// 931 /// type ::= function-type 932 /// | non-function-type 933 /// 934 Type Parser::parseType() { 935 if (getToken().is(Token::l_paren)) 936 return parseFunctionType(); 937 return parseNonFunctionType(); 938 } 939 940 /// Parse a function result type. 941 /// 942 /// function-result-type ::= type-list-parens 943 /// | non-function-type 944 /// 945 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 946 if (getToken().is(Token::l_paren)) 947 return parseTypeListParens(elements); 948 949 Type t = parseNonFunctionType(); 950 if (!t) 951 return failure(); 952 elements.push_back(t); 953 return success(); 954 } 955 956 /// Parse a list of types without an enclosing parenthesis. The list must have 957 /// at least one member. 958 /// 959 /// type-list-no-parens ::= type (`,` type)* 960 /// 961 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 962 auto parseElt = [&]() -> ParseResult { 963 auto elt = parseType(); 964 elements.push_back(elt); 965 return elt ? success() : failure(); 966 }; 967 968 return parseCommaSeparatedList(parseElt); 969 } 970 971 /// Parse a parenthesized list of types. 972 /// 973 /// type-list-parens ::= `(` `)` 974 /// | `(` type-list-no-parens `)` 975 /// 976 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 977 if (parseToken(Token::l_paren, "expected '('")) 978 return failure(); 979 980 // Handle empty lists. 981 if (getToken().is(Token::r_paren)) 982 return consumeToken(), success(); 983 984 if (parseTypeListNoParens(elements) || 985 parseToken(Token::r_paren, "expected ')'")) 986 return failure(); 987 return success(); 988 } 989 990 /// Parse a complex type. 991 /// 992 /// complex-type ::= `complex` `<` type `>` 993 /// 994 Type Parser::parseComplexType() { 995 consumeToken(Token::kw_complex); 996 997 // Parse the '<'. 998 if (parseToken(Token::less, "expected '<' in complex type")) 999 return nullptr; 1000 1001 llvm::SMLoc elementTypeLoc = getToken().getLoc(); 1002 auto elementType = parseType(); 1003 if (!elementType || 1004 parseToken(Token::greater, "expected '>' in complex type")) 1005 return nullptr; 1006 if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) 1007 return emitError(elementTypeLoc, "invalid element type for complex"), 1008 nullptr; 1009 1010 return ComplexType::get(elementType); 1011 } 1012 1013 /// Parse an extended type. 1014 /// 1015 /// extended-type ::= (dialect-type | type-alias) 1016 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 1017 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 1018 /// type-alias ::= `!` alias-name 1019 /// 1020 Type Parser::parseExtendedType() { 1021 return parseExtendedSymbol<Type>( 1022 *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, 1023 [&](StringRef dialectName, StringRef symbolData, 1024 llvm::SMLoc loc) -> Type { 1025 // If we found a registered dialect, then ask it to parse the type. 1026 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1027 return parseSymbol<Type>( 1028 symbolData, state.context, state.symbols, [&](Parser &parser) { 1029 CustomDialectAsmParser customParser(symbolData, parser); 1030 return dialect->parseType(customParser); 1031 }); 1032 } 1033 1034 // Otherwise, form a new opaque type. 1035 return OpaqueType::getChecked( 1036 Identifier::get(dialectName, state.context), symbolData, 1037 state.context, getEncodedSourceLocation(loc)); 1038 }); 1039 } 1040 1041 /// Parse a function type. 1042 /// 1043 /// function-type ::= type-list-parens `->` function-result-type 1044 /// 1045 Type Parser::parseFunctionType() { 1046 assert(getToken().is(Token::l_paren)); 1047 1048 SmallVector<Type, 4> arguments, results; 1049 if (parseTypeListParens(arguments) || 1050 parseToken(Token::arrow, "expected '->' in function type") || 1051 parseFunctionResultTypes(results)) 1052 return nullptr; 1053 1054 return builder.getFunctionType(arguments, results); 1055 } 1056 1057 /// Parse the offset and strides from a strided layout specification. 1058 /// 1059 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 1060 /// 1061 ParseResult Parser::parseStridedLayout(int64_t &offset, 1062 SmallVectorImpl<int64_t> &strides) { 1063 // Parse offset. 1064 consumeToken(Token::kw_offset); 1065 if (!consumeIf(Token::colon)) 1066 return emitError("expected colon after `offset` keyword"); 1067 auto maybeOffset = getToken().getUnsignedIntegerValue(); 1068 bool question = getToken().is(Token::question); 1069 if (!maybeOffset && !question) 1070 return emitError("invalid offset"); 1071 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue()) 1072 : MemRefType::getDynamicStrideOrOffset(); 1073 consumeToken(); 1074 1075 if (!consumeIf(Token::comma)) 1076 return emitError("expected comma after offset value"); 1077 1078 // Parse stride list. 1079 if (!consumeIf(Token::kw_strides)) 1080 return emitError("expected `strides` keyword after offset specification"); 1081 if (!consumeIf(Token::colon)) 1082 return emitError("expected colon after `strides` keyword"); 1083 if (failed(parseStrideList(strides))) 1084 return emitError("invalid braces-enclosed stride list"); 1085 if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) 1086 return emitError("invalid memref stride"); 1087 1088 return success(); 1089 } 1090 1091 /// Parse a memref type. 1092 /// 1093 /// memref-type ::= ranked-memref-type | unranked-memref-type 1094 /// 1095 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type 1096 /// (`,` semi-affine-map-composition)? (`,` 1097 /// memory-space)? `>` 1098 /// 1099 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` 1100 /// 1101 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 1102 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 1103 /// 1104 Type Parser::parseMemRefType() { 1105 consumeToken(Token::kw_memref); 1106 1107 if (parseToken(Token::less, "expected '<' in memref type")) 1108 return nullptr; 1109 1110 bool isUnranked; 1111 SmallVector<int64_t, 4> dimensions; 1112 1113 if (consumeIf(Token::star)) { 1114 // This is an unranked memref type. 1115 isUnranked = true; 1116 if (parseXInDimensionList()) 1117 return nullptr; 1118 1119 } else { 1120 isUnranked = false; 1121 if (parseDimensionListRanked(dimensions)) 1122 return nullptr; 1123 } 1124 1125 // Parse the element type. 1126 auto typeLoc = getToken().getLoc(); 1127 auto elementType = parseType(); 1128 if (!elementType) 1129 return nullptr; 1130 1131 // Check that memref is formed from allowed types. 1132 if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() && 1133 !elementType.isa<ComplexType>()) 1134 return emitError(typeLoc, "invalid memref element type"), nullptr; 1135 1136 // Parse semi-affine-map-composition. 1137 SmallVector<AffineMap, 2> affineMapComposition; 1138 Optional<unsigned> memorySpace; 1139 unsigned numDims = dimensions.size(); 1140 1141 auto parseElt = [&]() -> ParseResult { 1142 // Check for the memory space. 1143 if (getToken().is(Token::integer)) { 1144 if (memorySpace) 1145 return emitError("multiple memory spaces specified in memref type"); 1146 memorySpace = getToken().getUnsignedIntegerValue(); 1147 if (!memorySpace.hasValue()) 1148 return emitError("invalid memory space in memref type"); 1149 consumeToken(Token::integer); 1150 return success(); 1151 } 1152 if (isUnranked) 1153 return emitError("cannot have affine map for unranked memref type"); 1154 if (memorySpace) 1155 return emitError("expected memory space to be last in memref type"); 1156 1157 AffineMap map; 1158 llvm::SMLoc mapLoc = getToken().getLoc(); 1159 if (getToken().is(Token::kw_offset)) { 1160 int64_t offset; 1161 SmallVector<int64_t, 4> strides; 1162 if (failed(parseStridedLayout(offset, strides))) 1163 return failure(); 1164 // Construct strided affine map. 1165 map = makeStridedLinearLayoutMap(strides, offset, state.context); 1166 } else { 1167 // Parse an affine map attribute. 1168 auto affineMap = parseAttribute(); 1169 if (!affineMap) 1170 return failure(); 1171 auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>(); 1172 if (!affineMapAttr) 1173 return emitError("expected affine map in memref type"); 1174 map = affineMapAttr.getValue(); 1175 } 1176 1177 if (map.getNumDims() != numDims) { 1178 size_t i = affineMapComposition.size(); 1179 return emitError(mapLoc, "memref affine map dimension mismatch between ") 1180 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 1181 << " and affine map" << i + 1 << ": " << numDims 1182 << " != " << map.getNumDims(); 1183 } 1184 numDims = map.getNumResults(); 1185 affineMapComposition.push_back(map); 1186 return success(); 1187 }; 1188 1189 // Parse a list of mappings and address space if present. 1190 if (!consumeIf(Token::greater)) { 1191 // Parse comma separated list of affine maps, followed by memory space. 1192 if (parseToken(Token::comma, "expected ',' or '>' in memref type") || 1193 parseCommaSeparatedListUntil(Token::greater, parseElt, 1194 /*allowEmptyList=*/false)) { 1195 return nullptr; 1196 } 1197 } 1198 1199 if (isUnranked) 1200 return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); 1201 1202 return MemRefType::get(dimensions, elementType, affineMapComposition, 1203 memorySpace.getValueOr(0)); 1204 } 1205 1206 /// Parse any type except the function type. 1207 /// 1208 /// non-function-type ::= integer-type 1209 /// | index-type 1210 /// | float-type 1211 /// | extended-type 1212 /// | vector-type 1213 /// | tensor-type 1214 /// | memref-type 1215 /// | complex-type 1216 /// | tuple-type 1217 /// | none-type 1218 /// 1219 /// index-type ::= `index` 1220 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 1221 /// none-type ::= `none` 1222 /// 1223 Type Parser::parseNonFunctionType() { 1224 switch (getToken().getKind()) { 1225 default: 1226 return (emitError("expected non-function type"), nullptr); 1227 case Token::kw_memref: 1228 return parseMemRefType(); 1229 case Token::kw_tensor: 1230 return parseTensorType(); 1231 case Token::kw_complex: 1232 return parseComplexType(); 1233 case Token::kw_tuple: 1234 return parseTupleType(); 1235 case Token::kw_vector: 1236 return parseVectorType(); 1237 // integer-type 1238 case Token::inttype: { 1239 auto width = getToken().getIntTypeBitwidth(); 1240 if (!width.hasValue()) 1241 return (emitError("invalid integer width"), nullptr); 1242 if (width.getValue() > IntegerType::kMaxWidth) { 1243 emitError(getToken().getLoc(), "integer bitwidth is limited to ") 1244 << IntegerType::kMaxWidth << " bits"; 1245 return nullptr; 1246 } 1247 1248 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; 1249 if (Optional<bool> signedness = getToken().getIntTypeSignedness()) 1250 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; 1251 1252 auto loc = getEncodedSourceLocation(getToken().getLoc()); 1253 consumeToken(Token::inttype); 1254 return IntegerType::getChecked(width.getValue(), signSemantics, loc); 1255 } 1256 1257 // float-type 1258 case Token::kw_bf16: 1259 consumeToken(Token::kw_bf16); 1260 return builder.getBF16Type(); 1261 case Token::kw_f16: 1262 consumeToken(Token::kw_f16); 1263 return builder.getF16Type(); 1264 case Token::kw_f32: 1265 consumeToken(Token::kw_f32); 1266 return builder.getF32Type(); 1267 case Token::kw_f64: 1268 consumeToken(Token::kw_f64); 1269 return builder.getF64Type(); 1270 1271 // index-type 1272 case Token::kw_index: 1273 consumeToken(Token::kw_index); 1274 return builder.getIndexType(); 1275 1276 // none-type 1277 case Token::kw_none: 1278 consumeToken(Token::kw_none); 1279 return builder.getNoneType(); 1280 1281 // extended type 1282 case Token::exclamation_identifier: 1283 return parseExtendedType(); 1284 } 1285 } 1286 1287 /// Parse a tensor type. 1288 /// 1289 /// tensor-type ::= `tensor` `<` dimension-list type `>` 1290 /// dimension-list ::= dimension-list-ranked | `*x` 1291 /// 1292 Type Parser::parseTensorType() { 1293 consumeToken(Token::kw_tensor); 1294 1295 if (parseToken(Token::less, "expected '<' in tensor type")) 1296 return nullptr; 1297 1298 bool isUnranked; 1299 SmallVector<int64_t, 4> dimensions; 1300 1301 if (consumeIf(Token::star)) { 1302 // This is an unranked tensor type. 1303 isUnranked = true; 1304 1305 if (parseXInDimensionList()) 1306 return nullptr; 1307 1308 } else { 1309 isUnranked = false; 1310 if (parseDimensionListRanked(dimensions)) 1311 return nullptr; 1312 } 1313 1314 // Parse the element type. 1315 auto elementTypeLoc = getToken().getLoc(); 1316 auto elementType = parseType(); 1317 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 1318 return nullptr; 1319 if (!TensorType::isValidElementType(elementType)) 1320 return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; 1321 1322 if (isUnranked) 1323 return UnrankedTensorType::get(elementType); 1324 return RankedTensorType::get(dimensions, elementType); 1325 } 1326 1327 /// Parse a tuple type. 1328 /// 1329 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 1330 /// 1331 Type Parser::parseTupleType() { 1332 consumeToken(Token::kw_tuple); 1333 1334 // Parse the '<'. 1335 if (parseToken(Token::less, "expected '<' in tuple type")) 1336 return nullptr; 1337 1338 // Check for an empty tuple by directly parsing '>'. 1339 if (consumeIf(Token::greater)) 1340 return TupleType::get(getContext()); 1341 1342 // Parse the element types and the '>'. 1343 SmallVector<Type, 4> types; 1344 if (parseTypeListNoParens(types) || 1345 parseToken(Token::greater, "expected '>' in tuple type")) 1346 return nullptr; 1347 1348 return TupleType::get(types, getContext()); 1349 } 1350 1351 /// Parse a vector type. 1352 /// 1353 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` 1354 /// non-empty-static-dimension-list ::= decimal-literal `x` 1355 /// static-dimension-list 1356 /// static-dimension-list ::= (decimal-literal `x`)* 1357 /// 1358 VectorType Parser::parseVectorType() { 1359 consumeToken(Token::kw_vector); 1360 1361 if (parseToken(Token::less, "expected '<' in vector type")) 1362 return nullptr; 1363 1364 SmallVector<int64_t, 4> dimensions; 1365 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 1366 return nullptr; 1367 if (dimensions.empty()) 1368 return (emitError("expected dimension size in vector type"), nullptr); 1369 if (any_of(dimensions, [](int64_t i) { return i <= 0; })) 1370 return emitError(getToken().getLoc(), 1371 "vector types must have positive constant sizes"), 1372 nullptr; 1373 1374 // Parse the element type. 1375 auto typeLoc = getToken().getLoc(); 1376 auto elementType = parseType(); 1377 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 1378 return nullptr; 1379 if (!VectorType::isValidElementType(elementType)) 1380 return emitError(typeLoc, "vector elements must be int or float type"), 1381 nullptr; 1382 1383 return VectorType::get(dimensions, elementType); 1384 } 1385 1386 /// Parse a dimension list of a tensor or memref type. This populates the 1387 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 1388 /// errors out on `?` otherwise. 1389 /// 1390 /// dimension-list-ranked ::= (dimension `x`)* 1391 /// dimension ::= `?` | decimal-literal 1392 /// 1393 /// When `allowDynamic` is not set, this is used to parse: 1394 /// 1395 /// static-dimension-list ::= (decimal-literal `x`)* 1396 ParseResult 1397 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 1398 bool allowDynamic) { 1399 while (getToken().isAny(Token::integer, Token::question)) { 1400 if (consumeIf(Token::question)) { 1401 if (!allowDynamic) 1402 return emitError("expected static shape"); 1403 dimensions.push_back(-1); 1404 } else { 1405 // Hexadecimal integer literals (starting with `0x`) are not allowed in 1406 // aggregate type declarations. Therefore, `0xf32` should be processed as 1407 // a sequence of separate elements `0`, `x`, `f32`. 1408 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 1409 // We can get here only if the token is an integer literal. Hexadecimal 1410 // integer literals can only start with `0x` (`1x` wouldn't lex as a 1411 // literal, just `1` would, at which point we don't get into this 1412 // branch). 1413 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 1414 dimensions.push_back(0); 1415 state.lex.resetPointer(getTokenSpelling().data() + 1); 1416 consumeToken(); 1417 } else { 1418 // Make sure this integer value is in bound and valid. 1419 auto dimension = getToken().getUnsignedIntegerValue(); 1420 if (!dimension.hasValue()) 1421 return emitError("invalid dimension"); 1422 dimensions.push_back((int64_t)dimension.getValue()); 1423 consumeToken(Token::integer); 1424 } 1425 } 1426 1427 // Make sure we have an 'x' or something like 'xbf32'. 1428 if (parseXInDimensionList()) 1429 return failure(); 1430 } 1431 1432 return success(); 1433 } 1434 1435 /// Parse an 'x' token in a dimension list, handling the case where the x is 1436 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 1437 /// token. 1438 ParseResult Parser::parseXInDimensionList() { 1439 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 1440 return emitError("expected 'x' in dimension list"); 1441 1442 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 1443 if (getTokenSpelling().size() != 1) 1444 state.lex.resetPointer(getTokenSpelling().data() + 1); 1445 1446 // Consume the 'x'. 1447 consumeToken(Token::bare_identifier); 1448 1449 return success(); 1450 } 1451 1452 // Parse a comma-separated list of dimensions, possibly empty: 1453 // stride-list ::= `[` (dimension (`,` dimension)*)? `]` 1454 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { 1455 if (!consumeIf(Token::l_square)) 1456 return failure(); 1457 // Empty list early exit. 1458 if (consumeIf(Token::r_square)) 1459 return success(); 1460 while (true) { 1461 if (consumeIf(Token::question)) { 1462 dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); 1463 } else { 1464 // This must be an integer value. 1465 int64_t val; 1466 if (getToken().getSpelling().getAsInteger(10, val)) 1467 return emitError("invalid integer value: ") << getToken().getSpelling(); 1468 // Make sure it is not the one value for `?`. 1469 if (ShapedType::isDynamic(val)) 1470 return emitError("invalid integer value: ") 1471 << getToken().getSpelling() 1472 << ", use `?` to specify a dynamic dimension"; 1473 dimensions.push_back(val); 1474 consumeToken(Token::integer); 1475 } 1476 if (!consumeIf(Token::comma)) 1477 break; 1478 } 1479 if (!consumeIf(Token::r_square)) 1480 return failure(); 1481 return success(); 1482 } 1483 1484 //===----------------------------------------------------------------------===// 1485 // Attribute parsing. 1486 //===----------------------------------------------------------------------===// 1487 1488 /// Return the symbol reference referred to by the given token, that is known to 1489 /// be an @-identifier. 1490 static std::string extractSymbolReference(Token tok) { 1491 assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); 1492 StringRef nameStr = tok.getSpelling().drop_front(); 1493 1494 // Check to see if the reference is a string literal, or a bare identifier. 1495 if (nameStr.front() == '"') 1496 return tok.getStringValue(); 1497 return std::string(nameStr); 1498 } 1499 1500 /// Parse an arbitrary attribute. 1501 /// 1502 /// attribute-value ::= `unit` 1503 /// | bool-literal 1504 /// | integer-literal (`:` (index-type | integer-type))? 1505 /// | float-literal (`:` float-type)? 1506 /// | string-literal (`:` type)? 1507 /// | type 1508 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 1509 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 1510 /// | symbol-ref-id (`::` symbol-ref-id)* 1511 /// | `dense` `<` attribute-value `>` `:` 1512 /// (tensor-type | vector-type) 1513 /// | `sparse` `<` attribute-value `,` attribute-value `>` 1514 /// `:` (tensor-type | vector-type) 1515 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 1516 /// `>` `:` (tensor-type | vector-type) 1517 /// | extended-attribute 1518 /// 1519 Attribute Parser::parseAttribute(Type type) { 1520 switch (getToken().getKind()) { 1521 // Parse an AffineMap or IntegerSet attribute. 1522 case Token::kw_affine_map: { 1523 consumeToken(Token::kw_affine_map); 1524 1525 AffineMap map; 1526 if (parseToken(Token::less, "expected '<' in affine map") || 1527 parseAffineMapReference(map) || 1528 parseToken(Token::greater, "expected '>' in affine map")) 1529 return Attribute(); 1530 return AffineMapAttr::get(map); 1531 } 1532 case Token::kw_affine_set: { 1533 consumeToken(Token::kw_affine_set); 1534 1535 IntegerSet set; 1536 if (parseToken(Token::less, "expected '<' in integer set") || 1537 parseIntegerSetReference(set) || 1538 parseToken(Token::greater, "expected '>' in integer set")) 1539 return Attribute(); 1540 return IntegerSetAttr::get(set); 1541 } 1542 1543 // Parse an array attribute. 1544 case Token::l_square: { 1545 consumeToken(Token::l_square); 1546 1547 SmallVector<Attribute, 4> elements; 1548 auto parseElt = [&]() -> ParseResult { 1549 elements.push_back(parseAttribute()); 1550 return elements.back() ? success() : failure(); 1551 }; 1552 1553 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 1554 return nullptr; 1555 return builder.getArrayAttr(elements); 1556 } 1557 1558 // Parse a boolean attribute. 1559 case Token::kw_false: 1560 consumeToken(Token::kw_false); 1561 return builder.getBoolAttr(false); 1562 case Token::kw_true: 1563 consumeToken(Token::kw_true); 1564 return builder.getBoolAttr(true); 1565 1566 // Parse a dense elements attribute. 1567 case Token::kw_dense: 1568 return parseDenseElementsAttr(type); 1569 1570 // Parse a dictionary attribute. 1571 case Token::l_brace: { 1572 SmallVector<NamedAttribute, 4> elements; 1573 if (parseAttributeDict(elements)) 1574 return nullptr; 1575 return builder.getDictionaryAttr(elements); 1576 } 1577 1578 // Parse an extended attribute, i.e. alias or dialect attribute. 1579 case Token::hash_identifier: 1580 return parseExtendedAttr(type); 1581 1582 // Parse floating point and integer attributes. 1583 case Token::floatliteral: 1584 return parseFloatAttr(type, /*isNegative=*/false); 1585 case Token::integer: 1586 return parseDecOrHexAttr(type, /*isNegative=*/false); 1587 case Token::minus: { 1588 consumeToken(Token::minus); 1589 if (getToken().is(Token::integer)) 1590 return parseDecOrHexAttr(type, /*isNegative=*/true); 1591 if (getToken().is(Token::floatliteral)) 1592 return parseFloatAttr(type, /*isNegative=*/true); 1593 1594 return (emitError("expected constant integer or floating point value"), 1595 nullptr); 1596 } 1597 1598 // Parse a location attribute. 1599 case Token::kw_loc: { 1600 LocationAttr attr; 1601 return failed(parseLocation(attr)) ? Attribute() : attr; 1602 } 1603 1604 // Parse an opaque elements attribute. 1605 case Token::kw_opaque: 1606 return parseOpaqueElementsAttr(type); 1607 1608 // Parse a sparse elements attribute. 1609 case Token::kw_sparse: 1610 return parseSparseElementsAttr(type); 1611 1612 // Parse a string attribute. 1613 case Token::string: { 1614 auto val = getToken().getStringValue(); 1615 consumeToken(Token::string); 1616 // Parse the optional trailing colon type if one wasn't explicitly provided. 1617 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1618 return Attribute(); 1619 1620 return type ? StringAttr::get(val, type) 1621 : StringAttr::get(val, getContext()); 1622 } 1623 1624 // Parse a symbol reference attribute. 1625 case Token::at_identifier: { 1626 std::string nameStr = extractSymbolReference(getToken()); 1627 consumeToken(Token::at_identifier); 1628 1629 // Parse any nested references. 1630 std::vector<FlatSymbolRefAttr> nestedRefs; 1631 while (getToken().is(Token::colon)) { 1632 // Check for the '::' prefix. 1633 const char *curPointer = getToken().getLoc().getPointer(); 1634 consumeToken(Token::colon); 1635 if (!consumeIf(Token::colon)) { 1636 state.lex.resetPointer(curPointer); 1637 consumeToken(); 1638 break; 1639 } 1640 // Parse the reference itself. 1641 auto curLoc = getToken().getLoc(); 1642 if (getToken().isNot(Token::at_identifier)) { 1643 emitError(curLoc, "expected nested symbol reference identifier"); 1644 return Attribute(); 1645 } 1646 1647 std::string nameStr = extractSymbolReference(getToken()); 1648 consumeToken(Token::at_identifier); 1649 nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); 1650 } 1651 1652 return builder.getSymbolRefAttr(nameStr, nestedRefs); 1653 } 1654 1655 // Parse a 'unit' attribute. 1656 case Token::kw_unit: 1657 consumeToken(Token::kw_unit); 1658 return builder.getUnitAttr(); 1659 1660 default: 1661 // Parse a type attribute. 1662 if (Type type = parseType()) 1663 return TypeAttr::get(type); 1664 return nullptr; 1665 } 1666 } 1667 1668 /// Attribute dictionary. 1669 /// 1670 /// attribute-dict ::= `{` `}` 1671 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1672 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value 1673 /// 1674 ParseResult 1675 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1676 if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) 1677 return failure(); 1678 1679 llvm::SmallDenseSet<Identifier> seenKeys; 1680 auto parseElt = [&]() -> ParseResult { 1681 // The name of an attribute can either be a bare identifier, or a string. 1682 Optional<Identifier> nameId; 1683 if (getToken().is(Token::string)) 1684 nameId = builder.getIdentifier(getToken().getStringValue()); 1685 else if (getToken().isAny(Token::bare_identifier, Token::inttype) || 1686 getToken().isKeyword()) 1687 nameId = builder.getIdentifier(getTokenSpelling()); 1688 else 1689 return emitError("expected attribute name"); 1690 if (!seenKeys.insert(*nameId).second) 1691 return emitError("duplicate key in dictionary attribute"); 1692 consumeToken(); 1693 1694 // Try to parse the '=' for the attribute value. 1695 if (!consumeIf(Token::equal)) { 1696 // If there is no '=', we treat this as a unit attribute. 1697 attributes.push_back({*nameId, builder.getUnitAttr()}); 1698 return success(); 1699 } 1700 1701 auto attr = parseAttribute(); 1702 if (!attr) 1703 return failure(); 1704 1705 attributes.push_back({*nameId, attr}); 1706 return success(); 1707 }; 1708 1709 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1710 return failure(); 1711 1712 return success(); 1713 } 1714 1715 /// Parse an extended attribute. 1716 /// 1717 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1718 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1719 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1720 /// attribute-alias ::= `#` alias-name 1721 /// 1722 Attribute Parser::parseExtendedAttr(Type type) { 1723 Attribute attr = parseExtendedSymbol<Attribute>( 1724 *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, 1725 [&](StringRef dialectName, StringRef symbolData, 1726 llvm::SMLoc loc) -> Attribute { 1727 // Parse an optional trailing colon type. 1728 Type attrType = type; 1729 if (consumeIf(Token::colon) && !(attrType = parseType())) 1730 return Attribute(); 1731 1732 // If we found a registered dialect, then ask it to parse the attribute. 1733 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1734 return parseSymbol<Attribute>( 1735 symbolData, state.context, state.symbols, [&](Parser &parser) { 1736 CustomDialectAsmParser customParser(symbolData, parser); 1737 return dialect->parseAttribute(customParser, attrType); 1738 }); 1739 } 1740 1741 // Otherwise, form a new opaque attribute. 1742 return OpaqueAttr::getChecked( 1743 Identifier::get(dialectName, state.context), symbolData, 1744 attrType ? attrType : NoneType::get(state.context), 1745 getEncodedSourceLocation(loc)); 1746 }); 1747 1748 // Ensure that the attribute has the same type as requested. 1749 if (attr && type && attr.getType() != type) { 1750 emitError("attribute type different than expected: expected ") 1751 << type << ", but got " << attr.getType(); 1752 return nullptr; 1753 } 1754 return attr; 1755 } 1756 1757 /// Parse a float attribute. 1758 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1759 auto val = getToken().getFloatingPointValue(); 1760 if (!val.hasValue()) 1761 return (emitError("floating point value too large for attribute"), nullptr); 1762 consumeToken(Token::floatliteral); 1763 if (!type) { 1764 // Default to F64 when no type is specified. 1765 if (!consumeIf(Token::colon)) 1766 type = builder.getF64Type(); 1767 else if (!(type = parseType())) 1768 return nullptr; 1769 } 1770 if (!type.isa<FloatType>()) 1771 return (emitError("floating point value not valid for specified type"), 1772 nullptr); 1773 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1774 } 1775 1776 /// Construct a float attribute bitwise equivalent to the integer literal. 1777 static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1778 uint64_t value) { 1779 // FIXME: bfloat is currently stored as a double internally because it doesn't 1780 // have valid APFloat semantics. 1781 if (type.isF64() || type.isBF16()) 1782 return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); 1783 1784 APInt apInt(type.getWidth(), value); 1785 if (apInt != value) { 1786 p->emitError("hexadecimal float constant out of range for type"); 1787 return llvm::None; 1788 } 1789 return APFloat(type.getFloatSemantics(), apInt); 1790 } 1791 1792 /// Construct an APint from a parsed value, a known attribute type and 1793 /// sign. 1794 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative, 1795 StringRef spelling) { 1796 // Parse the integer value into an APInt that is big enough to hold the value. 1797 APInt result; 1798 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1799 if (spelling.getAsInteger(isHex ? 0 : 10, result)) 1800 return llvm::None; 1801 1802 // Extend or truncate the bitwidth to the right size. 1803 unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth 1804 : type.getIntOrFloatBitWidth(); 1805 if (width > result.getBitWidth()) { 1806 result = result.zext(width); 1807 } else if (width < result.getBitWidth()) { 1808 // The parser can return an unnecessarily wide result with leading zeros. 1809 // This isn't a problem, but truncating off bits is bad. 1810 if (result.countLeadingZeros() < result.getBitWidth() - width) 1811 return llvm::None; 1812 1813 result = result.trunc(width); 1814 } 1815 1816 if (isNegative) { 1817 // The value is negative, we have an overflow if the sign bit is not set 1818 // in the negated apInt. 1819 result.negate(); 1820 if (!result.isSignBitSet()) 1821 return llvm::None; 1822 } else if ((type.isSignedInteger() || type.isIndex()) && 1823 result.isSignBitSet()) { 1824 // The value is a positive signed integer or index, 1825 // we have an overflow if the sign bit is set. 1826 return llvm::None; 1827 } 1828 1829 return result; 1830 } 1831 1832 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1833 /// or a float attribute. 1834 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1835 // Remember if the literal is hexadecimal. 1836 StringRef spelling = getToken().getSpelling(); 1837 auto loc = state.curToken.getLoc(); 1838 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1839 1840 consumeToken(Token::integer); 1841 if (!type) { 1842 // Default to i64 if not type is specified. 1843 if (!consumeIf(Token::colon)) 1844 type = builder.getIntegerType(64); 1845 else if (!(type = parseType())) 1846 return nullptr; 1847 } 1848 1849 if (auto floatType = type.dyn_cast<FloatType>()) { 1850 if (isNegative) 1851 return emitError( 1852 loc, 1853 "hexadecimal float literal should not have a leading minus"), 1854 nullptr; 1855 if (!isHex) { 1856 emitError(loc, "unexpected decimal integer literal for a float attribute") 1857 .attachNote() 1858 << "add a trailing dot to make the literal a float"; 1859 return nullptr; 1860 } 1861 1862 auto val = Token::getUInt64IntegerValue(spelling); 1863 if (!val.hasValue()) 1864 return emitError("integer constant out of range for attribute"), nullptr; 1865 1866 // Construct a float attribute bitwise equivalent to the integer literal. 1867 Optional<APFloat> apVal = 1868 buildHexadecimalFloatLiteral(this, floatType, *val); 1869 return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); 1870 } 1871 1872 if (!type.isa<IntegerType>() && !type.isa<IndexType>()) 1873 return emitError(loc, "integer literal not valid for specified type"), 1874 nullptr; 1875 1876 if (isNegative && type.isUnsignedInteger()) { 1877 emitError(loc, 1878 "negative integer literal not valid for unsigned integer type"); 1879 return nullptr; 1880 } 1881 1882 Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling); 1883 if (!apInt) 1884 return emitError(loc, "integer constant out of range for attribute"), 1885 nullptr; 1886 return builder.getIntegerAttr(type, *apInt); 1887 } 1888 1889 /// Parse elements values stored within a hex etring. On success, the values are 1890 /// stored into 'result'. 1891 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, 1892 std::string &result) { 1893 std::string val = tok.getStringValue(); 1894 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1895 return parser.emitError(tok.getLoc(), 1896 "elements hex string should start with '0x'"); 1897 1898 StringRef hexValues = StringRef(val).drop_front(2); 1899 if (!llvm::all_of(hexValues, llvm::isHexDigit)) 1900 return parser.emitError(tok.getLoc(), 1901 "elements hex string only contains hex digits"); 1902 1903 result = llvm::fromHex(hexValues); 1904 return success(); 1905 } 1906 1907 /// Parse an opaque elements attribute. 1908 Attribute Parser::parseOpaqueElementsAttr(Type attrType) { 1909 consumeToken(Token::kw_opaque); 1910 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1911 return nullptr; 1912 1913 if (getToken().isNot(Token::string)) 1914 return (emitError("expected dialect namespace"), nullptr); 1915 1916 auto name = getToken().getStringValue(); 1917 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1918 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1919 // attribute. Otherwise, it can't be roundtripped without having the dialect 1920 // registered. 1921 if (!dialect) 1922 return (emitError("no registered dialect with namespace '" + name + "'"), 1923 nullptr); 1924 consumeToken(Token::string); 1925 1926 if (parseToken(Token::comma, "expected ','")) 1927 return nullptr; 1928 1929 Token hexTok = getToken(); 1930 if (parseToken(Token::string, "elements hex string should start with '0x'") || 1931 parseToken(Token::greater, "expected '>'")) 1932 return nullptr; 1933 auto type = parseElementsLiteralType(attrType); 1934 if (!type) 1935 return nullptr; 1936 1937 std::string data; 1938 if (parseElementAttrHexValues(*this, hexTok, data)) 1939 return nullptr; 1940 return OpaqueElementsAttr::get(dialect, type, data); 1941 } 1942 1943 namespace { 1944 class TensorLiteralParser { 1945 public: 1946 TensorLiteralParser(Parser &p) : p(p) {} 1947 1948 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 1949 /// may also parse a tensor literal that is store as a hex string. 1950 ParseResult parse(bool allowHex); 1951 1952 /// Build a dense attribute instance with the parsed elements and the given 1953 /// shaped type. 1954 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1955 1956 ArrayRef<int64_t> getShape() const { return shape; } 1957 1958 private: 1959 /// Get the parsed elements for an integer attribute. 1960 ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy, 1961 std::vector<APInt> &intValues); 1962 1963 /// Get the parsed elements for a float attribute. 1964 ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy, 1965 std::vector<APFloat> &floatValues); 1966 1967 /// Build a Dense String attribute for the given type. 1968 DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy); 1969 1970 /// Build a Dense attribute with hex data for the given type. 1971 DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); 1972 1973 /// Parse a single element, returning failure if it isn't a valid element 1974 /// literal. For example: 1975 /// parseElement(1) -> Success, 1 1976 /// parseElement([1]) -> Failure 1977 ParseResult parseElement(); 1978 1979 /// Parse a list of either lists or elements, returning the dimensions of the 1980 /// parsed sub-tensors in dims. For example: 1981 /// parseList([1, 2, 3]) -> Success, [3] 1982 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1983 /// parseList([[1, 2], 3]) -> Failure 1984 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1985 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 1986 1987 /// Parse a literal that was printed as a hex string. 1988 ParseResult parseHexElements(); 1989 1990 Parser &p; 1991 1992 /// The shape inferred from the parsed elements. 1993 SmallVector<int64_t, 4> shape; 1994 1995 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 1996 std::vector<std::pair<bool, Token>> storage; 1997 1998 /// Storage used when parsing elements that were stored as hex values. 1999 Optional<Token> hexStorage; 2000 }; 2001 } // namespace 2002 2003 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 2004 /// may also parse a tensor literal that is store as a hex string. 2005 ParseResult TensorLiteralParser::parse(bool allowHex) { 2006 // If hex is allowed, check for a string literal. 2007 if (allowHex && p.getToken().is(Token::string)) { 2008 hexStorage = p.getToken(); 2009 p.consumeToken(Token::string); 2010 return success(); 2011 } 2012 // Otherwise, parse a list or an individual element. 2013 if (p.getToken().is(Token::l_square)) 2014 return parseList(shape); 2015 return parseElement(); 2016 } 2017 2018 /// Build a dense attribute instance with the parsed elements and the given 2019 /// shaped type. 2020 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 2021 ShapedType type) { 2022 Type eltType = type.getElementType(); 2023 2024 // Check to see if we parse the literal from a hex string. 2025 if (hexStorage.hasValue() && 2026 (eltType.isIntOrFloat() || eltType.isa<ComplexType>())) 2027 return getHexAttr(loc, type); 2028 2029 // Check that the parsed storage size has the same number of elements to the 2030 // type, or is a known splat. 2031 if (!shape.empty() && getShape() != type.getShape()) { 2032 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 2033 << "]) does not match type ([" << type.getShape() << "])"; 2034 return nullptr; 2035 } 2036 2037 // Handle complex types in the specific element type cases below. 2038 bool isComplex = false; 2039 if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) { 2040 eltType = complexTy.getElementType(); 2041 isComplex = true; 2042 } 2043 2044 // Handle integer and index types. 2045 if (eltType.isIntOrIndex()) { 2046 std::vector<APInt> intValues; 2047 if (failed(getIntAttrElements(loc, eltType, intValues))) 2048 return nullptr; 2049 if (isComplex) { 2050 // If this is a complex, treat the parsed values as complex values. 2051 auto complexData = llvm::makeArrayRef( 2052 reinterpret_cast<std::complex<APInt> *>(intValues.data()), 2053 intValues.size() / 2); 2054 return DenseElementsAttr::get(type, complexData); 2055 } 2056 return DenseElementsAttr::get(type, intValues); 2057 } 2058 // Handle floating point types. 2059 if (FloatType floatTy = eltType.dyn_cast<FloatType>()) { 2060 std::vector<APFloat> floatValues; 2061 if (failed(getFloatAttrElements(loc, floatTy, floatValues))) 2062 return nullptr; 2063 if (isComplex) { 2064 // If this is a complex, treat the parsed values as complex values. 2065 auto complexData = llvm::makeArrayRef( 2066 reinterpret_cast<std::complex<APFloat> *>(floatValues.data()), 2067 floatValues.size() / 2); 2068 return DenseElementsAttr::get(type, complexData); 2069 } 2070 return DenseElementsAttr::get(type, floatValues); 2071 } 2072 2073 // Other types are assumed to be string representations. 2074 return getStringAttr(loc, type, type.getElementType()); 2075 } 2076 2077 /// Build a Dense Integer attribute for the given type. 2078 ParseResult 2079 TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy, 2080 std::vector<APInt> &intValues) { 2081 intValues.reserve(storage.size()); 2082 bool isUintType = eltTy.isUnsignedInteger(); 2083 for (const auto &signAndToken : storage) { 2084 bool isNegative = signAndToken.first; 2085 const Token &token = signAndToken.second; 2086 auto tokenLoc = token.getLoc(); 2087 2088 if (isNegative && isUintType) { 2089 return p.emitError(tokenLoc) 2090 << "expected unsigned integer elements, but parsed negative value"; 2091 } 2092 2093 // Check to see if floating point values were parsed. 2094 if (token.is(Token::floatliteral)) { 2095 return p.emitError(tokenLoc) 2096 << "expected integer elements, but parsed floating-point"; 2097 } 2098 2099 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 2100 "unexpected token type"); 2101 if (token.isAny(Token::kw_true, Token::kw_false)) { 2102 if (!eltTy.isInteger(1)) { 2103 return p.emitError(tokenLoc) 2104 << "expected i1 type for 'true' or 'false' values"; 2105 } 2106 APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false); 2107 intValues.push_back(apInt); 2108 continue; 2109 } 2110 2111 // Create APInt values for each element with the correct bitwidth. 2112 Optional<APInt> apInt = 2113 buildAttributeAPInt(eltTy, isNegative, token.getSpelling()); 2114 if (!apInt) 2115 return p.emitError(tokenLoc, "integer constant out of range for type"); 2116 intValues.push_back(*apInt); 2117 } 2118 return success(); 2119 } 2120 2121 /// Build a Dense Float attribute for the given type. 2122 ParseResult 2123 TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy, 2124 std::vector<APFloat> &floatValues) { 2125 floatValues.reserve(storage.size()); 2126 for (const auto &signAndToken : storage) { 2127 bool isNegative = signAndToken.first; 2128 const Token &token = signAndToken.second; 2129 2130 // Handle hexadecimal float literals. 2131 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 2132 if (isNegative) { 2133 return p.emitError(token.getLoc()) 2134 << "hexadecimal float literal should not have a leading minus"; 2135 } 2136 auto val = token.getUInt64IntegerValue(); 2137 if (!val.hasValue()) { 2138 return p.emitError( 2139 "hexadecimal float constant out of range for attribute"); 2140 } 2141 Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); 2142 if (!apVal) 2143 return failure(); 2144 floatValues.push_back(*apVal); 2145 continue; 2146 } 2147 2148 // Check to see if any decimal integers or booleans were parsed. 2149 if (!token.is(Token::floatliteral)) 2150 return p.emitError() 2151 << "expected floating-point elements, but parsed integer"; 2152 2153 // Build the float values from tokens. 2154 auto val = token.getFloatingPointValue(); 2155 if (!val.hasValue()) 2156 return p.emitError("floating point value too large for attribute"); 2157 2158 // Treat BF16 as double because it is not supported in LLVM's APFloat. 2159 APFloat apVal(isNegative ? -*val : *val); 2160 if (!eltTy.isBF16() && !eltTy.isF64()) { 2161 bool unused; 2162 apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, 2163 &unused); 2164 } 2165 floatValues.push_back(apVal); 2166 } 2167 return success(); 2168 } 2169 2170 /// Build a Dense String attribute for the given type. 2171 DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc, 2172 ShapedType type, 2173 Type eltTy) { 2174 if (hexStorage.hasValue()) { 2175 auto stringValue = hexStorage.getValue().getStringValue(); 2176 return DenseStringElementsAttr::get(type, {stringValue}); 2177 } 2178 2179 std::vector<std::string> stringValues; 2180 std::vector<StringRef> stringRefValues; 2181 stringValues.reserve(storage.size()); 2182 stringRefValues.reserve(storage.size()); 2183 2184 for (auto val : storage) { 2185 stringValues.push_back(val.second.getStringValue()); 2186 stringRefValues.push_back(stringValues.back()); 2187 } 2188 2189 return DenseStringElementsAttr::get(type, stringRefValues); 2190 } 2191 2192 /// Build a Dense attribute with hex data for the given type. 2193 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, 2194 ShapedType type) { 2195 Type elementType = type.getElementType(); 2196 if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) { 2197 p.emitError(loc) 2198 << "expected floating-point, integer, or complex element type, got " 2199 << elementType; 2200 return nullptr; 2201 } 2202 2203 std::string data; 2204 if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) 2205 return nullptr; 2206 2207 ArrayRef<char> rawData(data.data(), data.size()); 2208 bool detectedSplat = false; 2209 if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { 2210 p.emitError(loc) << "elements hex data size is invalid for provided type: " 2211 << type; 2212 return nullptr; 2213 } 2214 2215 return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat); 2216 } 2217 2218 ParseResult TensorLiteralParser::parseElement() { 2219 switch (p.getToken().getKind()) { 2220 // Parse a boolean element. 2221 case Token::kw_true: 2222 case Token::kw_false: 2223 case Token::floatliteral: 2224 case Token::integer: 2225 storage.emplace_back(/*isNegative=*/false, p.getToken()); 2226 p.consumeToken(); 2227 break; 2228 2229 // Parse a signed integer or a negative floating-point element. 2230 case Token::minus: 2231 p.consumeToken(Token::minus); 2232 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 2233 return p.emitError("expected integer or floating point literal"); 2234 storage.emplace_back(/*isNegative=*/true, p.getToken()); 2235 p.consumeToken(); 2236 break; 2237 2238 case Token::string: 2239 storage.emplace_back(/*isNegative=*/ false, p.getToken()); 2240 p.consumeToken(); 2241 break; 2242 2243 // Parse a complex element of the form '(' element ',' element ')'. 2244 case Token::l_paren: 2245 p.consumeToken(Token::l_paren); 2246 if (parseElement() || 2247 p.parseToken(Token::comma, "expected ',' between complex elements") || 2248 parseElement() || 2249 p.parseToken(Token::r_paren, "expected ')' after complex elements")) 2250 return failure(); 2251 break; 2252 2253 default: 2254 return p.emitError("expected element literal of primitive type"); 2255 } 2256 2257 return success(); 2258 } 2259 2260 /// Parse a list of either lists or elements, returning the dimensions of the 2261 /// parsed sub-tensors in dims. For example: 2262 /// parseList([1, 2, 3]) -> Success, [3] 2263 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 2264 /// parseList([[1, 2], 3]) -> Failure 2265 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 2266 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 2267 p.consumeToken(Token::l_square); 2268 2269 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 2270 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 2271 if (prevDims == newDims) 2272 return success(); 2273 return p.emitError("tensor literal is invalid; ranks are not consistent " 2274 "between elements"); 2275 }; 2276 2277 bool first = true; 2278 SmallVector<int64_t, 4> newDims; 2279 unsigned size = 0; 2280 auto parseCommaSeparatedList = [&]() -> ParseResult { 2281 SmallVector<int64_t, 4> thisDims; 2282 if (p.getToken().getKind() == Token::l_square) { 2283 if (parseList(thisDims)) 2284 return failure(); 2285 } else if (parseElement()) { 2286 return failure(); 2287 } 2288 ++size; 2289 if (!first) 2290 return checkDims(newDims, thisDims); 2291 newDims = thisDims; 2292 first = false; 2293 return success(); 2294 }; 2295 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 2296 return failure(); 2297 2298 // Return the sublists' dimensions with 'size' prepended. 2299 dims.clear(); 2300 dims.push_back(size); 2301 dims.append(newDims.begin(), newDims.end()); 2302 return success(); 2303 } 2304 2305 /// Parse a dense elements attribute. 2306 Attribute Parser::parseDenseElementsAttr(Type attrType) { 2307 consumeToken(Token::kw_dense); 2308 if (parseToken(Token::less, "expected '<' after 'dense'")) 2309 return nullptr; 2310 2311 // Parse the literal data. 2312 TensorLiteralParser literalParser(*this); 2313 if (literalParser.parse(/*allowHex=*/true)) 2314 return nullptr; 2315 2316 if (parseToken(Token::greater, "expected '>'")) 2317 return nullptr; 2318 2319 auto typeLoc = getToken().getLoc(); 2320 auto type = parseElementsLiteralType(attrType); 2321 if (!type) 2322 return nullptr; 2323 return literalParser.getAttr(typeLoc, type); 2324 } 2325 2326 /// Shaped type for elements attribute. 2327 /// 2328 /// elements-literal-type ::= vector-type | ranked-tensor-type 2329 /// 2330 /// This method also checks the type has static shape. 2331 ShapedType Parser::parseElementsLiteralType(Type type) { 2332 // If the user didn't provide a type, parse the colon type for the literal. 2333 if (!type) { 2334 if (parseToken(Token::colon, "expected ':'")) 2335 return nullptr; 2336 if (!(type = parseType())) 2337 return nullptr; 2338 } 2339 2340 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 2341 emitError("elements literal must be a ranked tensor or vector type"); 2342 return nullptr; 2343 } 2344 2345 auto sType = type.cast<ShapedType>(); 2346 if (!sType.hasStaticShape()) 2347 return (emitError("elements literal type must have static shape"), nullptr); 2348 2349 return sType; 2350 } 2351 2352 /// Parse a sparse elements attribute. 2353 Attribute Parser::parseSparseElementsAttr(Type attrType) { 2354 consumeToken(Token::kw_sparse); 2355 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 2356 return nullptr; 2357 2358 /// Parse the indices. We don't allow hex values here as we may need to use 2359 /// the inferred shape. 2360 auto indicesLoc = getToken().getLoc(); 2361 TensorLiteralParser indiceParser(*this); 2362 if (indiceParser.parse(/*allowHex=*/false)) 2363 return nullptr; 2364 2365 if (parseToken(Token::comma, "expected ','")) 2366 return nullptr; 2367 2368 /// Parse the values. 2369 auto valuesLoc = getToken().getLoc(); 2370 TensorLiteralParser valuesParser(*this); 2371 if (valuesParser.parse(/*allowHex=*/true)) 2372 return nullptr; 2373 2374 if (parseToken(Token::greater, "expected '>'")) 2375 return nullptr; 2376 2377 auto type = parseElementsLiteralType(attrType); 2378 if (!type) 2379 return nullptr; 2380 2381 // If the indices are a splat, i.e. the literal parser parsed an element and 2382 // not a list, we set the shape explicitly. The indices are represented by a 2383 // 2-dimensional shape where the second dimension is the rank of the type. 2384 // Given that the parsed indices is a splat, we know that we only have one 2385 // indice and thus one for the first dimension. 2386 auto indiceEltType = builder.getIntegerType(64); 2387 ShapedType indicesType; 2388 if (indiceParser.getShape().empty()) { 2389 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 2390 } else { 2391 // Otherwise, set the shape to the one parsed by the literal parser. 2392 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 2393 } 2394 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 2395 2396 // If the values are a splat, set the shape explicitly based on the number of 2397 // indices. The number of indices is encoded in the first dimension of the 2398 // indice shape type. 2399 auto valuesEltType = type.getElementType(); 2400 ShapedType valuesType = 2401 valuesParser.getShape().empty() 2402 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 2403 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 2404 auto values = valuesParser.getAttr(valuesLoc, valuesType); 2405 2406 /// Sanity check. 2407 if (valuesType.getRank() != 1) 2408 return (emitError("expected 1-d tensor for values"), nullptr); 2409 2410 auto sameShape = (indicesType.getRank() == 1) || 2411 (type.getRank() == indicesType.getDimSize(1)); 2412 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 2413 if (!sameShape || !sameElementNum) { 2414 emitError() << "expected shape ([" << type.getShape() 2415 << "]); inferred shape of indices literal ([" 2416 << indicesType.getShape() 2417 << "]); inferred shape of values literal ([" 2418 << valuesType.getShape() << "])"; 2419 return nullptr; 2420 } 2421 2422 // Build the sparse elements attribute by the indices and values. 2423 return SparseElementsAttr::get(type, indices, values); 2424 } 2425 2426 //===----------------------------------------------------------------------===// 2427 // Location parsing. 2428 //===----------------------------------------------------------------------===// 2429 2430 /// Parse a location. 2431 /// 2432 /// location ::= `loc` inline-location 2433 /// inline-location ::= '(' location-inst ')' 2434 /// 2435 ParseResult Parser::parseLocation(LocationAttr &loc) { 2436 // Check for 'loc' identifier. 2437 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 2438 return emitError(); 2439 2440 // Parse the inline-location. 2441 if (parseToken(Token::l_paren, "expected '(' in inline location") || 2442 parseLocationInstance(loc) || 2443 parseToken(Token::r_paren, "expected ')' in inline location")) 2444 return failure(); 2445 return success(); 2446 } 2447 2448 /// Specific location instances. 2449 /// 2450 /// location-inst ::= filelinecol-location | 2451 /// name-location | 2452 /// callsite-location | 2453 /// fused-location | 2454 /// unknown-location 2455 /// filelinecol-location ::= string-literal ':' integer-literal 2456 /// ':' integer-literal 2457 /// name-location ::= string-literal 2458 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 2459 /// fused-location ::= fused ('<' attribute-value '>')? 2460 /// '[' location-inst (location-inst ',')* ']' 2461 /// unknown-location ::= 'unknown' 2462 /// 2463 ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { 2464 consumeToken(Token::bare_identifier); 2465 2466 // Parse the '('. 2467 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 2468 return failure(); 2469 2470 // Parse the callee location. 2471 LocationAttr calleeLoc; 2472 if (parseLocationInstance(calleeLoc)) 2473 return failure(); 2474 2475 // Parse the 'at'. 2476 if (getToken().isNot(Token::bare_identifier) || 2477 getToken().getSpelling() != "at") 2478 return emitError("expected 'at' in callsite location"); 2479 consumeToken(Token::bare_identifier); 2480 2481 // Parse the caller location. 2482 LocationAttr callerLoc; 2483 if (parseLocationInstance(callerLoc)) 2484 return failure(); 2485 2486 // Parse the ')'. 2487 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 2488 return failure(); 2489 2490 // Return the callsite location. 2491 loc = CallSiteLoc::get(calleeLoc, callerLoc); 2492 return success(); 2493 } 2494 2495 ParseResult Parser::parseFusedLocation(LocationAttr &loc) { 2496 consumeToken(Token::bare_identifier); 2497 2498 // Try to parse the optional metadata. 2499 Attribute metadata; 2500 if (consumeIf(Token::less)) { 2501 metadata = parseAttribute(); 2502 if (!metadata) 2503 return emitError("expected valid attribute metadata"); 2504 // Parse the '>' token. 2505 if (parseToken(Token::greater, 2506 "expected '>' after fused location metadata")) 2507 return failure(); 2508 } 2509 2510 SmallVector<Location, 4> locations; 2511 auto parseElt = [&] { 2512 LocationAttr newLoc; 2513 if (parseLocationInstance(newLoc)) 2514 return failure(); 2515 locations.push_back(newLoc); 2516 return success(); 2517 }; 2518 2519 if (parseToken(Token::l_square, "expected '[' in fused location") || 2520 parseCommaSeparatedList(parseElt) || 2521 parseToken(Token::r_square, "expected ']' in fused location")) 2522 return failure(); 2523 2524 // Return the fused location. 2525 loc = FusedLoc::get(locations, metadata, getContext()); 2526 return success(); 2527 } 2528 2529 ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { 2530 auto *ctx = getContext(); 2531 auto str = getToken().getStringValue(); 2532 consumeToken(Token::string); 2533 2534 // If the next token is ':' this is a filelinecol location. 2535 if (consumeIf(Token::colon)) { 2536 // Parse the line number. 2537 if (getToken().isNot(Token::integer)) 2538 return emitError("expected integer line number in FileLineColLoc"); 2539 auto line = getToken().getUnsignedIntegerValue(); 2540 if (!line.hasValue()) 2541 return emitError("expected integer line number in FileLineColLoc"); 2542 consumeToken(Token::integer); 2543 2544 // Parse the ':'. 2545 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 2546 return failure(); 2547 2548 // Parse the column number. 2549 if (getToken().isNot(Token::integer)) 2550 return emitError("expected integer column number in FileLineColLoc"); 2551 auto column = getToken().getUnsignedIntegerValue(); 2552 if (!column.hasValue()) 2553 return emitError("expected integer column number in FileLineColLoc"); 2554 consumeToken(Token::integer); 2555 2556 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 2557 return success(); 2558 } 2559 2560 // Otherwise, this is a NameLoc. 2561 2562 // Check for a child location. 2563 if (consumeIf(Token::l_paren)) { 2564 auto childSourceLoc = getToken().getLoc(); 2565 2566 // Parse the child location. 2567 LocationAttr childLoc; 2568 if (parseLocationInstance(childLoc)) 2569 return failure(); 2570 2571 // The child must not be another NameLoc. 2572 if (childLoc.isa<NameLoc>()) 2573 return emitError(childSourceLoc, 2574 "child of NameLoc cannot be another NameLoc"); 2575 loc = NameLoc::get(Identifier::get(str, ctx), childLoc); 2576 2577 // Parse the closing ')'. 2578 if (parseToken(Token::r_paren, 2579 "expected ')' after child location of NameLoc")) 2580 return failure(); 2581 } else { 2582 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 2583 } 2584 2585 return success(); 2586 } 2587 2588 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 2589 // Handle either name or filelinecol locations. 2590 if (getToken().is(Token::string)) 2591 return parseNameOrFileLineColLocation(loc); 2592 2593 // Bare tokens required for other cases. 2594 if (!getToken().is(Token::bare_identifier)) 2595 return emitError("expected location instance"); 2596 2597 // Check for the 'callsite' signifying a callsite location. 2598 if (getToken().getSpelling() == "callsite") 2599 return parseCallSiteLocation(loc); 2600 2601 // If the token is 'fused', then this is a fused location. 2602 if (getToken().getSpelling() == "fused") 2603 return parseFusedLocation(loc); 2604 2605 // Check for a 'unknown' for an unknown location. 2606 if (getToken().getSpelling() == "unknown") { 2607 consumeToken(Token::bare_identifier); 2608 loc = UnknownLoc::get(getContext()); 2609 return success(); 2610 } 2611 2612 return emitError("expected location instance"); 2613 } 2614 2615 //===----------------------------------------------------------------------===// 2616 // Affine parsing. 2617 //===----------------------------------------------------------------------===// 2618 2619 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 2620 /// the boolean sense. 2621 enum AffineLowPrecOp { 2622 /// Null value. 2623 LNoOp, 2624 Add, 2625 Sub 2626 }; 2627 2628 /// Higher precedence ops - all at the same precedence level. HNoOp is false 2629 /// in the boolean sense. 2630 enum AffineHighPrecOp { 2631 /// Null value. 2632 HNoOp, 2633 Mul, 2634 FloorDiv, 2635 CeilDiv, 2636 Mod 2637 }; 2638 2639 namespace { 2640 /// This is a specialized parser for affine structures (affine maps, affine 2641 /// expressions, and integer sets), maintaining the state transient to their 2642 /// bodies. 2643 class AffineParser : public Parser { 2644 public: 2645 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 2646 function_ref<ParseResult(bool)> parseElement = nullptr) 2647 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 2648 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 2649 2650 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 2651 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 2652 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 2653 ParseResult parseAffineMapOfSSAIds(AffineMap &map, 2654 OpAsmParser::Delimiter delimiter); 2655 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 2656 unsigned &numDims); 2657 2658 private: 2659 // Binary affine op parsing. 2660 AffineLowPrecOp consumeIfLowPrecOp(); 2661 AffineHighPrecOp consumeIfHighPrecOp(); 2662 2663 // Identifier lists for polyhedral structures. 2664 ParseResult parseDimIdList(unsigned &numDims); 2665 ParseResult parseSymbolIdList(unsigned &numSymbols); 2666 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 2667 unsigned &numSymbols); 2668 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 2669 2670 AffineExpr parseAffineExpr(); 2671 AffineExpr parseParentheticalExpr(); 2672 AffineExpr parseNegateExpression(AffineExpr lhs); 2673 AffineExpr parseIntegerExpr(); 2674 AffineExpr parseBareIdExpr(); 2675 AffineExpr parseSSAIdExpr(bool isSymbol); 2676 AffineExpr parseSymbolSSAIdExpr(); 2677 2678 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 2679 AffineExpr rhs, SMLoc opLoc); 2680 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 2681 AffineExpr rhs); 2682 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 2683 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 2684 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 2685 SMLoc llhsOpLoc); 2686 AffineExpr parseAffineConstraint(bool *isEq); 2687 2688 private: 2689 bool allowParsingSSAIds; 2690 function_ref<ParseResult(bool)> parseElement; 2691 unsigned numDimOperands; 2692 unsigned numSymbolOperands; 2693 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 2694 }; 2695 } // end anonymous namespace 2696 2697 /// Create an affine binary high precedence op expression (mul's, div's, mod). 2698 /// opLoc is the location of the op token to be used to report errors 2699 /// for non-conforming expressions. 2700 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 2701 AffineExpr lhs, AffineExpr rhs, 2702 SMLoc opLoc) { 2703 // TODO: make the error location info accurate. 2704 switch (op) { 2705 case Mul: 2706 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 2707 emitError(opLoc, "non-affine expression: at least one of the multiply " 2708 "operands has to be either a constant or symbolic"); 2709 return nullptr; 2710 } 2711 return lhs * rhs; 2712 case FloorDiv: 2713 if (!rhs.isSymbolicOrConstant()) { 2714 emitError(opLoc, "non-affine expression: right operand of floordiv " 2715 "has to be either a constant or symbolic"); 2716 return nullptr; 2717 } 2718 return lhs.floorDiv(rhs); 2719 case CeilDiv: 2720 if (!rhs.isSymbolicOrConstant()) { 2721 emitError(opLoc, "non-affine expression: right operand of ceildiv " 2722 "has to be either a constant or symbolic"); 2723 return nullptr; 2724 } 2725 return lhs.ceilDiv(rhs); 2726 case Mod: 2727 if (!rhs.isSymbolicOrConstant()) { 2728 emitError(opLoc, "non-affine expression: right operand of mod " 2729 "has to be either a constant or symbolic"); 2730 return nullptr; 2731 } 2732 return lhs % rhs; 2733 case HNoOp: 2734 llvm_unreachable("can't create affine expression for null high prec op"); 2735 return nullptr; 2736 } 2737 llvm_unreachable("Unknown AffineHighPrecOp"); 2738 } 2739 2740 /// Create an affine binary low precedence op expression (add, sub). 2741 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 2742 AffineExpr lhs, AffineExpr rhs) { 2743 switch (op) { 2744 case AffineLowPrecOp::Add: 2745 return lhs + rhs; 2746 case AffineLowPrecOp::Sub: 2747 return lhs - rhs; 2748 case AffineLowPrecOp::LNoOp: 2749 llvm_unreachable("can't create affine expression for null low prec op"); 2750 return nullptr; 2751 } 2752 llvm_unreachable("Unknown AffineLowPrecOp"); 2753 } 2754 2755 /// Consume this token if it is a lower precedence affine op (there are only 2756 /// two precedence levels). 2757 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 2758 switch (getToken().getKind()) { 2759 case Token::plus: 2760 consumeToken(Token::plus); 2761 return AffineLowPrecOp::Add; 2762 case Token::minus: 2763 consumeToken(Token::minus); 2764 return AffineLowPrecOp::Sub; 2765 default: 2766 return AffineLowPrecOp::LNoOp; 2767 } 2768 } 2769 2770 /// Consume this token if it is a higher precedence affine op (there are only 2771 /// two precedence levels) 2772 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 2773 switch (getToken().getKind()) { 2774 case Token::star: 2775 consumeToken(Token::star); 2776 return Mul; 2777 case Token::kw_floordiv: 2778 consumeToken(Token::kw_floordiv); 2779 return FloorDiv; 2780 case Token::kw_ceildiv: 2781 consumeToken(Token::kw_ceildiv); 2782 return CeilDiv; 2783 case Token::kw_mod: 2784 consumeToken(Token::kw_mod); 2785 return Mod; 2786 default: 2787 return HNoOp; 2788 } 2789 } 2790 2791 /// Parse a high precedence op expression list: mul, div, and mod are high 2792 /// precedence binary ops, i.e., parse a 2793 /// expr_1 op_1 expr_2 op_2 ... expr_n 2794 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2795 /// All affine binary ops are left associative. 2796 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2797 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2798 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2799 /// report an error for non-conforming expressions. 2800 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2801 AffineHighPrecOp llhsOp, 2802 SMLoc llhsOpLoc) { 2803 AffineExpr lhs = parseAffineOperandExpr(llhs); 2804 if (!lhs) 2805 return nullptr; 2806 2807 // Found an LHS. Parse the remaining expression. 2808 auto opLoc = getToken().getLoc(); 2809 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2810 if (llhs) { 2811 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2812 if (!expr) 2813 return nullptr; 2814 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2815 } 2816 // No LLHS, get RHS 2817 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2818 } 2819 2820 // This is the last operand in this expression. 2821 if (llhs) 2822 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2823 2824 // No llhs, 'lhs' itself is the expression. 2825 return lhs; 2826 } 2827 2828 /// Parse an affine expression inside parentheses. 2829 /// 2830 /// affine-expr ::= `(` affine-expr `)` 2831 AffineExpr AffineParser::parseParentheticalExpr() { 2832 if (parseToken(Token::l_paren, "expected '('")) 2833 return nullptr; 2834 if (getToken().is(Token::r_paren)) 2835 return (emitError("no expression inside parentheses"), nullptr); 2836 2837 auto expr = parseAffineExpr(); 2838 if (!expr) 2839 return nullptr; 2840 if (parseToken(Token::r_paren, "expected ')'")) 2841 return nullptr; 2842 2843 return expr; 2844 } 2845 2846 /// Parse the negation expression. 2847 /// 2848 /// affine-expr ::= `-` affine-expr 2849 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2850 if (parseToken(Token::minus, "expected '-'")) 2851 return nullptr; 2852 2853 AffineExpr operand = parseAffineOperandExpr(lhs); 2854 // Since negation has the highest precedence of all ops (including high 2855 // precedence ops) but lower than parentheses, we are only going to use 2856 // parseAffineOperandExpr instead of parseAffineExpr here. 2857 if (!operand) 2858 // Extra error message although parseAffineOperandExpr would have 2859 // complained. Leads to a better diagnostic. 2860 return (emitError("missing operand of negation"), nullptr); 2861 return (-1) * operand; 2862 } 2863 2864 /// Parse a bare id that may appear in an affine expression. 2865 /// 2866 /// affine-expr ::= bare-id 2867 AffineExpr AffineParser::parseBareIdExpr() { 2868 if (getToken().isNot(Token::bare_identifier)) 2869 return (emitError("expected bare identifier"), nullptr); 2870 2871 StringRef sRef = getTokenSpelling(); 2872 for (auto entry : dimsAndSymbols) { 2873 if (entry.first == sRef) { 2874 consumeToken(Token::bare_identifier); 2875 return entry.second; 2876 } 2877 } 2878 2879 return (emitError("use of undeclared identifier"), nullptr); 2880 } 2881 2882 /// Parse an SSA id which may appear in an affine expression. 2883 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2884 if (!allowParsingSSAIds) 2885 return (emitError("unexpected ssa identifier"), nullptr); 2886 if (getToken().isNot(Token::percent_identifier)) 2887 return (emitError("expected ssa identifier"), nullptr); 2888 auto name = getTokenSpelling(); 2889 // Check if we already parsed this SSA id. 2890 for (auto entry : dimsAndSymbols) { 2891 if (entry.first == name) { 2892 consumeToken(Token::percent_identifier); 2893 return entry.second; 2894 } 2895 } 2896 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2897 if (parseElement(isSymbol)) 2898 return (emitError("failed to parse ssa identifier"), nullptr); 2899 auto idExpr = isSymbol 2900 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2901 : getAffineDimExpr(numDimOperands++, getContext()); 2902 dimsAndSymbols.push_back({name, idExpr}); 2903 return idExpr; 2904 } 2905 2906 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2907 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2908 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2909 return nullptr; 2910 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2911 if (!symbolExpr) 2912 return nullptr; 2913 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2914 return nullptr; 2915 return symbolExpr; 2916 } 2917 2918 /// Parse a positive integral constant appearing in an affine expression. 2919 /// 2920 /// affine-expr ::= integer-literal 2921 AffineExpr AffineParser::parseIntegerExpr() { 2922 auto val = getToken().getUInt64IntegerValue(); 2923 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2924 return (emitError("constant too large for index"), nullptr); 2925 2926 consumeToken(Token::integer); 2927 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2928 } 2929 2930 /// Parses an expression that can be a valid operand of an affine expression. 2931 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2932 /// operator, the rhs of which is being parsed. This is used to determine 2933 /// whether an error should be emitted for a missing right operand. 2934 // Eg: for an expression without parentheses (like i + j + k + l), each 2935 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2936 // operand expression, it's an op expression and will be parsed via 2937 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2938 // -l are valid operands that will be parsed by this function. 2939 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2940 switch (getToken().getKind()) { 2941 case Token::bare_identifier: 2942 return parseBareIdExpr(); 2943 case Token::kw_symbol: 2944 return parseSymbolSSAIdExpr(); 2945 case Token::percent_identifier: 2946 return parseSSAIdExpr(/*isSymbol=*/false); 2947 case Token::integer: 2948 return parseIntegerExpr(); 2949 case Token::l_paren: 2950 return parseParentheticalExpr(); 2951 case Token::minus: 2952 return parseNegateExpression(lhs); 2953 case Token::kw_ceildiv: 2954 case Token::kw_floordiv: 2955 case Token::kw_mod: 2956 case Token::plus: 2957 case Token::star: 2958 if (lhs) 2959 emitError("missing right operand of binary operator"); 2960 else 2961 emitError("missing left operand of binary operator"); 2962 return nullptr; 2963 default: 2964 if (lhs) 2965 emitError("missing right operand of binary operator"); 2966 else 2967 emitError("expected affine expression"); 2968 return nullptr; 2969 } 2970 } 2971 2972 /// Parse affine expressions that are bare-id's, integer constants, 2973 /// parenthetical affine expressions, and affine op expressions that are a 2974 /// composition of those. 2975 /// 2976 /// All binary op's associate from left to right. 2977 /// 2978 /// {add, sub} have lower precedence than {mul, div, and mod}. 2979 /// 2980 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2981 /// ceildiv, and mod are at the same higher precedence level. Negation has 2982 /// higher precedence than any binary op. 2983 /// 2984 /// llhs: the affine expression appearing on the left of the one being parsed. 2985 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2986 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2987 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2988 /// associativity. 2989 /// 2990 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2991 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2992 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2993 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2994 AffineLowPrecOp llhsOp) { 2995 AffineExpr lhs; 2996 if (!(lhs = parseAffineOperandExpr(llhs))) 2997 return nullptr; 2998 2999 // Found an LHS. Deal with the ops. 3000 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 3001 if (llhs) { 3002 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 3003 return parseAffineLowPrecOpExpr(sum, lOp); 3004 } 3005 // No LLHS, get RHS and form the expression. 3006 return parseAffineLowPrecOpExpr(lhs, lOp); 3007 } 3008 auto opLoc = getToken().getLoc(); 3009 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 3010 // We have a higher precedence op here. Get the rhs operand for the llhs 3011 // through parseAffineHighPrecOpExpr. 3012 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 3013 if (!highRes) 3014 return nullptr; 3015 3016 // If llhs is null, the product forms the first operand of the yet to be 3017 // found expression. If non-null, the op to associate with llhs is llhsOp. 3018 AffineExpr expr = 3019 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 3020 3021 // Recurse for subsequent low prec op's after the affine high prec op 3022 // expression. 3023 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 3024 return parseAffineLowPrecOpExpr(expr, nextOp); 3025 return expr; 3026 } 3027 // Last operand in the expression list. 3028 if (llhs) 3029 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 3030 // No llhs, 'lhs' itself is the expression. 3031 return lhs; 3032 } 3033 3034 /// Parse an affine expression. 3035 /// affine-expr ::= `(` affine-expr `)` 3036 /// | `-` affine-expr 3037 /// | affine-expr `+` affine-expr 3038 /// | affine-expr `-` affine-expr 3039 /// | affine-expr `*` affine-expr 3040 /// | affine-expr `floordiv` affine-expr 3041 /// | affine-expr `ceildiv` affine-expr 3042 /// | affine-expr `mod` affine-expr 3043 /// | bare-id 3044 /// | integer-literal 3045 /// 3046 /// Additional conditions are checked depending on the production. For eg., 3047 /// one of the operands for `*` has to be either constant/symbolic; the second 3048 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 3049 AffineExpr AffineParser::parseAffineExpr() { 3050 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 3051 } 3052 3053 /// Parse a dim or symbol from the lists appearing before the actual 3054 /// expressions of the affine map. Update our state to store the 3055 /// dimensional/symbolic identifier. 3056 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 3057 if (getToken().isNot(Token::bare_identifier)) 3058 return emitError("expected bare identifier"); 3059 3060 auto name = getTokenSpelling(); 3061 for (auto entry : dimsAndSymbols) { 3062 if (entry.first == name) 3063 return emitError("redefinition of identifier '" + name + "'"); 3064 } 3065 consumeToken(Token::bare_identifier); 3066 3067 dimsAndSymbols.push_back({name, idExpr}); 3068 return success(); 3069 } 3070 3071 /// Parse the list of dimensional identifiers to an affine map. 3072 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 3073 if (parseToken(Token::l_paren, 3074 "expected '(' at start of dimensional identifiers list")) { 3075 return failure(); 3076 } 3077 3078 auto parseElt = [&]() -> ParseResult { 3079 auto dimension = getAffineDimExpr(numDims++, getContext()); 3080 return parseIdentifierDefinition(dimension); 3081 }; 3082 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 3083 } 3084 3085 /// Parse the list of symbolic identifiers to an affine map. 3086 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 3087 consumeToken(Token::l_square); 3088 auto parseElt = [&]() -> ParseResult { 3089 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 3090 return parseIdentifierDefinition(symbol); 3091 }; 3092 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 3093 } 3094 3095 /// Parse the list of symbolic identifiers to an affine map. 3096 ParseResult 3097 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 3098 unsigned &numSymbols) { 3099 if (parseDimIdList(numDims)) { 3100 return failure(); 3101 } 3102 if (!getToken().is(Token::l_square)) { 3103 numSymbols = 0; 3104 return success(); 3105 } 3106 return parseSymbolIdList(numSymbols); 3107 } 3108 3109 /// Parses an ambiguous affine map or integer set definition inline. 3110 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 3111 IntegerSet &set) { 3112 unsigned numDims = 0, numSymbols = 0; 3113 3114 // List of dimensional and optional symbol identifiers. 3115 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 3116 return failure(); 3117 } 3118 3119 // This is needed for parsing attributes as we wouldn't know whether we would 3120 // be parsing an integer set attribute or an affine map attribute. 3121 bool isArrow = getToken().is(Token::arrow); 3122 bool isColon = getToken().is(Token::colon); 3123 if (!isArrow && !isColon) { 3124 return emitError("expected '->' or ':'"); 3125 } else if (isArrow) { 3126 parseToken(Token::arrow, "expected '->' or '['"); 3127 map = parseAffineMapRange(numDims, numSymbols); 3128 return map ? success() : failure(); 3129 } else if (parseToken(Token::colon, "expected ':' or '['")) { 3130 return failure(); 3131 } 3132 3133 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 3134 return success(); 3135 3136 return failure(); 3137 } 3138 3139 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 3140 ParseResult 3141 AffineParser::parseAffineMapOfSSAIds(AffineMap &map, 3142 OpAsmParser::Delimiter delimiter) { 3143 Token::Kind rightToken; 3144 switch (delimiter) { 3145 case OpAsmParser::Delimiter::Square: 3146 if (parseToken(Token::l_square, "expected '['")) 3147 return failure(); 3148 rightToken = Token::r_square; 3149 break; 3150 case OpAsmParser::Delimiter::Paren: 3151 if (parseToken(Token::l_paren, "expected '('")) 3152 return failure(); 3153 rightToken = Token::r_paren; 3154 break; 3155 default: 3156 return emitError("unexpected delimiter"); 3157 } 3158 3159 SmallVector<AffineExpr, 4> exprs; 3160 auto parseElt = [&]() -> ParseResult { 3161 auto elt = parseAffineExpr(); 3162 exprs.push_back(elt); 3163 return elt ? success() : failure(); 3164 }; 3165 3166 // Parse a multi-dimensional affine expression (a comma-separated list of 3167 // 1-d affine expressions); the list can be empty. Grammar: 3168 // multi-dim-affine-expr ::= `(` `)` 3169 // | `(` affine-expr (`,` affine-expr)* `)` 3170 if (parseCommaSeparatedListUntil(rightToken, parseElt, 3171 /*allowEmptyList=*/true)) 3172 return failure(); 3173 // Parsed a valid affine map. 3174 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 3175 exprs, getContext()); 3176 return success(); 3177 } 3178 3179 /// Parse the range and sizes affine map definition inline. 3180 /// 3181 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 3182 /// 3183 /// multi-dim-affine-expr ::= `(` `)` 3184 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` 3185 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 3186 unsigned numSymbols) { 3187 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 3188 3189 SmallVector<AffineExpr, 4> exprs; 3190 auto parseElt = [&]() -> ParseResult { 3191 auto elt = parseAffineExpr(); 3192 ParseResult res = elt ? success() : failure(); 3193 exprs.push_back(elt); 3194 return res; 3195 }; 3196 3197 // Parse a multi-dimensional affine expression (a comma-separated list of 3198 // 1-d affine expressions). Grammar: 3199 // multi-dim-affine-expr ::= `(` `)` 3200 // | `(` affine-expr (`,` affine-expr)* `)` 3201 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3202 return AffineMap(); 3203 3204 // Parsed a valid affine map. 3205 return AffineMap::get(numDims, numSymbols, exprs, getContext()); 3206 } 3207 3208 /// Parse an affine constraint. 3209 /// affine-constraint ::= affine-expr `>=` `0` 3210 /// | affine-expr `==` `0` 3211 /// 3212 /// isEq is set to true if the parsed constraint is an equality, false if it 3213 /// is an inequality (greater than or equal). 3214 /// 3215 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 3216 AffineExpr expr = parseAffineExpr(); 3217 if (!expr) 3218 return nullptr; 3219 3220 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 3221 getToken().is(Token::integer)) { 3222 auto dim = getToken().getUnsignedIntegerValue(); 3223 if (dim.hasValue() && dim.getValue() == 0) { 3224 consumeToken(Token::integer); 3225 *isEq = false; 3226 return expr; 3227 } 3228 return (emitError("expected '0' after '>='"), nullptr); 3229 } 3230 3231 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 3232 getToken().is(Token::integer)) { 3233 auto dim = getToken().getUnsignedIntegerValue(); 3234 if (dim.hasValue() && dim.getValue() == 0) { 3235 consumeToken(Token::integer); 3236 *isEq = true; 3237 return expr; 3238 } 3239 return (emitError("expected '0' after '=='"), nullptr); 3240 } 3241 3242 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 3243 nullptr); 3244 } 3245 3246 /// Parse the constraints that are part of an integer set definition. 3247 /// integer-set-inline 3248 /// ::= dim-and-symbol-id-lists `:` 3249 /// '(' affine-constraint-conjunction? ')' 3250 /// affine-constraint-conjunction ::= affine-constraint (`,` 3251 /// affine-constraint)* 3252 /// 3253 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 3254 unsigned numSymbols) { 3255 if (parseToken(Token::l_paren, 3256 "expected '(' at start of integer set constraint list")) 3257 return IntegerSet(); 3258 3259 SmallVector<AffineExpr, 4> constraints; 3260 SmallVector<bool, 4> isEqs; 3261 auto parseElt = [&]() -> ParseResult { 3262 bool isEq; 3263 auto elt = parseAffineConstraint(&isEq); 3264 ParseResult res = elt ? success() : failure(); 3265 if (elt) { 3266 constraints.push_back(elt); 3267 isEqs.push_back(isEq); 3268 } 3269 return res; 3270 }; 3271 3272 // Parse a list of affine constraints (comma-separated). 3273 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3274 return IntegerSet(); 3275 3276 // If no constraints were parsed, then treat this as a degenerate 'true' case. 3277 if (constraints.empty()) { 3278 /* 0 == 0 */ 3279 auto zero = getAffineConstantExpr(0, getContext()); 3280 return IntegerSet::get(numDims, numSymbols, zero, true); 3281 } 3282 3283 // Parsed a valid integer set. 3284 return IntegerSet::get(numDims, numSymbols, constraints, isEqs); 3285 } 3286 3287 /// Parse an ambiguous reference to either and affine map or an integer set. 3288 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 3289 IntegerSet &set) { 3290 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 3291 } 3292 ParseResult Parser::parseAffineMapReference(AffineMap &map) { 3293 llvm::SMLoc curLoc = getToken().getLoc(); 3294 IntegerSet set; 3295 if (parseAffineMapOrIntegerSetReference(map, set)) 3296 return failure(); 3297 if (set) 3298 return emitError(curLoc, "expected AffineMap, but got IntegerSet"); 3299 return success(); 3300 } 3301 ParseResult Parser::parseIntegerSetReference(IntegerSet &set) { 3302 llvm::SMLoc curLoc = getToken().getLoc(); 3303 AffineMap map; 3304 if (parseAffineMapOrIntegerSetReference(map, set)) 3305 return failure(); 3306 if (map) 3307 return emitError(curLoc, "expected IntegerSet, but got AffineMap"); 3308 return success(); 3309 } 3310 3311 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 3312 /// parse SSA value uses encountered while parsing affine expressions. 3313 ParseResult 3314 Parser::parseAffineMapOfSSAIds(AffineMap &map, 3315 function_ref<ParseResult(bool)> parseElement, 3316 OpAsmParser::Delimiter delimiter) { 3317 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 3318 .parseAffineMapOfSSAIds(map, delimiter); 3319 } 3320 3321 //===----------------------------------------------------------------------===// 3322 // OperationParser 3323 //===----------------------------------------------------------------------===// 3324 3325 namespace { 3326 /// This class provides support for parsing operations and regions of 3327 /// operations. 3328 class OperationParser : public Parser { 3329 public: 3330 OperationParser(ParserState &state, ModuleOp moduleOp) 3331 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 3332 } 3333 3334 ~OperationParser(); 3335 3336 /// After parsing is finished, this function must be called to see if there 3337 /// are any remaining issues. 3338 ParseResult finalize(); 3339 3340 //===--------------------------------------------------------------------===// 3341 // SSA Value Handling 3342 //===--------------------------------------------------------------------===// 3343 3344 /// This represents a use of an SSA value in the program. The first two 3345 /// entries in the tuple are the name and result number of a reference. The 3346 /// third is the location of the reference, which is used in case this ends 3347 /// up being a use of an undefined value. 3348 struct SSAUseInfo { 3349 StringRef name; // Value name, e.g. %42 or %abc 3350 unsigned number; // Number, specified with #12 3351 SMLoc loc; // Location of first definition or use. 3352 }; 3353 3354 /// Push a new SSA name scope to the parser. 3355 void pushSSANameScope(bool isIsolated); 3356 3357 /// Pop the last SSA name scope from the parser. 3358 ParseResult popSSANameScope(); 3359 3360 /// Register a definition of a value with the symbol table. 3361 ParseResult addDefinition(SSAUseInfo useInfo, Value value); 3362 3363 /// Parse an optional list of SSA uses into 'results'. 3364 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 3365 3366 /// Parse a single SSA use into 'result'. 3367 ParseResult parseSSAUse(SSAUseInfo &result); 3368 3369 /// Given a reference to an SSA value and its type, return a reference. This 3370 /// returns null on failure. 3371 Value resolveSSAUse(SSAUseInfo useInfo, Type type); 3372 3373 ParseResult parseSSADefOrUseAndType( 3374 const std::function<ParseResult(SSAUseInfo, Type)> &action); 3375 3376 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value> &results); 3377 3378 /// Return the location of the value identified by its name and number if it 3379 /// has been already reference. 3380 Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 3381 auto &values = isolatedNameScopes.back().values; 3382 if (!values.count(name) || number >= values[name].size()) 3383 return {}; 3384 if (values[name][number].first) 3385 return values[name][number].second; 3386 return {}; 3387 } 3388 3389 //===--------------------------------------------------------------------===// 3390 // Operation Parsing 3391 //===--------------------------------------------------------------------===// 3392 3393 /// Parse an operation instance. 3394 ParseResult parseOperation(); 3395 3396 /// Parse a single operation successor. 3397 ParseResult parseSuccessor(Block *&dest); 3398 3399 /// Parse a comma-separated list of operation successors in brackets. 3400 ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations); 3401 3402 /// Parse an operation instance that is in the generic form. 3403 Operation *parseGenericOperation(); 3404 3405 /// Parse an operation instance that is in the generic form and insert it at 3406 /// the provided insertion point. 3407 Operation *parseGenericOperation(Block *insertBlock, 3408 Block::iterator insertPt); 3409 3410 /// This is the structure of a result specifier in the assembly syntax, 3411 /// including the name, number of results, and location. 3412 typedef std::tuple<StringRef, unsigned, SMLoc> ResultRecord; 3413 3414 /// Parse an operation instance that is in the op-defined custom form. 3415 /// resultInfo specifies information about the "%name =" specifiers. 3416 Operation *parseCustomOperation(ArrayRef<ResultRecord> resultInfo); 3417 3418 //===--------------------------------------------------------------------===// 3419 // Region Parsing 3420 //===--------------------------------------------------------------------===// 3421 3422 /// Parse a region into 'region' with the provided entry block arguments. 3423 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 3424 /// isolated from those above. 3425 ParseResult parseRegion(Region ®ion, 3426 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 3427 bool isIsolatedNameScope = false); 3428 3429 /// Parse a region body into 'region'. 3430 ParseResult parseRegionBody(Region ®ion); 3431 3432 //===--------------------------------------------------------------------===// 3433 // Block Parsing 3434 //===--------------------------------------------------------------------===// 3435 3436 /// Parse a new block into 'block'. 3437 ParseResult parseBlock(Block *&block); 3438 3439 /// Parse a list of operations into 'block'. 3440 ParseResult parseBlockBody(Block *block); 3441 3442 /// Parse a (possibly empty) list of block arguments. 3443 ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results, 3444 Block *owner); 3445 3446 /// Get the block with the specified name, creating it if it doesn't 3447 /// already exist. The location specified is the point of use, which allows 3448 /// us to diagnose references to blocks that are not defined precisely. 3449 Block *getBlockNamed(StringRef name, SMLoc loc); 3450 3451 /// Define the block with the specified name. Returns the Block* or nullptr in 3452 /// the case of redefinition. 3453 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 3454 3455 private: 3456 /// Returns the info for a block at the current scope for the given name. 3457 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 3458 return blocksByName.back()[name]; 3459 } 3460 3461 /// Insert a new forward reference to the given block. 3462 void insertForwardRef(Block *block, SMLoc loc) { 3463 forwardRef.back().try_emplace(block, loc); 3464 } 3465 3466 /// Erase any forward reference to the given block. 3467 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 3468 3469 /// Record that a definition was added at the current scope. 3470 void recordDefinition(StringRef def); 3471 3472 /// Get the value entry for the given SSA name. 3473 SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name); 3474 3475 /// Create a forward reference placeholder value with the given location and 3476 /// result type. 3477 Value createForwardRefPlaceholder(SMLoc loc, Type type); 3478 3479 /// Return true if this is a forward reference. 3480 bool isForwardRefPlaceholder(Value value) { 3481 return forwardRefPlaceholders.count(value); 3482 } 3483 3484 /// This struct represents an isolated SSA name scope. This scope may contain 3485 /// other nested non-isolated scopes. These scopes are used for operations 3486 /// that are known to be isolated to allow for reusing names within their 3487 /// regions, even if those names are used above. 3488 struct IsolatedSSANameScope { 3489 /// Record that a definition was added at the current scope. 3490 void recordDefinition(StringRef def) { 3491 definitionsPerScope.back().insert(def); 3492 } 3493 3494 /// Push a nested name scope. 3495 void pushSSANameScope() { definitionsPerScope.push_back({}); } 3496 3497 /// Pop a nested name scope. 3498 void popSSANameScope() { 3499 for (auto &def : definitionsPerScope.pop_back_val()) 3500 values.erase(def.getKey()); 3501 } 3502 3503 /// This keeps track of all of the SSA values we are tracking for each name 3504 /// scope, indexed by their name. This has one entry per result number. 3505 llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values; 3506 3507 /// This keeps track of all of the values defined by a specific name scope. 3508 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 3509 }; 3510 3511 /// A list of isolated name scopes. 3512 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 3513 3514 /// This keeps track of the block names as well as the location of the first 3515 /// reference for each nested name scope. This is used to diagnose invalid 3516 /// block references and memorize them. 3517 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 3518 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 3519 3520 /// These are all of the placeholders we've made along with the location of 3521 /// their first reference, to allow checking for use of undefined values. 3522 DenseMap<Value, SMLoc> forwardRefPlaceholders; 3523 3524 /// The builder used when creating parsed operation instances. 3525 OpBuilder opBuilder; 3526 3527 /// The top level module operation. 3528 ModuleOp moduleOp; 3529 }; 3530 } // end anonymous namespace 3531 3532 OperationParser::~OperationParser() { 3533 for (auto &fwd : forwardRefPlaceholders) { 3534 // Drop all uses of undefined forward declared reference and destroy 3535 // defining operation. 3536 fwd.first.dropAllUses(); 3537 fwd.first.getDefiningOp()->destroy(); 3538 } 3539 } 3540 3541 /// After parsing is finished, this function must be called to see if there are 3542 /// any remaining issues. 3543 ParseResult OperationParser::finalize() { 3544 // Check for any forward references that are left. If we find any, error 3545 // out. 3546 if (!forwardRefPlaceholders.empty()) { 3547 SmallVector<std::pair<const char *, Value>, 4> errors; 3548 // Iteration over the map isn't deterministic, so sort by source location. 3549 for (auto entry : forwardRefPlaceholders) 3550 errors.push_back({entry.second.getPointer(), entry.first}); 3551 llvm::array_pod_sort(errors.begin(), errors.end()); 3552 3553 for (auto entry : errors) { 3554 auto loc = SMLoc::getFromPointer(entry.first); 3555 emitError(loc, "use of undeclared SSA value name"); 3556 } 3557 return failure(); 3558 } 3559 3560 return success(); 3561 } 3562 3563 //===----------------------------------------------------------------------===// 3564 // SSA Value Handling 3565 //===----------------------------------------------------------------------===// 3566 3567 void OperationParser::pushSSANameScope(bool isIsolated) { 3568 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 3569 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 3570 3571 // Push back a new name definition scope. 3572 if (isIsolated) 3573 isolatedNameScopes.push_back({}); 3574 isolatedNameScopes.back().pushSSANameScope(); 3575 } 3576 3577 ParseResult OperationParser::popSSANameScope() { 3578 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 3579 3580 // Verify that all referenced blocks were defined. 3581 if (!forwardRefInCurrentScope.empty()) { 3582 SmallVector<std::pair<const char *, Block *>, 4> errors; 3583 // Iteration over the map isn't deterministic, so sort by source location. 3584 for (auto entry : forwardRefInCurrentScope) { 3585 errors.push_back({entry.second.getPointer(), entry.first}); 3586 // Add this block to the top-level region to allow for automatic cleanup. 3587 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 3588 } 3589 llvm::array_pod_sort(errors.begin(), errors.end()); 3590 3591 for (auto entry : errors) { 3592 auto loc = SMLoc::getFromPointer(entry.first); 3593 emitError(loc, "reference to an undefined block"); 3594 } 3595 return failure(); 3596 } 3597 3598 // Pop the next nested namescope. If there is only one internal namescope, 3599 // just pop the isolated scope. 3600 auto ¤tNameScope = isolatedNameScopes.back(); 3601 if (currentNameScope.definitionsPerScope.size() == 1) 3602 isolatedNameScopes.pop_back(); 3603 else 3604 currentNameScope.popSSANameScope(); 3605 3606 blocksByName.pop_back(); 3607 return success(); 3608 } 3609 3610 /// Register a definition of a value with the symbol table. 3611 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) { 3612 auto &entries = getSSAValueEntry(useInfo.name); 3613 3614 // Make sure there is a slot for this value. 3615 if (entries.size() <= useInfo.number) 3616 entries.resize(useInfo.number + 1); 3617 3618 // If we already have an entry for this, check to see if it was a definition 3619 // or a forward reference. 3620 if (auto existing = entries[useInfo.number].first) { 3621 if (!isForwardRefPlaceholder(existing)) { 3622 return emitError(useInfo.loc) 3623 .append("redefinition of SSA value '", useInfo.name, "'") 3624 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3625 .append("previously defined here"); 3626 } 3627 3628 if (existing.getType() != value.getType()) { 3629 return emitError(useInfo.loc) 3630 .append("definition of SSA value '", useInfo.name, "#", 3631 useInfo.number, "' has type ", value.getType()) 3632 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3633 .append("previously used here with type ", existing.getType()); 3634 } 3635 3636 // If it was a forward reference, update everything that used it to use 3637 // the actual definition instead, delete the forward ref, and remove it 3638 // from our set of forward references we track. 3639 existing.replaceAllUsesWith(value); 3640 existing.getDefiningOp()->destroy(); 3641 forwardRefPlaceholders.erase(existing); 3642 } 3643 3644 /// Record this definition for the current scope. 3645 entries[useInfo.number] = {value, useInfo.loc}; 3646 recordDefinition(useInfo.name); 3647 return success(); 3648 } 3649 3650 /// Parse a (possibly empty) list of SSA operands. 3651 /// 3652 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 3653 /// ssa-use-list-opt ::= ssa-use-list? 3654 /// 3655 ParseResult 3656 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 3657 if (getToken().isNot(Token::percent_identifier)) 3658 return success(); 3659 return parseCommaSeparatedList([&]() -> ParseResult { 3660 SSAUseInfo result; 3661 if (parseSSAUse(result)) 3662 return failure(); 3663 results.push_back(result); 3664 return success(); 3665 }); 3666 } 3667 3668 /// Parse a SSA operand for an operation. 3669 /// 3670 /// ssa-use ::= ssa-id 3671 /// 3672 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 3673 result.name = getTokenSpelling(); 3674 result.number = 0; 3675 result.loc = getToken().getLoc(); 3676 if (parseToken(Token::percent_identifier, "expected SSA operand")) 3677 return failure(); 3678 3679 // If we have an attribute ID, it is a result number. 3680 if (getToken().is(Token::hash_identifier)) { 3681 if (auto value = getToken().getHashIdentifierNumber()) 3682 result.number = value.getValue(); 3683 else 3684 return emitError("invalid SSA value result number"); 3685 consumeToken(Token::hash_identifier); 3686 } 3687 3688 return success(); 3689 } 3690 3691 /// Given an unbound reference to an SSA value and its type, return the value 3692 /// it specifies. This returns null on failure. 3693 Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 3694 auto &entries = getSSAValueEntry(useInfo.name); 3695 3696 // If we have already seen a value of this name, return it. 3697 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 3698 auto result = entries[useInfo.number].first; 3699 // Check that the type matches the other uses. 3700 if (result.getType() == type) 3701 return result; 3702 3703 emitError(useInfo.loc, "use of value '") 3704 .append(useInfo.name, 3705 "' expects different type than prior uses: ", type, " vs ", 3706 result.getType()) 3707 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3708 .append("prior use here"); 3709 return nullptr; 3710 } 3711 3712 // Make sure we have enough slots for this. 3713 if (entries.size() <= useInfo.number) 3714 entries.resize(useInfo.number + 1); 3715 3716 // If the value has already been defined and this is an overly large result 3717 // number, diagnose that. 3718 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 3719 return (emitError(useInfo.loc, "reference to invalid result number"), 3720 nullptr); 3721 3722 // Otherwise, this is a forward reference. Create a placeholder and remember 3723 // that we did so. 3724 auto result = createForwardRefPlaceholder(useInfo.loc, type); 3725 entries[useInfo.number].first = result; 3726 entries[useInfo.number].second = useInfo.loc; 3727 return result; 3728 } 3729 3730 /// Parse an SSA use with an associated type. 3731 /// 3732 /// ssa-use-and-type ::= ssa-use `:` type 3733 ParseResult OperationParser::parseSSADefOrUseAndType( 3734 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 3735 SSAUseInfo useInfo; 3736 if (parseSSAUse(useInfo) || 3737 parseToken(Token::colon, "expected ':' and type for SSA operand")) 3738 return failure(); 3739 3740 auto type = parseType(); 3741 if (!type) 3742 return failure(); 3743 3744 return action(useInfo, type); 3745 } 3746 3747 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 3748 /// followed by a type list. 3749 /// 3750 /// ssa-use-and-type-list 3751 /// ::= ssa-use-list ':' type-list-no-parens 3752 /// 3753 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 3754 SmallVectorImpl<Value> &results) { 3755 SmallVector<SSAUseInfo, 4> valueIDs; 3756 if (parseOptionalSSAUseList(valueIDs)) 3757 return failure(); 3758 3759 // If there were no operands, then there is no colon or type lists. 3760 if (valueIDs.empty()) 3761 return success(); 3762 3763 SmallVector<Type, 4> types; 3764 if (parseToken(Token::colon, "expected ':' in operand list") || 3765 parseTypeListNoParens(types)) 3766 return failure(); 3767 3768 if (valueIDs.size() != types.size()) 3769 return emitError("expected ") 3770 << valueIDs.size() << " types to match operand list"; 3771 3772 results.reserve(valueIDs.size()); 3773 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 3774 if (auto value = resolveSSAUse(valueIDs[i], types[i])) 3775 results.push_back(value); 3776 else 3777 return failure(); 3778 } 3779 3780 return success(); 3781 } 3782 3783 /// Record that a definition was added at the current scope. 3784 void OperationParser::recordDefinition(StringRef def) { 3785 isolatedNameScopes.back().recordDefinition(def); 3786 } 3787 3788 /// Get the value entry for the given SSA name. 3789 SmallVectorImpl<std::pair<Value, SMLoc>> & 3790 OperationParser::getSSAValueEntry(StringRef name) { 3791 return isolatedNameScopes.back().values[name]; 3792 } 3793 3794 /// Create and remember a new placeholder for a forward reference. 3795 Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 3796 // Forward references are always created as operations, because we just need 3797 // something with a def/use chain. 3798 // 3799 // We create these placeholders as having an empty name, which we know 3800 // cannot be created through normal user input, allowing us to distinguish 3801 // them. 3802 auto name = OperationName("placeholder", getContext()); 3803 auto *op = Operation::create( 3804 getEncodedSourceLocation(loc), name, type, /*operands=*/{}, 3805 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0); 3806 forwardRefPlaceholders[op->getResult(0)] = loc; 3807 return op->getResult(0); 3808 } 3809 3810 //===----------------------------------------------------------------------===// 3811 // Operation Parsing 3812 //===----------------------------------------------------------------------===// 3813 3814 /// Parse an operation. 3815 /// 3816 /// operation ::= op-result-list? 3817 /// (generic-operation | custom-operation) 3818 /// trailing-location? 3819 /// generic-operation ::= string-literal `(` ssa-use-list? `)` 3820 /// successor-list? (`(` region-list `)`)? 3821 /// attribute-dict? `:` function-type 3822 /// custom-operation ::= bare-id custom-operation-format 3823 /// op-result-list ::= op-result (`,` op-result)* `=` 3824 /// op-result ::= ssa-id (`:` integer-literal) 3825 /// 3826 ParseResult OperationParser::parseOperation() { 3827 auto loc = getToken().getLoc(); 3828 SmallVector<ResultRecord, 1> resultIDs; 3829 size_t numExpectedResults = 0; 3830 if (getToken().is(Token::percent_identifier)) { 3831 // Parse the group of result ids. 3832 auto parseNextResult = [&]() -> ParseResult { 3833 // Parse the next result id. 3834 if (!getToken().is(Token::percent_identifier)) 3835 return emitError("expected valid ssa identifier"); 3836 3837 Token nameTok = getToken(); 3838 consumeToken(Token::percent_identifier); 3839 3840 // If the next token is a ':', we parse the expected result count. 3841 size_t expectedSubResults = 1; 3842 if (consumeIf(Token::colon)) { 3843 // Check that the next token is an integer. 3844 if (!getToken().is(Token::integer)) 3845 return emitError("expected integer number of results"); 3846 3847 // Check that number of results is > 0. 3848 auto val = getToken().getUInt64IntegerValue(); 3849 if (!val.hasValue() || val.getValue() < 1) 3850 return emitError("expected named operation to have atleast 1 result"); 3851 consumeToken(Token::integer); 3852 expectedSubResults = *val; 3853 } 3854 3855 resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults, 3856 nameTok.getLoc()); 3857 numExpectedResults += expectedSubResults; 3858 return success(); 3859 }; 3860 if (parseCommaSeparatedList(parseNextResult)) 3861 return failure(); 3862 3863 if (parseToken(Token::equal, "expected '=' after SSA name")) 3864 return failure(); 3865 } 3866 3867 Operation *op; 3868 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3869 op = parseCustomOperation(resultIDs); 3870 else if (getToken().is(Token::string)) 3871 op = parseGenericOperation(); 3872 else 3873 return emitError("expected operation name in quotes"); 3874 3875 // If parsing of the basic operation failed, then this whole thing fails. 3876 if (!op) 3877 return failure(); 3878 3879 // If the operation had a name, register it. 3880 if (!resultIDs.empty()) { 3881 if (op->getNumResults() == 0) 3882 return emitError(loc, "cannot name an operation with no results"); 3883 if (numExpectedResults != op->getNumResults()) 3884 return emitError(loc, "operation defines ") 3885 << op->getNumResults() << " results but was provided " 3886 << numExpectedResults << " to bind"; 3887 3888 // Add definitions for each of the result groups. 3889 unsigned opResI = 0; 3890 for (ResultRecord &resIt : resultIDs) { 3891 for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) { 3892 if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, 3893 op->getResult(opResI++))) 3894 return failure(); 3895 } 3896 } 3897 } 3898 3899 return success(); 3900 } 3901 3902 /// Parse a single operation successor. 3903 /// 3904 /// successor ::= block-id 3905 /// 3906 ParseResult OperationParser::parseSuccessor(Block *&dest) { 3907 // Verify branch is identifier and get the matching block. 3908 if (!getToken().is(Token::caret_identifier)) 3909 return emitError("expected block name"); 3910 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3911 consumeToken(); 3912 return success(); 3913 } 3914 3915 /// Parse a comma-separated list of operation successors in brackets. 3916 /// 3917 /// successor-list ::= `[` successor (`,` successor )* `]` 3918 /// 3919 ParseResult 3920 OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) { 3921 if (parseToken(Token::l_square, "expected '['")) 3922 return failure(); 3923 3924 auto parseElt = [this, &destinations] { 3925 Block *dest; 3926 ParseResult res = parseSuccessor(dest); 3927 destinations.push_back(dest); 3928 return res; 3929 }; 3930 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3931 /*allowEmptyList=*/false); 3932 } 3933 3934 namespace { 3935 // RAII-style guard for cleaning up the regions in the operation state before 3936 // deleting them. Within the parser, regions may get deleted if parsing failed, 3937 // and other errors may be present, in particular undominated uses. This makes 3938 // sure such uses are deleted. 3939 struct CleanupOpStateRegions { 3940 ~CleanupOpStateRegions() { 3941 SmallVector<Region *, 4> regionsToClean; 3942 regionsToClean.reserve(state.regions.size()); 3943 for (auto ®ion : state.regions) 3944 if (region) 3945 for (auto &block : *region) 3946 block.dropAllDefinedValueUses(); 3947 } 3948 OperationState &state; 3949 }; 3950 } // namespace 3951 3952 Operation *OperationParser::parseGenericOperation() { 3953 // Get location information for the operation. 3954 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3955 3956 auto name = getToken().getStringValue(); 3957 if (name.empty()) 3958 return (emitError("empty operation name is invalid"), nullptr); 3959 if (name.find('\0') != StringRef::npos) 3960 return (emitError("null character not allowed in operation name"), nullptr); 3961 3962 consumeToken(Token::string); 3963 3964 OperationState result(srcLocation, name); 3965 3966 // Parse the operand list. 3967 SmallVector<SSAUseInfo, 8> operandInfos; 3968 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3969 parseOptionalSSAUseList(operandInfos) || 3970 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3971 return nullptr; 3972 } 3973 3974 // Parse the successor list. 3975 if (getToken().is(Token::l_square)) { 3976 // Check if the operation is a known terminator. 3977 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3978 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3979 return emitError("successors in non-terminator"), nullptr; 3980 3981 SmallVector<Block *, 2> successors; 3982 if (parseSuccessors(successors)) 3983 return nullptr; 3984 result.addSuccessors(successors); 3985 } 3986 3987 // Parse the region list. 3988 CleanupOpStateRegions guard{result}; 3989 if (consumeIf(Token::l_paren)) { 3990 do { 3991 // Create temporary regions with the top level region as parent. 3992 result.regions.emplace_back(new Region(moduleOp)); 3993 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3994 return nullptr; 3995 } while (consumeIf(Token::comma)); 3996 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3997 return nullptr; 3998 } 3999 4000 if (getToken().is(Token::l_brace)) { 4001 if (parseAttributeDict(result.attributes)) 4002 return nullptr; 4003 } 4004 4005 if (parseToken(Token::colon, "expected ':' followed by operation type")) 4006 return nullptr; 4007 4008 auto typeLoc = getToken().getLoc(); 4009 auto type = parseType(); 4010 if (!type) 4011 return nullptr; 4012 auto fnType = type.dyn_cast<FunctionType>(); 4013 if (!fnType) 4014 return (emitError(typeLoc, "expected function type"), nullptr); 4015 4016 result.addTypes(fnType.getResults()); 4017 4018 // Check that we have the right number of types for the operands. 4019 auto operandTypes = fnType.getInputs(); 4020 if (operandTypes.size() != operandInfos.size()) { 4021 auto plural = "s"[operandInfos.size() == 1]; 4022 return (emitError(typeLoc, "expected ") 4023 << operandInfos.size() << " operand type" << plural 4024 << " but had " << operandTypes.size(), 4025 nullptr); 4026 } 4027 4028 // Resolve all of the operands. 4029 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 4030 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 4031 if (!result.operands.back()) 4032 return nullptr; 4033 } 4034 4035 // Parse a location if one is present. 4036 if (parseOptionalTrailingLocation(result.location)) 4037 return nullptr; 4038 4039 return opBuilder.createOperation(result); 4040 } 4041 4042 Operation *OperationParser::parseGenericOperation(Block *insertBlock, 4043 Block::iterator insertPt) { 4044 OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); 4045 opBuilder.setInsertionPoint(insertBlock, insertPt); 4046 return parseGenericOperation(); 4047 } 4048 4049 namespace { 4050 class CustomOpAsmParser : public OpAsmParser { 4051 public: 4052 CustomOpAsmParser(SMLoc nameLoc, 4053 ArrayRef<OperationParser::ResultRecord> resultIDs, 4054 const AbstractOperation *opDefinition, 4055 OperationParser &parser) 4056 : nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition), 4057 parser(parser) {} 4058 4059 /// Parse an instance of the operation described by 'opDefinition' into the 4060 /// provided operation state. 4061 ParseResult parseOperation(OperationState &opState) { 4062 if (opDefinition->parseAssembly(*this, opState)) 4063 return failure(); 4064 return success(); 4065 } 4066 4067 Operation *parseGenericOperation(Block *insertBlock, 4068 Block::iterator insertPt) final { 4069 return parser.parseGenericOperation(insertBlock, insertPt); 4070 } 4071 4072 //===--------------------------------------------------------------------===// 4073 // Utilities 4074 //===--------------------------------------------------------------------===// 4075 4076 /// Return if any errors were emitted during parsing. 4077 bool didEmitError() const { return emittedError; } 4078 4079 /// Emit a diagnostic at the specified location and return failure. 4080 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 4081 emittedError = true; 4082 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 4083 message); 4084 } 4085 4086 llvm::SMLoc getCurrentLocation() override { 4087 return parser.getToken().getLoc(); 4088 } 4089 4090 Builder &getBuilder() const override { return parser.builder; } 4091 4092 /// Return the name of the specified result in the specified syntax, as well 4093 /// as the subelement in the name. For example, in this operation: 4094 /// 4095 /// %x, %y:2, %z = foo.op 4096 /// 4097 /// getResultName(0) == {"x", 0 } 4098 /// getResultName(1) == {"y", 0 } 4099 /// getResultName(2) == {"y", 1 } 4100 /// getResultName(3) == {"z", 0 } 4101 std::pair<StringRef, unsigned> 4102 getResultName(unsigned resultNo) const override { 4103 // Scan for the resultID that contains this result number. 4104 for (unsigned nameID = 0, e = resultIDs.size(); nameID != e; ++nameID) { 4105 const auto &entry = resultIDs[nameID]; 4106 if (resultNo < std::get<1>(entry)) { 4107 // Don't pass on the leading %. 4108 StringRef name = std::get<0>(entry).drop_front(); 4109 return {name, resultNo}; 4110 } 4111 resultNo -= std::get<1>(entry); 4112 } 4113 4114 // Invalid result number. 4115 return {"", ~0U}; 4116 } 4117 4118 /// Return the number of declared SSA results. This returns 4 for the foo.op 4119 /// example in the comment for getResultName. 4120 size_t getNumResults() const override { 4121 size_t count = 0; 4122 for (auto &entry : resultIDs) 4123 count += std::get<1>(entry); 4124 return count; 4125 } 4126 4127 llvm::SMLoc getNameLoc() const override { return nameLoc; } 4128 4129 //===--------------------------------------------------------------------===// 4130 // Token Parsing 4131 //===--------------------------------------------------------------------===// 4132 4133 /// Parse a `->` token. 4134 ParseResult parseArrow() override { 4135 return parser.parseToken(Token::arrow, "expected '->'"); 4136 } 4137 4138 /// Parses a `->` if present. 4139 ParseResult parseOptionalArrow() override { 4140 return success(parser.consumeIf(Token::arrow)); 4141 } 4142 4143 /// Parse a `:` token. 4144 ParseResult parseColon() override { 4145 return parser.parseToken(Token::colon, "expected ':'"); 4146 } 4147 4148 /// Parse a `:` token if present. 4149 ParseResult parseOptionalColon() override { 4150 return success(parser.consumeIf(Token::colon)); 4151 } 4152 4153 /// Parse a `,` token. 4154 ParseResult parseComma() override { 4155 return parser.parseToken(Token::comma, "expected ','"); 4156 } 4157 4158 /// Parse a `,` token if present. 4159 ParseResult parseOptionalComma() override { 4160 return success(parser.consumeIf(Token::comma)); 4161 } 4162 4163 /// Parses a `...` if present. 4164 ParseResult parseOptionalEllipsis() override { 4165 return success(parser.consumeIf(Token::ellipsis)); 4166 } 4167 4168 /// Parse a `=` token. 4169 ParseResult parseEqual() override { 4170 return parser.parseToken(Token::equal, "expected '='"); 4171 } 4172 4173 /// Parse a '<' token. 4174 ParseResult parseLess() override { 4175 return parser.parseToken(Token::less, "expected '<'"); 4176 } 4177 4178 /// Parse a '>' token. 4179 ParseResult parseGreater() override { 4180 return parser.parseToken(Token::greater, "expected '>'"); 4181 } 4182 4183 /// Parse a `(` token. 4184 ParseResult parseLParen() override { 4185 return parser.parseToken(Token::l_paren, "expected '('"); 4186 } 4187 4188 /// Parses a '(' if present. 4189 ParseResult parseOptionalLParen() override { 4190 return success(parser.consumeIf(Token::l_paren)); 4191 } 4192 4193 /// Parse a `)` token. 4194 ParseResult parseRParen() override { 4195 return parser.parseToken(Token::r_paren, "expected ')'"); 4196 } 4197 4198 /// Parses a ')' if present. 4199 ParseResult parseOptionalRParen() override { 4200 return success(parser.consumeIf(Token::r_paren)); 4201 } 4202 4203 /// Parse a `[` token. 4204 ParseResult parseLSquare() override { 4205 return parser.parseToken(Token::l_square, "expected '['"); 4206 } 4207 4208 /// Parses a '[' if present. 4209 ParseResult parseOptionalLSquare() override { 4210 return success(parser.consumeIf(Token::l_square)); 4211 } 4212 4213 /// Parse a `]` token. 4214 ParseResult parseRSquare() override { 4215 return parser.parseToken(Token::r_square, "expected ']'"); 4216 } 4217 4218 /// Parses a ']' if present. 4219 ParseResult parseOptionalRSquare() override { 4220 return success(parser.consumeIf(Token::r_square)); 4221 } 4222 4223 //===--------------------------------------------------------------------===// 4224 // Attribute Parsing 4225 //===--------------------------------------------------------------------===// 4226 4227 /// Parse an arbitrary attribute of a given type and return it in result. This 4228 /// also adds the attribute to the specified attribute list with the specified 4229 /// name. 4230 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 4231 SmallVectorImpl<NamedAttribute> &attrs) override { 4232 result = parser.parseAttribute(type); 4233 if (!result) 4234 return failure(); 4235 4236 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 4237 return success(); 4238 } 4239 4240 /// Parse a named dictionary into 'result' if it is present. 4241 ParseResult 4242 parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override { 4243 if (parser.getToken().isNot(Token::l_brace)) 4244 return success(); 4245 return parser.parseAttributeDict(result); 4246 } 4247 4248 /// Parse a named dictionary into 'result' if the `attributes` keyword is 4249 /// present. 4250 ParseResult parseOptionalAttrDictWithKeyword( 4251 SmallVectorImpl<NamedAttribute> &result) override { 4252 if (failed(parseOptionalKeyword("attributes"))) 4253 return success(); 4254 return parser.parseAttributeDict(result); 4255 } 4256 4257 /// Parse an affine map instance into 'map'. 4258 ParseResult parseAffineMap(AffineMap &map) override { 4259 return parser.parseAffineMapReference(map); 4260 } 4261 4262 /// Parse an integer set instance into 'set'. 4263 ParseResult printIntegerSet(IntegerSet &set) override { 4264 return parser.parseIntegerSetReference(set); 4265 } 4266 4267 //===--------------------------------------------------------------------===// 4268 // Identifier Parsing 4269 //===--------------------------------------------------------------------===// 4270 4271 /// Returns if the current token corresponds to a keyword. 4272 bool isCurrentTokenAKeyword() const { 4273 return parser.getToken().is(Token::bare_identifier) || 4274 parser.getToken().isKeyword(); 4275 } 4276 4277 /// Parse the given keyword if present. 4278 ParseResult parseOptionalKeyword(StringRef keyword) override { 4279 // Check that the current token has the same spelling. 4280 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 4281 return failure(); 4282 parser.consumeToken(); 4283 return success(); 4284 } 4285 4286 /// Parse a keyword, if present, into 'keyword'. 4287 ParseResult parseOptionalKeyword(StringRef *keyword) override { 4288 // Check that the current token is a keyword. 4289 if (!isCurrentTokenAKeyword()) 4290 return failure(); 4291 4292 *keyword = parser.getTokenSpelling(); 4293 parser.consumeToken(); 4294 return success(); 4295 } 4296 4297 /// Parse an optional @-identifier and store it (without the '@' symbol) in a 4298 /// string attribute named 'attrName'. 4299 ParseResult 4300 parseOptionalSymbolName(StringAttr &result, StringRef attrName, 4301 SmallVectorImpl<NamedAttribute> &attrs) override { 4302 Token atToken = parser.getToken(); 4303 if (atToken.isNot(Token::at_identifier)) 4304 return failure(); 4305 4306 result = getBuilder().getStringAttr(extractSymbolReference(atToken)); 4307 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 4308 parser.consumeToken(); 4309 return success(); 4310 } 4311 4312 //===--------------------------------------------------------------------===// 4313 // Operand Parsing 4314 //===--------------------------------------------------------------------===// 4315 4316 /// Parse a single operand. 4317 ParseResult parseOperand(OperandType &result) override { 4318 OperationParser::SSAUseInfo useInfo; 4319 if (parser.parseSSAUse(useInfo)) 4320 return failure(); 4321 4322 result = {useInfo.loc, useInfo.name, useInfo.number}; 4323 return success(); 4324 } 4325 4326 /// Parse a single operand if present. 4327 OptionalParseResult parseOptionalOperand(OperandType &result) override { 4328 if (parser.getToken().is(Token::percent_identifier)) 4329 return parseOperand(result); 4330 return llvm::None; 4331 } 4332 4333 /// Parse zero or more SSA comma-separated operand references with a specified 4334 /// surrounding delimiter, and an optional required operand count. 4335 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 4336 int requiredOperandCount = -1, 4337 Delimiter delimiter = Delimiter::None) override { 4338 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 4339 requiredOperandCount, delimiter); 4340 } 4341 4342 /// Parse zero or more SSA comma-separated operand or region arguments with 4343 /// optional surrounding delimiter and required operand count. 4344 ParseResult 4345 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 4346 bool isOperandList, int requiredOperandCount = -1, 4347 Delimiter delimiter = Delimiter::None) { 4348 auto startLoc = parser.getToken().getLoc(); 4349 4350 // Handle delimiters. 4351 switch (delimiter) { 4352 case Delimiter::None: 4353 // Don't check for the absence of a delimiter if the number of operands 4354 // is unknown (and hence the operand list could be empty). 4355 if (requiredOperandCount == -1) 4356 break; 4357 // Token already matches an identifier and so can't be a delimiter. 4358 if (parser.getToken().is(Token::percent_identifier)) 4359 break; 4360 // Test against known delimiters. 4361 if (parser.getToken().is(Token::l_paren) || 4362 parser.getToken().is(Token::l_square)) 4363 return emitError(startLoc, "unexpected delimiter"); 4364 return emitError(startLoc, "invalid operand"); 4365 case Delimiter::OptionalParen: 4366 if (parser.getToken().isNot(Token::l_paren)) 4367 return success(); 4368 LLVM_FALLTHROUGH; 4369 case Delimiter::Paren: 4370 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 4371 return failure(); 4372 break; 4373 case Delimiter::OptionalSquare: 4374 if (parser.getToken().isNot(Token::l_square)) 4375 return success(); 4376 LLVM_FALLTHROUGH; 4377 case Delimiter::Square: 4378 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 4379 return failure(); 4380 break; 4381 } 4382 4383 // Check for zero operands. 4384 if (parser.getToken().is(Token::percent_identifier)) { 4385 do { 4386 OperandType operandOrArg; 4387 if (isOperandList ? parseOperand(operandOrArg) 4388 : parseRegionArgument(operandOrArg)) 4389 return failure(); 4390 result.push_back(operandOrArg); 4391 } while (parser.consumeIf(Token::comma)); 4392 } 4393 4394 // Handle delimiters. If we reach here, the optional delimiters were 4395 // present, so we need to parse their closing one. 4396 switch (delimiter) { 4397 case Delimiter::None: 4398 break; 4399 case Delimiter::OptionalParen: 4400 case Delimiter::Paren: 4401 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 4402 return failure(); 4403 break; 4404 case Delimiter::OptionalSquare: 4405 case Delimiter::Square: 4406 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 4407 return failure(); 4408 break; 4409 } 4410 4411 if (requiredOperandCount != -1 && 4412 result.size() != static_cast<size_t>(requiredOperandCount)) 4413 return emitError(startLoc, "expected ") 4414 << requiredOperandCount << " operands"; 4415 return success(); 4416 } 4417 4418 /// Parse zero or more trailing SSA comma-separated trailing operand 4419 /// references with a specified surrounding delimiter, and an optional 4420 /// required operand count. A leading comma is expected before the operands. 4421 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 4422 int requiredOperandCount, 4423 Delimiter delimiter) override { 4424 if (parser.getToken().is(Token::comma)) { 4425 parseComma(); 4426 return parseOperandList(result, requiredOperandCount, delimiter); 4427 } 4428 if (requiredOperandCount != -1) 4429 return emitError(parser.getToken().getLoc(), "expected ") 4430 << requiredOperandCount << " operands"; 4431 return success(); 4432 } 4433 4434 /// Resolve an operand to an SSA value, emitting an error on failure. 4435 ParseResult resolveOperand(const OperandType &operand, Type type, 4436 SmallVectorImpl<Value> &result) override { 4437 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4438 operand.location}; 4439 if (auto value = parser.resolveSSAUse(operandInfo, type)) { 4440 result.push_back(value); 4441 return success(); 4442 } 4443 return failure(); 4444 } 4445 4446 /// Parse an AffineMap of SSA ids. 4447 ParseResult parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 4448 Attribute &mapAttr, StringRef attrName, 4449 SmallVectorImpl<NamedAttribute> &attrs, 4450 Delimiter delimiter) override { 4451 SmallVector<OperandType, 2> dimOperands; 4452 SmallVector<OperandType, 1> symOperands; 4453 4454 auto parseElement = [&](bool isSymbol) -> ParseResult { 4455 OperandType operand; 4456 if (parseOperand(operand)) 4457 return failure(); 4458 if (isSymbol) 4459 symOperands.push_back(operand); 4460 else 4461 dimOperands.push_back(operand); 4462 return success(); 4463 }; 4464 4465 AffineMap map; 4466 if (parser.parseAffineMapOfSSAIds(map, parseElement, delimiter)) 4467 return failure(); 4468 // Add AffineMap attribute. 4469 if (map) { 4470 mapAttr = AffineMapAttr::get(map); 4471 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 4472 } 4473 4474 // Add dim operands before symbol operands in 'operands'. 4475 operands.assign(dimOperands.begin(), dimOperands.end()); 4476 operands.append(symOperands.begin(), symOperands.end()); 4477 return success(); 4478 } 4479 4480 //===--------------------------------------------------------------------===// 4481 // Region Parsing 4482 //===--------------------------------------------------------------------===// 4483 4484 /// Parse a region that takes `arguments` of `argTypes` types. This 4485 /// effectively defines the SSA values of `arguments` and assigns their type. 4486 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 4487 ArrayRef<Type> argTypes, 4488 bool enableNameShadowing) override { 4489 assert(arguments.size() == argTypes.size() && 4490 "mismatching number of arguments and types"); 4491 4492 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 4493 regionArguments; 4494 for (auto pair : llvm::zip(arguments, argTypes)) { 4495 const OperandType &operand = std::get<0>(pair); 4496 Type type = std::get<1>(pair); 4497 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4498 operand.location}; 4499 regionArguments.emplace_back(operandInfo, type); 4500 } 4501 4502 // Try to parse the region. 4503 assert((!enableNameShadowing || 4504 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 4505 "name shadowing is only allowed on isolated regions"); 4506 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 4507 return failure(); 4508 return success(); 4509 } 4510 4511 /// Parses a region if present. 4512 ParseResult parseOptionalRegion(Region ®ion, 4513 ArrayRef<OperandType> arguments, 4514 ArrayRef<Type> argTypes, 4515 bool enableNameShadowing) override { 4516 if (parser.getToken().isNot(Token::l_brace)) 4517 return success(); 4518 return parseRegion(region, arguments, argTypes, enableNameShadowing); 4519 } 4520 4521 /// Parse a region argument. The type of the argument will be resolved later 4522 /// by a call to `parseRegion`. 4523 ParseResult parseRegionArgument(OperandType &argument) override { 4524 return parseOperand(argument); 4525 } 4526 4527 /// Parse a region argument if present. 4528 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 4529 if (parser.getToken().isNot(Token::percent_identifier)) 4530 return success(); 4531 return parseRegionArgument(argument); 4532 } 4533 4534 ParseResult 4535 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 4536 int requiredOperandCount = -1, 4537 Delimiter delimiter = Delimiter::None) override { 4538 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 4539 requiredOperandCount, delimiter); 4540 } 4541 4542 //===--------------------------------------------------------------------===// 4543 // Successor Parsing 4544 //===--------------------------------------------------------------------===// 4545 4546 /// Parse a single operation successor. 4547 ParseResult parseSuccessor(Block *&dest) override { 4548 return parser.parseSuccessor(dest); 4549 } 4550 4551 /// Parse an optional operation successor and its operand list. 4552 OptionalParseResult parseOptionalSuccessor(Block *&dest) override { 4553 if (parser.getToken().isNot(Token::caret_identifier)) 4554 return llvm::None; 4555 return parseSuccessor(dest); 4556 } 4557 4558 /// Parse a single operation successor and its operand list. 4559 ParseResult 4560 parseSuccessorAndUseList(Block *&dest, 4561 SmallVectorImpl<Value> &operands) override { 4562 if (parseSuccessor(dest)) 4563 return failure(); 4564 4565 // Handle optional arguments. 4566 if (succeeded(parseOptionalLParen()) && 4567 (parser.parseOptionalSSAUseAndTypeList(operands) || parseRParen())) { 4568 return failure(); 4569 } 4570 return success(); 4571 } 4572 4573 //===--------------------------------------------------------------------===// 4574 // Type Parsing 4575 //===--------------------------------------------------------------------===// 4576 4577 /// Parse a type. 4578 ParseResult parseType(Type &result) override { 4579 return failure(!(result = parser.parseType())); 4580 } 4581 4582 /// Parse an optional type. 4583 OptionalParseResult parseOptionalType(Type &result) override { 4584 return parser.parseOptionalType(result); 4585 } 4586 4587 /// Parse an arrow followed by a type list. 4588 ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override { 4589 if (parseArrow() || parser.parseFunctionResultTypes(result)) 4590 return failure(); 4591 return success(); 4592 } 4593 4594 /// Parse an optional arrow followed by a type list. 4595 ParseResult 4596 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 4597 if (!parser.consumeIf(Token::arrow)) 4598 return success(); 4599 return parser.parseFunctionResultTypes(result); 4600 } 4601 4602 /// Parse a colon followed by a type. 4603 ParseResult parseColonType(Type &result) override { 4604 return failure(parser.parseToken(Token::colon, "expected ':'") || 4605 !(result = parser.parseType())); 4606 } 4607 4608 /// Parse a colon followed by a type list, which must have at least one type. 4609 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 4610 if (parser.parseToken(Token::colon, "expected ':'")) 4611 return failure(); 4612 return parser.parseTypeListNoParens(result); 4613 } 4614 4615 /// Parse an optional colon followed by a type list, which if present must 4616 /// have at least one type. 4617 ParseResult 4618 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 4619 if (!parser.consumeIf(Token::colon)) 4620 return success(); 4621 return parser.parseTypeListNoParens(result); 4622 } 4623 4624 /// Parse a list of assignments of the form 4625 /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). 4626 /// The list must contain at least one entry 4627 ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs, 4628 SmallVectorImpl<OperandType> &rhs) override { 4629 auto parseElt = [&]() -> ParseResult { 4630 OperandType regionArg, operand; 4631 if (parseRegionArgument(regionArg) || parseEqual() || 4632 parseOperand(operand)) 4633 return failure(); 4634 lhs.push_back(regionArg); 4635 rhs.push_back(operand); 4636 return success(); 4637 }; 4638 if (parseLParen()) 4639 return failure(); 4640 return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); 4641 } 4642 4643 private: 4644 /// The source location of the operation name. 4645 SMLoc nameLoc; 4646 4647 /// Information about the result name specifiers. 4648 ArrayRef<OperationParser::ResultRecord> resultIDs; 4649 4650 /// The abstract information of the operation. 4651 const AbstractOperation *opDefinition; 4652 4653 /// The main operation parser. 4654 OperationParser &parser; 4655 4656 /// A flag that indicates if any errors were emitted during parsing. 4657 bool emittedError = false; 4658 }; 4659 } // end anonymous namespace. 4660 4661 Operation * 4662 OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) { 4663 auto opLoc = getToken().getLoc(); 4664 auto opName = getTokenSpelling(); 4665 4666 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 4667 if (!opDefinition && !opName.contains('.')) { 4668 // If the operation name has no namespace prefix we treat it as a standard 4669 // operation and prefix it with "std". 4670 // TODO: Would it be better to just build a mapping of the registered 4671 // operations in the standard dialect? 4672 opDefinition = 4673 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 4674 } 4675 4676 if (!opDefinition) { 4677 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 4678 return nullptr; 4679 } 4680 4681 consumeToken(); 4682 4683 // If the custom op parser crashes, produce some indication to help 4684 // debugging. 4685 std::string opNameStr = opName.str(); 4686 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 4687 opNameStr.c_str()); 4688 4689 // Get location information for the operation. 4690 auto srcLocation = getEncodedSourceLocation(opLoc); 4691 4692 // Have the op implementation take a crack and parsing this. 4693 OperationState opState(srcLocation, opDefinition->name); 4694 CleanupOpStateRegions guard{opState}; 4695 CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this); 4696 if (opAsmParser.parseOperation(opState)) 4697 return nullptr; 4698 4699 // If it emitted an error, we failed. 4700 if (opAsmParser.didEmitError()) 4701 return nullptr; 4702 4703 // Parse a location if one is present. 4704 if (parseOptionalTrailingLocation(opState.location)) 4705 return nullptr; 4706 4707 // Otherwise, we succeeded. Use the state it parsed as our op information. 4708 return opBuilder.createOperation(opState); 4709 } 4710 4711 //===----------------------------------------------------------------------===// 4712 // Region Parsing 4713 //===----------------------------------------------------------------------===// 4714 4715 /// Region. 4716 /// 4717 /// region ::= '{' region-body 4718 /// 4719 ParseResult OperationParser::parseRegion( 4720 Region ®ion, 4721 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 4722 bool isIsolatedNameScope) { 4723 // Parse the '{'. 4724 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 4725 return failure(); 4726 4727 // Check for an empty region. 4728 if (entryArguments.empty() && consumeIf(Token::r_brace)) 4729 return success(); 4730 auto currentPt = opBuilder.saveInsertionPoint(); 4731 4732 // Push a new named value scope. 4733 pushSSANameScope(isIsolatedNameScope); 4734 4735 // Parse the first block directly to allow for it to be unnamed. 4736 Block *block = new Block(); 4737 4738 // Add arguments to the entry block. 4739 if (!entryArguments.empty()) { 4740 for (auto &placeholderArgPair : entryArguments) { 4741 auto &argInfo = placeholderArgPair.first; 4742 // Ensure that the argument was not already defined. 4743 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 4744 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 4745 "' is already in use") 4746 .attachNote(getEncodedSourceLocation(*defLoc)) 4747 << "previously referenced here"; 4748 } 4749 if (addDefinition(placeholderArgPair.first, 4750 block->addArgument(placeholderArgPair.second))) { 4751 delete block; 4752 return failure(); 4753 } 4754 } 4755 4756 // If we had named arguments, then don't allow a block name. 4757 if (getToken().is(Token::caret_identifier)) 4758 return emitError("invalid block name in region with named arguments"); 4759 } 4760 4761 if (parseBlock(block)) { 4762 delete block; 4763 return failure(); 4764 } 4765 4766 // Verify that no other arguments were parsed. 4767 if (!entryArguments.empty() && 4768 block->getNumArguments() > entryArguments.size()) { 4769 delete block; 4770 return emitError("entry block arguments were already defined"); 4771 } 4772 4773 // Parse the rest of the region. 4774 region.push_back(block); 4775 if (parseRegionBody(region)) 4776 return failure(); 4777 4778 // Pop the SSA value scope for this region. 4779 if (popSSANameScope()) 4780 return failure(); 4781 4782 // Reset the original insertion point. 4783 opBuilder.restoreInsertionPoint(currentPt); 4784 return success(); 4785 } 4786 4787 /// Region. 4788 /// 4789 /// region-body ::= block* '}' 4790 /// 4791 ParseResult OperationParser::parseRegionBody(Region ®ion) { 4792 // Parse the list of blocks. 4793 while (!consumeIf(Token::r_brace)) { 4794 Block *newBlock = nullptr; 4795 if (parseBlock(newBlock)) 4796 return failure(); 4797 region.push_back(newBlock); 4798 } 4799 return success(); 4800 } 4801 4802 //===----------------------------------------------------------------------===// 4803 // Block Parsing 4804 //===----------------------------------------------------------------------===// 4805 4806 /// Block declaration. 4807 /// 4808 /// block ::= block-label? operation* 4809 /// block-label ::= block-id block-arg-list? `:` 4810 /// block-id ::= caret-id 4811 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 4812 /// 4813 ParseResult OperationParser::parseBlock(Block *&block) { 4814 // The first block of a region may already exist, if it does the caret 4815 // identifier is optional. 4816 if (block && getToken().isNot(Token::caret_identifier)) 4817 return parseBlockBody(block); 4818 4819 SMLoc nameLoc = getToken().getLoc(); 4820 auto name = getTokenSpelling(); 4821 if (parseToken(Token::caret_identifier, "expected block name")) 4822 return failure(); 4823 4824 block = defineBlockNamed(name, nameLoc, block); 4825 4826 // Fail if the block was already defined. 4827 if (!block) 4828 return emitError(nameLoc, "redefinition of block '") << name << "'"; 4829 4830 // If an argument list is present, parse it. 4831 if (consumeIf(Token::l_paren)) { 4832 SmallVector<BlockArgument, 8> bbArgs; 4833 if (parseOptionalBlockArgList(bbArgs, block) || 4834 parseToken(Token::r_paren, "expected ')' to end argument list")) 4835 return failure(); 4836 } 4837 4838 if (parseToken(Token::colon, "expected ':' after block name")) 4839 return failure(); 4840 4841 return parseBlockBody(block); 4842 } 4843 4844 ParseResult OperationParser::parseBlockBody(Block *block) { 4845 // Set the insertion point to the end of the block to parse. 4846 opBuilder.setInsertionPointToEnd(block); 4847 4848 // Parse the list of operations that make up the body of the block. 4849 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 4850 if (parseOperation()) 4851 return failure(); 4852 4853 return success(); 4854 } 4855 4856 /// Get the block with the specified name, creating it if it doesn't already 4857 /// exist. The location specified is the point of use, which allows 4858 /// us to diagnose references to blocks that are not defined precisely. 4859 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 4860 auto &blockAndLoc = getBlockInfoByName(name); 4861 if (!blockAndLoc.first) { 4862 blockAndLoc = {new Block(), loc}; 4863 insertForwardRef(blockAndLoc.first, loc); 4864 } 4865 4866 return blockAndLoc.first; 4867 } 4868 4869 /// Define the block with the specified name. Returns the Block* or nullptr in 4870 /// the case of redefinition. 4871 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 4872 Block *existing) { 4873 auto &blockAndLoc = getBlockInfoByName(name); 4874 if (!blockAndLoc.first) { 4875 // If the caller provided a block, use it. Otherwise create a new one. 4876 if (!existing) 4877 existing = new Block(); 4878 blockAndLoc.first = existing; 4879 blockAndLoc.second = loc; 4880 return blockAndLoc.first; 4881 } 4882 4883 // Forward declarations are removed once defined, so if we are defining a 4884 // existing block and it is not a forward declaration, then it is a 4885 // redeclaration. 4886 if (!eraseForwardRef(blockAndLoc.first)) 4887 return nullptr; 4888 return blockAndLoc.first; 4889 } 4890 4891 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 4892 /// 4893 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 4894 /// 4895 ParseResult OperationParser::parseOptionalBlockArgList( 4896 SmallVectorImpl<BlockArgument> &results, Block *owner) { 4897 if (getToken().is(Token::r_brace)) 4898 return success(); 4899 4900 // If the block already has arguments, then we're handling the entry block. 4901 // Parse and register the names for the arguments, but do not add them. 4902 bool definingExistingArgs = owner->getNumArguments() != 0; 4903 unsigned nextArgument = 0; 4904 4905 return parseCommaSeparatedList([&]() -> ParseResult { 4906 return parseSSADefOrUseAndType( 4907 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 4908 // If this block did not have existing arguments, define a new one. 4909 if (!definingExistingArgs) 4910 return addDefinition(useInfo, owner->addArgument(type)); 4911 4912 // Otherwise, ensure that this argument has already been created. 4913 if (nextArgument >= owner->getNumArguments()) 4914 return emitError("too many arguments specified in argument list"); 4915 4916 // Finally, make sure the existing argument has the correct type. 4917 auto arg = owner->getArgument(nextArgument++); 4918 if (arg.getType() != type) 4919 return emitError("argument and block argument type mismatch"); 4920 return addDefinition(useInfo, arg); 4921 }); 4922 }); 4923 } 4924 4925 //===----------------------------------------------------------------------===// 4926 // Top-level entity parsing. 4927 //===----------------------------------------------------------------------===// 4928 4929 namespace { 4930 /// This parser handles entities that are only valid at the top level of the 4931 /// file. 4932 class ModuleParser : public Parser { 4933 public: 4934 explicit ModuleParser(ParserState &state) : Parser(state) {} 4935 4936 ParseResult parseModule(ModuleOp module); 4937 4938 private: 4939 /// Parse an attribute alias declaration. 4940 ParseResult parseAttributeAliasDef(); 4941 4942 /// Parse an attribute alias declaration. 4943 ParseResult parseTypeAliasDef(); 4944 }; 4945 } // end anonymous namespace 4946 4947 /// Parses an attribute alias declaration. 4948 /// 4949 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 4950 /// 4951 ParseResult ModuleParser::parseAttributeAliasDef() { 4952 assert(getToken().is(Token::hash_identifier)); 4953 StringRef aliasName = getTokenSpelling().drop_front(); 4954 4955 // Check for redefinitions. 4956 if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) 4957 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 4958 4959 // Make sure this isn't invading the dialect attribute namespace. 4960 if (aliasName.contains('.')) 4961 return emitError("attribute names with a '.' are reserved for " 4962 "dialect-defined names"); 4963 4964 consumeToken(Token::hash_identifier); 4965 4966 // Parse the '='. 4967 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 4968 return failure(); 4969 4970 // Parse the attribute value. 4971 Attribute attr = parseAttribute(); 4972 if (!attr) 4973 return failure(); 4974 4975 getState().symbols.attributeAliasDefinitions[aliasName] = attr; 4976 return success(); 4977 } 4978 4979 /// Parse a type alias declaration. 4980 /// 4981 /// type-alias-def ::= '!' alias-name `=` 'type' type 4982 /// 4983 ParseResult ModuleParser::parseTypeAliasDef() { 4984 assert(getToken().is(Token::exclamation_identifier)); 4985 StringRef aliasName = getTokenSpelling().drop_front(); 4986 4987 // Check for redefinitions. 4988 if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) 4989 return emitError("redefinition of type alias id '" + aliasName + "'"); 4990 4991 // Make sure this isn't invading the dialect type namespace. 4992 if (aliasName.contains('.')) 4993 return emitError("type names with a '.' are reserved for " 4994 "dialect-defined names"); 4995 4996 consumeToken(Token::exclamation_identifier); 4997 4998 // Parse the '=' and 'type'. 4999 if (parseToken(Token::equal, "expected '=' in type alias definition") || 5000 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 5001 return failure(); 5002 5003 // Parse the type. 5004 Type aliasedType = parseType(); 5005 if (!aliasedType) 5006 return failure(); 5007 5008 // Register this alias with the parser state. 5009 getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); 5010 return success(); 5011 } 5012 5013 /// This is the top-level module parser. 5014 ParseResult ModuleParser::parseModule(ModuleOp module) { 5015 OperationParser opParser(getState(), module); 5016 5017 // Module itself is a name scope. 5018 opParser.pushSSANameScope(/*isIsolated=*/true); 5019 5020 while (true) { 5021 switch (getToken().getKind()) { 5022 default: 5023 // Parse a top-level operation. 5024 if (opParser.parseOperation()) 5025 return failure(); 5026 break; 5027 5028 // If we got to the end of the file, then we're done. 5029 case Token::eof: { 5030 if (opParser.finalize()) 5031 return failure(); 5032 5033 // Handle the case where the top level module was explicitly defined. 5034 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 5035 auto &operations = bodyBlocks.front().getOperations(); 5036 assert(!operations.empty() && "expected a valid module terminator"); 5037 5038 // Check that the first operation is a module, and it is the only 5039 // non-terminator operation. 5040 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 5041 if (nested && std::next(operations.begin(), 2) == operations.end()) { 5042 // Merge the data of the nested module operation into 'module'. 5043 module.setLoc(nested.getLoc()); 5044 module.setAttrs(nested.getOperation()->getMutableAttrDict()); 5045 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 5046 5047 // Erase the original module body. 5048 bodyBlocks.pop_front(); 5049 } 5050 5051 return opParser.popSSANameScope(); 5052 } 5053 5054 // If we got an error token, then the lexer already emitted an error, just 5055 // stop. Someday we could introduce error recovery if there was demand 5056 // for it. 5057 case Token::error: 5058 return failure(); 5059 5060 // Parse an attribute alias. 5061 case Token::hash_identifier: 5062 if (parseAttributeAliasDef()) 5063 return failure(); 5064 break; 5065 5066 // Parse a type alias. 5067 case Token::exclamation_identifier: 5068 if (parseTypeAliasDef()) 5069 return failure(); 5070 break; 5071 } 5072 } 5073 } 5074 5075 //===----------------------------------------------------------------------===// 5076 5077 /// This parses the file specified by the indicated SourceMgr and returns an 5078 /// MLIR module if it was valid. If not, it emits diagnostics and returns 5079 /// null. 5080 OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 5081 MLIRContext *context) { 5082 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 5083 5084 // This is the result module we are parsing into. 5085 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 5086 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 5087 5088 SymbolState aliasState; 5089 ParserState state(sourceMgr, context, aliasState); 5090 if (ModuleParser(state).parseModule(*module)) 5091 return nullptr; 5092 5093 // Make sure the parse module has no other structural problems detected by 5094 // the verifier. 5095 if (failed(verify(*module))) 5096 return nullptr; 5097 5098 return module; 5099 } 5100 5101 /// This parses the file specified by the indicated filename and returns an 5102 /// MLIR module if it was valid. If not, the error message is emitted through 5103 /// the error handler registered in the context, and a null pointer is returned. 5104 OwningModuleRef mlir::parseSourceFile(StringRef filename, 5105 MLIRContext *context) { 5106 llvm::SourceMgr sourceMgr; 5107 return parseSourceFile(filename, sourceMgr, context); 5108 } 5109 5110 /// This parses the file specified by the indicated filename using the provided 5111 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 5112 /// message is emitted through the error handler registered in the context, and 5113 /// a null pointer is returned. 5114 OwningModuleRef mlir::parseSourceFile(StringRef filename, 5115 llvm::SourceMgr &sourceMgr, 5116 MLIRContext *context) { 5117 if (sourceMgr.getNumBuffers() != 0) { 5118 // TODO(b/136086478): Extend to support multiple buffers. 5119 emitError(mlir::UnknownLoc::get(context), 5120 "only main buffer parsed at the moment"); 5121 return nullptr; 5122 } 5123 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 5124 if (std::error_code error = file_or_err.getError()) { 5125 emitError(mlir::UnknownLoc::get(context), 5126 "could not open input file " + filename); 5127 return nullptr; 5128 } 5129 5130 // Load the MLIR module. 5131 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 5132 return parseSourceFile(sourceMgr, context); 5133 } 5134 5135 /// This parses the program string to a MLIR module if it was valid. If not, 5136 /// it emits diagnostics and returns null. 5137 OwningModuleRef mlir::parseSourceString(StringRef moduleStr, 5138 MLIRContext *context) { 5139 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 5140 if (!memBuffer) 5141 return nullptr; 5142 5143 SourceMgr sourceMgr; 5144 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 5145 return parseSourceFile(sourceMgr, context); 5146 } 5147 5148 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 5149 /// parsing failed, nullptr is returned. The number of bytes read from the input 5150 /// string is returned in 'numRead'. 5151 template <typename T, typename ParserFn> 5152 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, 5153 ParserFn &&parserFn) { 5154 SymbolState aliasState; 5155 return parseSymbol<T>( 5156 inputStr, context, aliasState, 5157 [&](Parser &parser) { 5158 SourceMgrDiagnosticHandler handler( 5159 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), 5160 parser.getContext()); 5161 return parserFn(parser); 5162 }, 5163 &numRead); 5164 } 5165 5166 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { 5167 size_t numRead = 0; 5168 return parseAttribute(attrStr, context, numRead); 5169 } 5170 Attribute mlir::parseAttribute(StringRef attrStr, Type type) { 5171 size_t numRead = 0; 5172 return parseAttribute(attrStr, type, numRead); 5173 } 5174 5175 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, 5176 size_t &numRead) { 5177 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { 5178 return parser.parseAttribute(); 5179 }); 5180 } 5181 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { 5182 return parseSymbol<Attribute>( 5183 attrStr, type.getContext(), numRead, 5184 [type](Parser &parser) { return parser.parseAttribute(type); }); 5185 } 5186 5187 Type mlir::parseType(StringRef typeStr, MLIRContext *context) { 5188 size_t numRead = 0; 5189 return parseType(typeStr, context, numRead); 5190 } 5191 5192 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { 5193 return parseSymbol<Type>(typeStr, context, numRead, 5194 [](Parser &parser) { return parser.parseType(); }); 5195 } 5196