1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR textual form. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Parser.h" 14 #include "Lexer.h" 15 #include "mlir/Analysis/Verifier.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/IR/Location.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/Module.h" 26 #include "mlir/IR/OpImplementation.h" 27 #include "mlir/IR/StandardTypes.h" 28 #include "mlir/Support/STLExtras.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/StringExtras.h" 32 #include "llvm/ADT/StringSet.h" 33 #include "llvm/ADT/bit.h" 34 #include "llvm/Support/PrettyStackTrace.h" 35 #include "llvm/Support/SMLoc.h" 36 #include "llvm/Support/SourceMgr.h" 37 #include <algorithm> 38 using namespace mlir; 39 using llvm::MemoryBuffer; 40 using llvm::SMLoc; 41 using llvm::SourceMgr; 42 43 namespace { 44 class Parser; 45 46 //===----------------------------------------------------------------------===// 47 // SymbolState 48 //===----------------------------------------------------------------------===// 49 50 /// This class contains record of any parsed top-level symbols. 51 struct SymbolState { 52 // A map from attribute alias identifier to Attribute. 53 llvm::StringMap<Attribute> attributeAliasDefinitions; 54 55 // A map from type alias identifier to Type. 56 llvm::StringMap<Type> typeAliasDefinitions; 57 58 /// A set of locations into the main parser memory buffer for each of the 59 /// active nested parsers. Given that some nested parsers, i.e. custom dialect 60 /// parsers, operate on a temporary memory buffer, this provides an anchor 61 /// point for emitting diagnostics. 62 SmallVector<llvm::SMLoc, 1> nestedParserLocs; 63 64 /// The top-level lexer that contains the original memory buffer provided by 65 /// the user. This is used by nested parsers to get a properly encoded source 66 /// location. 67 Lexer *topLevelLexer = nullptr; 68 }; 69 70 //===----------------------------------------------------------------------===// 71 // ParserState 72 //===----------------------------------------------------------------------===// 73 74 /// This class refers to all of the state maintained globally by the parser, 75 /// such as the current lexer position etc. 76 struct ParserState { 77 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, 78 SymbolState &symbols) 79 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), 80 symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { 81 // Set the top level lexer for the symbol state if one doesn't exist. 82 if (!symbols.topLevelLexer) 83 symbols.topLevelLexer = &lex; 84 } 85 ~ParserState() { 86 // Reset the top level lexer if it refers the lexer in our state. 87 if (symbols.topLevelLexer == &lex) 88 symbols.topLevelLexer = nullptr; 89 } 90 ParserState(const ParserState &) = delete; 91 void operator=(const ParserState &) = delete; 92 93 /// The context we're parsing into. 94 MLIRContext *const context; 95 96 /// The lexer for the source file we're parsing. 97 Lexer lex; 98 99 /// This is the next token that hasn't been consumed yet. 100 Token curToken; 101 102 /// The current state for symbol parsing. 103 SymbolState &symbols; 104 105 /// The depth of this parser in the nested parsing stack. 106 size_t parserDepth; 107 }; 108 109 //===----------------------------------------------------------------------===// 110 // Parser 111 //===----------------------------------------------------------------------===// 112 113 /// This class implement support for parsing global entities like types and 114 /// shared entities like SSA names. It is intended to be subclassed by 115 /// specialized subparsers that include state, e.g. when a local symbol table. 116 class Parser { 117 public: 118 Builder builder; 119 120 Parser(ParserState &state) : builder(state.context), state(state) {} 121 122 // Helper methods to get stuff from the parser-global state. 123 ParserState &getState() const { return state; } 124 MLIRContext *getContext() const { return state.context; } 125 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 126 127 /// Parse a comma-separated list of elements up until the specified end token. 128 ParseResult 129 parseCommaSeparatedListUntil(Token::Kind rightToken, 130 const std::function<ParseResult()> &parseElement, 131 bool allowEmptyList = true); 132 133 /// Parse a comma separated list of elements that must have at least one entry 134 /// in it. 135 ParseResult 136 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 137 138 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 139 140 // We have two forms of parsing methods - those that return a non-null 141 // pointer on success, and those that return a ParseResult to indicate whether 142 // they returned a failure. The second class fills in by-reference arguments 143 // as the results of their action. 144 145 //===--------------------------------------------------------------------===// 146 // Error Handling 147 //===--------------------------------------------------------------------===// 148 149 /// Emit an error and return failure. 150 InFlightDiagnostic emitError(const Twine &message = {}) { 151 return emitError(state.curToken.getLoc(), message); 152 } 153 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 154 155 /// Encode the specified source location information into an attribute for 156 /// attachment to the IR. 157 Location getEncodedSourceLocation(llvm::SMLoc loc) { 158 // If there are no active nested parsers, we can get the encoded source 159 // location directly. 160 if (state.parserDepth == 0) 161 return state.lex.getEncodedSourceLocation(loc); 162 // Otherwise, we need to re-encode it to point to the top level buffer. 163 return state.symbols.topLevelLexer->getEncodedSourceLocation( 164 remapLocationToTopLevelBuffer(loc)); 165 } 166 167 /// Remaps the given SMLoc to the top level lexer of the parser. This is used 168 /// to adjust locations of potentially nested parsers to ensure that they can 169 /// be emitted properly as diagnostics. 170 llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { 171 // If there are no active nested parsers, we can return location directly. 172 SymbolState &symbols = state.symbols; 173 if (state.parserDepth == 0) 174 return loc; 175 assert(symbols.topLevelLexer && "expected valid top-level lexer"); 176 177 // Otherwise, we need to remap the location to the main parser. This is 178 // simply offseting the location onto the location of the last nested 179 // parser. 180 size_t offset = loc.getPointer() - state.lex.getBufferBegin(); 181 auto *rawLoc = 182 symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; 183 return llvm::SMLoc::getFromPointer(rawLoc); 184 } 185 186 //===--------------------------------------------------------------------===// 187 // Token Parsing 188 //===--------------------------------------------------------------------===// 189 190 /// Return the current token the parser is inspecting. 191 const Token &getToken() const { return state.curToken; } 192 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 193 194 /// If the current token has the specified kind, consume it and return true. 195 /// If not, return false. 196 bool consumeIf(Token::Kind kind) { 197 if (state.curToken.isNot(kind)) 198 return false; 199 consumeToken(kind); 200 return true; 201 } 202 203 /// Advance the current lexer onto the next token. 204 void consumeToken() { 205 assert(state.curToken.isNot(Token::eof, Token::error) && 206 "shouldn't advance past EOF or errors"); 207 state.curToken = state.lex.lexToken(); 208 } 209 210 /// Advance the current lexer onto the next token, asserting what the expected 211 /// current token is. This is preferred to the above method because it leads 212 /// to more self-documenting code with better checking. 213 void consumeToken(Token::Kind kind) { 214 assert(state.curToken.is(kind) && "consumed an unexpected token"); 215 consumeToken(); 216 } 217 218 /// Consume the specified token if present and return success. On failure, 219 /// output a diagnostic and return failure. 220 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 221 222 //===--------------------------------------------------------------------===// 223 // Type Parsing 224 //===--------------------------------------------------------------------===// 225 226 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 227 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 228 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 229 230 /// Parse an arbitrary type. 231 Type parseType(); 232 233 /// Parse a complex type. 234 Type parseComplexType(); 235 236 /// Parse an extended type. 237 Type parseExtendedType(); 238 239 /// Parse a function type. 240 Type parseFunctionType(); 241 242 /// Parse a memref type. 243 Type parseMemRefType(); 244 245 /// Parse a non function type. 246 Type parseNonFunctionType(); 247 248 /// Parse a tensor type. 249 Type parseTensorType(); 250 251 /// Parse a tuple type. 252 Type parseTupleType(); 253 254 /// Parse a vector type. 255 VectorType parseVectorType(); 256 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 257 bool allowDynamic = true); 258 ParseResult parseXInDimensionList(); 259 260 /// Parse strided layout specification. 261 ParseResult parseStridedLayout(int64_t &offset, 262 SmallVectorImpl<int64_t> &strides); 263 264 // Parse a brace-delimiter list of comma-separated integers with `?` as an 265 // unknown marker. 266 ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions); 267 268 //===--------------------------------------------------------------------===// 269 // Attribute Parsing 270 //===--------------------------------------------------------------------===// 271 272 /// Parse an arbitrary attribute with an optional type. 273 Attribute parseAttribute(Type type = {}); 274 275 /// Parse an attribute dictionary. 276 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 277 278 /// Parse an extended attribute. 279 Attribute parseExtendedAttr(Type type); 280 281 /// Parse a float attribute. 282 Attribute parseFloatAttr(Type type, bool isNegative); 283 284 /// Parse a decimal or a hexadecimal literal, which can be either an integer 285 /// or a float attribute. 286 Attribute parseDecOrHexAttr(Type type, bool isNegative); 287 288 /// Parse an opaque elements attribute. 289 Attribute parseOpaqueElementsAttr(Type attrType); 290 291 /// Parse a dense elements attribute. 292 Attribute parseDenseElementsAttr(Type attrType); 293 ShapedType parseElementsLiteralType(Type type); 294 295 /// Parse a sparse elements attribute. 296 Attribute parseSparseElementsAttr(Type attrType); 297 298 //===--------------------------------------------------------------------===// 299 // Location Parsing 300 //===--------------------------------------------------------------------===// 301 302 /// Parse an inline location. 303 ParseResult parseLocation(LocationAttr &loc); 304 305 /// Parse a raw location instance. 306 ParseResult parseLocationInstance(LocationAttr &loc); 307 308 /// Parse a callsite location instance. 309 ParseResult parseCallSiteLocation(LocationAttr &loc); 310 311 /// Parse a fused location instance. 312 ParseResult parseFusedLocation(LocationAttr &loc); 313 314 /// Parse a name or FileLineCol location instance. 315 ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); 316 317 /// Parse an optional trailing location. 318 /// 319 /// trailing-location ::= (`loc` `(` location `)`)? 320 /// 321 ParseResult parseOptionalTrailingLocation(Location &loc) { 322 // If there is a 'loc' we parse a trailing location. 323 if (!getToken().is(Token::kw_loc)) 324 return success(); 325 326 // Parse the location. 327 LocationAttr directLoc; 328 if (parseLocation(directLoc)) 329 return failure(); 330 loc = directLoc; 331 return success(); 332 } 333 334 //===--------------------------------------------------------------------===// 335 // Affine Parsing 336 //===--------------------------------------------------------------------===// 337 338 /// Parse a reference to either an affine map, or an integer set. 339 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 340 IntegerSet &set); 341 ParseResult parseAffineMapReference(AffineMap &map); 342 ParseResult parseIntegerSetReference(IntegerSet &set); 343 344 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 345 ParseResult 346 parseAffineMapOfSSAIds(AffineMap &map, 347 function_ref<ParseResult(bool)> parseElement, 348 OpAsmParser::Delimiter delimiter); 349 350 private: 351 /// The Parser is subclassed and reinstantiated. Do not add additional 352 /// non-trivial state here, add it to the ParserState class. 353 ParserState &state; 354 }; 355 } // end anonymous namespace 356 357 //===----------------------------------------------------------------------===// 358 // Helper methods. 359 //===----------------------------------------------------------------------===// 360 361 /// Parse a comma separated list of elements that must have at least one entry 362 /// in it. 363 ParseResult Parser::parseCommaSeparatedList( 364 const std::function<ParseResult()> &parseElement) { 365 // Non-empty case starts with an element. 366 if (parseElement()) 367 return failure(); 368 369 // Otherwise we have a list of comma separated elements. 370 while (consumeIf(Token::comma)) { 371 if (parseElement()) 372 return failure(); 373 } 374 return success(); 375 } 376 377 /// Parse a comma-separated list of elements, terminated with an arbitrary 378 /// token. This allows empty lists if allowEmptyList is true. 379 /// 380 /// abstract-list ::= rightToken // if allowEmptyList == true 381 /// abstract-list ::= element (',' element)* rightToken 382 /// 383 ParseResult Parser::parseCommaSeparatedListUntil( 384 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 385 bool allowEmptyList) { 386 // Handle the empty case. 387 if (getToken().is(rightToken)) { 388 if (!allowEmptyList) 389 return emitError("expected list element"); 390 consumeToken(rightToken); 391 return success(); 392 } 393 394 if (parseCommaSeparatedList(parseElement) || 395 parseToken(rightToken, "expected ',' or '" + 396 Token::getTokenSpelling(rightToken) + "'")) 397 return failure(); 398 399 return success(); 400 } 401 402 //===----------------------------------------------------------------------===// 403 // DialectAsmParser 404 //===----------------------------------------------------------------------===// 405 406 namespace { 407 /// This class provides the main implementation of the DialectAsmParser that 408 /// allows for dialects to parse attributes and types. This allows for dialect 409 /// hooking into the main MLIR parsing logic. 410 class CustomDialectAsmParser : public DialectAsmParser { 411 public: 412 CustomDialectAsmParser(StringRef fullSpec, Parser &parser) 413 : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), 414 parser(parser) {} 415 ~CustomDialectAsmParser() override {} 416 417 /// Emit a diagnostic at the specified location and return failure. 418 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 419 return parser.emitError(loc, message); 420 } 421 422 /// Return a builder which provides useful access to MLIRContext, global 423 /// objects like types and attributes. 424 Builder &getBuilder() const override { return parser.builder; } 425 426 /// Get the location of the next token and store it into the argument. This 427 /// always succeeds. 428 llvm::SMLoc getCurrentLocation() override { 429 return parser.getToken().getLoc(); 430 } 431 432 /// Return the location of the original name token. 433 llvm::SMLoc getNameLoc() const override { return nameLoc; } 434 435 /// Re-encode the given source location as an MLIR location and return it. 436 Location getEncodedSourceLoc(llvm::SMLoc loc) override { 437 return parser.getEncodedSourceLocation(loc); 438 } 439 440 /// Returns the full specification of the symbol being parsed. This allows 441 /// for using a separate parser if necessary. 442 StringRef getFullSymbolSpec() const override { return fullSpec; } 443 444 /// Parse a floating point value from the stream. 445 ParseResult parseFloat(double &result) override { 446 bool negative = parser.consumeIf(Token::minus); 447 Token curTok = parser.getToken(); 448 449 // Check for a floating point value. 450 if (curTok.is(Token::floatliteral)) { 451 auto val = curTok.getFloatingPointValue(); 452 if (!val.hasValue()) 453 return emitError(curTok.getLoc(), "floating point value too large"); 454 parser.consumeToken(Token::floatliteral); 455 result = negative ? -*val : *val; 456 return success(); 457 } 458 459 // TODO(riverriddle) support hex floating point values. 460 return emitError(getCurrentLocation(), "expected floating point literal"); 461 } 462 463 /// Parse an optional integer value from the stream. 464 OptionalParseResult parseOptionalInteger(uint64_t &result) override { 465 Token curToken = parser.getToken(); 466 if (curToken.isNot(Token::integer, Token::minus)) 467 return llvm::None; 468 469 bool negative = parser.consumeIf(Token::minus); 470 Token curTok = parser.getToken(); 471 if (parser.parseToken(Token::integer, "expected integer value")) 472 return failure(); 473 474 auto val = curTok.getUInt64IntegerValue(); 475 if (!val) 476 return emitError(curTok.getLoc(), "integer value too large"); 477 result = negative ? -*val : *val; 478 return success(); 479 } 480 481 //===--------------------------------------------------------------------===// 482 // Token Parsing 483 //===--------------------------------------------------------------------===// 484 485 /// Parse a `->` token. 486 ParseResult parseArrow() override { 487 return parser.parseToken(Token::arrow, "expected '->'"); 488 } 489 490 /// Parses a `->` if present. 491 ParseResult parseOptionalArrow() override { 492 return success(parser.consumeIf(Token::arrow)); 493 } 494 495 /// Parse a '{' token. 496 ParseResult parseLBrace() override { 497 return parser.parseToken(Token::l_brace, "expected '{'"); 498 } 499 500 /// Parse a '{' token if present 501 ParseResult parseOptionalLBrace() override { 502 return success(parser.consumeIf(Token::l_brace)); 503 } 504 505 /// Parse a `}` token. 506 ParseResult parseRBrace() override { 507 return parser.parseToken(Token::r_brace, "expected '}'"); 508 } 509 510 /// Parse a `}` token if present 511 ParseResult parseOptionalRBrace() override { 512 return success(parser.consumeIf(Token::r_brace)); 513 } 514 515 /// Parse a `:` token. 516 ParseResult parseColon() override { 517 return parser.parseToken(Token::colon, "expected ':'"); 518 } 519 520 /// Parse a `:` token if present. 521 ParseResult parseOptionalColon() override { 522 return success(parser.consumeIf(Token::colon)); 523 } 524 525 /// Parse a `,` token. 526 ParseResult parseComma() override { 527 return parser.parseToken(Token::comma, "expected ','"); 528 } 529 530 /// Parse a `,` token if present. 531 ParseResult parseOptionalComma() override { 532 return success(parser.consumeIf(Token::comma)); 533 } 534 535 /// Parses a `...` if present. 536 ParseResult parseOptionalEllipsis() override { 537 return success(parser.consumeIf(Token::ellipsis)); 538 } 539 540 /// Parse a `=` token. 541 ParseResult parseEqual() override { 542 return parser.parseToken(Token::equal, "expected '='"); 543 } 544 545 /// Parse a '<' token. 546 ParseResult parseLess() override { 547 return parser.parseToken(Token::less, "expected '<'"); 548 } 549 550 /// Parse a `<` token if present. 551 ParseResult parseOptionalLess() override { 552 return success(parser.consumeIf(Token::less)); 553 } 554 555 /// Parse a '>' token. 556 ParseResult parseGreater() override { 557 return parser.parseToken(Token::greater, "expected '>'"); 558 } 559 560 /// Parse a `>` token if present. 561 ParseResult parseOptionalGreater() override { 562 return success(parser.consumeIf(Token::greater)); 563 } 564 565 /// Parse a `(` token. 566 ParseResult parseLParen() override { 567 return parser.parseToken(Token::l_paren, "expected '('"); 568 } 569 570 /// Parses a '(' if present. 571 ParseResult parseOptionalLParen() override { 572 return success(parser.consumeIf(Token::l_paren)); 573 } 574 575 /// Parse a `)` token. 576 ParseResult parseRParen() override { 577 return parser.parseToken(Token::r_paren, "expected ')'"); 578 } 579 580 /// Parses a ')' if present. 581 ParseResult parseOptionalRParen() override { 582 return success(parser.consumeIf(Token::r_paren)); 583 } 584 585 /// Parse a `[` token. 586 ParseResult parseLSquare() override { 587 return parser.parseToken(Token::l_square, "expected '['"); 588 } 589 590 /// Parses a '[' if present. 591 ParseResult parseOptionalLSquare() override { 592 return success(parser.consumeIf(Token::l_square)); 593 } 594 595 /// Parse a `]` token. 596 ParseResult parseRSquare() override { 597 return parser.parseToken(Token::r_square, "expected ']'"); 598 } 599 600 /// Parses a ']' if present. 601 ParseResult parseOptionalRSquare() override { 602 return success(parser.consumeIf(Token::r_square)); 603 } 604 605 /// Parses a '?' if present. 606 ParseResult parseOptionalQuestion() override { 607 return success(parser.consumeIf(Token::question)); 608 } 609 610 /// Parses a '*' if present. 611 ParseResult parseOptionalStar() override { 612 return success(parser.consumeIf(Token::star)); 613 } 614 615 /// Returns if the current token corresponds to a keyword. 616 bool isCurrentTokenAKeyword() const { 617 return parser.getToken().is(Token::bare_identifier) || 618 parser.getToken().isKeyword(); 619 } 620 621 /// Parse the given keyword if present. 622 ParseResult parseOptionalKeyword(StringRef keyword) override { 623 // Check that the current token has the same spelling. 624 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 625 return failure(); 626 parser.consumeToken(); 627 return success(); 628 } 629 630 /// Parse a keyword, if present, into 'keyword'. 631 ParseResult parseOptionalKeyword(StringRef *keyword) override { 632 // Check that the current token is a keyword. 633 if (!isCurrentTokenAKeyword()) 634 return failure(); 635 636 *keyword = parser.getTokenSpelling(); 637 parser.consumeToken(); 638 return success(); 639 } 640 641 //===--------------------------------------------------------------------===// 642 // Attribute Parsing 643 //===--------------------------------------------------------------------===// 644 645 /// Parse an arbitrary attribute and return it in result. 646 ParseResult parseAttribute(Attribute &result, Type type) override { 647 result = parser.parseAttribute(type); 648 return success(static_cast<bool>(result)); 649 } 650 651 /// Parse an affine map instance into 'map'. 652 ParseResult parseAffineMap(AffineMap &map) override { 653 return parser.parseAffineMapReference(map); 654 } 655 656 /// Parse an integer set instance into 'set'. 657 ParseResult printIntegerSet(IntegerSet &set) override { 658 return parser.parseIntegerSetReference(set); 659 } 660 661 //===--------------------------------------------------------------------===// 662 // Type Parsing 663 //===--------------------------------------------------------------------===// 664 665 ParseResult parseType(Type &result) override { 666 result = parser.parseType(); 667 return success(static_cast<bool>(result)); 668 } 669 670 ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, 671 bool allowDynamic) override { 672 return parser.parseDimensionListRanked(dimensions, allowDynamic); 673 } 674 675 private: 676 /// The full symbol specification. 677 StringRef fullSpec; 678 679 /// The source location of the dialect symbol. 680 SMLoc nameLoc; 681 682 /// The main parser. 683 Parser &parser; 684 }; 685 } // namespace 686 687 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 688 /// and may be recursive. Return with the 'prettyName' StringRef encompassing 689 /// the entire pretty name. 690 /// 691 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 692 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 693 /// | '(' pretty-dialect-sym-contents+ ')' 694 /// | '[' pretty-dialect-sym-contents+ ']' 695 /// | '{' pretty-dialect-sym-contents+ '}' 696 /// | '[^[<({>\])}\0]+' 697 /// 698 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 699 // Pretty symbol names are a relatively unstructured format that contains a 700 // series of properly nested punctuation, with anything else in the middle. 701 // Scan ahead to find it and consume it if successful, otherwise emit an 702 // error. 703 auto *curPtr = getTokenSpelling().data(); 704 705 SmallVector<char, 8> nestedPunctuation; 706 707 // Scan over the nested punctuation, bailing out on error and consuming until 708 // we find the end. We know that we're currently looking at the '<', so we 709 // can go until we find the matching '>' character. 710 assert(*curPtr == '<'); 711 do { 712 char c = *curPtr++; 713 switch (c) { 714 case '\0': 715 // This also handles the EOF case. 716 return emitError("unexpected nul or EOF in pretty dialect name"); 717 case '<': 718 case '[': 719 case '(': 720 case '{': 721 nestedPunctuation.push_back(c); 722 continue; 723 724 case '-': 725 // The sequence `->` is treated as special token. 726 if (*curPtr == '>') 727 ++curPtr; 728 continue; 729 730 case '>': 731 if (nestedPunctuation.pop_back_val() != '<') 732 return emitError("unbalanced '>' character in pretty dialect name"); 733 break; 734 case ']': 735 if (nestedPunctuation.pop_back_val() != '[') 736 return emitError("unbalanced ']' character in pretty dialect name"); 737 break; 738 case ')': 739 if (nestedPunctuation.pop_back_val() != '(') 740 return emitError("unbalanced ')' character in pretty dialect name"); 741 break; 742 case '}': 743 if (nestedPunctuation.pop_back_val() != '{') 744 return emitError("unbalanced '}' character in pretty dialect name"); 745 break; 746 747 default: 748 continue; 749 } 750 } while (!nestedPunctuation.empty()); 751 752 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 753 // consuming all this stuff, and return. 754 state.lex.resetPointer(curPtr); 755 756 unsigned length = curPtr - prettyName.begin(); 757 prettyName = StringRef(prettyName.begin(), length); 758 consumeToken(); 759 return success(); 760 } 761 762 /// Parse an extended dialect symbol. 763 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 764 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 765 SymbolAliasMap &aliases, 766 CreateFn &&createSymbol) { 767 // Parse the dialect namespace. 768 StringRef identifier = p.getTokenSpelling().drop_front(); 769 auto loc = p.getToken().getLoc(); 770 p.consumeToken(identifierTok); 771 772 // If there is no '<' token following this, and if the typename contains no 773 // dot, then we are parsing a symbol alias. 774 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 775 // Check for an alias for this type. 776 auto aliasIt = aliases.find(identifier); 777 if (aliasIt == aliases.end()) 778 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 779 nullptr); 780 return aliasIt->second; 781 } 782 783 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 784 // a dot, then this is the "pretty" form. If not, it is the verbose form that 785 // looks like <"...">. 786 std::string symbolData; 787 auto dialectName = identifier; 788 789 // Handle the verbose form, where "identifier" is a simple dialect name. 790 if (!identifier.contains('.')) { 791 // Consume the '<'. 792 if (p.parseToken(Token::less, "expected '<' in dialect type")) 793 return nullptr; 794 795 // Parse the symbol specific data. 796 if (p.getToken().isNot(Token::string)) 797 return (p.emitError("expected string literal data in dialect symbol"), 798 nullptr); 799 symbolData = p.getToken().getStringValue(); 800 loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); 801 p.consumeToken(Token::string); 802 803 // Consume the '>'. 804 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 805 return nullptr; 806 } else { 807 // Ok, the dialect name is the part of the identifier before the dot, the 808 // part after the dot is the dialect's symbol, or the start thereof. 809 auto dotHalves = identifier.split('.'); 810 dialectName = dotHalves.first; 811 auto prettyName = dotHalves.second; 812 loc = llvm::SMLoc::getFromPointer(prettyName.data()); 813 814 // If the dialect's symbol is followed immediately by a <, then lex the body 815 // of it into prettyName. 816 if (p.getToken().is(Token::less) && 817 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 818 if (p.parsePrettyDialectSymbolName(prettyName)) 819 return nullptr; 820 } 821 822 symbolData = prettyName.str(); 823 } 824 825 // Record the name location of the type remapped to the top level buffer. 826 llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); 827 p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); 828 829 // Call into the provided symbol construction function. 830 Symbol sym = createSymbol(dialectName, symbolData, loc); 831 832 // Pop the last parser location. 833 p.getState().symbols.nestedParserLocs.pop_back(); 834 return sym; 835 } 836 837 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 838 /// parsing failed, nullptr is returned. The number of bytes read from the input 839 /// string is returned in 'numRead'. 840 template <typename T, typename ParserFn> 841 static T parseSymbol(StringRef inputStr, MLIRContext *context, 842 SymbolState &symbolState, ParserFn &&parserFn, 843 size_t *numRead = nullptr) { 844 SourceMgr sourceMgr; 845 auto memBuffer = MemoryBuffer::getMemBuffer( 846 inputStr, /*BufferName=*/"<mlir_parser_buffer>", 847 /*RequiresNullTerminator=*/false); 848 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 849 ParserState state(sourceMgr, context, symbolState); 850 Parser parser(state); 851 852 Token startTok = parser.getToken(); 853 T symbol = parserFn(parser); 854 if (!symbol) 855 return T(); 856 857 // If 'numRead' is valid, then provide the number of bytes that were read. 858 Token endTok = parser.getToken(); 859 if (numRead) { 860 *numRead = static_cast<size_t>(endTok.getLoc().getPointer() - 861 startTok.getLoc().getPointer()); 862 863 // Otherwise, ensure that all of the tokens were parsed. 864 } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { 865 parser.emitError(endTok.getLoc(), "encountered unexpected token"); 866 return T(); 867 } 868 return symbol; 869 } 870 871 //===----------------------------------------------------------------------===// 872 // Error Handling 873 //===----------------------------------------------------------------------===// 874 875 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 876 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 877 878 // If we hit a parse error in response to a lexer error, then the lexer 879 // already reported the error. 880 if (getToken().is(Token::error)) 881 diag.abandon(); 882 return diag; 883 } 884 885 //===----------------------------------------------------------------------===// 886 // Token Parsing 887 //===----------------------------------------------------------------------===// 888 889 /// Consume the specified token if present and return success. On failure, 890 /// output a diagnostic and return failure. 891 ParseResult Parser::parseToken(Token::Kind expectedToken, 892 const Twine &message) { 893 if (consumeIf(expectedToken)) 894 return success(); 895 return emitError(message); 896 } 897 898 //===----------------------------------------------------------------------===// 899 // Type Parsing 900 //===----------------------------------------------------------------------===// 901 902 /// Parse an arbitrary type. 903 /// 904 /// type ::= function-type 905 /// | non-function-type 906 /// 907 Type Parser::parseType() { 908 if (getToken().is(Token::l_paren)) 909 return parseFunctionType(); 910 return parseNonFunctionType(); 911 } 912 913 /// Parse a function result type. 914 /// 915 /// function-result-type ::= type-list-parens 916 /// | non-function-type 917 /// 918 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 919 if (getToken().is(Token::l_paren)) 920 return parseTypeListParens(elements); 921 922 Type t = parseNonFunctionType(); 923 if (!t) 924 return failure(); 925 elements.push_back(t); 926 return success(); 927 } 928 929 /// Parse a list of types without an enclosing parenthesis. The list must have 930 /// at least one member. 931 /// 932 /// type-list-no-parens ::= type (`,` type)* 933 /// 934 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 935 auto parseElt = [&]() -> ParseResult { 936 auto elt = parseType(); 937 elements.push_back(elt); 938 return elt ? success() : failure(); 939 }; 940 941 return parseCommaSeparatedList(parseElt); 942 } 943 944 /// Parse a parenthesized list of types. 945 /// 946 /// type-list-parens ::= `(` `)` 947 /// | `(` type-list-no-parens `)` 948 /// 949 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 950 if (parseToken(Token::l_paren, "expected '('")) 951 return failure(); 952 953 // Handle empty lists. 954 if (getToken().is(Token::r_paren)) 955 return consumeToken(), success(); 956 957 if (parseTypeListNoParens(elements) || 958 parseToken(Token::r_paren, "expected ')'")) 959 return failure(); 960 return success(); 961 } 962 963 /// Parse a complex type. 964 /// 965 /// complex-type ::= `complex` `<` type `>` 966 /// 967 Type Parser::parseComplexType() { 968 consumeToken(Token::kw_complex); 969 970 // Parse the '<'. 971 if (parseToken(Token::less, "expected '<' in complex type")) 972 return nullptr; 973 974 llvm::SMLoc elementTypeLoc = getToken().getLoc(); 975 auto elementType = parseType(); 976 if (!elementType || 977 parseToken(Token::greater, "expected '>' in complex type")) 978 return nullptr; 979 if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) 980 return emitError(elementTypeLoc, "invalid element type for complex"), 981 nullptr; 982 983 return ComplexType::get(elementType); 984 } 985 986 /// Parse an extended type. 987 /// 988 /// extended-type ::= (dialect-type | type-alias) 989 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 990 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 991 /// type-alias ::= `!` alias-name 992 /// 993 Type Parser::parseExtendedType() { 994 return parseExtendedSymbol<Type>( 995 *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, 996 [&](StringRef dialectName, StringRef symbolData, 997 llvm::SMLoc loc) -> Type { 998 // If we found a registered dialect, then ask it to parse the type. 999 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1000 return parseSymbol<Type>( 1001 symbolData, state.context, state.symbols, [&](Parser &parser) { 1002 CustomDialectAsmParser customParser(symbolData, parser); 1003 return dialect->parseType(customParser); 1004 }); 1005 } 1006 1007 // Otherwise, form a new opaque type. 1008 return OpaqueType::getChecked( 1009 Identifier::get(dialectName, state.context), symbolData, 1010 state.context, getEncodedSourceLocation(loc)); 1011 }); 1012 } 1013 1014 /// Parse a function type. 1015 /// 1016 /// function-type ::= type-list-parens `->` function-result-type 1017 /// 1018 Type Parser::parseFunctionType() { 1019 assert(getToken().is(Token::l_paren)); 1020 1021 SmallVector<Type, 4> arguments, results; 1022 if (parseTypeListParens(arguments) || 1023 parseToken(Token::arrow, "expected '->' in function type") || 1024 parseFunctionResultTypes(results)) 1025 return nullptr; 1026 1027 return builder.getFunctionType(arguments, results); 1028 } 1029 1030 /// Parse the offset and strides from a strided layout specification. 1031 /// 1032 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 1033 /// 1034 ParseResult Parser::parseStridedLayout(int64_t &offset, 1035 SmallVectorImpl<int64_t> &strides) { 1036 // Parse offset. 1037 consumeToken(Token::kw_offset); 1038 if (!consumeIf(Token::colon)) 1039 return emitError("expected colon after `offset` keyword"); 1040 auto maybeOffset = getToken().getUnsignedIntegerValue(); 1041 bool question = getToken().is(Token::question); 1042 if (!maybeOffset && !question) 1043 return emitError("invalid offset"); 1044 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue()) 1045 : MemRefType::getDynamicStrideOrOffset(); 1046 consumeToken(); 1047 1048 if (!consumeIf(Token::comma)) 1049 return emitError("expected comma after offset value"); 1050 1051 // Parse stride list. 1052 if (!consumeIf(Token::kw_strides)) 1053 return emitError("expected `strides` keyword after offset specification"); 1054 if (!consumeIf(Token::colon)) 1055 return emitError("expected colon after `strides` keyword"); 1056 if (failed(parseStrideList(strides))) 1057 return emitError("invalid braces-enclosed stride list"); 1058 if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) 1059 return emitError("invalid memref stride"); 1060 1061 return success(); 1062 } 1063 1064 /// Parse a memref type. 1065 /// 1066 /// memref-type ::= ranked-memref-type | unranked-memref-type 1067 /// 1068 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type 1069 /// (`,` semi-affine-map-composition)? (`,` 1070 /// memory-space)? `>` 1071 /// 1072 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` 1073 /// 1074 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 1075 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 1076 /// 1077 Type Parser::parseMemRefType() { 1078 consumeToken(Token::kw_memref); 1079 1080 if (parseToken(Token::less, "expected '<' in memref type")) 1081 return nullptr; 1082 1083 bool isUnranked; 1084 SmallVector<int64_t, 4> dimensions; 1085 1086 if (consumeIf(Token::star)) { 1087 // This is an unranked memref type. 1088 isUnranked = true; 1089 if (parseXInDimensionList()) 1090 return nullptr; 1091 1092 } else { 1093 isUnranked = false; 1094 if (parseDimensionListRanked(dimensions)) 1095 return nullptr; 1096 } 1097 1098 // Parse the element type. 1099 auto typeLoc = getToken().getLoc(); 1100 auto elementType = parseType(); 1101 if (!elementType) 1102 return nullptr; 1103 1104 // Check that memref is formed from allowed types. 1105 if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() && 1106 !elementType.isa<ComplexType>()) 1107 return emitError(typeLoc, "invalid memref element type"), nullptr; 1108 1109 // Parse semi-affine-map-composition. 1110 SmallVector<AffineMap, 2> affineMapComposition; 1111 Optional<unsigned> memorySpace; 1112 unsigned numDims = dimensions.size(); 1113 1114 auto parseElt = [&]() -> ParseResult { 1115 // Check for the memory space. 1116 if (getToken().is(Token::integer)) { 1117 if (memorySpace) 1118 return emitError("multiple memory spaces specified in memref type"); 1119 memorySpace = getToken().getUnsignedIntegerValue(); 1120 if (!memorySpace.hasValue()) 1121 return emitError("invalid memory space in memref type"); 1122 consumeToken(Token::integer); 1123 return success(); 1124 } 1125 if (isUnranked) 1126 return emitError("cannot have affine map for unranked memref type"); 1127 if (memorySpace) 1128 return emitError("expected memory space to be last in memref type"); 1129 1130 AffineMap map; 1131 llvm::SMLoc mapLoc = getToken().getLoc(); 1132 if (getToken().is(Token::kw_offset)) { 1133 int64_t offset; 1134 SmallVector<int64_t, 4> strides; 1135 if (failed(parseStridedLayout(offset, strides))) 1136 return failure(); 1137 // Construct strided affine map. 1138 map = makeStridedLinearLayoutMap(strides, offset, state.context); 1139 } else { 1140 // Parse an affine map attribute. 1141 auto affineMap = parseAttribute(); 1142 if (!affineMap) 1143 return failure(); 1144 auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>(); 1145 if (!affineMapAttr) 1146 return emitError("expected affine map in memref type"); 1147 map = affineMapAttr.getValue(); 1148 } 1149 1150 if (map.getNumDims() != numDims) { 1151 size_t i = affineMapComposition.size(); 1152 return emitError(mapLoc, "memref affine map dimension mismatch between ") 1153 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 1154 << " and affine map" << i + 1 << ": " << numDims 1155 << " != " << map.getNumDims(); 1156 } 1157 numDims = map.getNumResults(); 1158 affineMapComposition.push_back(map); 1159 return success(); 1160 }; 1161 1162 // Parse a list of mappings and address space if present. 1163 if (!consumeIf(Token::greater)) { 1164 // Parse comma separated list of affine maps, followed by memory space. 1165 if (parseToken(Token::comma, "expected ',' or '>' in memref type") || 1166 parseCommaSeparatedListUntil(Token::greater, parseElt, 1167 /*allowEmptyList=*/false)) { 1168 return nullptr; 1169 } 1170 } 1171 1172 if (isUnranked) 1173 return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); 1174 1175 return MemRefType::get(dimensions, elementType, affineMapComposition, 1176 memorySpace.getValueOr(0)); 1177 } 1178 1179 /// Parse any type except the function type. 1180 /// 1181 /// non-function-type ::= integer-type 1182 /// | index-type 1183 /// | float-type 1184 /// | extended-type 1185 /// | vector-type 1186 /// | tensor-type 1187 /// | memref-type 1188 /// | complex-type 1189 /// | tuple-type 1190 /// | none-type 1191 /// 1192 /// index-type ::= `index` 1193 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 1194 /// none-type ::= `none` 1195 /// 1196 Type Parser::parseNonFunctionType() { 1197 switch (getToken().getKind()) { 1198 default: 1199 return (emitError("expected non-function type"), nullptr); 1200 case Token::kw_memref: 1201 return parseMemRefType(); 1202 case Token::kw_tensor: 1203 return parseTensorType(); 1204 case Token::kw_complex: 1205 return parseComplexType(); 1206 case Token::kw_tuple: 1207 return parseTupleType(); 1208 case Token::kw_vector: 1209 return parseVectorType(); 1210 // integer-type 1211 case Token::inttype: { 1212 auto width = getToken().getIntTypeBitwidth(); 1213 if (!width.hasValue()) 1214 return (emitError("invalid integer width"), nullptr); 1215 if (width.getValue() > IntegerType::kMaxWidth) { 1216 emitError(getToken().getLoc(), "integer bitwidth is limited to ") 1217 << IntegerType::kMaxWidth << " bits"; 1218 return nullptr; 1219 } 1220 1221 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; 1222 if (Optional<bool> signedness = getToken().getIntTypeSignedness()) 1223 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; 1224 1225 auto loc = getEncodedSourceLocation(getToken().getLoc()); 1226 consumeToken(Token::inttype); 1227 return IntegerType::getChecked(width.getValue(), signSemantics, loc); 1228 } 1229 1230 // float-type 1231 case Token::kw_bf16: 1232 consumeToken(Token::kw_bf16); 1233 return builder.getBF16Type(); 1234 case Token::kw_f16: 1235 consumeToken(Token::kw_f16); 1236 return builder.getF16Type(); 1237 case Token::kw_f32: 1238 consumeToken(Token::kw_f32); 1239 return builder.getF32Type(); 1240 case Token::kw_f64: 1241 consumeToken(Token::kw_f64); 1242 return builder.getF64Type(); 1243 1244 // index-type 1245 case Token::kw_index: 1246 consumeToken(Token::kw_index); 1247 return builder.getIndexType(); 1248 1249 // none-type 1250 case Token::kw_none: 1251 consumeToken(Token::kw_none); 1252 return builder.getNoneType(); 1253 1254 // extended type 1255 case Token::exclamation_identifier: 1256 return parseExtendedType(); 1257 } 1258 } 1259 1260 /// Parse a tensor type. 1261 /// 1262 /// tensor-type ::= `tensor` `<` dimension-list type `>` 1263 /// dimension-list ::= dimension-list-ranked | `*x` 1264 /// 1265 Type Parser::parseTensorType() { 1266 consumeToken(Token::kw_tensor); 1267 1268 if (parseToken(Token::less, "expected '<' in tensor type")) 1269 return nullptr; 1270 1271 bool isUnranked; 1272 SmallVector<int64_t, 4> dimensions; 1273 1274 if (consumeIf(Token::star)) { 1275 // This is an unranked tensor type. 1276 isUnranked = true; 1277 1278 if (parseXInDimensionList()) 1279 return nullptr; 1280 1281 } else { 1282 isUnranked = false; 1283 if (parseDimensionListRanked(dimensions)) 1284 return nullptr; 1285 } 1286 1287 // Parse the element type. 1288 auto elementTypeLoc = getToken().getLoc(); 1289 auto elementType = parseType(); 1290 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 1291 return nullptr; 1292 if (!TensorType::isValidElementType(elementType)) 1293 return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; 1294 1295 if (isUnranked) 1296 return UnrankedTensorType::get(elementType); 1297 return RankedTensorType::get(dimensions, elementType); 1298 } 1299 1300 /// Parse a tuple type. 1301 /// 1302 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 1303 /// 1304 Type Parser::parseTupleType() { 1305 consumeToken(Token::kw_tuple); 1306 1307 // Parse the '<'. 1308 if (parseToken(Token::less, "expected '<' in tuple type")) 1309 return nullptr; 1310 1311 // Check for an empty tuple by directly parsing '>'. 1312 if (consumeIf(Token::greater)) 1313 return TupleType::get(getContext()); 1314 1315 // Parse the element types and the '>'. 1316 SmallVector<Type, 4> types; 1317 if (parseTypeListNoParens(types) || 1318 parseToken(Token::greater, "expected '>' in tuple type")) 1319 return nullptr; 1320 1321 return TupleType::get(types, getContext()); 1322 } 1323 1324 /// Parse a vector type. 1325 /// 1326 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` 1327 /// non-empty-static-dimension-list ::= decimal-literal `x` 1328 /// static-dimension-list 1329 /// static-dimension-list ::= (decimal-literal `x`)* 1330 /// 1331 VectorType Parser::parseVectorType() { 1332 consumeToken(Token::kw_vector); 1333 1334 if (parseToken(Token::less, "expected '<' in vector type")) 1335 return nullptr; 1336 1337 SmallVector<int64_t, 4> dimensions; 1338 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 1339 return nullptr; 1340 if (dimensions.empty()) 1341 return (emitError("expected dimension size in vector type"), nullptr); 1342 if (any_of(dimensions, [](int64_t i) { return i <= 0; })) 1343 return emitError(getToken().getLoc(), 1344 "vector types must have positive constant sizes"), 1345 nullptr; 1346 1347 // Parse the element type. 1348 auto typeLoc = getToken().getLoc(); 1349 auto elementType = parseType(); 1350 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 1351 return nullptr; 1352 if (!VectorType::isValidElementType(elementType)) 1353 return emitError(typeLoc, "vector elements must be int or float type"), 1354 nullptr; 1355 1356 return VectorType::get(dimensions, elementType); 1357 } 1358 1359 /// Parse a dimension list of a tensor or memref type. This populates the 1360 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 1361 /// errors out on `?` otherwise. 1362 /// 1363 /// dimension-list-ranked ::= (dimension `x`)* 1364 /// dimension ::= `?` | decimal-literal 1365 /// 1366 /// When `allowDynamic` is not set, this is used to parse: 1367 /// 1368 /// static-dimension-list ::= (decimal-literal `x`)* 1369 ParseResult 1370 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 1371 bool allowDynamic) { 1372 while (getToken().isAny(Token::integer, Token::question)) { 1373 if (consumeIf(Token::question)) { 1374 if (!allowDynamic) 1375 return emitError("expected static shape"); 1376 dimensions.push_back(-1); 1377 } else { 1378 // Hexadecimal integer literals (starting with `0x`) are not allowed in 1379 // aggregate type declarations. Therefore, `0xf32` should be processed as 1380 // a sequence of separate elements `0`, `x`, `f32`. 1381 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 1382 // We can get here only if the token is an integer literal. Hexadecimal 1383 // integer literals can only start with `0x` (`1x` wouldn't lex as a 1384 // literal, just `1` would, at which point we don't get into this 1385 // branch). 1386 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 1387 dimensions.push_back(0); 1388 state.lex.resetPointer(getTokenSpelling().data() + 1); 1389 consumeToken(); 1390 } else { 1391 // Make sure this integer value is in bound and valid. 1392 auto dimension = getToken().getUnsignedIntegerValue(); 1393 if (!dimension.hasValue()) 1394 return emitError("invalid dimension"); 1395 dimensions.push_back((int64_t)dimension.getValue()); 1396 consumeToken(Token::integer); 1397 } 1398 } 1399 1400 // Make sure we have an 'x' or something like 'xbf32'. 1401 if (parseXInDimensionList()) 1402 return failure(); 1403 } 1404 1405 return success(); 1406 } 1407 1408 /// Parse an 'x' token in a dimension list, handling the case where the x is 1409 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 1410 /// token. 1411 ParseResult Parser::parseXInDimensionList() { 1412 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 1413 return emitError("expected 'x' in dimension list"); 1414 1415 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 1416 if (getTokenSpelling().size() != 1) 1417 state.lex.resetPointer(getTokenSpelling().data() + 1); 1418 1419 // Consume the 'x'. 1420 consumeToken(Token::bare_identifier); 1421 1422 return success(); 1423 } 1424 1425 // Parse a comma-separated list of dimensions, possibly empty: 1426 // stride-list ::= `[` (dimension (`,` dimension)*)? `]` 1427 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { 1428 if (!consumeIf(Token::l_square)) 1429 return failure(); 1430 // Empty list early exit. 1431 if (consumeIf(Token::r_square)) 1432 return success(); 1433 while (true) { 1434 if (consumeIf(Token::question)) { 1435 dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); 1436 } else { 1437 // This must be an integer value. 1438 int64_t val; 1439 if (getToken().getSpelling().getAsInteger(10, val)) 1440 return emitError("invalid integer value: ") << getToken().getSpelling(); 1441 // Make sure it is not the one value for `?`. 1442 if (ShapedType::isDynamic(val)) 1443 return emitError("invalid integer value: ") 1444 << getToken().getSpelling() 1445 << ", use `?` to specify a dynamic dimension"; 1446 dimensions.push_back(val); 1447 consumeToken(Token::integer); 1448 } 1449 if (!consumeIf(Token::comma)) 1450 break; 1451 } 1452 if (!consumeIf(Token::r_square)) 1453 return failure(); 1454 return success(); 1455 } 1456 1457 //===----------------------------------------------------------------------===// 1458 // Attribute parsing. 1459 //===----------------------------------------------------------------------===// 1460 1461 /// Return the symbol reference referred to by the given token, that is known to 1462 /// be an @-identifier. 1463 static std::string extractSymbolReference(Token tok) { 1464 assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); 1465 StringRef nameStr = tok.getSpelling().drop_front(); 1466 1467 // Check to see if the reference is a string literal, or a bare identifier. 1468 if (nameStr.front() == '"') 1469 return tok.getStringValue(); 1470 return std::string(nameStr); 1471 } 1472 1473 /// Parse an arbitrary attribute. 1474 /// 1475 /// attribute-value ::= `unit` 1476 /// | bool-literal 1477 /// | integer-literal (`:` (index-type | integer-type))? 1478 /// | float-literal (`:` float-type)? 1479 /// | string-literal (`:` type)? 1480 /// | type 1481 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 1482 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 1483 /// | symbol-ref-id (`::` symbol-ref-id)* 1484 /// | `dense` `<` attribute-value `>` `:` 1485 /// (tensor-type | vector-type) 1486 /// | `sparse` `<` attribute-value `,` attribute-value `>` 1487 /// `:` (tensor-type | vector-type) 1488 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 1489 /// `>` `:` (tensor-type | vector-type) 1490 /// | extended-attribute 1491 /// 1492 Attribute Parser::parseAttribute(Type type) { 1493 switch (getToken().getKind()) { 1494 // Parse an AffineMap or IntegerSet attribute. 1495 case Token::kw_affine_map: { 1496 consumeToken(Token::kw_affine_map); 1497 1498 AffineMap map; 1499 if (parseToken(Token::less, "expected '<' in affine map") || 1500 parseAffineMapReference(map) || 1501 parseToken(Token::greater, "expected '>' in affine map")) 1502 return Attribute(); 1503 return AffineMapAttr::get(map); 1504 } 1505 case Token::kw_affine_set: { 1506 consumeToken(Token::kw_affine_set); 1507 1508 IntegerSet set; 1509 if (parseToken(Token::less, "expected '<' in integer set") || 1510 parseIntegerSetReference(set) || 1511 parseToken(Token::greater, "expected '>' in integer set")) 1512 return Attribute(); 1513 return IntegerSetAttr::get(set); 1514 } 1515 1516 // Parse an array attribute. 1517 case Token::l_square: { 1518 consumeToken(Token::l_square); 1519 1520 SmallVector<Attribute, 4> elements; 1521 auto parseElt = [&]() -> ParseResult { 1522 elements.push_back(parseAttribute()); 1523 return elements.back() ? success() : failure(); 1524 }; 1525 1526 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 1527 return nullptr; 1528 return builder.getArrayAttr(elements); 1529 } 1530 1531 // Parse a boolean attribute. 1532 case Token::kw_false: 1533 consumeToken(Token::kw_false); 1534 return builder.getBoolAttr(false); 1535 case Token::kw_true: 1536 consumeToken(Token::kw_true); 1537 return builder.getBoolAttr(true); 1538 1539 // Parse a dense elements attribute. 1540 case Token::kw_dense: 1541 return parseDenseElementsAttr(type); 1542 1543 // Parse a dictionary attribute. 1544 case Token::l_brace: { 1545 SmallVector<NamedAttribute, 4> elements; 1546 if (parseAttributeDict(elements)) 1547 return nullptr; 1548 return builder.getDictionaryAttr(elements); 1549 } 1550 1551 // Parse an extended attribute, i.e. alias or dialect attribute. 1552 case Token::hash_identifier: 1553 return parseExtendedAttr(type); 1554 1555 // Parse floating point and integer attributes. 1556 case Token::floatliteral: 1557 return parseFloatAttr(type, /*isNegative=*/false); 1558 case Token::integer: 1559 return parseDecOrHexAttr(type, /*isNegative=*/false); 1560 case Token::minus: { 1561 consumeToken(Token::minus); 1562 if (getToken().is(Token::integer)) 1563 return parseDecOrHexAttr(type, /*isNegative=*/true); 1564 if (getToken().is(Token::floatliteral)) 1565 return parseFloatAttr(type, /*isNegative=*/true); 1566 1567 return (emitError("expected constant integer or floating point value"), 1568 nullptr); 1569 } 1570 1571 // Parse a location attribute. 1572 case Token::kw_loc: { 1573 LocationAttr attr; 1574 return failed(parseLocation(attr)) ? Attribute() : attr; 1575 } 1576 1577 // Parse an opaque elements attribute. 1578 case Token::kw_opaque: 1579 return parseOpaqueElementsAttr(type); 1580 1581 // Parse a sparse elements attribute. 1582 case Token::kw_sparse: 1583 return parseSparseElementsAttr(type); 1584 1585 // Parse a string attribute. 1586 case Token::string: { 1587 auto val = getToken().getStringValue(); 1588 consumeToken(Token::string); 1589 // Parse the optional trailing colon type if one wasn't explicitly provided. 1590 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1591 return Attribute(); 1592 1593 return type ? StringAttr::get(val, type) 1594 : StringAttr::get(val, getContext()); 1595 } 1596 1597 // Parse a symbol reference attribute. 1598 case Token::at_identifier: { 1599 std::string nameStr = extractSymbolReference(getToken()); 1600 consumeToken(Token::at_identifier); 1601 1602 // Parse any nested references. 1603 std::vector<FlatSymbolRefAttr> nestedRefs; 1604 while (getToken().is(Token::colon)) { 1605 // Check for the '::' prefix. 1606 const char *curPointer = getToken().getLoc().getPointer(); 1607 consumeToken(Token::colon); 1608 if (!consumeIf(Token::colon)) { 1609 state.lex.resetPointer(curPointer); 1610 consumeToken(); 1611 break; 1612 } 1613 // Parse the reference itself. 1614 auto curLoc = getToken().getLoc(); 1615 if (getToken().isNot(Token::at_identifier)) { 1616 emitError(curLoc, "expected nested symbol reference identifier"); 1617 return Attribute(); 1618 } 1619 1620 std::string nameStr = extractSymbolReference(getToken()); 1621 consumeToken(Token::at_identifier); 1622 nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); 1623 } 1624 1625 return builder.getSymbolRefAttr(nameStr, nestedRefs); 1626 } 1627 1628 // Parse a 'unit' attribute. 1629 case Token::kw_unit: 1630 consumeToken(Token::kw_unit); 1631 return builder.getUnitAttr(); 1632 1633 default: 1634 // Parse a type attribute. 1635 if (Type type = parseType()) 1636 return TypeAttr::get(type); 1637 return nullptr; 1638 } 1639 } 1640 1641 /// Attribute dictionary. 1642 /// 1643 /// attribute-dict ::= `{` `}` 1644 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1645 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value 1646 /// 1647 ParseResult 1648 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1649 if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) 1650 return failure(); 1651 1652 auto parseElt = [&]() -> ParseResult { 1653 // The name of an attribute can either be a bare identifier, or a string. 1654 Optional<Identifier> nameId; 1655 if (getToken().is(Token::string)) 1656 nameId = builder.getIdentifier(getToken().getStringValue()); 1657 else if (getToken().isAny(Token::bare_identifier, Token::inttype) || 1658 getToken().isKeyword()) 1659 nameId = builder.getIdentifier(getTokenSpelling()); 1660 else 1661 return emitError("expected attribute name"); 1662 consumeToken(); 1663 1664 // Try to parse the '=' for the attribute value. 1665 if (!consumeIf(Token::equal)) { 1666 // If there is no '=', we treat this as a unit attribute. 1667 attributes.push_back({*nameId, builder.getUnitAttr()}); 1668 return success(); 1669 } 1670 1671 auto attr = parseAttribute(); 1672 if (!attr) 1673 return failure(); 1674 1675 attributes.push_back({*nameId, attr}); 1676 return success(); 1677 }; 1678 1679 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1680 return failure(); 1681 1682 return success(); 1683 } 1684 1685 /// Parse an extended attribute. 1686 /// 1687 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1688 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1689 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1690 /// attribute-alias ::= `#` alias-name 1691 /// 1692 Attribute Parser::parseExtendedAttr(Type type) { 1693 Attribute attr = parseExtendedSymbol<Attribute>( 1694 *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, 1695 [&](StringRef dialectName, StringRef symbolData, 1696 llvm::SMLoc loc) -> Attribute { 1697 // Parse an optional trailing colon type. 1698 Type attrType = type; 1699 if (consumeIf(Token::colon) && !(attrType = parseType())) 1700 return Attribute(); 1701 1702 // If we found a registered dialect, then ask it to parse the attribute. 1703 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1704 return parseSymbol<Attribute>( 1705 symbolData, state.context, state.symbols, [&](Parser &parser) { 1706 CustomDialectAsmParser customParser(symbolData, parser); 1707 return dialect->parseAttribute(customParser, attrType); 1708 }); 1709 } 1710 1711 // Otherwise, form a new opaque attribute. 1712 return OpaqueAttr::getChecked( 1713 Identifier::get(dialectName, state.context), symbolData, 1714 attrType ? attrType : NoneType::get(state.context), 1715 getEncodedSourceLocation(loc)); 1716 }); 1717 1718 // Ensure that the attribute has the same type as requested. 1719 if (attr && type && attr.getType() != type) { 1720 emitError("attribute type different than expected: expected ") 1721 << type << ", but got " << attr.getType(); 1722 return nullptr; 1723 } 1724 return attr; 1725 } 1726 1727 /// Parse a float attribute. 1728 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1729 auto val = getToken().getFloatingPointValue(); 1730 if (!val.hasValue()) 1731 return (emitError("floating point value too large for attribute"), nullptr); 1732 consumeToken(Token::floatliteral); 1733 if (!type) { 1734 // Default to F64 when no type is specified. 1735 if (!consumeIf(Token::colon)) 1736 type = builder.getF64Type(); 1737 else if (!(type = parseType())) 1738 return nullptr; 1739 } 1740 if (!type.isa<FloatType>()) 1741 return (emitError("floating point value not valid for specified type"), 1742 nullptr); 1743 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1744 } 1745 1746 /// Construct a float attribute bitwise equivalent to the integer literal. 1747 static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1748 uint64_t value) { 1749 // FIXME: bfloat is currently stored as a double internally because it doesn't 1750 // have valid APFloat semantics. 1751 if (type.isF64() || type.isBF16()) 1752 return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); 1753 1754 APInt apInt(type.getWidth(), value); 1755 if (apInt != value) { 1756 p->emitError("hexadecimal float constant out of range for type"); 1757 return llvm::None; 1758 } 1759 return APFloat(type.getFloatSemantics(), apInt); 1760 } 1761 1762 /// Construct an APint from a parsed value, a known attribute type and 1763 /// sign. 1764 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative, 1765 uint64_t value) { 1766 // We have the integer literal as an uint64_t in val, now convert it into an 1767 // APInt and check that we don't overflow. 1768 int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); 1769 APInt apInt(width, value, isNegative); 1770 if (apInt != value) 1771 return llvm::None; 1772 1773 if (isNegative) { 1774 // The value is negative, we have an overflow if the sign bit is not set 1775 // in the negated apInt. 1776 apInt.negate(); 1777 if (!apInt.isSignBitSet()) 1778 return llvm::None; 1779 } else if ((type.isSignedInteger() || type.isIndex()) && 1780 apInt.isSignBitSet()) { 1781 // The value is a positive signed integer or index, 1782 // we have an overflow if the sign bit is set. 1783 return llvm::None; 1784 } 1785 1786 return apInt; 1787 } 1788 1789 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1790 /// or a float attribute. 1791 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1792 auto val = getToken().getUInt64IntegerValue(); 1793 if (!val.hasValue()) 1794 return (emitError("integer constant out of range for attribute"), nullptr); 1795 1796 // Remember if the literal is hexadecimal. 1797 StringRef spelling = getToken().getSpelling(); 1798 auto loc = state.curToken.getLoc(); 1799 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1800 1801 consumeToken(Token::integer); 1802 if (!type) { 1803 // Default to i64 if not type is specified. 1804 if (!consumeIf(Token::colon)) 1805 type = builder.getIntegerType(64); 1806 else if (!(type = parseType())) 1807 return nullptr; 1808 } 1809 1810 if (auto floatType = type.dyn_cast<FloatType>()) { 1811 if (isNegative) 1812 return emitError( 1813 loc, 1814 "hexadecimal float literal should not have a leading minus"), 1815 nullptr; 1816 if (!isHex) { 1817 emitError(loc, "unexpected decimal integer literal for a float attribute") 1818 .attachNote() 1819 << "add a trailing dot to make the literal a float"; 1820 return nullptr; 1821 } 1822 1823 // Construct a float attribute bitwise equivalent to the integer literal. 1824 Optional<APFloat> apVal = 1825 buildHexadecimalFloatLiteral(this, floatType, *val); 1826 return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); 1827 } 1828 1829 if (!type.isa<IntegerType>() && !type.isa<IndexType>()) 1830 return emitError(loc, "integer literal not valid for specified type"), 1831 nullptr; 1832 1833 if (isNegative && type.isUnsignedInteger()) { 1834 emitError(loc, 1835 "negative integer literal not valid for unsigned integer type"); 1836 return nullptr; 1837 } 1838 1839 Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, *val); 1840 1841 if (!apInt) 1842 return emitError(loc, "integer constant out of range for attribute"), 1843 nullptr; 1844 return builder.getIntegerAttr(type, *apInt); 1845 } 1846 1847 /// Parse elements values stored within a hex etring. On success, the values are 1848 /// stored into 'result'. 1849 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, 1850 std::string &result) { 1851 std::string val = tok.getStringValue(); 1852 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1853 return parser.emitError(tok.getLoc(), 1854 "elements hex string should start with '0x'"); 1855 1856 StringRef hexValues = StringRef(val).drop_front(2); 1857 if (!llvm::all_of(hexValues, llvm::isHexDigit)) 1858 return parser.emitError(tok.getLoc(), 1859 "elements hex string only contains hex digits"); 1860 1861 result = llvm::fromHex(hexValues); 1862 return success(); 1863 } 1864 1865 /// Parse an opaque elements attribute. 1866 Attribute Parser::parseOpaqueElementsAttr(Type attrType) { 1867 consumeToken(Token::kw_opaque); 1868 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1869 return nullptr; 1870 1871 if (getToken().isNot(Token::string)) 1872 return (emitError("expected dialect namespace"), nullptr); 1873 1874 auto name = getToken().getStringValue(); 1875 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1876 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1877 // attribute. Otherwise, it can't be roundtripped without having the dialect 1878 // registered. 1879 if (!dialect) 1880 return (emitError("no registered dialect with namespace '" + name + "'"), 1881 nullptr); 1882 consumeToken(Token::string); 1883 1884 if (parseToken(Token::comma, "expected ','")) 1885 return nullptr; 1886 1887 Token hexTok = getToken(); 1888 if (parseToken(Token::string, "elements hex string should start with '0x'") || 1889 parseToken(Token::greater, "expected '>'")) 1890 return nullptr; 1891 auto type = parseElementsLiteralType(attrType); 1892 if (!type) 1893 return nullptr; 1894 1895 std::string data; 1896 if (parseElementAttrHexValues(*this, hexTok, data)) 1897 return nullptr; 1898 return OpaqueElementsAttr::get(dialect, type, data); 1899 } 1900 1901 namespace { 1902 class TensorLiteralParser { 1903 public: 1904 TensorLiteralParser(Parser &p) : p(p) {} 1905 1906 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 1907 /// may also parse a tensor literal that is store as a hex string. 1908 ParseResult parse(bool allowHex); 1909 1910 /// Build a dense attribute instance with the parsed elements and the given 1911 /// shaped type. 1912 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1913 1914 ArrayRef<int64_t> getShape() const { return shape; } 1915 1916 private: 1917 enum class ElementKind { Boolean, Integer, Float }; 1918 1919 /// Return a string to represent the given element kind. 1920 const char *getElementKindStr(ElementKind kind) { 1921 switch (kind) { 1922 case ElementKind::Boolean: 1923 return "'boolean'"; 1924 case ElementKind::Integer: 1925 return "'integer'"; 1926 case ElementKind::Float: 1927 return "'float'"; 1928 } 1929 llvm_unreachable("unknown element kind"); 1930 } 1931 1932 /// Build a Dense Integer attribute for the given type. 1933 DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, 1934 IntegerType eltTy); 1935 1936 /// Build a Dense Float attribute for the given type. 1937 DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, 1938 FloatType eltTy); 1939 1940 /// Build a Dense attribute with hex data for the given type. 1941 DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); 1942 1943 /// Parse a single element, returning failure if it isn't a valid element 1944 /// literal. For example: 1945 /// parseElement(1) -> Success, 1 1946 /// parseElement([1]) -> Failure 1947 ParseResult parseElement(); 1948 1949 /// Parse a list of either lists or elements, returning the dimensions of the 1950 /// parsed sub-tensors in dims. For example: 1951 /// parseList([1, 2, 3]) -> Success, [3] 1952 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1953 /// parseList([[1, 2], 3]) -> Failure 1954 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1955 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 1956 1957 /// Parse a literal that was printed as a hex string. 1958 ParseResult parseHexElements(); 1959 1960 Parser &p; 1961 1962 /// The shape inferred from the parsed elements. 1963 SmallVector<int64_t, 4> shape; 1964 1965 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 1966 std::vector<std::pair<bool, Token>> storage; 1967 1968 /// A flag that indicates the type of elements that have been parsed. 1969 Optional<ElementKind> knownEltKind; 1970 1971 /// Storage used when parsing elements that were stored as hex values. 1972 Optional<Token> hexStorage; 1973 }; 1974 } // namespace 1975 1976 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 1977 /// may also parse a tensor literal that is store as a hex string. 1978 ParseResult TensorLiteralParser::parse(bool allowHex) { 1979 // If hex is allowed, check for a string literal. 1980 if (allowHex && p.getToken().is(Token::string)) { 1981 hexStorage = p.getToken(); 1982 p.consumeToken(Token::string); 1983 return success(); 1984 } 1985 // Otherwise, parse a list or an individual element. 1986 if (p.getToken().is(Token::l_square)) 1987 return parseList(shape); 1988 return parseElement(); 1989 } 1990 1991 /// Build a dense attribute instance with the parsed elements and the given 1992 /// shaped type. 1993 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 1994 ShapedType type) { 1995 // Check to see if we parsed the literal from a hex string. 1996 if (hexStorage.hasValue()) 1997 return getHexAttr(loc, type); 1998 1999 // Check that the parsed storage size has the same number of elements to the 2000 // type, or is a known splat. 2001 if (!shape.empty() && getShape() != type.getShape()) { 2002 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 2003 << "]) does not match type ([" << type.getShape() << "])"; 2004 return nullptr; 2005 } 2006 2007 // If the type is an integer, build a set of APInt values from the storage 2008 // with the correct bitwidth. 2009 if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) 2010 return getIntAttr(loc, type, intTy); 2011 2012 // Otherwise, this must be a floating point type. 2013 auto floatTy = type.getElementType().dyn_cast<FloatType>(); 2014 if (!floatTy) { 2015 p.emitError(loc) << "expected floating-point or integer element type, got " 2016 << type.getElementType(); 2017 return nullptr; 2018 } 2019 return getFloatAttr(loc, type, floatTy); 2020 } 2021 2022 /// Build a Dense Integer attribute for the given type. 2023 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, 2024 ShapedType type, 2025 IntegerType eltTy) { 2026 std::vector<APInt> intElements; 2027 intElements.reserve(storage.size()); 2028 auto isUintType = type.getElementType().isUnsignedInteger(); 2029 for (const auto &signAndToken : storage) { 2030 bool isNegative = signAndToken.first; 2031 const Token &token = signAndToken.second; 2032 auto tokenLoc = token.getLoc(); 2033 2034 if (isNegative && isUintType) { 2035 p.emitError(tokenLoc) 2036 << "expected unsigned integer elements, but parsed negative value"; 2037 return nullptr; 2038 } 2039 2040 // Check to see if floating point values were parsed. 2041 if (token.is(Token::floatliteral)) { 2042 p.emitError(tokenLoc) 2043 << "expected integer elements, but parsed floating-point"; 2044 return nullptr; 2045 } 2046 2047 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 2048 "unexpected token type"); 2049 if (token.isAny(Token::kw_true, Token::kw_false)) { 2050 if (!eltTy.isInteger(1)) 2051 p.emitError(tokenLoc) 2052 << "expected i1 type for 'true' or 'false' values"; 2053 APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), 2054 /*isSigned=*/false); 2055 intElements.push_back(apInt); 2056 continue; 2057 } 2058 2059 // Create APInt values for each element with the correct bitwidth. 2060 auto val = token.getUInt64IntegerValue(); 2061 if (!val.hasValue()) { 2062 p.emitError(tokenLoc, "integer constant out of range for attribute"); 2063 return nullptr; 2064 } 2065 Optional<APInt> apInt = buildAttributeAPInt(eltTy, isNegative, *val); 2066 if (!apInt) 2067 return (p.emitError(tokenLoc, "integer constant out of range for type"), 2068 nullptr); 2069 intElements.push_back(*apInt); 2070 } 2071 2072 return DenseElementsAttr::get(type, intElements); 2073 } 2074 2075 /// Build a Dense Float attribute for the given type. 2076 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, 2077 ShapedType type, 2078 FloatType eltTy) { 2079 std::vector<APFloat> floatValues; 2080 floatValues.reserve(storage.size()); 2081 for (const auto &signAndToken : storage) { 2082 bool isNegative = signAndToken.first; 2083 const Token &token = signAndToken.second; 2084 2085 // Handle hexadecimal float literals. 2086 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 2087 if (isNegative) { 2088 p.emitError(token.getLoc()) 2089 << "hexadecimal float literal should not have a leading minus"; 2090 return nullptr; 2091 } 2092 auto val = token.getUInt64IntegerValue(); 2093 if (!val.hasValue()) { 2094 p.emitError("hexadecimal float constant out of range for attribute"); 2095 return nullptr; 2096 } 2097 Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); 2098 if (!apVal) 2099 return nullptr; 2100 floatValues.push_back(*apVal); 2101 continue; 2102 } 2103 2104 // Check to see if any decimal integers or booleans were parsed. 2105 if (!token.is(Token::floatliteral)) { 2106 p.emitError() << "expected floating-point elements, but parsed integer"; 2107 return nullptr; 2108 } 2109 2110 // Build the float values from tokens. 2111 auto val = token.getFloatingPointValue(); 2112 if (!val.hasValue()) { 2113 p.emitError("floating point value too large for attribute"); 2114 return nullptr; 2115 } 2116 // Treat BF16 as double because it is not supported in LLVM's APFloat. 2117 APFloat apVal(isNegative ? -*val : *val); 2118 if (!eltTy.isBF16() && !eltTy.isF64()) { 2119 bool unused; 2120 apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, 2121 &unused); 2122 } 2123 floatValues.push_back(apVal); 2124 } 2125 2126 return DenseElementsAttr::get(type, floatValues); 2127 } 2128 2129 /// Build a Dense attribute with hex data for the given type. 2130 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, 2131 ShapedType type) { 2132 Type elementType = type.getElementType(); 2133 if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) { 2134 p.emitError(loc) << "expected floating-point or integer element type, got " 2135 << elementType; 2136 return nullptr; 2137 } 2138 2139 std::string data; 2140 if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) 2141 return nullptr; 2142 2143 // Check that the size of the hex data corresponds to the size of the type, or 2144 // a splat of the type. 2145 // TODO: bf16 is currently stored as a double, this should be removed when 2146 // APFloat properly supports it. 2147 int64_t elementWidth = 2148 elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth(); 2149 if (static_cast<int64_t>(data.size() * CHAR_BIT) != 2150 (type.getNumElements() * elementWidth)) { 2151 p.emitError(loc) << "elements hex data size is invalid for provided type: " 2152 << type; 2153 return nullptr; 2154 } 2155 2156 return DenseElementsAttr::getFromRawBuffer( 2157 type, ArrayRef<char>(data.data(), data.size()), /*isSplatBuffer=*/false); 2158 } 2159 2160 ParseResult TensorLiteralParser::parseElement() { 2161 switch (p.getToken().getKind()) { 2162 // Parse a boolean element. 2163 case Token::kw_true: 2164 case Token::kw_false: 2165 case Token::floatliteral: 2166 case Token::integer: 2167 storage.emplace_back(/*isNegative=*/false, p.getToken()); 2168 p.consumeToken(); 2169 break; 2170 2171 // Parse a signed integer or a negative floating-point element. 2172 case Token::minus: 2173 p.consumeToken(Token::minus); 2174 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 2175 return p.emitError("expected integer or floating point literal"); 2176 storage.emplace_back(/*isNegative=*/true, p.getToken()); 2177 p.consumeToken(); 2178 break; 2179 2180 default: 2181 return p.emitError("expected element literal of primitive type"); 2182 } 2183 2184 return success(); 2185 } 2186 2187 /// Parse a list of either lists or elements, returning the dimensions of the 2188 /// parsed sub-tensors in dims. For example: 2189 /// parseList([1, 2, 3]) -> Success, [3] 2190 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 2191 /// parseList([[1, 2], 3]) -> Failure 2192 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 2193 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 2194 p.consumeToken(Token::l_square); 2195 2196 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 2197 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 2198 if (prevDims == newDims) 2199 return success(); 2200 return p.emitError("tensor literal is invalid; ranks are not consistent " 2201 "between elements"); 2202 }; 2203 2204 bool first = true; 2205 SmallVector<int64_t, 4> newDims; 2206 unsigned size = 0; 2207 auto parseCommaSeparatedList = [&]() -> ParseResult { 2208 SmallVector<int64_t, 4> thisDims; 2209 if (p.getToken().getKind() == Token::l_square) { 2210 if (parseList(thisDims)) 2211 return failure(); 2212 } else if (parseElement()) { 2213 return failure(); 2214 } 2215 ++size; 2216 if (!first) 2217 return checkDims(newDims, thisDims); 2218 newDims = thisDims; 2219 first = false; 2220 return success(); 2221 }; 2222 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 2223 return failure(); 2224 2225 // Return the sublists' dimensions with 'size' prepended. 2226 dims.clear(); 2227 dims.push_back(size); 2228 dims.append(newDims.begin(), newDims.end()); 2229 return success(); 2230 } 2231 2232 /// Parse a dense elements attribute. 2233 Attribute Parser::parseDenseElementsAttr(Type attrType) { 2234 consumeToken(Token::kw_dense); 2235 if (parseToken(Token::less, "expected '<' after 'dense'")) 2236 return nullptr; 2237 2238 // Parse the literal data. 2239 TensorLiteralParser literalParser(*this); 2240 if (literalParser.parse(/*allowHex=*/true)) 2241 return nullptr; 2242 2243 if (parseToken(Token::greater, "expected '>'")) 2244 return nullptr; 2245 2246 auto typeLoc = getToken().getLoc(); 2247 auto type = parseElementsLiteralType(attrType); 2248 if (!type) 2249 return nullptr; 2250 return literalParser.getAttr(typeLoc, type); 2251 } 2252 2253 /// Shaped type for elements attribute. 2254 /// 2255 /// elements-literal-type ::= vector-type | ranked-tensor-type 2256 /// 2257 /// This method also checks the type has static shape. 2258 ShapedType Parser::parseElementsLiteralType(Type type) { 2259 // If the user didn't provide a type, parse the colon type for the literal. 2260 if (!type) { 2261 if (parseToken(Token::colon, "expected ':'")) 2262 return nullptr; 2263 if (!(type = parseType())) 2264 return nullptr; 2265 } 2266 2267 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 2268 emitError("elements literal must be a ranked tensor or vector type"); 2269 return nullptr; 2270 } 2271 2272 auto sType = type.cast<ShapedType>(); 2273 if (!sType.hasStaticShape()) 2274 return (emitError("elements literal type must have static shape"), nullptr); 2275 2276 return sType; 2277 } 2278 2279 /// Parse a sparse elements attribute. 2280 Attribute Parser::parseSparseElementsAttr(Type attrType) { 2281 consumeToken(Token::kw_sparse); 2282 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 2283 return nullptr; 2284 2285 /// Parse the indices. We don't allow hex values here as we may need to use 2286 /// the inferred shape. 2287 auto indicesLoc = getToken().getLoc(); 2288 TensorLiteralParser indiceParser(*this); 2289 if (indiceParser.parse(/*allowHex=*/false)) 2290 return nullptr; 2291 2292 if (parseToken(Token::comma, "expected ','")) 2293 return nullptr; 2294 2295 /// Parse the values. 2296 auto valuesLoc = getToken().getLoc(); 2297 TensorLiteralParser valuesParser(*this); 2298 if (valuesParser.parse(/*allowHex=*/true)) 2299 return nullptr; 2300 2301 if (parseToken(Token::greater, "expected '>'")) 2302 return nullptr; 2303 2304 auto type = parseElementsLiteralType(attrType); 2305 if (!type) 2306 return nullptr; 2307 2308 // If the indices are a splat, i.e. the literal parser parsed an element and 2309 // not a list, we set the shape explicitly. The indices are represented by a 2310 // 2-dimensional shape where the second dimension is the rank of the type. 2311 // Given that the parsed indices is a splat, we know that we only have one 2312 // indice and thus one for the first dimension. 2313 auto indiceEltType = builder.getIntegerType(64); 2314 ShapedType indicesType; 2315 if (indiceParser.getShape().empty()) { 2316 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 2317 } else { 2318 // Otherwise, set the shape to the one parsed by the literal parser. 2319 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 2320 } 2321 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 2322 2323 // If the values are a splat, set the shape explicitly based on the number of 2324 // indices. The number of indices is encoded in the first dimension of the 2325 // indice shape type. 2326 auto valuesEltType = type.getElementType(); 2327 ShapedType valuesType = 2328 valuesParser.getShape().empty() 2329 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 2330 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 2331 auto values = valuesParser.getAttr(valuesLoc, valuesType); 2332 2333 /// Sanity check. 2334 if (valuesType.getRank() != 1) 2335 return (emitError("expected 1-d tensor for values"), nullptr); 2336 2337 auto sameShape = (indicesType.getRank() == 1) || 2338 (type.getRank() == indicesType.getDimSize(1)); 2339 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 2340 if (!sameShape || !sameElementNum) { 2341 emitError() << "expected shape ([" << type.getShape() 2342 << "]); inferred shape of indices literal ([" 2343 << indicesType.getShape() 2344 << "]); inferred shape of values literal ([" 2345 << valuesType.getShape() << "])"; 2346 return nullptr; 2347 } 2348 2349 // Build the sparse elements attribute by the indices and values. 2350 return SparseElementsAttr::get(type, indices, values); 2351 } 2352 2353 //===----------------------------------------------------------------------===// 2354 // Location parsing. 2355 //===----------------------------------------------------------------------===// 2356 2357 /// Parse a location. 2358 /// 2359 /// location ::= `loc` inline-location 2360 /// inline-location ::= '(' location-inst ')' 2361 /// 2362 ParseResult Parser::parseLocation(LocationAttr &loc) { 2363 // Check for 'loc' identifier. 2364 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 2365 return emitError(); 2366 2367 // Parse the inline-location. 2368 if (parseToken(Token::l_paren, "expected '(' in inline location") || 2369 parseLocationInstance(loc) || 2370 parseToken(Token::r_paren, "expected ')' in inline location")) 2371 return failure(); 2372 return success(); 2373 } 2374 2375 /// Specific location instances. 2376 /// 2377 /// location-inst ::= filelinecol-location | 2378 /// name-location | 2379 /// callsite-location | 2380 /// fused-location | 2381 /// unknown-location 2382 /// filelinecol-location ::= string-literal ':' integer-literal 2383 /// ':' integer-literal 2384 /// name-location ::= string-literal 2385 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 2386 /// fused-location ::= fused ('<' attribute-value '>')? 2387 /// '[' location-inst (location-inst ',')* ']' 2388 /// unknown-location ::= 'unknown' 2389 /// 2390 ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { 2391 consumeToken(Token::bare_identifier); 2392 2393 // Parse the '('. 2394 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 2395 return failure(); 2396 2397 // Parse the callee location. 2398 LocationAttr calleeLoc; 2399 if (parseLocationInstance(calleeLoc)) 2400 return failure(); 2401 2402 // Parse the 'at'. 2403 if (getToken().isNot(Token::bare_identifier) || 2404 getToken().getSpelling() != "at") 2405 return emitError("expected 'at' in callsite location"); 2406 consumeToken(Token::bare_identifier); 2407 2408 // Parse the caller location. 2409 LocationAttr callerLoc; 2410 if (parseLocationInstance(callerLoc)) 2411 return failure(); 2412 2413 // Parse the ')'. 2414 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 2415 return failure(); 2416 2417 // Return the callsite location. 2418 loc = CallSiteLoc::get(calleeLoc, callerLoc); 2419 return success(); 2420 } 2421 2422 ParseResult Parser::parseFusedLocation(LocationAttr &loc) { 2423 consumeToken(Token::bare_identifier); 2424 2425 // Try to parse the optional metadata. 2426 Attribute metadata; 2427 if (consumeIf(Token::less)) { 2428 metadata = parseAttribute(); 2429 if (!metadata) 2430 return emitError("expected valid attribute metadata"); 2431 // Parse the '>' token. 2432 if (parseToken(Token::greater, 2433 "expected '>' after fused location metadata")) 2434 return failure(); 2435 } 2436 2437 SmallVector<Location, 4> locations; 2438 auto parseElt = [&] { 2439 LocationAttr newLoc; 2440 if (parseLocationInstance(newLoc)) 2441 return failure(); 2442 locations.push_back(newLoc); 2443 return success(); 2444 }; 2445 2446 if (parseToken(Token::l_square, "expected '[' in fused location") || 2447 parseCommaSeparatedList(parseElt) || 2448 parseToken(Token::r_square, "expected ']' in fused location")) 2449 return failure(); 2450 2451 // Return the fused location. 2452 loc = FusedLoc::get(locations, metadata, getContext()); 2453 return success(); 2454 } 2455 2456 ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { 2457 auto *ctx = getContext(); 2458 auto str = getToken().getStringValue(); 2459 consumeToken(Token::string); 2460 2461 // If the next token is ':' this is a filelinecol location. 2462 if (consumeIf(Token::colon)) { 2463 // Parse the line number. 2464 if (getToken().isNot(Token::integer)) 2465 return emitError("expected integer line number in FileLineColLoc"); 2466 auto line = getToken().getUnsignedIntegerValue(); 2467 if (!line.hasValue()) 2468 return emitError("expected integer line number in FileLineColLoc"); 2469 consumeToken(Token::integer); 2470 2471 // Parse the ':'. 2472 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 2473 return failure(); 2474 2475 // Parse the column number. 2476 if (getToken().isNot(Token::integer)) 2477 return emitError("expected integer column number in FileLineColLoc"); 2478 auto column = getToken().getUnsignedIntegerValue(); 2479 if (!column.hasValue()) 2480 return emitError("expected integer column number in FileLineColLoc"); 2481 consumeToken(Token::integer); 2482 2483 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 2484 return success(); 2485 } 2486 2487 // Otherwise, this is a NameLoc. 2488 2489 // Check for a child location. 2490 if (consumeIf(Token::l_paren)) { 2491 auto childSourceLoc = getToken().getLoc(); 2492 2493 // Parse the child location. 2494 LocationAttr childLoc; 2495 if (parseLocationInstance(childLoc)) 2496 return failure(); 2497 2498 // The child must not be another NameLoc. 2499 if (childLoc.isa<NameLoc>()) 2500 return emitError(childSourceLoc, 2501 "child of NameLoc cannot be another NameLoc"); 2502 loc = NameLoc::get(Identifier::get(str, ctx), childLoc); 2503 2504 // Parse the closing ')'. 2505 if (parseToken(Token::r_paren, 2506 "expected ')' after child location of NameLoc")) 2507 return failure(); 2508 } else { 2509 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 2510 } 2511 2512 return success(); 2513 } 2514 2515 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 2516 // Handle either name or filelinecol locations. 2517 if (getToken().is(Token::string)) 2518 return parseNameOrFileLineColLocation(loc); 2519 2520 // Bare tokens required for other cases. 2521 if (!getToken().is(Token::bare_identifier)) 2522 return emitError("expected location instance"); 2523 2524 // Check for the 'callsite' signifying a callsite location. 2525 if (getToken().getSpelling() == "callsite") 2526 return parseCallSiteLocation(loc); 2527 2528 // If the token is 'fused', then this is a fused location. 2529 if (getToken().getSpelling() == "fused") 2530 return parseFusedLocation(loc); 2531 2532 // Check for a 'unknown' for an unknown location. 2533 if (getToken().getSpelling() == "unknown") { 2534 consumeToken(Token::bare_identifier); 2535 loc = UnknownLoc::get(getContext()); 2536 return success(); 2537 } 2538 2539 return emitError("expected location instance"); 2540 } 2541 2542 //===----------------------------------------------------------------------===// 2543 // Affine parsing. 2544 //===----------------------------------------------------------------------===// 2545 2546 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 2547 /// the boolean sense. 2548 enum AffineLowPrecOp { 2549 /// Null value. 2550 LNoOp, 2551 Add, 2552 Sub 2553 }; 2554 2555 /// Higher precedence ops - all at the same precedence level. HNoOp is false 2556 /// in the boolean sense. 2557 enum AffineHighPrecOp { 2558 /// Null value. 2559 HNoOp, 2560 Mul, 2561 FloorDiv, 2562 CeilDiv, 2563 Mod 2564 }; 2565 2566 namespace { 2567 /// This is a specialized parser for affine structures (affine maps, affine 2568 /// expressions, and integer sets), maintaining the state transient to their 2569 /// bodies. 2570 class AffineParser : public Parser { 2571 public: 2572 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 2573 function_ref<ParseResult(bool)> parseElement = nullptr) 2574 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 2575 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 2576 2577 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 2578 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 2579 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 2580 ParseResult parseAffineMapOfSSAIds(AffineMap &map, 2581 OpAsmParser::Delimiter delimiter); 2582 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 2583 unsigned &numDims); 2584 2585 private: 2586 // Binary affine op parsing. 2587 AffineLowPrecOp consumeIfLowPrecOp(); 2588 AffineHighPrecOp consumeIfHighPrecOp(); 2589 2590 // Identifier lists for polyhedral structures. 2591 ParseResult parseDimIdList(unsigned &numDims); 2592 ParseResult parseSymbolIdList(unsigned &numSymbols); 2593 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 2594 unsigned &numSymbols); 2595 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 2596 2597 AffineExpr parseAffineExpr(); 2598 AffineExpr parseParentheticalExpr(); 2599 AffineExpr parseNegateExpression(AffineExpr lhs); 2600 AffineExpr parseIntegerExpr(); 2601 AffineExpr parseBareIdExpr(); 2602 AffineExpr parseSSAIdExpr(bool isSymbol); 2603 AffineExpr parseSymbolSSAIdExpr(); 2604 2605 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 2606 AffineExpr rhs, SMLoc opLoc); 2607 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 2608 AffineExpr rhs); 2609 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 2610 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 2611 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 2612 SMLoc llhsOpLoc); 2613 AffineExpr parseAffineConstraint(bool *isEq); 2614 2615 private: 2616 bool allowParsingSSAIds; 2617 function_ref<ParseResult(bool)> parseElement; 2618 unsigned numDimOperands; 2619 unsigned numSymbolOperands; 2620 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 2621 }; 2622 } // end anonymous namespace 2623 2624 /// Create an affine binary high precedence op expression (mul's, div's, mod). 2625 /// opLoc is the location of the op token to be used to report errors 2626 /// for non-conforming expressions. 2627 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 2628 AffineExpr lhs, AffineExpr rhs, 2629 SMLoc opLoc) { 2630 // TODO: make the error location info accurate. 2631 switch (op) { 2632 case Mul: 2633 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 2634 emitError(opLoc, "non-affine expression: at least one of the multiply " 2635 "operands has to be either a constant or symbolic"); 2636 return nullptr; 2637 } 2638 return lhs * rhs; 2639 case FloorDiv: 2640 if (!rhs.isSymbolicOrConstant()) { 2641 emitError(opLoc, "non-affine expression: right operand of floordiv " 2642 "has to be either a constant or symbolic"); 2643 return nullptr; 2644 } 2645 return lhs.floorDiv(rhs); 2646 case CeilDiv: 2647 if (!rhs.isSymbolicOrConstant()) { 2648 emitError(opLoc, "non-affine expression: right operand of ceildiv " 2649 "has to be either a constant or symbolic"); 2650 return nullptr; 2651 } 2652 return lhs.ceilDiv(rhs); 2653 case Mod: 2654 if (!rhs.isSymbolicOrConstant()) { 2655 emitError(opLoc, "non-affine expression: right operand of mod " 2656 "has to be either a constant or symbolic"); 2657 return nullptr; 2658 } 2659 return lhs % rhs; 2660 case HNoOp: 2661 llvm_unreachable("can't create affine expression for null high prec op"); 2662 return nullptr; 2663 } 2664 llvm_unreachable("Unknown AffineHighPrecOp"); 2665 } 2666 2667 /// Create an affine binary low precedence op expression (add, sub). 2668 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 2669 AffineExpr lhs, AffineExpr rhs) { 2670 switch (op) { 2671 case AffineLowPrecOp::Add: 2672 return lhs + rhs; 2673 case AffineLowPrecOp::Sub: 2674 return lhs - rhs; 2675 case AffineLowPrecOp::LNoOp: 2676 llvm_unreachable("can't create affine expression for null low prec op"); 2677 return nullptr; 2678 } 2679 llvm_unreachable("Unknown AffineLowPrecOp"); 2680 } 2681 2682 /// Consume this token if it is a lower precedence affine op (there are only 2683 /// two precedence levels). 2684 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 2685 switch (getToken().getKind()) { 2686 case Token::plus: 2687 consumeToken(Token::plus); 2688 return AffineLowPrecOp::Add; 2689 case Token::minus: 2690 consumeToken(Token::minus); 2691 return AffineLowPrecOp::Sub; 2692 default: 2693 return AffineLowPrecOp::LNoOp; 2694 } 2695 } 2696 2697 /// Consume this token if it is a higher precedence affine op (there are only 2698 /// two precedence levels) 2699 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 2700 switch (getToken().getKind()) { 2701 case Token::star: 2702 consumeToken(Token::star); 2703 return Mul; 2704 case Token::kw_floordiv: 2705 consumeToken(Token::kw_floordiv); 2706 return FloorDiv; 2707 case Token::kw_ceildiv: 2708 consumeToken(Token::kw_ceildiv); 2709 return CeilDiv; 2710 case Token::kw_mod: 2711 consumeToken(Token::kw_mod); 2712 return Mod; 2713 default: 2714 return HNoOp; 2715 } 2716 } 2717 2718 /// Parse a high precedence op expression list: mul, div, and mod are high 2719 /// precedence binary ops, i.e., parse a 2720 /// expr_1 op_1 expr_2 op_2 ... expr_n 2721 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2722 /// All affine binary ops are left associative. 2723 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2724 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2725 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2726 /// report an error for non-conforming expressions. 2727 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2728 AffineHighPrecOp llhsOp, 2729 SMLoc llhsOpLoc) { 2730 AffineExpr lhs = parseAffineOperandExpr(llhs); 2731 if (!lhs) 2732 return nullptr; 2733 2734 // Found an LHS. Parse the remaining expression. 2735 auto opLoc = getToken().getLoc(); 2736 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2737 if (llhs) { 2738 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2739 if (!expr) 2740 return nullptr; 2741 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2742 } 2743 // No LLHS, get RHS 2744 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2745 } 2746 2747 // This is the last operand in this expression. 2748 if (llhs) 2749 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2750 2751 // No llhs, 'lhs' itself is the expression. 2752 return lhs; 2753 } 2754 2755 /// Parse an affine expression inside parentheses. 2756 /// 2757 /// affine-expr ::= `(` affine-expr `)` 2758 AffineExpr AffineParser::parseParentheticalExpr() { 2759 if (parseToken(Token::l_paren, "expected '('")) 2760 return nullptr; 2761 if (getToken().is(Token::r_paren)) 2762 return (emitError("no expression inside parentheses"), nullptr); 2763 2764 auto expr = parseAffineExpr(); 2765 if (!expr) 2766 return nullptr; 2767 if (parseToken(Token::r_paren, "expected ')'")) 2768 return nullptr; 2769 2770 return expr; 2771 } 2772 2773 /// Parse the negation expression. 2774 /// 2775 /// affine-expr ::= `-` affine-expr 2776 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2777 if (parseToken(Token::minus, "expected '-'")) 2778 return nullptr; 2779 2780 AffineExpr operand = parseAffineOperandExpr(lhs); 2781 // Since negation has the highest precedence of all ops (including high 2782 // precedence ops) but lower than parentheses, we are only going to use 2783 // parseAffineOperandExpr instead of parseAffineExpr here. 2784 if (!operand) 2785 // Extra error message although parseAffineOperandExpr would have 2786 // complained. Leads to a better diagnostic. 2787 return (emitError("missing operand of negation"), nullptr); 2788 return (-1) * operand; 2789 } 2790 2791 /// Parse a bare id that may appear in an affine expression. 2792 /// 2793 /// affine-expr ::= bare-id 2794 AffineExpr AffineParser::parseBareIdExpr() { 2795 if (getToken().isNot(Token::bare_identifier)) 2796 return (emitError("expected bare identifier"), nullptr); 2797 2798 StringRef sRef = getTokenSpelling(); 2799 for (auto entry : dimsAndSymbols) { 2800 if (entry.first == sRef) { 2801 consumeToken(Token::bare_identifier); 2802 return entry.second; 2803 } 2804 } 2805 2806 return (emitError("use of undeclared identifier"), nullptr); 2807 } 2808 2809 /// Parse an SSA id which may appear in an affine expression. 2810 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2811 if (!allowParsingSSAIds) 2812 return (emitError("unexpected ssa identifier"), nullptr); 2813 if (getToken().isNot(Token::percent_identifier)) 2814 return (emitError("expected ssa identifier"), nullptr); 2815 auto name = getTokenSpelling(); 2816 // Check if we already parsed this SSA id. 2817 for (auto entry : dimsAndSymbols) { 2818 if (entry.first == name) { 2819 consumeToken(Token::percent_identifier); 2820 return entry.second; 2821 } 2822 } 2823 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2824 if (parseElement(isSymbol)) 2825 return (emitError("failed to parse ssa identifier"), nullptr); 2826 auto idExpr = isSymbol 2827 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2828 : getAffineDimExpr(numDimOperands++, getContext()); 2829 dimsAndSymbols.push_back({name, idExpr}); 2830 return idExpr; 2831 } 2832 2833 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2834 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2835 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2836 return nullptr; 2837 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2838 if (!symbolExpr) 2839 return nullptr; 2840 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2841 return nullptr; 2842 return symbolExpr; 2843 } 2844 2845 /// Parse a positive integral constant appearing in an affine expression. 2846 /// 2847 /// affine-expr ::= integer-literal 2848 AffineExpr AffineParser::parseIntegerExpr() { 2849 auto val = getToken().getUInt64IntegerValue(); 2850 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2851 return (emitError("constant too large for index"), nullptr); 2852 2853 consumeToken(Token::integer); 2854 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2855 } 2856 2857 /// Parses an expression that can be a valid operand of an affine expression. 2858 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2859 /// operator, the rhs of which is being parsed. This is used to determine 2860 /// whether an error should be emitted for a missing right operand. 2861 // Eg: for an expression without parentheses (like i + j + k + l), each 2862 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2863 // operand expression, it's an op expression and will be parsed via 2864 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2865 // -l are valid operands that will be parsed by this function. 2866 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2867 switch (getToken().getKind()) { 2868 case Token::bare_identifier: 2869 return parseBareIdExpr(); 2870 case Token::kw_symbol: 2871 return parseSymbolSSAIdExpr(); 2872 case Token::percent_identifier: 2873 return parseSSAIdExpr(/*isSymbol=*/false); 2874 case Token::integer: 2875 return parseIntegerExpr(); 2876 case Token::l_paren: 2877 return parseParentheticalExpr(); 2878 case Token::minus: 2879 return parseNegateExpression(lhs); 2880 case Token::kw_ceildiv: 2881 case Token::kw_floordiv: 2882 case Token::kw_mod: 2883 case Token::plus: 2884 case Token::star: 2885 if (lhs) 2886 emitError("missing right operand of binary operator"); 2887 else 2888 emitError("missing left operand of binary operator"); 2889 return nullptr; 2890 default: 2891 if (lhs) 2892 emitError("missing right operand of binary operator"); 2893 else 2894 emitError("expected affine expression"); 2895 return nullptr; 2896 } 2897 } 2898 2899 /// Parse affine expressions that are bare-id's, integer constants, 2900 /// parenthetical affine expressions, and affine op expressions that are a 2901 /// composition of those. 2902 /// 2903 /// All binary op's associate from left to right. 2904 /// 2905 /// {add, sub} have lower precedence than {mul, div, and mod}. 2906 /// 2907 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2908 /// ceildiv, and mod are at the same higher precedence level. Negation has 2909 /// higher precedence than any binary op. 2910 /// 2911 /// llhs: the affine expression appearing on the left of the one being parsed. 2912 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2913 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2914 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2915 /// associativity. 2916 /// 2917 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2918 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2919 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2920 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2921 AffineLowPrecOp llhsOp) { 2922 AffineExpr lhs; 2923 if (!(lhs = parseAffineOperandExpr(llhs))) 2924 return nullptr; 2925 2926 // Found an LHS. Deal with the ops. 2927 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 2928 if (llhs) { 2929 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2930 return parseAffineLowPrecOpExpr(sum, lOp); 2931 } 2932 // No LLHS, get RHS and form the expression. 2933 return parseAffineLowPrecOpExpr(lhs, lOp); 2934 } 2935 auto opLoc = getToken().getLoc(); 2936 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 2937 // We have a higher precedence op here. Get the rhs operand for the llhs 2938 // through parseAffineHighPrecOpExpr. 2939 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 2940 if (!highRes) 2941 return nullptr; 2942 2943 // If llhs is null, the product forms the first operand of the yet to be 2944 // found expression. If non-null, the op to associate with llhs is llhsOp. 2945 AffineExpr expr = 2946 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 2947 2948 // Recurse for subsequent low prec op's after the affine high prec op 2949 // expression. 2950 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 2951 return parseAffineLowPrecOpExpr(expr, nextOp); 2952 return expr; 2953 } 2954 // Last operand in the expression list. 2955 if (llhs) 2956 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2957 // No llhs, 'lhs' itself is the expression. 2958 return lhs; 2959 } 2960 2961 /// Parse an affine expression. 2962 /// affine-expr ::= `(` affine-expr `)` 2963 /// | `-` affine-expr 2964 /// | affine-expr `+` affine-expr 2965 /// | affine-expr `-` affine-expr 2966 /// | affine-expr `*` affine-expr 2967 /// | affine-expr `floordiv` affine-expr 2968 /// | affine-expr `ceildiv` affine-expr 2969 /// | affine-expr `mod` affine-expr 2970 /// | bare-id 2971 /// | integer-literal 2972 /// 2973 /// Additional conditions are checked depending on the production. For eg., 2974 /// one of the operands for `*` has to be either constant/symbolic; the second 2975 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 2976 AffineExpr AffineParser::parseAffineExpr() { 2977 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 2978 } 2979 2980 /// Parse a dim or symbol from the lists appearing before the actual 2981 /// expressions of the affine map. Update our state to store the 2982 /// dimensional/symbolic identifier. 2983 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 2984 if (getToken().isNot(Token::bare_identifier)) 2985 return emitError("expected bare identifier"); 2986 2987 auto name = getTokenSpelling(); 2988 for (auto entry : dimsAndSymbols) { 2989 if (entry.first == name) 2990 return emitError("redefinition of identifier '" + name + "'"); 2991 } 2992 consumeToken(Token::bare_identifier); 2993 2994 dimsAndSymbols.push_back({name, idExpr}); 2995 return success(); 2996 } 2997 2998 /// Parse the list of dimensional identifiers to an affine map. 2999 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 3000 if (parseToken(Token::l_paren, 3001 "expected '(' at start of dimensional identifiers list")) { 3002 return failure(); 3003 } 3004 3005 auto parseElt = [&]() -> ParseResult { 3006 auto dimension = getAffineDimExpr(numDims++, getContext()); 3007 return parseIdentifierDefinition(dimension); 3008 }; 3009 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 3010 } 3011 3012 /// Parse the list of symbolic identifiers to an affine map. 3013 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 3014 consumeToken(Token::l_square); 3015 auto parseElt = [&]() -> ParseResult { 3016 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 3017 return parseIdentifierDefinition(symbol); 3018 }; 3019 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 3020 } 3021 3022 /// Parse the list of symbolic identifiers to an affine map. 3023 ParseResult 3024 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 3025 unsigned &numSymbols) { 3026 if (parseDimIdList(numDims)) { 3027 return failure(); 3028 } 3029 if (!getToken().is(Token::l_square)) { 3030 numSymbols = 0; 3031 return success(); 3032 } 3033 return parseSymbolIdList(numSymbols); 3034 } 3035 3036 /// Parses an ambiguous affine map or integer set definition inline. 3037 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 3038 IntegerSet &set) { 3039 unsigned numDims = 0, numSymbols = 0; 3040 3041 // List of dimensional and optional symbol identifiers. 3042 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 3043 return failure(); 3044 } 3045 3046 // This is needed for parsing attributes as we wouldn't know whether we would 3047 // be parsing an integer set attribute or an affine map attribute. 3048 bool isArrow = getToken().is(Token::arrow); 3049 bool isColon = getToken().is(Token::colon); 3050 if (!isArrow && !isColon) { 3051 return emitError("expected '->' or ':'"); 3052 } else if (isArrow) { 3053 parseToken(Token::arrow, "expected '->' or '['"); 3054 map = parseAffineMapRange(numDims, numSymbols); 3055 return map ? success() : failure(); 3056 } else if (parseToken(Token::colon, "expected ':' or '['")) { 3057 return failure(); 3058 } 3059 3060 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 3061 return success(); 3062 3063 return failure(); 3064 } 3065 3066 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 3067 ParseResult 3068 AffineParser::parseAffineMapOfSSAIds(AffineMap &map, 3069 OpAsmParser::Delimiter delimiter) { 3070 Token::Kind rightToken; 3071 switch (delimiter) { 3072 case OpAsmParser::Delimiter::Square: 3073 if (parseToken(Token::l_square, "expected '['")) 3074 return failure(); 3075 rightToken = Token::r_square; 3076 break; 3077 case OpAsmParser::Delimiter::Paren: 3078 if (parseToken(Token::l_paren, "expected '('")) 3079 return failure(); 3080 rightToken = Token::r_paren; 3081 break; 3082 default: 3083 return emitError("unexpected delimiter"); 3084 } 3085 3086 SmallVector<AffineExpr, 4> exprs; 3087 auto parseElt = [&]() -> ParseResult { 3088 auto elt = parseAffineExpr(); 3089 exprs.push_back(elt); 3090 return elt ? success() : failure(); 3091 }; 3092 3093 // Parse a multi-dimensional affine expression (a comma-separated list of 3094 // 1-d affine expressions); the list can be empty. Grammar: 3095 // multi-dim-affine-expr ::= `(` `)` 3096 // | `(` affine-expr (`,` affine-expr)* `)` 3097 if (parseCommaSeparatedListUntil(rightToken, parseElt, 3098 /*allowEmptyList=*/true)) 3099 return failure(); 3100 // Parsed a valid affine map. 3101 if (exprs.empty()) 3102 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 3103 getContext()); 3104 else 3105 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 3106 exprs); 3107 return success(); 3108 } 3109 3110 /// Parse the range and sizes affine map definition inline. 3111 /// 3112 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 3113 /// 3114 /// multi-dim-affine-expr ::= `(` `)` 3115 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` 3116 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 3117 unsigned numSymbols) { 3118 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 3119 3120 SmallVector<AffineExpr, 4> exprs; 3121 auto parseElt = [&]() -> ParseResult { 3122 auto elt = parseAffineExpr(); 3123 ParseResult res = elt ? success() : failure(); 3124 exprs.push_back(elt); 3125 return res; 3126 }; 3127 3128 // Parse a multi-dimensional affine expression (a comma-separated list of 3129 // 1-d affine expressions). Grammar: 3130 // multi-dim-affine-expr ::= `(` `)` 3131 // | `(` affine-expr (`,` affine-expr)* `)` 3132 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3133 return AffineMap(); 3134 3135 if (exprs.empty()) 3136 return AffineMap::get(numDims, numSymbols, getContext()); 3137 3138 // Parsed a valid affine map. 3139 return AffineMap::get(numDims, numSymbols, exprs); 3140 } 3141 3142 /// Parse an affine constraint. 3143 /// affine-constraint ::= affine-expr `>=` `0` 3144 /// | affine-expr `==` `0` 3145 /// 3146 /// isEq is set to true if the parsed constraint is an equality, false if it 3147 /// is an inequality (greater than or equal). 3148 /// 3149 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 3150 AffineExpr expr = parseAffineExpr(); 3151 if (!expr) 3152 return nullptr; 3153 3154 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 3155 getToken().is(Token::integer)) { 3156 auto dim = getToken().getUnsignedIntegerValue(); 3157 if (dim.hasValue() && dim.getValue() == 0) { 3158 consumeToken(Token::integer); 3159 *isEq = false; 3160 return expr; 3161 } 3162 return (emitError("expected '0' after '>='"), nullptr); 3163 } 3164 3165 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 3166 getToken().is(Token::integer)) { 3167 auto dim = getToken().getUnsignedIntegerValue(); 3168 if (dim.hasValue() && dim.getValue() == 0) { 3169 consumeToken(Token::integer); 3170 *isEq = true; 3171 return expr; 3172 } 3173 return (emitError("expected '0' after '=='"), nullptr); 3174 } 3175 3176 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 3177 nullptr); 3178 } 3179 3180 /// Parse the constraints that are part of an integer set definition. 3181 /// integer-set-inline 3182 /// ::= dim-and-symbol-id-lists `:` 3183 /// '(' affine-constraint-conjunction? ')' 3184 /// affine-constraint-conjunction ::= affine-constraint (`,` 3185 /// affine-constraint)* 3186 /// 3187 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 3188 unsigned numSymbols) { 3189 if (parseToken(Token::l_paren, 3190 "expected '(' at start of integer set constraint list")) 3191 return IntegerSet(); 3192 3193 SmallVector<AffineExpr, 4> constraints; 3194 SmallVector<bool, 4> isEqs; 3195 auto parseElt = [&]() -> ParseResult { 3196 bool isEq; 3197 auto elt = parseAffineConstraint(&isEq); 3198 ParseResult res = elt ? success() : failure(); 3199 if (elt) { 3200 constraints.push_back(elt); 3201 isEqs.push_back(isEq); 3202 } 3203 return res; 3204 }; 3205 3206 // Parse a list of affine constraints (comma-separated). 3207 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3208 return IntegerSet(); 3209 3210 // If no constraints were parsed, then treat this as a degenerate 'true' case. 3211 if (constraints.empty()) { 3212 /* 0 == 0 */ 3213 auto zero = getAffineConstantExpr(0, getContext()); 3214 return IntegerSet::get(numDims, numSymbols, zero, true); 3215 } 3216 3217 // Parsed a valid integer set. 3218 return IntegerSet::get(numDims, numSymbols, constraints, isEqs); 3219 } 3220 3221 /// Parse an ambiguous reference to either and affine map or an integer set. 3222 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 3223 IntegerSet &set) { 3224 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 3225 } 3226 ParseResult Parser::parseAffineMapReference(AffineMap &map) { 3227 llvm::SMLoc curLoc = getToken().getLoc(); 3228 IntegerSet set; 3229 if (parseAffineMapOrIntegerSetReference(map, set)) 3230 return failure(); 3231 if (set) 3232 return emitError(curLoc, "expected AffineMap, but got IntegerSet"); 3233 return success(); 3234 } 3235 ParseResult Parser::parseIntegerSetReference(IntegerSet &set) { 3236 llvm::SMLoc curLoc = getToken().getLoc(); 3237 AffineMap map; 3238 if (parseAffineMapOrIntegerSetReference(map, set)) 3239 return failure(); 3240 if (map) 3241 return emitError(curLoc, "expected IntegerSet, but got AffineMap"); 3242 return success(); 3243 } 3244 3245 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 3246 /// parse SSA value uses encountered while parsing affine expressions. 3247 ParseResult 3248 Parser::parseAffineMapOfSSAIds(AffineMap &map, 3249 function_ref<ParseResult(bool)> parseElement, 3250 OpAsmParser::Delimiter delimiter) { 3251 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 3252 .parseAffineMapOfSSAIds(map, delimiter); 3253 } 3254 3255 //===----------------------------------------------------------------------===// 3256 // OperationParser 3257 //===----------------------------------------------------------------------===// 3258 3259 namespace { 3260 /// This class provides support for parsing operations and regions of 3261 /// operations. 3262 class OperationParser : public Parser { 3263 public: 3264 OperationParser(ParserState &state, ModuleOp moduleOp) 3265 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 3266 } 3267 3268 ~OperationParser(); 3269 3270 /// After parsing is finished, this function must be called to see if there 3271 /// are any remaining issues. 3272 ParseResult finalize(); 3273 3274 //===--------------------------------------------------------------------===// 3275 // SSA Value Handling 3276 //===--------------------------------------------------------------------===// 3277 3278 /// This represents a use of an SSA value in the program. The first two 3279 /// entries in the tuple are the name and result number of a reference. The 3280 /// third is the location of the reference, which is used in case this ends 3281 /// up being a use of an undefined value. 3282 struct SSAUseInfo { 3283 StringRef name; // Value name, e.g. %42 or %abc 3284 unsigned number; // Number, specified with #12 3285 SMLoc loc; // Location of first definition or use. 3286 }; 3287 3288 /// Push a new SSA name scope to the parser. 3289 void pushSSANameScope(bool isIsolated); 3290 3291 /// Pop the last SSA name scope from the parser. 3292 ParseResult popSSANameScope(); 3293 3294 /// Register a definition of a value with the symbol table. 3295 ParseResult addDefinition(SSAUseInfo useInfo, Value value); 3296 3297 /// Parse an optional list of SSA uses into 'results'. 3298 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 3299 3300 /// Parse a single SSA use into 'result'. 3301 ParseResult parseSSAUse(SSAUseInfo &result); 3302 3303 /// Given a reference to an SSA value and its type, return a reference. This 3304 /// returns null on failure. 3305 Value resolveSSAUse(SSAUseInfo useInfo, Type type); 3306 3307 ParseResult parseSSADefOrUseAndType( 3308 const std::function<ParseResult(SSAUseInfo, Type)> &action); 3309 3310 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value> &results); 3311 3312 /// Return the location of the value identified by its name and number if it 3313 /// has been already reference. 3314 Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 3315 auto &values = isolatedNameScopes.back().values; 3316 if (!values.count(name) || number >= values[name].size()) 3317 return {}; 3318 if (values[name][number].first) 3319 return values[name][number].second; 3320 return {}; 3321 } 3322 3323 //===--------------------------------------------------------------------===// 3324 // Operation Parsing 3325 //===--------------------------------------------------------------------===// 3326 3327 /// Parse an operation instance. 3328 ParseResult parseOperation(); 3329 3330 /// Parse a single operation successor. 3331 ParseResult parseSuccessor(Block *&dest); 3332 3333 /// Parse a comma-separated list of operation successors in brackets. 3334 ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations); 3335 3336 /// Parse an operation instance that is in the generic form. 3337 Operation *parseGenericOperation(); 3338 3339 /// Parse an operation instance that is in the generic form and insert it at 3340 /// the provided insertion point. 3341 Operation *parseGenericOperation(Block *insertBlock, 3342 Block::iterator insertPt); 3343 3344 /// This is the structure of a result specifier in the assembly syntax, 3345 /// including the name, number of results, and location. 3346 typedef std::tuple<StringRef, unsigned, SMLoc> ResultRecord; 3347 3348 /// Parse an operation instance that is in the op-defined custom form. 3349 /// resultInfo specifies information about the "%name =" specifiers. 3350 Operation *parseCustomOperation(ArrayRef<ResultRecord> resultInfo); 3351 3352 //===--------------------------------------------------------------------===// 3353 // Region Parsing 3354 //===--------------------------------------------------------------------===// 3355 3356 /// Parse a region into 'region' with the provided entry block arguments. 3357 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 3358 /// isolated from those above. 3359 ParseResult parseRegion(Region ®ion, 3360 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 3361 bool isIsolatedNameScope = false); 3362 3363 /// Parse a region body into 'region'. 3364 ParseResult parseRegionBody(Region ®ion); 3365 3366 //===--------------------------------------------------------------------===// 3367 // Block Parsing 3368 //===--------------------------------------------------------------------===// 3369 3370 /// Parse a new block into 'block'. 3371 ParseResult parseBlock(Block *&block); 3372 3373 /// Parse a list of operations into 'block'. 3374 ParseResult parseBlockBody(Block *block); 3375 3376 /// Parse a (possibly empty) list of block arguments. 3377 ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results, 3378 Block *owner); 3379 3380 /// Get the block with the specified name, creating it if it doesn't 3381 /// already exist. The location specified is the point of use, which allows 3382 /// us to diagnose references to blocks that are not defined precisely. 3383 Block *getBlockNamed(StringRef name, SMLoc loc); 3384 3385 /// Define the block with the specified name. Returns the Block* or nullptr in 3386 /// the case of redefinition. 3387 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 3388 3389 private: 3390 /// Returns the info for a block at the current scope for the given name. 3391 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 3392 return blocksByName.back()[name]; 3393 } 3394 3395 /// Insert a new forward reference to the given block. 3396 void insertForwardRef(Block *block, SMLoc loc) { 3397 forwardRef.back().try_emplace(block, loc); 3398 } 3399 3400 /// Erase any forward reference to the given block. 3401 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 3402 3403 /// Record that a definition was added at the current scope. 3404 void recordDefinition(StringRef def); 3405 3406 /// Get the value entry for the given SSA name. 3407 SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name); 3408 3409 /// Create a forward reference placeholder value with the given location and 3410 /// result type. 3411 Value createForwardRefPlaceholder(SMLoc loc, Type type); 3412 3413 /// Return true if this is a forward reference. 3414 bool isForwardRefPlaceholder(Value value) { 3415 return forwardRefPlaceholders.count(value); 3416 } 3417 3418 /// This struct represents an isolated SSA name scope. This scope may contain 3419 /// other nested non-isolated scopes. These scopes are used for operations 3420 /// that are known to be isolated to allow for reusing names within their 3421 /// regions, even if those names are used above. 3422 struct IsolatedSSANameScope { 3423 /// Record that a definition was added at the current scope. 3424 void recordDefinition(StringRef def) { 3425 definitionsPerScope.back().insert(def); 3426 } 3427 3428 /// Push a nested name scope. 3429 void pushSSANameScope() { definitionsPerScope.push_back({}); } 3430 3431 /// Pop a nested name scope. 3432 void popSSANameScope() { 3433 for (auto &def : definitionsPerScope.pop_back_val()) 3434 values.erase(def.getKey()); 3435 } 3436 3437 /// This keeps track of all of the SSA values we are tracking for each name 3438 /// scope, indexed by their name. This has one entry per result number. 3439 llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values; 3440 3441 /// This keeps track of all of the values defined by a specific name scope. 3442 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 3443 }; 3444 3445 /// A list of isolated name scopes. 3446 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 3447 3448 /// This keeps track of the block names as well as the location of the first 3449 /// reference for each nested name scope. This is used to diagnose invalid 3450 /// block references and memorize them. 3451 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 3452 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 3453 3454 /// These are all of the placeholders we've made along with the location of 3455 /// their first reference, to allow checking for use of undefined values. 3456 DenseMap<Value, SMLoc> forwardRefPlaceholders; 3457 3458 /// The builder used when creating parsed operation instances. 3459 OpBuilder opBuilder; 3460 3461 /// The top level module operation. 3462 ModuleOp moduleOp; 3463 }; 3464 } // end anonymous namespace 3465 3466 OperationParser::~OperationParser() { 3467 for (auto &fwd : forwardRefPlaceholders) { 3468 // Drop all uses of undefined forward declared reference and destroy 3469 // defining operation. 3470 fwd.first.dropAllUses(); 3471 fwd.first.getDefiningOp()->destroy(); 3472 } 3473 } 3474 3475 /// After parsing is finished, this function must be called to see if there are 3476 /// any remaining issues. 3477 ParseResult OperationParser::finalize() { 3478 // Check for any forward references that are left. If we find any, error 3479 // out. 3480 if (!forwardRefPlaceholders.empty()) { 3481 SmallVector<std::pair<const char *, Value>, 4> errors; 3482 // Iteration over the map isn't deterministic, so sort by source location. 3483 for (auto entry : forwardRefPlaceholders) 3484 errors.push_back({entry.second.getPointer(), entry.first}); 3485 llvm::array_pod_sort(errors.begin(), errors.end()); 3486 3487 for (auto entry : errors) { 3488 auto loc = SMLoc::getFromPointer(entry.first); 3489 emitError(loc, "use of undeclared SSA value name"); 3490 } 3491 return failure(); 3492 } 3493 3494 return success(); 3495 } 3496 3497 //===----------------------------------------------------------------------===// 3498 // SSA Value Handling 3499 //===----------------------------------------------------------------------===// 3500 3501 void OperationParser::pushSSANameScope(bool isIsolated) { 3502 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 3503 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 3504 3505 // Push back a new name definition scope. 3506 if (isIsolated) 3507 isolatedNameScopes.push_back({}); 3508 isolatedNameScopes.back().pushSSANameScope(); 3509 } 3510 3511 ParseResult OperationParser::popSSANameScope() { 3512 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 3513 3514 // Verify that all referenced blocks were defined. 3515 if (!forwardRefInCurrentScope.empty()) { 3516 SmallVector<std::pair<const char *, Block *>, 4> errors; 3517 // Iteration over the map isn't deterministic, so sort by source location. 3518 for (auto entry : forwardRefInCurrentScope) { 3519 errors.push_back({entry.second.getPointer(), entry.first}); 3520 // Add this block to the top-level region to allow for automatic cleanup. 3521 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 3522 } 3523 llvm::array_pod_sort(errors.begin(), errors.end()); 3524 3525 for (auto entry : errors) { 3526 auto loc = SMLoc::getFromPointer(entry.first); 3527 emitError(loc, "reference to an undefined block"); 3528 } 3529 return failure(); 3530 } 3531 3532 // Pop the next nested namescope. If there is only one internal namescope, 3533 // just pop the isolated scope. 3534 auto ¤tNameScope = isolatedNameScopes.back(); 3535 if (currentNameScope.definitionsPerScope.size() == 1) 3536 isolatedNameScopes.pop_back(); 3537 else 3538 currentNameScope.popSSANameScope(); 3539 3540 blocksByName.pop_back(); 3541 return success(); 3542 } 3543 3544 /// Register a definition of a value with the symbol table. 3545 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) { 3546 auto &entries = getSSAValueEntry(useInfo.name); 3547 3548 // Make sure there is a slot for this value. 3549 if (entries.size() <= useInfo.number) 3550 entries.resize(useInfo.number + 1); 3551 3552 // If we already have an entry for this, check to see if it was a definition 3553 // or a forward reference. 3554 if (auto existing = entries[useInfo.number].first) { 3555 if (!isForwardRefPlaceholder(existing)) { 3556 return emitError(useInfo.loc) 3557 .append("redefinition of SSA value '", useInfo.name, "'") 3558 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3559 .append("previously defined here"); 3560 } 3561 3562 // If it was a forward reference, update everything that used it to use 3563 // the actual definition instead, delete the forward ref, and remove it 3564 // from our set of forward references we track. 3565 existing.replaceAllUsesWith(value); 3566 existing.getDefiningOp()->destroy(); 3567 forwardRefPlaceholders.erase(existing); 3568 } 3569 3570 /// Record this definition for the current scope. 3571 entries[useInfo.number] = {value, useInfo.loc}; 3572 recordDefinition(useInfo.name); 3573 return success(); 3574 } 3575 3576 /// Parse a (possibly empty) list of SSA operands. 3577 /// 3578 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 3579 /// ssa-use-list-opt ::= ssa-use-list? 3580 /// 3581 ParseResult 3582 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 3583 if (getToken().isNot(Token::percent_identifier)) 3584 return success(); 3585 return parseCommaSeparatedList([&]() -> ParseResult { 3586 SSAUseInfo result; 3587 if (parseSSAUse(result)) 3588 return failure(); 3589 results.push_back(result); 3590 return success(); 3591 }); 3592 } 3593 3594 /// Parse a SSA operand for an operation. 3595 /// 3596 /// ssa-use ::= ssa-id 3597 /// 3598 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 3599 result.name = getTokenSpelling(); 3600 result.number = 0; 3601 result.loc = getToken().getLoc(); 3602 if (parseToken(Token::percent_identifier, "expected SSA operand")) 3603 return failure(); 3604 3605 // If we have an attribute ID, it is a result number. 3606 if (getToken().is(Token::hash_identifier)) { 3607 if (auto value = getToken().getHashIdentifierNumber()) 3608 result.number = value.getValue(); 3609 else 3610 return emitError("invalid SSA value result number"); 3611 consumeToken(Token::hash_identifier); 3612 } 3613 3614 return success(); 3615 } 3616 3617 /// Given an unbound reference to an SSA value and its type, return the value 3618 /// it specifies. This returns null on failure. 3619 Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 3620 auto &entries = getSSAValueEntry(useInfo.name); 3621 3622 // If we have already seen a value of this name, return it. 3623 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 3624 auto result = entries[useInfo.number].first; 3625 // Check that the type matches the other uses. 3626 if (result.getType() == type) 3627 return result; 3628 3629 emitError(useInfo.loc, "use of value '") 3630 .append(useInfo.name, 3631 "' expects different type than prior uses: ", type, " vs ", 3632 result.getType()) 3633 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3634 .append("prior use here"); 3635 return nullptr; 3636 } 3637 3638 // Make sure we have enough slots for this. 3639 if (entries.size() <= useInfo.number) 3640 entries.resize(useInfo.number + 1); 3641 3642 // If the value has already been defined and this is an overly large result 3643 // number, diagnose that. 3644 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 3645 return (emitError(useInfo.loc, "reference to invalid result number"), 3646 nullptr); 3647 3648 // Otherwise, this is a forward reference. Create a placeholder and remember 3649 // that we did so. 3650 auto result = createForwardRefPlaceholder(useInfo.loc, type); 3651 entries[useInfo.number].first = result; 3652 entries[useInfo.number].second = useInfo.loc; 3653 return result; 3654 } 3655 3656 /// Parse an SSA use with an associated type. 3657 /// 3658 /// ssa-use-and-type ::= ssa-use `:` type 3659 ParseResult OperationParser::parseSSADefOrUseAndType( 3660 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 3661 SSAUseInfo useInfo; 3662 if (parseSSAUse(useInfo) || 3663 parseToken(Token::colon, "expected ':' and type for SSA operand")) 3664 return failure(); 3665 3666 auto type = parseType(); 3667 if (!type) 3668 return failure(); 3669 3670 return action(useInfo, type); 3671 } 3672 3673 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 3674 /// followed by a type list. 3675 /// 3676 /// ssa-use-and-type-list 3677 /// ::= ssa-use-list ':' type-list-no-parens 3678 /// 3679 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 3680 SmallVectorImpl<Value> &results) { 3681 SmallVector<SSAUseInfo, 4> valueIDs; 3682 if (parseOptionalSSAUseList(valueIDs)) 3683 return failure(); 3684 3685 // If there were no operands, then there is no colon or type lists. 3686 if (valueIDs.empty()) 3687 return success(); 3688 3689 SmallVector<Type, 4> types; 3690 if (parseToken(Token::colon, "expected ':' in operand list") || 3691 parseTypeListNoParens(types)) 3692 return failure(); 3693 3694 if (valueIDs.size() != types.size()) 3695 return emitError("expected ") 3696 << valueIDs.size() << " types to match operand list"; 3697 3698 results.reserve(valueIDs.size()); 3699 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 3700 if (auto value = resolveSSAUse(valueIDs[i], types[i])) 3701 results.push_back(value); 3702 else 3703 return failure(); 3704 } 3705 3706 return success(); 3707 } 3708 3709 /// Record that a definition was added at the current scope. 3710 void OperationParser::recordDefinition(StringRef def) { 3711 isolatedNameScopes.back().recordDefinition(def); 3712 } 3713 3714 /// Get the value entry for the given SSA name. 3715 SmallVectorImpl<std::pair<Value, SMLoc>> & 3716 OperationParser::getSSAValueEntry(StringRef name) { 3717 return isolatedNameScopes.back().values[name]; 3718 } 3719 3720 /// Create and remember a new placeholder for a forward reference. 3721 Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 3722 // Forward references are always created as operations, because we just need 3723 // something with a def/use chain. 3724 // 3725 // We create these placeholders as having an empty name, which we know 3726 // cannot be created through normal user input, allowing us to distinguish 3727 // them. 3728 auto name = OperationName("placeholder", getContext()); 3729 auto *op = Operation::create( 3730 getEncodedSourceLocation(loc), name, type, /*operands=*/{}, 3731 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, 3732 /*resizableOperandList=*/false); 3733 forwardRefPlaceholders[op->getResult(0)] = loc; 3734 return op->getResult(0); 3735 } 3736 3737 //===----------------------------------------------------------------------===// 3738 // Operation Parsing 3739 //===----------------------------------------------------------------------===// 3740 3741 /// Parse an operation. 3742 /// 3743 /// operation ::= op-result-list? 3744 /// (generic-operation | custom-operation) 3745 /// trailing-location? 3746 /// generic-operation ::= string-literal `(` ssa-use-list? `)` 3747 /// successor-list? (`(` region-list `)`)? 3748 /// attribute-dict? `:` function-type 3749 /// custom-operation ::= bare-id custom-operation-format 3750 /// op-result-list ::= op-result (`,` op-result)* `=` 3751 /// op-result ::= ssa-id (`:` integer-literal) 3752 /// 3753 ParseResult OperationParser::parseOperation() { 3754 auto loc = getToken().getLoc(); 3755 SmallVector<ResultRecord, 1> resultIDs; 3756 size_t numExpectedResults = 0; 3757 if (getToken().is(Token::percent_identifier)) { 3758 // Parse the group of result ids. 3759 auto parseNextResult = [&]() -> ParseResult { 3760 // Parse the next result id. 3761 if (!getToken().is(Token::percent_identifier)) 3762 return emitError("expected valid ssa identifier"); 3763 3764 Token nameTok = getToken(); 3765 consumeToken(Token::percent_identifier); 3766 3767 // If the next token is a ':', we parse the expected result count. 3768 size_t expectedSubResults = 1; 3769 if (consumeIf(Token::colon)) { 3770 // Check that the next token is an integer. 3771 if (!getToken().is(Token::integer)) 3772 return emitError("expected integer number of results"); 3773 3774 // Check that number of results is > 0. 3775 auto val = getToken().getUInt64IntegerValue(); 3776 if (!val.hasValue() || val.getValue() < 1) 3777 return emitError("expected named operation to have atleast 1 result"); 3778 consumeToken(Token::integer); 3779 expectedSubResults = *val; 3780 } 3781 3782 resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults, 3783 nameTok.getLoc()); 3784 numExpectedResults += expectedSubResults; 3785 return success(); 3786 }; 3787 if (parseCommaSeparatedList(parseNextResult)) 3788 return failure(); 3789 3790 if (parseToken(Token::equal, "expected '=' after SSA name")) 3791 return failure(); 3792 } 3793 3794 Operation *op; 3795 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3796 op = parseCustomOperation(resultIDs); 3797 else if (getToken().is(Token::string)) 3798 op = parseGenericOperation(); 3799 else 3800 return emitError("expected operation name in quotes"); 3801 3802 // If parsing of the basic operation failed, then this whole thing fails. 3803 if (!op) 3804 return failure(); 3805 3806 // If the operation had a name, register it. 3807 if (!resultIDs.empty()) { 3808 if (op->getNumResults() == 0) 3809 return emitError(loc, "cannot name an operation with no results"); 3810 if (numExpectedResults != op->getNumResults()) 3811 return emitError(loc, "operation defines ") 3812 << op->getNumResults() << " results but was provided " 3813 << numExpectedResults << " to bind"; 3814 3815 // Add definitions for each of the result groups. 3816 unsigned opResI = 0; 3817 for (ResultRecord &resIt : resultIDs) { 3818 for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) { 3819 if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, 3820 op->getResult(opResI++))) 3821 return failure(); 3822 } 3823 } 3824 } 3825 3826 return success(); 3827 } 3828 3829 /// Parse a single operation successor. 3830 /// 3831 /// successor ::= block-id 3832 /// 3833 ParseResult OperationParser::parseSuccessor(Block *&dest) { 3834 // Verify branch is identifier and get the matching block. 3835 if (!getToken().is(Token::caret_identifier)) 3836 return emitError("expected block name"); 3837 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3838 consumeToken(); 3839 return success(); 3840 } 3841 3842 /// Parse a comma-separated list of operation successors in brackets. 3843 /// 3844 /// successor-list ::= `[` successor (`,` successor )* `]` 3845 /// 3846 ParseResult 3847 OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) { 3848 if (parseToken(Token::l_square, "expected '['")) 3849 return failure(); 3850 3851 auto parseElt = [this, &destinations] { 3852 Block *dest; 3853 ParseResult res = parseSuccessor(dest); 3854 destinations.push_back(dest); 3855 return res; 3856 }; 3857 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3858 /*allowEmptyList=*/false); 3859 } 3860 3861 namespace { 3862 // RAII-style guard for cleaning up the regions in the operation state before 3863 // deleting them. Within the parser, regions may get deleted if parsing failed, 3864 // and other errors may be present, in particular undominated uses. This makes 3865 // sure such uses are deleted. 3866 struct CleanupOpStateRegions { 3867 ~CleanupOpStateRegions() { 3868 SmallVector<Region *, 4> regionsToClean; 3869 regionsToClean.reserve(state.regions.size()); 3870 for (auto ®ion : state.regions) 3871 if (region) 3872 for (auto &block : *region) 3873 block.dropAllDefinedValueUses(); 3874 } 3875 OperationState &state; 3876 }; 3877 } // namespace 3878 3879 Operation *OperationParser::parseGenericOperation() { 3880 // Get location information for the operation. 3881 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3882 3883 auto name = getToken().getStringValue(); 3884 if (name.empty()) 3885 return (emitError("empty operation name is invalid"), nullptr); 3886 if (name.find('\0') != StringRef::npos) 3887 return (emitError("null character not allowed in operation name"), nullptr); 3888 3889 consumeToken(Token::string); 3890 3891 OperationState result(srcLocation, name); 3892 3893 // Generic operations have a resizable operation list. 3894 result.setOperandListToResizable(); 3895 3896 // Parse the operand list. 3897 SmallVector<SSAUseInfo, 8> operandInfos; 3898 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3899 parseOptionalSSAUseList(operandInfos) || 3900 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3901 return nullptr; 3902 } 3903 3904 // Parse the successor list. 3905 if (getToken().is(Token::l_square)) { 3906 // Check if the operation is a known terminator. 3907 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3908 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3909 return emitError("successors in non-terminator"), nullptr; 3910 3911 SmallVector<Block *, 2> successors; 3912 if (parseSuccessors(successors)) 3913 return nullptr; 3914 result.addSuccessors(successors); 3915 } 3916 3917 // Parse the region list. 3918 CleanupOpStateRegions guard{result}; 3919 if (consumeIf(Token::l_paren)) { 3920 do { 3921 // Create temporary regions with the top level region as parent. 3922 result.regions.emplace_back(new Region(moduleOp)); 3923 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3924 return nullptr; 3925 } while (consumeIf(Token::comma)); 3926 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3927 return nullptr; 3928 } 3929 3930 if (getToken().is(Token::l_brace)) { 3931 if (parseAttributeDict(result.attributes)) 3932 return nullptr; 3933 } 3934 3935 if (parseToken(Token::colon, "expected ':' followed by operation type")) 3936 return nullptr; 3937 3938 auto typeLoc = getToken().getLoc(); 3939 auto type = parseType(); 3940 if (!type) 3941 return nullptr; 3942 auto fnType = type.dyn_cast<FunctionType>(); 3943 if (!fnType) 3944 return (emitError(typeLoc, "expected function type"), nullptr); 3945 3946 result.addTypes(fnType.getResults()); 3947 3948 // Check that we have the right number of types for the operands. 3949 auto operandTypes = fnType.getInputs(); 3950 if (operandTypes.size() != operandInfos.size()) { 3951 auto plural = "s"[operandInfos.size() == 1]; 3952 return (emitError(typeLoc, "expected ") 3953 << operandInfos.size() << " operand type" << plural 3954 << " but had " << operandTypes.size(), 3955 nullptr); 3956 } 3957 3958 // Resolve all of the operands. 3959 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 3960 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 3961 if (!result.operands.back()) 3962 return nullptr; 3963 } 3964 3965 // Parse a location if one is present. 3966 if (parseOptionalTrailingLocation(result.location)) 3967 return nullptr; 3968 3969 return opBuilder.createOperation(result); 3970 } 3971 3972 Operation *OperationParser::parseGenericOperation(Block *insertBlock, 3973 Block::iterator insertPt) { 3974 OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); 3975 opBuilder.setInsertionPoint(insertBlock, insertPt); 3976 return parseGenericOperation(); 3977 } 3978 3979 namespace { 3980 class CustomOpAsmParser : public OpAsmParser { 3981 public: 3982 CustomOpAsmParser(SMLoc nameLoc, 3983 ArrayRef<OperationParser::ResultRecord> resultIDs, 3984 const AbstractOperation *opDefinition, 3985 OperationParser &parser) 3986 : nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition), 3987 parser(parser) {} 3988 3989 /// Parse an instance of the operation described by 'opDefinition' into the 3990 /// provided operation state. 3991 ParseResult parseOperation(OperationState &opState) { 3992 if (opDefinition->parseAssembly(*this, opState)) 3993 return failure(); 3994 return success(); 3995 } 3996 3997 Operation *parseGenericOperation(Block *insertBlock, 3998 Block::iterator insertPt) final { 3999 return parser.parseGenericOperation(insertBlock, insertPt); 4000 } 4001 4002 //===--------------------------------------------------------------------===// 4003 // Utilities 4004 //===--------------------------------------------------------------------===// 4005 4006 /// Return if any errors were emitted during parsing. 4007 bool didEmitError() const { return emittedError; } 4008 4009 /// Emit a diagnostic at the specified location and return failure. 4010 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 4011 emittedError = true; 4012 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 4013 message); 4014 } 4015 4016 llvm::SMLoc getCurrentLocation() override { 4017 return parser.getToken().getLoc(); 4018 } 4019 4020 Builder &getBuilder() const override { return parser.builder; } 4021 4022 /// Return the name of the specified result in the specified syntax, as well 4023 /// as the subelement in the name. For example, in this operation: 4024 /// 4025 /// %x, %y:2, %z = foo.op 4026 /// 4027 /// getResultName(0) == {"x", 0 } 4028 /// getResultName(1) == {"y", 0 } 4029 /// getResultName(2) == {"y", 1 } 4030 /// getResultName(3) == {"z", 0 } 4031 std::pair<StringRef, unsigned> 4032 getResultName(unsigned resultNo) const override { 4033 // Scan for the resultID that contains this result number. 4034 for (unsigned nameID = 0, e = resultIDs.size(); nameID != e; ++nameID) { 4035 const auto &entry = resultIDs[nameID]; 4036 if (resultNo < std::get<1>(entry)) { 4037 // Don't pass on the leading %. 4038 StringRef name = std::get<0>(entry).drop_front(); 4039 return {name, resultNo}; 4040 } 4041 resultNo -= std::get<1>(entry); 4042 } 4043 4044 // Invalid result number. 4045 return {"", ~0U}; 4046 } 4047 4048 /// Return the number of declared SSA results. This returns 4 for the foo.op 4049 /// example in the comment for getResultName. 4050 size_t getNumResults() const override { 4051 size_t count = 0; 4052 for (auto &entry : resultIDs) 4053 count += std::get<1>(entry); 4054 return count; 4055 } 4056 4057 llvm::SMLoc getNameLoc() const override { return nameLoc; } 4058 4059 //===--------------------------------------------------------------------===// 4060 // Token Parsing 4061 //===--------------------------------------------------------------------===// 4062 4063 /// Parse a `->` token. 4064 ParseResult parseArrow() override { 4065 return parser.parseToken(Token::arrow, "expected '->'"); 4066 } 4067 4068 /// Parses a `->` if present. 4069 ParseResult parseOptionalArrow() override { 4070 return success(parser.consumeIf(Token::arrow)); 4071 } 4072 4073 /// Parse a `:` token. 4074 ParseResult parseColon() override { 4075 return parser.parseToken(Token::colon, "expected ':'"); 4076 } 4077 4078 /// Parse a `:` token if present. 4079 ParseResult parseOptionalColon() override { 4080 return success(parser.consumeIf(Token::colon)); 4081 } 4082 4083 /// Parse a `,` token. 4084 ParseResult parseComma() override { 4085 return parser.parseToken(Token::comma, "expected ','"); 4086 } 4087 4088 /// Parse a `,` token if present. 4089 ParseResult parseOptionalComma() override { 4090 return success(parser.consumeIf(Token::comma)); 4091 } 4092 4093 /// Parses a `...` if present. 4094 ParseResult parseOptionalEllipsis() override { 4095 return success(parser.consumeIf(Token::ellipsis)); 4096 } 4097 4098 /// Parse a `=` token. 4099 ParseResult parseEqual() override { 4100 return parser.parseToken(Token::equal, "expected '='"); 4101 } 4102 4103 /// Parse a '<' token. 4104 ParseResult parseLess() override { 4105 return parser.parseToken(Token::less, "expected '<'"); 4106 } 4107 4108 /// Parse a '>' token. 4109 ParseResult parseGreater() override { 4110 return parser.parseToken(Token::greater, "expected '>'"); 4111 } 4112 4113 /// Parse a `(` token. 4114 ParseResult parseLParen() override { 4115 return parser.parseToken(Token::l_paren, "expected '('"); 4116 } 4117 4118 /// Parses a '(' if present. 4119 ParseResult parseOptionalLParen() override { 4120 return success(parser.consumeIf(Token::l_paren)); 4121 } 4122 4123 /// Parse a `)` token. 4124 ParseResult parseRParen() override { 4125 return parser.parseToken(Token::r_paren, "expected ')'"); 4126 } 4127 4128 /// Parses a ')' if present. 4129 ParseResult parseOptionalRParen() override { 4130 return success(parser.consumeIf(Token::r_paren)); 4131 } 4132 4133 /// Parse a `[` token. 4134 ParseResult parseLSquare() override { 4135 return parser.parseToken(Token::l_square, "expected '['"); 4136 } 4137 4138 /// Parses a '[' if present. 4139 ParseResult parseOptionalLSquare() override { 4140 return success(parser.consumeIf(Token::l_square)); 4141 } 4142 4143 /// Parse a `]` token. 4144 ParseResult parseRSquare() override { 4145 return parser.parseToken(Token::r_square, "expected ']'"); 4146 } 4147 4148 /// Parses a ']' if present. 4149 ParseResult parseOptionalRSquare() override { 4150 return success(parser.consumeIf(Token::r_square)); 4151 } 4152 4153 //===--------------------------------------------------------------------===// 4154 // Attribute Parsing 4155 //===--------------------------------------------------------------------===// 4156 4157 /// Parse an arbitrary attribute of a given type and return it in result. This 4158 /// also adds the attribute to the specified attribute list with the specified 4159 /// name. 4160 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 4161 SmallVectorImpl<NamedAttribute> &attrs) override { 4162 result = parser.parseAttribute(type); 4163 if (!result) 4164 return failure(); 4165 4166 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 4167 return success(); 4168 } 4169 4170 /// Parse a named dictionary into 'result' if it is present. 4171 ParseResult 4172 parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override { 4173 if (parser.getToken().isNot(Token::l_brace)) 4174 return success(); 4175 return parser.parseAttributeDict(result); 4176 } 4177 4178 /// Parse a named dictionary into 'result' if the `attributes` keyword is 4179 /// present. 4180 ParseResult parseOptionalAttrDictWithKeyword( 4181 SmallVectorImpl<NamedAttribute> &result) override { 4182 if (failed(parseOptionalKeyword("attributes"))) 4183 return success(); 4184 return parser.parseAttributeDict(result); 4185 } 4186 4187 /// Parse an affine map instance into 'map'. 4188 ParseResult parseAffineMap(AffineMap &map) override { 4189 return parser.parseAffineMapReference(map); 4190 } 4191 4192 /// Parse an integer set instance into 'set'. 4193 ParseResult printIntegerSet(IntegerSet &set) override { 4194 return parser.parseIntegerSetReference(set); 4195 } 4196 4197 //===--------------------------------------------------------------------===// 4198 // Identifier Parsing 4199 //===--------------------------------------------------------------------===// 4200 4201 /// Returns if the current token corresponds to a keyword. 4202 bool isCurrentTokenAKeyword() const { 4203 return parser.getToken().is(Token::bare_identifier) || 4204 parser.getToken().isKeyword(); 4205 } 4206 4207 /// Parse the given keyword if present. 4208 ParseResult parseOptionalKeyword(StringRef keyword) override { 4209 // Check that the current token has the same spelling. 4210 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 4211 return failure(); 4212 parser.consumeToken(); 4213 return success(); 4214 } 4215 4216 /// Parse a keyword, if present, into 'keyword'. 4217 ParseResult parseOptionalKeyword(StringRef *keyword) override { 4218 // Check that the current token is a keyword. 4219 if (!isCurrentTokenAKeyword()) 4220 return failure(); 4221 4222 *keyword = parser.getTokenSpelling(); 4223 parser.consumeToken(); 4224 return success(); 4225 } 4226 4227 /// Parse an optional @-identifier and store it (without the '@' symbol) in a 4228 /// string attribute named 'attrName'. 4229 ParseResult 4230 parseOptionalSymbolName(StringAttr &result, StringRef attrName, 4231 SmallVectorImpl<NamedAttribute> &attrs) override { 4232 Token atToken = parser.getToken(); 4233 if (atToken.isNot(Token::at_identifier)) 4234 return failure(); 4235 4236 result = getBuilder().getStringAttr(extractSymbolReference(atToken)); 4237 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 4238 parser.consumeToken(); 4239 return success(); 4240 } 4241 4242 //===--------------------------------------------------------------------===// 4243 // Operand Parsing 4244 //===--------------------------------------------------------------------===// 4245 4246 /// Parse a single operand. 4247 ParseResult parseOperand(OperandType &result) override { 4248 OperationParser::SSAUseInfo useInfo; 4249 if (parser.parseSSAUse(useInfo)) 4250 return failure(); 4251 4252 result = {useInfo.loc, useInfo.name, useInfo.number}; 4253 return success(); 4254 } 4255 4256 /// Parse a single operand if present. 4257 OptionalParseResult parseOptionalOperand(OperandType &result) override { 4258 if (parser.getToken().is(Token::percent_identifier)) 4259 return parseOperand(result); 4260 return llvm::None; 4261 } 4262 4263 /// Parse zero or more SSA comma-separated operand references with a specified 4264 /// surrounding delimiter, and an optional required operand count. 4265 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 4266 int requiredOperandCount = -1, 4267 Delimiter delimiter = Delimiter::None) override { 4268 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 4269 requiredOperandCount, delimiter); 4270 } 4271 4272 /// Parse zero or more SSA comma-separated operand or region arguments with 4273 /// optional surrounding delimiter and required operand count. 4274 ParseResult 4275 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 4276 bool isOperandList, int requiredOperandCount = -1, 4277 Delimiter delimiter = Delimiter::None) { 4278 auto startLoc = parser.getToken().getLoc(); 4279 4280 // Handle delimiters. 4281 switch (delimiter) { 4282 case Delimiter::None: 4283 // Don't check for the absence of a delimiter if the number of operands 4284 // is unknown (and hence the operand list could be empty). 4285 if (requiredOperandCount == -1) 4286 break; 4287 // Token already matches an identifier and so can't be a delimiter. 4288 if (parser.getToken().is(Token::percent_identifier)) 4289 break; 4290 // Test against known delimiters. 4291 if (parser.getToken().is(Token::l_paren) || 4292 parser.getToken().is(Token::l_square)) 4293 return emitError(startLoc, "unexpected delimiter"); 4294 return emitError(startLoc, "invalid operand"); 4295 case Delimiter::OptionalParen: 4296 if (parser.getToken().isNot(Token::l_paren)) 4297 return success(); 4298 LLVM_FALLTHROUGH; 4299 case Delimiter::Paren: 4300 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 4301 return failure(); 4302 break; 4303 case Delimiter::OptionalSquare: 4304 if (parser.getToken().isNot(Token::l_square)) 4305 return success(); 4306 LLVM_FALLTHROUGH; 4307 case Delimiter::Square: 4308 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 4309 return failure(); 4310 break; 4311 } 4312 4313 // Check for zero operands. 4314 if (parser.getToken().is(Token::percent_identifier)) { 4315 do { 4316 OperandType operandOrArg; 4317 if (isOperandList ? parseOperand(operandOrArg) 4318 : parseRegionArgument(operandOrArg)) 4319 return failure(); 4320 result.push_back(operandOrArg); 4321 } while (parser.consumeIf(Token::comma)); 4322 } 4323 4324 // Handle delimiters. If we reach here, the optional delimiters were 4325 // present, so we need to parse their closing one. 4326 switch (delimiter) { 4327 case Delimiter::None: 4328 break; 4329 case Delimiter::OptionalParen: 4330 case Delimiter::Paren: 4331 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 4332 return failure(); 4333 break; 4334 case Delimiter::OptionalSquare: 4335 case Delimiter::Square: 4336 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 4337 return failure(); 4338 break; 4339 } 4340 4341 if (requiredOperandCount != -1 && 4342 result.size() != static_cast<size_t>(requiredOperandCount)) 4343 return emitError(startLoc, "expected ") 4344 << requiredOperandCount << " operands"; 4345 return success(); 4346 } 4347 4348 /// Parse zero or more trailing SSA comma-separated trailing operand 4349 /// references with a specified surrounding delimiter, and an optional 4350 /// required operand count. A leading comma is expected before the operands. 4351 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 4352 int requiredOperandCount, 4353 Delimiter delimiter) override { 4354 if (parser.getToken().is(Token::comma)) { 4355 parseComma(); 4356 return parseOperandList(result, requiredOperandCount, delimiter); 4357 } 4358 if (requiredOperandCount != -1) 4359 return emitError(parser.getToken().getLoc(), "expected ") 4360 << requiredOperandCount << " operands"; 4361 return success(); 4362 } 4363 4364 /// Resolve an operand to an SSA value, emitting an error on failure. 4365 ParseResult resolveOperand(const OperandType &operand, Type type, 4366 SmallVectorImpl<Value> &result) override { 4367 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4368 operand.location}; 4369 if (auto value = parser.resolveSSAUse(operandInfo, type)) { 4370 result.push_back(value); 4371 return success(); 4372 } 4373 return failure(); 4374 } 4375 4376 /// Parse an AffineMap of SSA ids. 4377 ParseResult parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 4378 Attribute &mapAttr, StringRef attrName, 4379 SmallVectorImpl<NamedAttribute> &attrs, 4380 Delimiter delimiter) override { 4381 SmallVector<OperandType, 2> dimOperands; 4382 SmallVector<OperandType, 1> symOperands; 4383 4384 auto parseElement = [&](bool isSymbol) -> ParseResult { 4385 OperandType operand; 4386 if (parseOperand(operand)) 4387 return failure(); 4388 if (isSymbol) 4389 symOperands.push_back(operand); 4390 else 4391 dimOperands.push_back(operand); 4392 return success(); 4393 }; 4394 4395 AffineMap map; 4396 if (parser.parseAffineMapOfSSAIds(map, parseElement, delimiter)) 4397 return failure(); 4398 // Add AffineMap attribute. 4399 if (map) { 4400 mapAttr = AffineMapAttr::get(map); 4401 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 4402 } 4403 4404 // Add dim operands before symbol operands in 'operands'. 4405 operands.assign(dimOperands.begin(), dimOperands.end()); 4406 operands.append(symOperands.begin(), symOperands.end()); 4407 return success(); 4408 } 4409 4410 //===--------------------------------------------------------------------===// 4411 // Region Parsing 4412 //===--------------------------------------------------------------------===// 4413 4414 /// Parse a region that takes `arguments` of `argTypes` types. This 4415 /// effectively defines the SSA values of `arguments` and assigns their type. 4416 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 4417 ArrayRef<Type> argTypes, 4418 bool enableNameShadowing) override { 4419 assert(arguments.size() == argTypes.size() && 4420 "mismatching number of arguments and types"); 4421 4422 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 4423 regionArguments; 4424 for (auto pair : llvm::zip(arguments, argTypes)) { 4425 const OperandType &operand = std::get<0>(pair); 4426 Type type = std::get<1>(pair); 4427 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4428 operand.location}; 4429 regionArguments.emplace_back(operandInfo, type); 4430 } 4431 4432 // Try to parse the region. 4433 assert((!enableNameShadowing || 4434 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 4435 "name shadowing is only allowed on isolated regions"); 4436 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 4437 return failure(); 4438 return success(); 4439 } 4440 4441 /// Parses a region if present. 4442 ParseResult parseOptionalRegion(Region ®ion, 4443 ArrayRef<OperandType> arguments, 4444 ArrayRef<Type> argTypes, 4445 bool enableNameShadowing) override { 4446 if (parser.getToken().isNot(Token::l_brace)) 4447 return success(); 4448 return parseRegion(region, arguments, argTypes, enableNameShadowing); 4449 } 4450 4451 /// Parse a region argument. The type of the argument will be resolved later 4452 /// by a call to `parseRegion`. 4453 ParseResult parseRegionArgument(OperandType &argument) override { 4454 return parseOperand(argument); 4455 } 4456 4457 /// Parse a region argument if present. 4458 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 4459 if (parser.getToken().isNot(Token::percent_identifier)) 4460 return success(); 4461 return parseRegionArgument(argument); 4462 } 4463 4464 ParseResult 4465 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 4466 int requiredOperandCount = -1, 4467 Delimiter delimiter = Delimiter::None) override { 4468 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 4469 requiredOperandCount, delimiter); 4470 } 4471 4472 //===--------------------------------------------------------------------===// 4473 // Successor Parsing 4474 //===--------------------------------------------------------------------===// 4475 4476 /// Parse a single operation successor. 4477 ParseResult parseSuccessor(Block *&dest) override { 4478 return parser.parseSuccessor(dest); 4479 } 4480 4481 /// Parse an optional operation successor and its operand list. 4482 OptionalParseResult parseOptionalSuccessor(Block *&dest) override { 4483 if (parser.getToken().isNot(Token::caret_identifier)) 4484 return llvm::None; 4485 return parseSuccessor(dest); 4486 } 4487 4488 /// Parse a single operation successor and its operand list. 4489 ParseResult 4490 parseSuccessorAndUseList(Block *&dest, 4491 SmallVectorImpl<Value> &operands) override { 4492 if (parseSuccessor(dest)) 4493 return failure(); 4494 4495 // Handle optional arguments. 4496 if (succeeded(parseOptionalLParen()) && 4497 (parser.parseOptionalSSAUseAndTypeList(operands) || parseRParen())) { 4498 return failure(); 4499 } 4500 return success(); 4501 } 4502 4503 //===--------------------------------------------------------------------===// 4504 // Type Parsing 4505 //===--------------------------------------------------------------------===// 4506 4507 /// Parse a type. 4508 ParseResult parseType(Type &result) override { 4509 return failure(!(result = parser.parseType())); 4510 } 4511 4512 /// Parse an arrow followed by a type list. 4513 ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override { 4514 if (parseArrow() || parser.parseFunctionResultTypes(result)) 4515 return failure(); 4516 return success(); 4517 } 4518 4519 /// Parse an optional arrow followed by a type list. 4520 ParseResult 4521 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 4522 if (!parser.consumeIf(Token::arrow)) 4523 return success(); 4524 return parser.parseFunctionResultTypes(result); 4525 } 4526 4527 /// Parse a colon followed by a type. 4528 ParseResult parseColonType(Type &result) override { 4529 return failure(parser.parseToken(Token::colon, "expected ':'") || 4530 !(result = parser.parseType())); 4531 } 4532 4533 /// Parse a colon followed by a type list, which must have at least one type. 4534 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 4535 if (parser.parseToken(Token::colon, "expected ':'")) 4536 return failure(); 4537 return parser.parseTypeListNoParens(result); 4538 } 4539 4540 /// Parse an optional colon followed by a type list, which if present must 4541 /// have at least one type. 4542 ParseResult 4543 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 4544 if (!parser.consumeIf(Token::colon)) 4545 return success(); 4546 return parser.parseTypeListNoParens(result); 4547 } 4548 4549 /// Parse a list of assignments of the form 4550 /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). 4551 /// The list must contain at least one entry 4552 ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs, 4553 SmallVectorImpl<OperandType> &rhs) override { 4554 auto parseElt = [&]() -> ParseResult { 4555 OperandType regionArg, operand; 4556 if (parseRegionArgument(regionArg) || parseEqual() || 4557 parseOperand(operand)) 4558 return failure(); 4559 lhs.push_back(regionArg); 4560 rhs.push_back(operand); 4561 return success(); 4562 }; 4563 if (parseLParen()) 4564 return failure(); 4565 return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); 4566 } 4567 4568 private: 4569 /// The source location of the operation name. 4570 SMLoc nameLoc; 4571 4572 /// Information about the result name specifiers. 4573 ArrayRef<OperationParser::ResultRecord> resultIDs; 4574 4575 /// The abstract information of the operation. 4576 const AbstractOperation *opDefinition; 4577 4578 /// The main operation parser. 4579 OperationParser &parser; 4580 4581 /// A flag that indicates if any errors were emitted during parsing. 4582 bool emittedError = false; 4583 }; 4584 } // end anonymous namespace. 4585 4586 Operation * 4587 OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) { 4588 auto opLoc = getToken().getLoc(); 4589 auto opName = getTokenSpelling(); 4590 4591 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 4592 if (!opDefinition && !opName.contains('.')) { 4593 // If the operation name has no namespace prefix we treat it as a standard 4594 // operation and prefix it with "std". 4595 // TODO: Would it be better to just build a mapping of the registered 4596 // operations in the standard dialect? 4597 opDefinition = 4598 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 4599 } 4600 4601 if (!opDefinition) { 4602 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 4603 return nullptr; 4604 } 4605 4606 consumeToken(); 4607 4608 // If the custom op parser crashes, produce some indication to help 4609 // debugging. 4610 std::string opNameStr = opName.str(); 4611 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 4612 opNameStr.c_str()); 4613 4614 // Get location information for the operation. 4615 auto srcLocation = getEncodedSourceLocation(opLoc); 4616 4617 // Have the op implementation take a crack and parsing this. 4618 OperationState opState(srcLocation, opDefinition->name); 4619 CleanupOpStateRegions guard{opState}; 4620 CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this); 4621 if (opAsmParser.parseOperation(opState)) 4622 return nullptr; 4623 4624 // If it emitted an error, we failed. 4625 if (opAsmParser.didEmitError()) 4626 return nullptr; 4627 4628 // Parse a location if one is present. 4629 if (parseOptionalTrailingLocation(opState.location)) 4630 return nullptr; 4631 4632 // Otherwise, we succeeded. Use the state it parsed as our op information. 4633 return opBuilder.createOperation(opState); 4634 } 4635 4636 //===----------------------------------------------------------------------===// 4637 // Region Parsing 4638 //===----------------------------------------------------------------------===// 4639 4640 /// Region. 4641 /// 4642 /// region ::= '{' region-body 4643 /// 4644 ParseResult OperationParser::parseRegion( 4645 Region ®ion, 4646 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 4647 bool isIsolatedNameScope) { 4648 // Parse the '{'. 4649 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 4650 return failure(); 4651 4652 // Check for an empty region. 4653 if (entryArguments.empty() && consumeIf(Token::r_brace)) 4654 return success(); 4655 auto currentPt = opBuilder.saveInsertionPoint(); 4656 4657 // Push a new named value scope. 4658 pushSSANameScope(isIsolatedNameScope); 4659 4660 // Parse the first block directly to allow for it to be unnamed. 4661 Block *block = new Block(); 4662 4663 // Add arguments to the entry block. 4664 if (!entryArguments.empty()) { 4665 for (auto &placeholderArgPair : entryArguments) { 4666 auto &argInfo = placeholderArgPair.first; 4667 // Ensure that the argument was not already defined. 4668 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 4669 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 4670 "' is already in use") 4671 .attachNote(getEncodedSourceLocation(*defLoc)) 4672 << "previously referenced here"; 4673 } 4674 if (addDefinition(placeholderArgPair.first, 4675 block->addArgument(placeholderArgPair.second))) { 4676 delete block; 4677 return failure(); 4678 } 4679 } 4680 4681 // If we had named arguments, then don't allow a block name. 4682 if (getToken().is(Token::caret_identifier)) 4683 return emitError("invalid block name in region with named arguments"); 4684 } 4685 4686 if (parseBlock(block)) { 4687 delete block; 4688 return failure(); 4689 } 4690 4691 // Verify that no other arguments were parsed. 4692 if (!entryArguments.empty() && 4693 block->getNumArguments() > entryArguments.size()) { 4694 delete block; 4695 return emitError("entry block arguments were already defined"); 4696 } 4697 4698 // Parse the rest of the region. 4699 region.push_back(block); 4700 if (parseRegionBody(region)) 4701 return failure(); 4702 4703 // Pop the SSA value scope for this region. 4704 if (popSSANameScope()) 4705 return failure(); 4706 4707 // Reset the original insertion point. 4708 opBuilder.restoreInsertionPoint(currentPt); 4709 return success(); 4710 } 4711 4712 /// Region. 4713 /// 4714 /// region-body ::= block* '}' 4715 /// 4716 ParseResult OperationParser::parseRegionBody(Region ®ion) { 4717 // Parse the list of blocks. 4718 while (!consumeIf(Token::r_brace)) { 4719 Block *newBlock = nullptr; 4720 if (parseBlock(newBlock)) 4721 return failure(); 4722 region.push_back(newBlock); 4723 } 4724 return success(); 4725 } 4726 4727 //===----------------------------------------------------------------------===// 4728 // Block Parsing 4729 //===----------------------------------------------------------------------===// 4730 4731 /// Block declaration. 4732 /// 4733 /// block ::= block-label? operation* 4734 /// block-label ::= block-id block-arg-list? `:` 4735 /// block-id ::= caret-id 4736 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 4737 /// 4738 ParseResult OperationParser::parseBlock(Block *&block) { 4739 // The first block of a region may already exist, if it does the caret 4740 // identifier is optional. 4741 if (block && getToken().isNot(Token::caret_identifier)) 4742 return parseBlockBody(block); 4743 4744 SMLoc nameLoc = getToken().getLoc(); 4745 auto name = getTokenSpelling(); 4746 if (parseToken(Token::caret_identifier, "expected block name")) 4747 return failure(); 4748 4749 block = defineBlockNamed(name, nameLoc, block); 4750 4751 // Fail if the block was already defined. 4752 if (!block) 4753 return emitError(nameLoc, "redefinition of block '") << name << "'"; 4754 4755 // If an argument list is present, parse it. 4756 if (consumeIf(Token::l_paren)) { 4757 SmallVector<BlockArgument, 8> bbArgs; 4758 if (parseOptionalBlockArgList(bbArgs, block) || 4759 parseToken(Token::r_paren, "expected ')' to end argument list")) 4760 return failure(); 4761 } 4762 4763 if (parseToken(Token::colon, "expected ':' after block name")) 4764 return failure(); 4765 4766 return parseBlockBody(block); 4767 } 4768 4769 ParseResult OperationParser::parseBlockBody(Block *block) { 4770 // Set the insertion point to the end of the block to parse. 4771 opBuilder.setInsertionPointToEnd(block); 4772 4773 // Parse the list of operations that make up the body of the block. 4774 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 4775 if (parseOperation()) 4776 return failure(); 4777 4778 return success(); 4779 } 4780 4781 /// Get the block with the specified name, creating it if it doesn't already 4782 /// exist. The location specified is the point of use, which allows 4783 /// us to diagnose references to blocks that are not defined precisely. 4784 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 4785 auto &blockAndLoc = getBlockInfoByName(name); 4786 if (!blockAndLoc.first) { 4787 blockAndLoc = {new Block(), loc}; 4788 insertForwardRef(blockAndLoc.first, loc); 4789 } 4790 4791 return blockAndLoc.first; 4792 } 4793 4794 /// Define the block with the specified name. Returns the Block* or nullptr in 4795 /// the case of redefinition. 4796 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 4797 Block *existing) { 4798 auto &blockAndLoc = getBlockInfoByName(name); 4799 if (!blockAndLoc.first) { 4800 // If the caller provided a block, use it. Otherwise create a new one. 4801 if (!existing) 4802 existing = new Block(); 4803 blockAndLoc.first = existing; 4804 blockAndLoc.second = loc; 4805 return blockAndLoc.first; 4806 } 4807 4808 // Forward declarations are removed once defined, so if we are defining a 4809 // existing block and it is not a forward declaration, then it is a 4810 // redeclaration. 4811 if (!eraseForwardRef(blockAndLoc.first)) 4812 return nullptr; 4813 return blockAndLoc.first; 4814 } 4815 4816 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 4817 /// 4818 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 4819 /// 4820 ParseResult OperationParser::parseOptionalBlockArgList( 4821 SmallVectorImpl<BlockArgument> &results, Block *owner) { 4822 if (getToken().is(Token::r_brace)) 4823 return success(); 4824 4825 // If the block already has arguments, then we're handling the entry block. 4826 // Parse and register the names for the arguments, but do not add them. 4827 bool definingExistingArgs = owner->getNumArguments() != 0; 4828 unsigned nextArgument = 0; 4829 4830 return parseCommaSeparatedList([&]() -> ParseResult { 4831 return parseSSADefOrUseAndType( 4832 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 4833 // If this block did not have existing arguments, define a new one. 4834 if (!definingExistingArgs) 4835 return addDefinition(useInfo, owner->addArgument(type)); 4836 4837 // Otherwise, ensure that this argument has already been created. 4838 if (nextArgument >= owner->getNumArguments()) 4839 return emitError("too many arguments specified in argument list"); 4840 4841 // Finally, make sure the existing argument has the correct type. 4842 auto arg = owner->getArgument(nextArgument++); 4843 if (arg.getType() != type) 4844 return emitError("argument and block argument type mismatch"); 4845 return addDefinition(useInfo, arg); 4846 }); 4847 }); 4848 } 4849 4850 //===----------------------------------------------------------------------===// 4851 // Top-level entity parsing. 4852 //===----------------------------------------------------------------------===// 4853 4854 namespace { 4855 /// This parser handles entities that are only valid at the top level of the 4856 /// file. 4857 class ModuleParser : public Parser { 4858 public: 4859 explicit ModuleParser(ParserState &state) : Parser(state) {} 4860 4861 ParseResult parseModule(ModuleOp module); 4862 4863 private: 4864 /// Parse an attribute alias declaration. 4865 ParseResult parseAttributeAliasDef(); 4866 4867 /// Parse an attribute alias declaration. 4868 ParseResult parseTypeAliasDef(); 4869 }; 4870 } // end anonymous namespace 4871 4872 /// Parses an attribute alias declaration. 4873 /// 4874 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 4875 /// 4876 ParseResult ModuleParser::parseAttributeAliasDef() { 4877 assert(getToken().is(Token::hash_identifier)); 4878 StringRef aliasName = getTokenSpelling().drop_front(); 4879 4880 // Check for redefinitions. 4881 if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) 4882 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 4883 4884 // Make sure this isn't invading the dialect attribute namespace. 4885 if (aliasName.contains('.')) 4886 return emitError("attribute names with a '.' are reserved for " 4887 "dialect-defined names"); 4888 4889 consumeToken(Token::hash_identifier); 4890 4891 // Parse the '='. 4892 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 4893 return failure(); 4894 4895 // Parse the attribute value. 4896 Attribute attr = parseAttribute(); 4897 if (!attr) 4898 return failure(); 4899 4900 getState().symbols.attributeAliasDefinitions[aliasName] = attr; 4901 return success(); 4902 } 4903 4904 /// Parse a type alias declaration. 4905 /// 4906 /// type-alias-def ::= '!' alias-name `=` 'type' type 4907 /// 4908 ParseResult ModuleParser::parseTypeAliasDef() { 4909 assert(getToken().is(Token::exclamation_identifier)); 4910 StringRef aliasName = getTokenSpelling().drop_front(); 4911 4912 // Check for redefinitions. 4913 if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) 4914 return emitError("redefinition of type alias id '" + aliasName + "'"); 4915 4916 // Make sure this isn't invading the dialect type namespace. 4917 if (aliasName.contains('.')) 4918 return emitError("type names with a '.' are reserved for " 4919 "dialect-defined names"); 4920 4921 consumeToken(Token::exclamation_identifier); 4922 4923 // Parse the '=' and 'type'. 4924 if (parseToken(Token::equal, "expected '=' in type alias definition") || 4925 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 4926 return failure(); 4927 4928 // Parse the type. 4929 Type aliasedType = parseType(); 4930 if (!aliasedType) 4931 return failure(); 4932 4933 // Register this alias with the parser state. 4934 getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); 4935 return success(); 4936 } 4937 4938 /// This is the top-level module parser. 4939 ParseResult ModuleParser::parseModule(ModuleOp module) { 4940 OperationParser opParser(getState(), module); 4941 4942 // Module itself is a name scope. 4943 opParser.pushSSANameScope(/*isIsolated=*/true); 4944 4945 while (true) { 4946 switch (getToken().getKind()) { 4947 default: 4948 // Parse a top-level operation. 4949 if (opParser.parseOperation()) 4950 return failure(); 4951 break; 4952 4953 // If we got to the end of the file, then we're done. 4954 case Token::eof: { 4955 if (opParser.finalize()) 4956 return failure(); 4957 4958 // Handle the case where the top level module was explicitly defined. 4959 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 4960 auto &operations = bodyBlocks.front().getOperations(); 4961 assert(!operations.empty() && "expected a valid module terminator"); 4962 4963 // Check that the first operation is a module, and it is the only 4964 // non-terminator operation. 4965 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 4966 if (nested && std::next(operations.begin(), 2) == operations.end()) { 4967 // Merge the data of the nested module operation into 'module'. 4968 module.setLoc(nested.getLoc()); 4969 module.setAttrs(nested.getOperation()->getAttrList()); 4970 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 4971 4972 // Erase the original module body. 4973 bodyBlocks.pop_front(); 4974 } 4975 4976 return opParser.popSSANameScope(); 4977 } 4978 4979 // If we got an error token, then the lexer already emitted an error, just 4980 // stop. Someday we could introduce error recovery if there was demand 4981 // for it. 4982 case Token::error: 4983 return failure(); 4984 4985 // Parse an attribute alias. 4986 case Token::hash_identifier: 4987 if (parseAttributeAliasDef()) 4988 return failure(); 4989 break; 4990 4991 // Parse a type alias. 4992 case Token::exclamation_identifier: 4993 if (parseTypeAliasDef()) 4994 return failure(); 4995 break; 4996 } 4997 } 4998 } 4999 5000 //===----------------------------------------------------------------------===// 5001 5002 /// This parses the file specified by the indicated SourceMgr and returns an 5003 /// MLIR module if it was valid. If not, it emits diagnostics and returns 5004 /// null. 5005 OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 5006 MLIRContext *context) { 5007 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 5008 5009 // This is the result module we are parsing into. 5010 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 5011 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 5012 5013 SymbolState aliasState; 5014 ParserState state(sourceMgr, context, aliasState); 5015 if (ModuleParser(state).parseModule(*module)) 5016 return nullptr; 5017 5018 // Make sure the parse module has no other structural problems detected by 5019 // the verifier. 5020 if (failed(verify(*module))) 5021 return nullptr; 5022 5023 return module; 5024 } 5025 5026 /// This parses the file specified by the indicated filename and returns an 5027 /// MLIR module if it was valid. If not, the error message is emitted through 5028 /// the error handler registered in the context, and a null pointer is returned. 5029 OwningModuleRef mlir::parseSourceFile(StringRef filename, 5030 MLIRContext *context) { 5031 llvm::SourceMgr sourceMgr; 5032 return parseSourceFile(filename, sourceMgr, context); 5033 } 5034 5035 /// This parses the file specified by the indicated filename using the provided 5036 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 5037 /// message is emitted through the error handler registered in the context, and 5038 /// a null pointer is returned. 5039 OwningModuleRef mlir::parseSourceFile(StringRef filename, 5040 llvm::SourceMgr &sourceMgr, 5041 MLIRContext *context) { 5042 if (sourceMgr.getNumBuffers() != 0) { 5043 // TODO(b/136086478): Extend to support multiple buffers. 5044 emitError(mlir::UnknownLoc::get(context), 5045 "only main buffer parsed at the moment"); 5046 return nullptr; 5047 } 5048 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 5049 if (std::error_code error = file_or_err.getError()) { 5050 emitError(mlir::UnknownLoc::get(context), 5051 "could not open input file " + filename); 5052 return nullptr; 5053 } 5054 5055 // Load the MLIR module. 5056 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 5057 return parseSourceFile(sourceMgr, context); 5058 } 5059 5060 /// This parses the program string to a MLIR module if it was valid. If not, 5061 /// it emits diagnostics and returns null. 5062 OwningModuleRef mlir::parseSourceString(StringRef moduleStr, 5063 MLIRContext *context) { 5064 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 5065 if (!memBuffer) 5066 return nullptr; 5067 5068 SourceMgr sourceMgr; 5069 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 5070 return parseSourceFile(sourceMgr, context); 5071 } 5072 5073 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 5074 /// parsing failed, nullptr is returned. The number of bytes read from the input 5075 /// string is returned in 'numRead'. 5076 template <typename T, typename ParserFn> 5077 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, 5078 ParserFn &&parserFn) { 5079 SymbolState aliasState; 5080 return parseSymbol<T>( 5081 inputStr, context, aliasState, 5082 [&](Parser &parser) { 5083 SourceMgrDiagnosticHandler handler( 5084 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), 5085 parser.getContext()); 5086 return parserFn(parser); 5087 }, 5088 &numRead); 5089 } 5090 5091 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { 5092 size_t numRead = 0; 5093 return parseAttribute(attrStr, context, numRead); 5094 } 5095 Attribute mlir::parseAttribute(StringRef attrStr, Type type) { 5096 size_t numRead = 0; 5097 return parseAttribute(attrStr, type, numRead); 5098 } 5099 5100 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, 5101 size_t &numRead) { 5102 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { 5103 return parser.parseAttribute(); 5104 }); 5105 } 5106 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { 5107 return parseSymbol<Attribute>( 5108 attrStr, type.getContext(), numRead, 5109 [type](Parser &parser) { return parser.parseAttribute(type); }); 5110 } 5111 5112 Type mlir::parseType(StringRef typeStr, MLIRContext *context) { 5113 size_t numRead = 0; 5114 return parseType(typeStr, context, numRead); 5115 } 5116 5117 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { 5118 return parseSymbol<Type>(typeStr, context, numRead, 5119 [](Parser &parser) { return parser.parseType(); }); 5120 } 5121