1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR textual form. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Parser.h" 14 #include "Lexer.h" 15 #include "mlir/Analysis/Verifier.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/IR/Location.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/Module.h" 26 #include "mlir/IR/OpImplementation.h" 27 #include "mlir/IR/StandardTypes.h" 28 #include "mlir/Support/STLExtras.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/ADT/bit.h" 33 #include "llvm/Support/PrettyStackTrace.h" 34 #include "llvm/Support/SMLoc.h" 35 #include "llvm/Support/SourceMgr.h" 36 #include <algorithm> 37 using namespace mlir; 38 using llvm::MemoryBuffer; 39 using llvm::SMLoc; 40 using llvm::SourceMgr; 41 42 namespace { 43 class Parser; 44 45 //===----------------------------------------------------------------------===// 46 // SymbolState 47 //===----------------------------------------------------------------------===// 48 49 /// This class contains record of any parsed top-level symbols. 50 struct SymbolState { 51 // A map from attribute alias identifier to Attribute. 52 llvm::StringMap<Attribute> attributeAliasDefinitions; 53 54 // A map from type alias identifier to Type. 55 llvm::StringMap<Type> typeAliasDefinitions; 56 57 /// A set of locations into the main parser memory buffer for each of the 58 /// active nested parsers. Given that some nested parsers, i.e. custom dialect 59 /// parsers, operate on a temporary memory buffer, this provides an anchor 60 /// point for emitting diagnostics. 61 SmallVector<llvm::SMLoc, 1> nestedParserLocs; 62 63 /// The top-level lexer that contains the original memory buffer provided by 64 /// the user. This is used by nested parsers to get a properly encoded source 65 /// location. 66 Lexer *topLevelLexer = nullptr; 67 }; 68 69 //===----------------------------------------------------------------------===// 70 // ParserState 71 //===----------------------------------------------------------------------===// 72 73 /// This class refers to all of the state maintained globally by the parser, 74 /// such as the current lexer position etc. 75 struct ParserState { 76 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, 77 SymbolState &symbols) 78 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), 79 symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { 80 // Set the top level lexer for the symbol state if one doesn't exist. 81 if (!symbols.topLevelLexer) 82 symbols.topLevelLexer = &lex; 83 } 84 ~ParserState() { 85 // Reset the top level lexer if it refers the lexer in our state. 86 if (symbols.topLevelLexer == &lex) 87 symbols.topLevelLexer = nullptr; 88 } 89 ParserState(const ParserState &) = delete; 90 void operator=(const ParserState &) = delete; 91 92 /// The context we're parsing into. 93 MLIRContext *const context; 94 95 /// The lexer for the source file we're parsing. 96 Lexer lex; 97 98 /// This is the next token that hasn't been consumed yet. 99 Token curToken; 100 101 /// The current state for symbol parsing. 102 SymbolState &symbols; 103 104 /// The depth of this parser in the nested parsing stack. 105 size_t parserDepth; 106 }; 107 108 //===----------------------------------------------------------------------===// 109 // Parser 110 //===----------------------------------------------------------------------===// 111 112 /// This class implement support for parsing global entities like types and 113 /// shared entities like SSA names. It is intended to be subclassed by 114 /// specialized subparsers that include state, e.g. when a local symbol table. 115 class Parser { 116 public: 117 Builder builder; 118 119 Parser(ParserState &state) : builder(state.context), state(state) {} 120 121 // Helper methods to get stuff from the parser-global state. 122 ParserState &getState() const { return state; } 123 MLIRContext *getContext() const { return state.context; } 124 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 125 126 /// Parse a comma-separated list of elements up until the specified end token. 127 ParseResult 128 parseCommaSeparatedListUntil(Token::Kind rightToken, 129 const std::function<ParseResult()> &parseElement, 130 bool allowEmptyList = true); 131 132 /// Parse a comma separated list of elements that must have at least one entry 133 /// in it. 134 ParseResult 135 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 136 137 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 138 139 // We have two forms of parsing methods - those that return a non-null 140 // pointer on success, and those that return a ParseResult to indicate whether 141 // they returned a failure. The second class fills in by-reference arguments 142 // as the results of their action. 143 144 //===--------------------------------------------------------------------===// 145 // Error Handling 146 //===--------------------------------------------------------------------===// 147 148 /// Emit an error and return failure. 149 InFlightDiagnostic emitError(const Twine &message = {}) { 150 return emitError(state.curToken.getLoc(), message); 151 } 152 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 153 154 /// Encode the specified source location information into an attribute for 155 /// attachment to the IR. 156 Location getEncodedSourceLocation(llvm::SMLoc loc) { 157 // If there are no active nested parsers, we can get the encoded source 158 // location directly. 159 if (state.parserDepth == 0) 160 return state.lex.getEncodedSourceLocation(loc); 161 // Otherwise, we need to re-encode it to point to the top level buffer. 162 return state.symbols.topLevelLexer->getEncodedSourceLocation( 163 remapLocationToTopLevelBuffer(loc)); 164 } 165 166 /// Remaps the given SMLoc to the top level lexer of the parser. This is used 167 /// to adjust locations of potentially nested parsers to ensure that they can 168 /// be emitted properly as diagnostics. 169 llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { 170 // If there are no active nested parsers, we can return location directly. 171 SymbolState &symbols = state.symbols; 172 if (state.parserDepth == 0) 173 return loc; 174 assert(symbols.topLevelLexer && "expected valid top-level lexer"); 175 176 // Otherwise, we need to remap the location to the main parser. This is 177 // simply offseting the location onto the location of the last nested 178 // parser. 179 size_t offset = loc.getPointer() - state.lex.getBufferBegin(); 180 auto *rawLoc = 181 symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; 182 return llvm::SMLoc::getFromPointer(rawLoc); 183 } 184 185 //===--------------------------------------------------------------------===// 186 // Token Parsing 187 //===--------------------------------------------------------------------===// 188 189 /// Return the current token the parser is inspecting. 190 const Token &getToken() const { return state.curToken; } 191 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 192 193 /// If the current token has the specified kind, consume it and return true. 194 /// If not, return false. 195 bool consumeIf(Token::Kind kind) { 196 if (state.curToken.isNot(kind)) 197 return false; 198 consumeToken(kind); 199 return true; 200 } 201 202 /// Advance the current lexer onto the next token. 203 void consumeToken() { 204 assert(state.curToken.isNot(Token::eof, Token::error) && 205 "shouldn't advance past EOF or errors"); 206 state.curToken = state.lex.lexToken(); 207 } 208 209 /// Advance the current lexer onto the next token, asserting what the expected 210 /// current token is. This is preferred to the above method because it leads 211 /// to more self-documenting code with better checking. 212 void consumeToken(Token::Kind kind) { 213 assert(state.curToken.is(kind) && "consumed an unexpected token"); 214 consumeToken(); 215 } 216 217 /// Consume the specified token if present and return success. On failure, 218 /// output a diagnostic and return failure. 219 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 220 221 //===--------------------------------------------------------------------===// 222 // Type Parsing 223 //===--------------------------------------------------------------------===// 224 225 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 226 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 227 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 228 229 /// Parse an arbitrary type. 230 Type parseType(); 231 232 /// Parse a complex type. 233 Type parseComplexType(); 234 235 /// Parse an extended type. 236 Type parseExtendedType(); 237 238 /// Parse a function type. 239 Type parseFunctionType(); 240 241 /// Parse a memref type. 242 Type parseMemRefType(); 243 244 /// Parse a non function type. 245 Type parseNonFunctionType(); 246 247 /// Parse a tensor type. 248 Type parseTensorType(); 249 250 /// Parse a tuple type. 251 Type parseTupleType(); 252 253 /// Parse a vector type. 254 VectorType parseVectorType(); 255 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 256 bool allowDynamic = true); 257 ParseResult parseXInDimensionList(); 258 259 /// Parse strided layout specification. 260 ParseResult parseStridedLayout(int64_t &offset, 261 SmallVectorImpl<int64_t> &strides); 262 263 // Parse a brace-delimiter list of comma-separated integers with `?` as an 264 // unknown marker. 265 ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions); 266 267 //===--------------------------------------------------------------------===// 268 // Attribute Parsing 269 //===--------------------------------------------------------------------===// 270 271 /// Parse an arbitrary attribute with an optional type. 272 Attribute parseAttribute(Type type = {}); 273 274 /// Parse an attribute dictionary. 275 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 276 277 /// Parse an extended attribute. 278 Attribute parseExtendedAttr(Type type); 279 280 /// Parse a float attribute. 281 Attribute parseFloatAttr(Type type, bool isNegative); 282 283 /// Parse a decimal or a hexadecimal literal, which can be either an integer 284 /// or a float attribute. 285 Attribute parseDecOrHexAttr(Type type, bool isNegative); 286 287 /// Parse an opaque elements attribute. 288 Attribute parseOpaqueElementsAttr(Type attrType); 289 290 /// Parse a dense elements attribute. 291 Attribute parseDenseElementsAttr(Type attrType); 292 ShapedType parseElementsLiteralType(Type type); 293 294 /// Parse a sparse elements attribute. 295 Attribute parseSparseElementsAttr(Type attrType); 296 297 //===--------------------------------------------------------------------===// 298 // Location Parsing 299 //===--------------------------------------------------------------------===// 300 301 /// Parse an inline location. 302 ParseResult parseLocation(LocationAttr &loc); 303 304 /// Parse a raw location instance. 305 ParseResult parseLocationInstance(LocationAttr &loc); 306 307 /// Parse a callsite location instance. 308 ParseResult parseCallSiteLocation(LocationAttr &loc); 309 310 /// Parse a fused location instance. 311 ParseResult parseFusedLocation(LocationAttr &loc); 312 313 /// Parse a name or FileLineCol location instance. 314 ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); 315 316 /// Parse an optional trailing location. 317 /// 318 /// trailing-location ::= (`loc` `(` location `)`)? 319 /// 320 ParseResult parseOptionalTrailingLocation(Location &loc) { 321 // If there is a 'loc' we parse a trailing location. 322 if (!getToken().is(Token::kw_loc)) 323 return success(); 324 325 // Parse the location. 326 LocationAttr directLoc; 327 if (parseLocation(directLoc)) 328 return failure(); 329 loc = directLoc; 330 return success(); 331 } 332 333 //===--------------------------------------------------------------------===// 334 // Affine Parsing 335 //===--------------------------------------------------------------------===// 336 337 /// Parse a reference to either an affine map, or an integer set. 338 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 339 IntegerSet &set); 340 ParseResult parseAffineMapReference(AffineMap &map); 341 ParseResult parseIntegerSetReference(IntegerSet &set); 342 343 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 344 ParseResult 345 parseAffineMapOfSSAIds(AffineMap &map, 346 function_ref<ParseResult(bool)> parseElement, 347 OpAsmParser::Delimiter delimiter); 348 349 private: 350 /// The Parser is subclassed and reinstantiated. Do not add additional 351 /// non-trivial state here, add it to the ParserState class. 352 ParserState &state; 353 }; 354 } // end anonymous namespace 355 356 //===----------------------------------------------------------------------===// 357 // Helper methods. 358 //===----------------------------------------------------------------------===// 359 360 /// Parse a comma separated list of elements that must have at least one entry 361 /// in it. 362 ParseResult Parser::parseCommaSeparatedList( 363 const std::function<ParseResult()> &parseElement) { 364 // Non-empty case starts with an element. 365 if (parseElement()) 366 return failure(); 367 368 // Otherwise we have a list of comma separated elements. 369 while (consumeIf(Token::comma)) { 370 if (parseElement()) 371 return failure(); 372 } 373 return success(); 374 } 375 376 /// Parse a comma-separated list of elements, terminated with an arbitrary 377 /// token. This allows empty lists if allowEmptyList is true. 378 /// 379 /// abstract-list ::= rightToken // if allowEmptyList == true 380 /// abstract-list ::= element (',' element)* rightToken 381 /// 382 ParseResult Parser::parseCommaSeparatedListUntil( 383 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 384 bool allowEmptyList) { 385 // Handle the empty case. 386 if (getToken().is(rightToken)) { 387 if (!allowEmptyList) 388 return emitError("expected list element"); 389 consumeToken(rightToken); 390 return success(); 391 } 392 393 if (parseCommaSeparatedList(parseElement) || 394 parseToken(rightToken, "expected ',' or '" + 395 Token::getTokenSpelling(rightToken) + "'")) 396 return failure(); 397 398 return success(); 399 } 400 401 //===----------------------------------------------------------------------===// 402 // DialectAsmParser 403 //===----------------------------------------------------------------------===// 404 405 namespace { 406 /// This class provides the main implementation of the DialectAsmParser that 407 /// allows for dialects to parse attributes and types. This allows for dialect 408 /// hooking into the main MLIR parsing logic. 409 class CustomDialectAsmParser : public DialectAsmParser { 410 public: 411 CustomDialectAsmParser(StringRef fullSpec, Parser &parser) 412 : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), 413 parser(parser) {} 414 ~CustomDialectAsmParser() override {} 415 416 /// Emit a diagnostic at the specified location and return failure. 417 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 418 return parser.emitError(loc, message); 419 } 420 421 /// Return a builder which provides useful access to MLIRContext, global 422 /// objects like types and attributes. 423 Builder &getBuilder() const override { return parser.builder; } 424 425 /// Get the location of the next token and store it into the argument. This 426 /// always succeeds. 427 llvm::SMLoc getCurrentLocation() override { 428 return parser.getToken().getLoc(); 429 } 430 431 /// Return the location of the original name token. 432 llvm::SMLoc getNameLoc() const override { return nameLoc; } 433 434 /// Re-encode the given source location as an MLIR location and return it. 435 Location getEncodedSourceLoc(llvm::SMLoc loc) override { 436 return parser.getEncodedSourceLocation(loc); 437 } 438 439 /// Returns the full specification of the symbol being parsed. This allows 440 /// for using a separate parser if necessary. 441 StringRef getFullSymbolSpec() const override { return fullSpec; } 442 443 /// Parse a floating point value from the stream. 444 ParseResult parseFloat(double &result) override { 445 bool negative = parser.consumeIf(Token::minus); 446 Token curTok = parser.getToken(); 447 448 // Check for a floating point value. 449 if (curTok.is(Token::floatliteral)) { 450 auto val = curTok.getFloatingPointValue(); 451 if (!val.hasValue()) 452 return emitError(curTok.getLoc(), "floating point value too large"); 453 parser.consumeToken(Token::floatliteral); 454 result = negative ? -*val : *val; 455 return success(); 456 } 457 458 // TODO(riverriddle) support hex floating point values. 459 return emitError(getCurrentLocation(), "expected floating point literal"); 460 } 461 462 /// Parse an optional integer value from the stream. 463 OptionalParseResult parseOptionalInteger(uint64_t &result) override { 464 Token curToken = parser.getToken(); 465 if (curToken.isNot(Token::integer, Token::minus)) 466 return llvm::None; 467 468 bool negative = parser.consumeIf(Token::minus); 469 Token curTok = parser.getToken(); 470 if (parser.parseToken(Token::integer, "expected integer value")) 471 return failure(); 472 473 auto val = curTok.getUInt64IntegerValue(); 474 if (!val) 475 return emitError(curTok.getLoc(), "integer value too large"); 476 result = negative ? -*val : *val; 477 return success(); 478 } 479 480 //===--------------------------------------------------------------------===// 481 // Token Parsing 482 //===--------------------------------------------------------------------===// 483 484 /// Parse a `->` token. 485 ParseResult parseArrow() override { 486 return parser.parseToken(Token::arrow, "expected '->'"); 487 } 488 489 /// Parses a `->` if present. 490 ParseResult parseOptionalArrow() override { 491 return success(parser.consumeIf(Token::arrow)); 492 } 493 494 /// Parse a '{' token. 495 ParseResult parseLBrace() override { 496 return parser.parseToken(Token::l_brace, "expected '{'"); 497 } 498 499 /// Parse a '{' token if present 500 ParseResult parseOptionalLBrace() override { 501 return success(parser.consumeIf(Token::l_brace)); 502 } 503 504 /// Parse a `}` token. 505 ParseResult parseRBrace() override { 506 return parser.parseToken(Token::r_brace, "expected '}'"); 507 } 508 509 /// Parse a `}` token if present 510 ParseResult parseOptionalRBrace() override { 511 return success(parser.consumeIf(Token::r_brace)); 512 } 513 514 /// Parse a `:` token. 515 ParseResult parseColon() override { 516 return parser.parseToken(Token::colon, "expected ':'"); 517 } 518 519 /// Parse a `:` token if present. 520 ParseResult parseOptionalColon() override { 521 return success(parser.consumeIf(Token::colon)); 522 } 523 524 /// Parse a `,` token. 525 ParseResult parseComma() override { 526 return parser.parseToken(Token::comma, "expected ','"); 527 } 528 529 /// Parse a `,` token if present. 530 ParseResult parseOptionalComma() override { 531 return success(parser.consumeIf(Token::comma)); 532 } 533 534 /// Parses a `...` if present. 535 ParseResult parseOptionalEllipsis() override { 536 return success(parser.consumeIf(Token::ellipsis)); 537 } 538 539 /// Parse a `=` token. 540 ParseResult parseEqual() override { 541 return parser.parseToken(Token::equal, "expected '='"); 542 } 543 544 /// Parse a '<' token. 545 ParseResult parseLess() override { 546 return parser.parseToken(Token::less, "expected '<'"); 547 } 548 549 /// Parse a `<` token if present. 550 ParseResult parseOptionalLess() override { 551 return success(parser.consumeIf(Token::less)); 552 } 553 554 /// Parse a '>' token. 555 ParseResult parseGreater() override { 556 return parser.parseToken(Token::greater, "expected '>'"); 557 } 558 559 /// Parse a `>` token if present. 560 ParseResult parseOptionalGreater() override { 561 return success(parser.consumeIf(Token::greater)); 562 } 563 564 /// Parse a `(` token. 565 ParseResult parseLParen() override { 566 return parser.parseToken(Token::l_paren, "expected '('"); 567 } 568 569 /// Parses a '(' if present. 570 ParseResult parseOptionalLParen() override { 571 return success(parser.consumeIf(Token::l_paren)); 572 } 573 574 /// Parse a `)` token. 575 ParseResult parseRParen() override { 576 return parser.parseToken(Token::r_paren, "expected ')'"); 577 } 578 579 /// Parses a ')' if present. 580 ParseResult parseOptionalRParen() override { 581 return success(parser.consumeIf(Token::r_paren)); 582 } 583 584 /// Parse a `[` token. 585 ParseResult parseLSquare() override { 586 return parser.parseToken(Token::l_square, "expected '['"); 587 } 588 589 /// Parses a '[' if present. 590 ParseResult parseOptionalLSquare() override { 591 return success(parser.consumeIf(Token::l_square)); 592 } 593 594 /// Parse a `]` token. 595 ParseResult parseRSquare() override { 596 return parser.parseToken(Token::r_square, "expected ']'"); 597 } 598 599 /// Parses a ']' if present. 600 ParseResult parseOptionalRSquare() override { 601 return success(parser.consumeIf(Token::r_square)); 602 } 603 604 /// Parses a '?' if present. 605 ParseResult parseOptionalQuestion() override { 606 return success(parser.consumeIf(Token::question)); 607 } 608 609 /// Parses a '*' if present. 610 ParseResult parseOptionalStar() override { 611 return success(parser.consumeIf(Token::star)); 612 } 613 614 /// Returns if the current token corresponds to a keyword. 615 bool isCurrentTokenAKeyword() const { 616 return parser.getToken().is(Token::bare_identifier) || 617 parser.getToken().isKeyword(); 618 } 619 620 /// Parse the given keyword if present. 621 ParseResult parseOptionalKeyword(StringRef keyword) override { 622 // Check that the current token has the same spelling. 623 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 624 return failure(); 625 parser.consumeToken(); 626 return success(); 627 } 628 629 /// Parse a keyword, if present, into 'keyword'. 630 ParseResult parseOptionalKeyword(StringRef *keyword) override { 631 // Check that the current token is a keyword. 632 if (!isCurrentTokenAKeyword()) 633 return failure(); 634 635 *keyword = parser.getTokenSpelling(); 636 parser.consumeToken(); 637 return success(); 638 } 639 640 //===--------------------------------------------------------------------===// 641 // Attribute Parsing 642 //===--------------------------------------------------------------------===// 643 644 /// Parse an arbitrary attribute and return it in result. 645 ParseResult parseAttribute(Attribute &result, Type type) override { 646 result = parser.parseAttribute(type); 647 return success(static_cast<bool>(result)); 648 } 649 650 /// Parse an affine map instance into 'map'. 651 ParseResult parseAffineMap(AffineMap &map) override { 652 return parser.parseAffineMapReference(map); 653 } 654 655 /// Parse an integer set instance into 'set'. 656 ParseResult printIntegerSet(IntegerSet &set) override { 657 return parser.parseIntegerSetReference(set); 658 } 659 660 //===--------------------------------------------------------------------===// 661 // Type Parsing 662 //===--------------------------------------------------------------------===// 663 664 ParseResult parseType(Type &result) override { 665 result = parser.parseType(); 666 return success(static_cast<bool>(result)); 667 } 668 669 ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, 670 bool allowDynamic) override { 671 return parser.parseDimensionListRanked(dimensions, allowDynamic); 672 } 673 674 private: 675 /// The full symbol specification. 676 StringRef fullSpec; 677 678 /// The source location of the dialect symbol. 679 SMLoc nameLoc; 680 681 /// The main parser. 682 Parser &parser; 683 }; 684 } // namespace 685 686 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 687 /// and may be recursive. Return with the 'prettyName' StringRef encompassing 688 /// the entire pretty name. 689 /// 690 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 691 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 692 /// | '(' pretty-dialect-sym-contents+ ')' 693 /// | '[' pretty-dialect-sym-contents+ ']' 694 /// | '{' pretty-dialect-sym-contents+ '}' 695 /// | '[^[<({>\])}\0]+' 696 /// 697 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 698 // Pretty symbol names are a relatively unstructured format that contains a 699 // series of properly nested punctuation, with anything else in the middle. 700 // Scan ahead to find it and consume it if successful, otherwise emit an 701 // error. 702 auto *curPtr = getTokenSpelling().data(); 703 704 SmallVector<char, 8> nestedPunctuation; 705 706 // Scan over the nested punctuation, bailing out on error and consuming until 707 // we find the end. We know that we're currently looking at the '<', so we 708 // can go until we find the matching '>' character. 709 assert(*curPtr == '<'); 710 do { 711 char c = *curPtr++; 712 switch (c) { 713 case '\0': 714 // This also handles the EOF case. 715 return emitError("unexpected nul or EOF in pretty dialect name"); 716 case '<': 717 case '[': 718 case '(': 719 case '{': 720 nestedPunctuation.push_back(c); 721 continue; 722 723 case '-': 724 // The sequence `->` is treated as special token. 725 if (*curPtr == '>') 726 ++curPtr; 727 continue; 728 729 case '>': 730 if (nestedPunctuation.pop_back_val() != '<') 731 return emitError("unbalanced '>' character in pretty dialect name"); 732 break; 733 case ']': 734 if (nestedPunctuation.pop_back_val() != '[') 735 return emitError("unbalanced ']' character in pretty dialect name"); 736 break; 737 case ')': 738 if (nestedPunctuation.pop_back_val() != '(') 739 return emitError("unbalanced ')' character in pretty dialect name"); 740 break; 741 case '}': 742 if (nestedPunctuation.pop_back_val() != '{') 743 return emitError("unbalanced '}' character in pretty dialect name"); 744 break; 745 746 default: 747 continue; 748 } 749 } while (!nestedPunctuation.empty()); 750 751 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 752 // consuming all this stuff, and return. 753 state.lex.resetPointer(curPtr); 754 755 unsigned length = curPtr - prettyName.begin(); 756 prettyName = StringRef(prettyName.begin(), length); 757 consumeToken(); 758 return success(); 759 } 760 761 /// Parse an extended dialect symbol. 762 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 763 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 764 SymbolAliasMap &aliases, 765 CreateFn &&createSymbol) { 766 // Parse the dialect namespace. 767 StringRef identifier = p.getTokenSpelling().drop_front(); 768 auto loc = p.getToken().getLoc(); 769 p.consumeToken(identifierTok); 770 771 // If there is no '<' token following this, and if the typename contains no 772 // dot, then we are parsing a symbol alias. 773 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 774 // Check for an alias for this type. 775 auto aliasIt = aliases.find(identifier); 776 if (aliasIt == aliases.end()) 777 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 778 nullptr); 779 return aliasIt->second; 780 } 781 782 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 783 // a dot, then this is the "pretty" form. If not, it is the verbose form that 784 // looks like <"...">. 785 std::string symbolData; 786 auto dialectName = identifier; 787 788 // Handle the verbose form, where "identifier" is a simple dialect name. 789 if (!identifier.contains('.')) { 790 // Consume the '<'. 791 if (p.parseToken(Token::less, "expected '<' in dialect type")) 792 return nullptr; 793 794 // Parse the symbol specific data. 795 if (p.getToken().isNot(Token::string)) 796 return (p.emitError("expected string literal data in dialect symbol"), 797 nullptr); 798 symbolData = p.getToken().getStringValue(); 799 loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); 800 p.consumeToken(Token::string); 801 802 // Consume the '>'. 803 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 804 return nullptr; 805 } else { 806 // Ok, the dialect name is the part of the identifier before the dot, the 807 // part after the dot is the dialect's symbol, or the start thereof. 808 auto dotHalves = identifier.split('.'); 809 dialectName = dotHalves.first; 810 auto prettyName = dotHalves.second; 811 loc = llvm::SMLoc::getFromPointer(prettyName.data()); 812 813 // If the dialect's symbol is followed immediately by a <, then lex the body 814 // of it into prettyName. 815 if (p.getToken().is(Token::less) && 816 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 817 if (p.parsePrettyDialectSymbolName(prettyName)) 818 return nullptr; 819 } 820 821 symbolData = prettyName.str(); 822 } 823 824 // Record the name location of the type remapped to the top level buffer. 825 llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); 826 p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); 827 828 // Call into the provided symbol construction function. 829 Symbol sym = createSymbol(dialectName, symbolData, loc); 830 831 // Pop the last parser location. 832 p.getState().symbols.nestedParserLocs.pop_back(); 833 return sym; 834 } 835 836 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 837 /// parsing failed, nullptr is returned. The number of bytes read from the input 838 /// string is returned in 'numRead'. 839 template <typename T, typename ParserFn> 840 static T parseSymbol(StringRef inputStr, MLIRContext *context, 841 SymbolState &symbolState, ParserFn &&parserFn, 842 size_t *numRead = nullptr) { 843 SourceMgr sourceMgr; 844 auto memBuffer = MemoryBuffer::getMemBuffer( 845 inputStr, /*BufferName=*/"<mlir_parser_buffer>", 846 /*RequiresNullTerminator=*/false); 847 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 848 ParserState state(sourceMgr, context, symbolState); 849 Parser parser(state); 850 851 Token startTok = parser.getToken(); 852 T symbol = parserFn(parser); 853 if (!symbol) 854 return T(); 855 856 // If 'numRead' is valid, then provide the number of bytes that were read. 857 Token endTok = parser.getToken(); 858 if (numRead) { 859 *numRead = static_cast<size_t>(endTok.getLoc().getPointer() - 860 startTok.getLoc().getPointer()); 861 862 // Otherwise, ensure that all of the tokens were parsed. 863 } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { 864 parser.emitError(endTok.getLoc(), "encountered unexpected token"); 865 return T(); 866 } 867 return symbol; 868 } 869 870 //===----------------------------------------------------------------------===// 871 // Error Handling 872 //===----------------------------------------------------------------------===// 873 874 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 875 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 876 877 // If we hit a parse error in response to a lexer error, then the lexer 878 // already reported the error. 879 if (getToken().is(Token::error)) 880 diag.abandon(); 881 return diag; 882 } 883 884 //===----------------------------------------------------------------------===// 885 // Token Parsing 886 //===----------------------------------------------------------------------===// 887 888 /// Consume the specified token if present and return success. On failure, 889 /// output a diagnostic and return failure. 890 ParseResult Parser::parseToken(Token::Kind expectedToken, 891 const Twine &message) { 892 if (consumeIf(expectedToken)) 893 return success(); 894 return emitError(message); 895 } 896 897 //===----------------------------------------------------------------------===// 898 // Type Parsing 899 //===----------------------------------------------------------------------===// 900 901 /// Parse an arbitrary type. 902 /// 903 /// type ::= function-type 904 /// | non-function-type 905 /// 906 Type Parser::parseType() { 907 if (getToken().is(Token::l_paren)) 908 return parseFunctionType(); 909 return parseNonFunctionType(); 910 } 911 912 /// Parse a function result type. 913 /// 914 /// function-result-type ::= type-list-parens 915 /// | non-function-type 916 /// 917 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 918 if (getToken().is(Token::l_paren)) 919 return parseTypeListParens(elements); 920 921 Type t = parseNonFunctionType(); 922 if (!t) 923 return failure(); 924 elements.push_back(t); 925 return success(); 926 } 927 928 /// Parse a list of types without an enclosing parenthesis. The list must have 929 /// at least one member. 930 /// 931 /// type-list-no-parens ::= type (`,` type)* 932 /// 933 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 934 auto parseElt = [&]() -> ParseResult { 935 auto elt = parseType(); 936 elements.push_back(elt); 937 return elt ? success() : failure(); 938 }; 939 940 return parseCommaSeparatedList(parseElt); 941 } 942 943 /// Parse a parenthesized list of types. 944 /// 945 /// type-list-parens ::= `(` `)` 946 /// | `(` type-list-no-parens `)` 947 /// 948 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 949 if (parseToken(Token::l_paren, "expected '('")) 950 return failure(); 951 952 // Handle empty lists. 953 if (getToken().is(Token::r_paren)) 954 return consumeToken(), success(); 955 956 if (parseTypeListNoParens(elements) || 957 parseToken(Token::r_paren, "expected ')'")) 958 return failure(); 959 return success(); 960 } 961 962 /// Parse a complex type. 963 /// 964 /// complex-type ::= `complex` `<` type `>` 965 /// 966 Type Parser::parseComplexType() { 967 consumeToken(Token::kw_complex); 968 969 // Parse the '<'. 970 if (parseToken(Token::less, "expected '<' in complex type")) 971 return nullptr; 972 973 llvm::SMLoc elementTypeLoc = getToken().getLoc(); 974 auto elementType = parseType(); 975 if (!elementType || 976 parseToken(Token::greater, "expected '>' in complex type")) 977 return nullptr; 978 if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) 979 return emitError(elementTypeLoc, "invalid element type for complex"), 980 nullptr; 981 982 return ComplexType::get(elementType); 983 } 984 985 /// Parse an extended type. 986 /// 987 /// extended-type ::= (dialect-type | type-alias) 988 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 989 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 990 /// type-alias ::= `!` alias-name 991 /// 992 Type Parser::parseExtendedType() { 993 return parseExtendedSymbol<Type>( 994 *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, 995 [&](StringRef dialectName, StringRef symbolData, 996 llvm::SMLoc loc) -> Type { 997 // If we found a registered dialect, then ask it to parse the type. 998 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 999 return parseSymbol<Type>( 1000 symbolData, state.context, state.symbols, [&](Parser &parser) { 1001 CustomDialectAsmParser customParser(symbolData, parser); 1002 return dialect->parseType(customParser); 1003 }); 1004 } 1005 1006 // Otherwise, form a new opaque type. 1007 return OpaqueType::getChecked( 1008 Identifier::get(dialectName, state.context), symbolData, 1009 state.context, getEncodedSourceLocation(loc)); 1010 }); 1011 } 1012 1013 /// Parse a function type. 1014 /// 1015 /// function-type ::= type-list-parens `->` function-result-type 1016 /// 1017 Type Parser::parseFunctionType() { 1018 assert(getToken().is(Token::l_paren)); 1019 1020 SmallVector<Type, 4> arguments, results; 1021 if (parseTypeListParens(arguments) || 1022 parseToken(Token::arrow, "expected '->' in function type") || 1023 parseFunctionResultTypes(results)) 1024 return nullptr; 1025 1026 return builder.getFunctionType(arguments, results); 1027 } 1028 1029 /// Parse the offset and strides from a strided layout specification. 1030 /// 1031 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 1032 /// 1033 ParseResult Parser::parseStridedLayout(int64_t &offset, 1034 SmallVectorImpl<int64_t> &strides) { 1035 // Parse offset. 1036 consumeToken(Token::kw_offset); 1037 if (!consumeIf(Token::colon)) 1038 return emitError("expected colon after `offset` keyword"); 1039 auto maybeOffset = getToken().getUnsignedIntegerValue(); 1040 bool question = getToken().is(Token::question); 1041 if (!maybeOffset && !question) 1042 return emitError("invalid offset"); 1043 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue()) 1044 : MemRefType::getDynamicStrideOrOffset(); 1045 consumeToken(); 1046 1047 if (!consumeIf(Token::comma)) 1048 return emitError("expected comma after offset value"); 1049 1050 // Parse stride list. 1051 if (!consumeIf(Token::kw_strides)) 1052 return emitError("expected `strides` keyword after offset specification"); 1053 if (!consumeIf(Token::colon)) 1054 return emitError("expected colon after `strides` keyword"); 1055 if (failed(parseStrideList(strides))) 1056 return emitError("invalid braces-enclosed stride list"); 1057 if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) 1058 return emitError("invalid memref stride"); 1059 1060 return success(); 1061 } 1062 1063 /// Parse a memref type. 1064 /// 1065 /// memref-type ::= ranked-memref-type | unranked-memref-type 1066 /// 1067 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type 1068 /// (`,` semi-affine-map-composition)? (`,` 1069 /// memory-space)? `>` 1070 /// 1071 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` 1072 /// 1073 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 1074 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 1075 /// 1076 Type Parser::parseMemRefType() { 1077 consumeToken(Token::kw_memref); 1078 1079 if (parseToken(Token::less, "expected '<' in memref type")) 1080 return nullptr; 1081 1082 bool isUnranked; 1083 SmallVector<int64_t, 4> dimensions; 1084 1085 if (consumeIf(Token::star)) { 1086 // This is an unranked memref type. 1087 isUnranked = true; 1088 if (parseXInDimensionList()) 1089 return nullptr; 1090 1091 } else { 1092 isUnranked = false; 1093 if (parseDimensionListRanked(dimensions)) 1094 return nullptr; 1095 } 1096 1097 // Parse the element type. 1098 auto typeLoc = getToken().getLoc(); 1099 auto elementType = parseType(); 1100 if (!elementType) 1101 return nullptr; 1102 1103 // Check that memref is formed from allowed types. 1104 if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() && 1105 !elementType.isa<ComplexType>()) 1106 return emitError(typeLoc, "invalid memref element type"), nullptr; 1107 1108 // Parse semi-affine-map-composition. 1109 SmallVector<AffineMap, 2> affineMapComposition; 1110 Optional<unsigned> memorySpace; 1111 unsigned numDims = dimensions.size(); 1112 1113 auto parseElt = [&]() -> ParseResult { 1114 // Check for the memory space. 1115 if (getToken().is(Token::integer)) { 1116 if (memorySpace) 1117 return emitError("multiple memory spaces specified in memref type"); 1118 memorySpace = getToken().getUnsignedIntegerValue(); 1119 if (!memorySpace.hasValue()) 1120 return emitError("invalid memory space in memref type"); 1121 consumeToken(Token::integer); 1122 return success(); 1123 } 1124 if (isUnranked) 1125 return emitError("cannot have affine map for unranked memref type"); 1126 if (memorySpace) 1127 return emitError("expected memory space to be last in memref type"); 1128 1129 AffineMap map; 1130 llvm::SMLoc mapLoc = getToken().getLoc(); 1131 if (getToken().is(Token::kw_offset)) { 1132 int64_t offset; 1133 SmallVector<int64_t, 4> strides; 1134 if (failed(parseStridedLayout(offset, strides))) 1135 return failure(); 1136 // Construct strided affine map. 1137 map = makeStridedLinearLayoutMap(strides, offset, state.context); 1138 } else { 1139 // Parse an affine map attribute. 1140 auto affineMap = parseAttribute(); 1141 if (!affineMap) 1142 return failure(); 1143 auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>(); 1144 if (!affineMapAttr) 1145 return emitError("expected affine map in memref type"); 1146 map = affineMapAttr.getValue(); 1147 } 1148 1149 if (map.getNumDims() != numDims) { 1150 size_t i = affineMapComposition.size(); 1151 return emitError(mapLoc, "memref affine map dimension mismatch between ") 1152 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 1153 << " and affine map" << i + 1 << ": " << numDims 1154 << " != " << map.getNumDims(); 1155 } 1156 numDims = map.getNumResults(); 1157 affineMapComposition.push_back(map); 1158 return success(); 1159 }; 1160 1161 // Parse a list of mappings and address space if present. 1162 if (!consumeIf(Token::greater)) { 1163 // Parse comma separated list of affine maps, followed by memory space. 1164 if (parseToken(Token::comma, "expected ',' or '>' in memref type") || 1165 parseCommaSeparatedListUntil(Token::greater, parseElt, 1166 /*allowEmptyList=*/false)) { 1167 return nullptr; 1168 } 1169 } 1170 1171 if (isUnranked) 1172 return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); 1173 1174 return MemRefType::get(dimensions, elementType, affineMapComposition, 1175 memorySpace.getValueOr(0)); 1176 } 1177 1178 /// Parse any type except the function type. 1179 /// 1180 /// non-function-type ::= integer-type 1181 /// | index-type 1182 /// | float-type 1183 /// | extended-type 1184 /// | vector-type 1185 /// | tensor-type 1186 /// | memref-type 1187 /// | complex-type 1188 /// | tuple-type 1189 /// | none-type 1190 /// 1191 /// index-type ::= `index` 1192 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 1193 /// none-type ::= `none` 1194 /// 1195 Type Parser::parseNonFunctionType() { 1196 switch (getToken().getKind()) { 1197 default: 1198 return (emitError("expected non-function type"), nullptr); 1199 case Token::kw_memref: 1200 return parseMemRefType(); 1201 case Token::kw_tensor: 1202 return parseTensorType(); 1203 case Token::kw_complex: 1204 return parseComplexType(); 1205 case Token::kw_tuple: 1206 return parseTupleType(); 1207 case Token::kw_vector: 1208 return parseVectorType(); 1209 // integer-type 1210 case Token::inttype: { 1211 auto width = getToken().getIntTypeBitwidth(); 1212 if (!width.hasValue()) 1213 return (emitError("invalid integer width"), nullptr); 1214 if (width.getValue() > IntegerType::kMaxWidth) { 1215 emitError(getToken().getLoc(), "integer bitwidth is limited to ") 1216 << IntegerType::kMaxWidth << " bits"; 1217 return nullptr; 1218 } 1219 1220 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; 1221 if (Optional<bool> signedness = getToken().getIntTypeSignedness()) 1222 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; 1223 1224 auto loc = getEncodedSourceLocation(getToken().getLoc()); 1225 consumeToken(Token::inttype); 1226 return IntegerType::getChecked(width.getValue(), signSemantics, loc); 1227 } 1228 1229 // float-type 1230 case Token::kw_bf16: 1231 consumeToken(Token::kw_bf16); 1232 return builder.getBF16Type(); 1233 case Token::kw_f16: 1234 consumeToken(Token::kw_f16); 1235 return builder.getF16Type(); 1236 case Token::kw_f32: 1237 consumeToken(Token::kw_f32); 1238 return builder.getF32Type(); 1239 case Token::kw_f64: 1240 consumeToken(Token::kw_f64); 1241 return builder.getF64Type(); 1242 1243 // index-type 1244 case Token::kw_index: 1245 consumeToken(Token::kw_index); 1246 return builder.getIndexType(); 1247 1248 // none-type 1249 case Token::kw_none: 1250 consumeToken(Token::kw_none); 1251 return builder.getNoneType(); 1252 1253 // extended type 1254 case Token::exclamation_identifier: 1255 return parseExtendedType(); 1256 } 1257 } 1258 1259 /// Parse a tensor type. 1260 /// 1261 /// tensor-type ::= `tensor` `<` dimension-list type `>` 1262 /// dimension-list ::= dimension-list-ranked | `*x` 1263 /// 1264 Type Parser::parseTensorType() { 1265 consumeToken(Token::kw_tensor); 1266 1267 if (parseToken(Token::less, "expected '<' in tensor type")) 1268 return nullptr; 1269 1270 bool isUnranked; 1271 SmallVector<int64_t, 4> dimensions; 1272 1273 if (consumeIf(Token::star)) { 1274 // This is an unranked tensor type. 1275 isUnranked = true; 1276 1277 if (parseXInDimensionList()) 1278 return nullptr; 1279 1280 } else { 1281 isUnranked = false; 1282 if (parseDimensionListRanked(dimensions)) 1283 return nullptr; 1284 } 1285 1286 // Parse the element type. 1287 auto elementTypeLoc = getToken().getLoc(); 1288 auto elementType = parseType(); 1289 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 1290 return nullptr; 1291 if (!TensorType::isValidElementType(elementType)) 1292 return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; 1293 1294 if (isUnranked) 1295 return UnrankedTensorType::get(elementType); 1296 return RankedTensorType::get(dimensions, elementType); 1297 } 1298 1299 /// Parse a tuple type. 1300 /// 1301 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 1302 /// 1303 Type Parser::parseTupleType() { 1304 consumeToken(Token::kw_tuple); 1305 1306 // Parse the '<'. 1307 if (parseToken(Token::less, "expected '<' in tuple type")) 1308 return nullptr; 1309 1310 // Check for an empty tuple by directly parsing '>'. 1311 if (consumeIf(Token::greater)) 1312 return TupleType::get(getContext()); 1313 1314 // Parse the element types and the '>'. 1315 SmallVector<Type, 4> types; 1316 if (parseTypeListNoParens(types) || 1317 parseToken(Token::greater, "expected '>' in tuple type")) 1318 return nullptr; 1319 1320 return TupleType::get(types, getContext()); 1321 } 1322 1323 /// Parse a vector type. 1324 /// 1325 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` 1326 /// non-empty-static-dimension-list ::= decimal-literal `x` 1327 /// static-dimension-list 1328 /// static-dimension-list ::= (decimal-literal `x`)* 1329 /// 1330 VectorType Parser::parseVectorType() { 1331 consumeToken(Token::kw_vector); 1332 1333 if (parseToken(Token::less, "expected '<' in vector type")) 1334 return nullptr; 1335 1336 SmallVector<int64_t, 4> dimensions; 1337 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 1338 return nullptr; 1339 if (dimensions.empty()) 1340 return (emitError("expected dimension size in vector type"), nullptr); 1341 if (any_of(dimensions, [](int64_t i) { return i <= 0; })) 1342 return emitError(getToken().getLoc(), 1343 "vector types must have positive constant sizes"), 1344 nullptr; 1345 1346 // Parse the element type. 1347 auto typeLoc = getToken().getLoc(); 1348 auto elementType = parseType(); 1349 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 1350 return nullptr; 1351 if (!VectorType::isValidElementType(elementType)) 1352 return emitError(typeLoc, "vector elements must be int or float type"), 1353 nullptr; 1354 1355 return VectorType::get(dimensions, elementType); 1356 } 1357 1358 /// Parse a dimension list of a tensor or memref type. This populates the 1359 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 1360 /// errors out on `?` otherwise. 1361 /// 1362 /// dimension-list-ranked ::= (dimension `x`)* 1363 /// dimension ::= `?` | decimal-literal 1364 /// 1365 /// When `allowDynamic` is not set, this is used to parse: 1366 /// 1367 /// static-dimension-list ::= (decimal-literal `x`)* 1368 ParseResult 1369 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 1370 bool allowDynamic) { 1371 while (getToken().isAny(Token::integer, Token::question)) { 1372 if (consumeIf(Token::question)) { 1373 if (!allowDynamic) 1374 return emitError("expected static shape"); 1375 dimensions.push_back(-1); 1376 } else { 1377 // Hexadecimal integer literals (starting with `0x`) are not allowed in 1378 // aggregate type declarations. Therefore, `0xf32` should be processed as 1379 // a sequence of separate elements `0`, `x`, `f32`. 1380 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 1381 // We can get here only if the token is an integer literal. Hexadecimal 1382 // integer literals can only start with `0x` (`1x` wouldn't lex as a 1383 // literal, just `1` would, at which point we don't get into this 1384 // branch). 1385 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 1386 dimensions.push_back(0); 1387 state.lex.resetPointer(getTokenSpelling().data() + 1); 1388 consumeToken(); 1389 } else { 1390 // Make sure this integer value is in bound and valid. 1391 auto dimension = getToken().getUnsignedIntegerValue(); 1392 if (!dimension.hasValue()) 1393 return emitError("invalid dimension"); 1394 dimensions.push_back((int64_t)dimension.getValue()); 1395 consumeToken(Token::integer); 1396 } 1397 } 1398 1399 // Make sure we have an 'x' or something like 'xbf32'. 1400 if (parseXInDimensionList()) 1401 return failure(); 1402 } 1403 1404 return success(); 1405 } 1406 1407 /// Parse an 'x' token in a dimension list, handling the case where the x is 1408 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 1409 /// token. 1410 ParseResult Parser::parseXInDimensionList() { 1411 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 1412 return emitError("expected 'x' in dimension list"); 1413 1414 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 1415 if (getTokenSpelling().size() != 1) 1416 state.lex.resetPointer(getTokenSpelling().data() + 1); 1417 1418 // Consume the 'x'. 1419 consumeToken(Token::bare_identifier); 1420 1421 return success(); 1422 } 1423 1424 // Parse a comma-separated list of dimensions, possibly empty: 1425 // stride-list ::= `[` (dimension (`,` dimension)*)? `]` 1426 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { 1427 if (!consumeIf(Token::l_square)) 1428 return failure(); 1429 // Empty list early exit. 1430 if (consumeIf(Token::r_square)) 1431 return success(); 1432 while (true) { 1433 if (consumeIf(Token::question)) { 1434 dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); 1435 } else { 1436 // This must be an integer value. 1437 int64_t val; 1438 if (getToken().getSpelling().getAsInteger(10, val)) 1439 return emitError("invalid integer value: ") << getToken().getSpelling(); 1440 // Make sure it is not the one value for `?`. 1441 if (ShapedType::isDynamic(val)) 1442 return emitError("invalid integer value: ") 1443 << getToken().getSpelling() 1444 << ", use `?` to specify a dynamic dimension"; 1445 dimensions.push_back(val); 1446 consumeToken(Token::integer); 1447 } 1448 if (!consumeIf(Token::comma)) 1449 break; 1450 } 1451 if (!consumeIf(Token::r_square)) 1452 return failure(); 1453 return success(); 1454 } 1455 1456 //===----------------------------------------------------------------------===// 1457 // Attribute parsing. 1458 //===----------------------------------------------------------------------===// 1459 1460 /// Return the symbol reference referred to by the given token, that is known to 1461 /// be an @-identifier. 1462 static std::string extractSymbolReference(Token tok) { 1463 assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); 1464 StringRef nameStr = tok.getSpelling().drop_front(); 1465 1466 // Check to see if the reference is a string literal, or a bare identifier. 1467 if (nameStr.front() == '"') 1468 return tok.getStringValue(); 1469 return std::string(nameStr); 1470 } 1471 1472 /// Parse an arbitrary attribute. 1473 /// 1474 /// attribute-value ::= `unit` 1475 /// | bool-literal 1476 /// | integer-literal (`:` (index-type | integer-type))? 1477 /// | float-literal (`:` float-type)? 1478 /// | string-literal (`:` type)? 1479 /// | type 1480 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 1481 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 1482 /// | symbol-ref-id (`::` symbol-ref-id)* 1483 /// | `dense` `<` attribute-value `>` `:` 1484 /// (tensor-type | vector-type) 1485 /// | `sparse` `<` attribute-value `,` attribute-value `>` 1486 /// `:` (tensor-type | vector-type) 1487 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 1488 /// `>` `:` (tensor-type | vector-type) 1489 /// | extended-attribute 1490 /// 1491 Attribute Parser::parseAttribute(Type type) { 1492 switch (getToken().getKind()) { 1493 // Parse an AffineMap or IntegerSet attribute. 1494 case Token::kw_affine_map: { 1495 consumeToken(Token::kw_affine_map); 1496 1497 AffineMap map; 1498 if (parseToken(Token::less, "expected '<' in affine map") || 1499 parseAffineMapReference(map) || 1500 parseToken(Token::greater, "expected '>' in affine map")) 1501 return Attribute(); 1502 return AffineMapAttr::get(map); 1503 } 1504 case Token::kw_affine_set: { 1505 consumeToken(Token::kw_affine_set); 1506 1507 IntegerSet set; 1508 if (parseToken(Token::less, "expected '<' in integer set") || 1509 parseIntegerSetReference(set) || 1510 parseToken(Token::greater, "expected '>' in integer set")) 1511 return Attribute(); 1512 return IntegerSetAttr::get(set); 1513 } 1514 1515 // Parse an array attribute. 1516 case Token::l_square: { 1517 consumeToken(Token::l_square); 1518 1519 SmallVector<Attribute, 4> elements; 1520 auto parseElt = [&]() -> ParseResult { 1521 elements.push_back(parseAttribute()); 1522 return elements.back() ? success() : failure(); 1523 }; 1524 1525 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 1526 return nullptr; 1527 return builder.getArrayAttr(elements); 1528 } 1529 1530 // Parse a boolean attribute. 1531 case Token::kw_false: 1532 consumeToken(Token::kw_false); 1533 return builder.getBoolAttr(false); 1534 case Token::kw_true: 1535 consumeToken(Token::kw_true); 1536 return builder.getBoolAttr(true); 1537 1538 // Parse a dense elements attribute. 1539 case Token::kw_dense: 1540 return parseDenseElementsAttr(type); 1541 1542 // Parse a dictionary attribute. 1543 case Token::l_brace: { 1544 SmallVector<NamedAttribute, 4> elements; 1545 if (parseAttributeDict(elements)) 1546 return nullptr; 1547 return builder.getDictionaryAttr(elements); 1548 } 1549 1550 // Parse an extended attribute, i.e. alias or dialect attribute. 1551 case Token::hash_identifier: 1552 return parseExtendedAttr(type); 1553 1554 // Parse floating point and integer attributes. 1555 case Token::floatliteral: 1556 return parseFloatAttr(type, /*isNegative=*/false); 1557 case Token::integer: 1558 return parseDecOrHexAttr(type, /*isNegative=*/false); 1559 case Token::minus: { 1560 consumeToken(Token::minus); 1561 if (getToken().is(Token::integer)) 1562 return parseDecOrHexAttr(type, /*isNegative=*/true); 1563 if (getToken().is(Token::floatliteral)) 1564 return parseFloatAttr(type, /*isNegative=*/true); 1565 1566 return (emitError("expected constant integer or floating point value"), 1567 nullptr); 1568 } 1569 1570 // Parse a location attribute. 1571 case Token::kw_loc: { 1572 LocationAttr attr; 1573 return failed(parseLocation(attr)) ? Attribute() : attr; 1574 } 1575 1576 // Parse an opaque elements attribute. 1577 case Token::kw_opaque: 1578 return parseOpaqueElementsAttr(type); 1579 1580 // Parse a sparse elements attribute. 1581 case Token::kw_sparse: 1582 return parseSparseElementsAttr(type); 1583 1584 // Parse a string attribute. 1585 case Token::string: { 1586 auto val = getToken().getStringValue(); 1587 consumeToken(Token::string); 1588 // Parse the optional trailing colon type if one wasn't explicitly provided. 1589 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1590 return Attribute(); 1591 1592 return type ? StringAttr::get(val, type) 1593 : StringAttr::get(val, getContext()); 1594 } 1595 1596 // Parse a symbol reference attribute. 1597 case Token::at_identifier: { 1598 std::string nameStr = extractSymbolReference(getToken()); 1599 consumeToken(Token::at_identifier); 1600 1601 // Parse any nested references. 1602 std::vector<FlatSymbolRefAttr> nestedRefs; 1603 while (getToken().is(Token::colon)) { 1604 // Check for the '::' prefix. 1605 const char *curPointer = getToken().getLoc().getPointer(); 1606 consumeToken(Token::colon); 1607 if (!consumeIf(Token::colon)) { 1608 state.lex.resetPointer(curPointer); 1609 consumeToken(); 1610 break; 1611 } 1612 // Parse the reference itself. 1613 auto curLoc = getToken().getLoc(); 1614 if (getToken().isNot(Token::at_identifier)) { 1615 emitError(curLoc, "expected nested symbol reference identifier"); 1616 return Attribute(); 1617 } 1618 1619 std::string nameStr = extractSymbolReference(getToken()); 1620 consumeToken(Token::at_identifier); 1621 nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); 1622 } 1623 1624 return builder.getSymbolRefAttr(nameStr, nestedRefs); 1625 } 1626 1627 // Parse a 'unit' attribute. 1628 case Token::kw_unit: 1629 consumeToken(Token::kw_unit); 1630 return builder.getUnitAttr(); 1631 1632 default: 1633 // Parse a type attribute. 1634 if (Type type = parseType()) 1635 return TypeAttr::get(type); 1636 return nullptr; 1637 } 1638 } 1639 1640 /// Attribute dictionary. 1641 /// 1642 /// attribute-dict ::= `{` `}` 1643 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1644 /// attribute-entry ::= bare-id `=` attribute-value 1645 /// 1646 ParseResult 1647 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1648 if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) 1649 return failure(); 1650 1651 auto parseElt = [&]() -> ParseResult { 1652 // We allow keywords as attribute names. 1653 if (getToken().isNot(Token::bare_identifier, Token::inttype) && 1654 !getToken().isKeyword()) 1655 return emitError("expected attribute name"); 1656 Identifier nameId = builder.getIdentifier(getTokenSpelling()); 1657 consumeToken(); 1658 1659 // Try to parse the '=' for the attribute value. 1660 if (!consumeIf(Token::equal)) { 1661 // If there is no '=', we treat this as a unit attribute. 1662 attributes.push_back({nameId, builder.getUnitAttr()}); 1663 return success(); 1664 } 1665 1666 auto attr = parseAttribute(); 1667 if (!attr) 1668 return failure(); 1669 1670 attributes.push_back({nameId, attr}); 1671 return success(); 1672 }; 1673 1674 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1675 return failure(); 1676 1677 return success(); 1678 } 1679 1680 /// Parse an extended attribute. 1681 /// 1682 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1683 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1684 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1685 /// attribute-alias ::= `#` alias-name 1686 /// 1687 Attribute Parser::parseExtendedAttr(Type type) { 1688 Attribute attr = parseExtendedSymbol<Attribute>( 1689 *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, 1690 [&](StringRef dialectName, StringRef symbolData, 1691 llvm::SMLoc loc) -> Attribute { 1692 // Parse an optional trailing colon type. 1693 Type attrType = type; 1694 if (consumeIf(Token::colon) && !(attrType = parseType())) 1695 return Attribute(); 1696 1697 // If we found a registered dialect, then ask it to parse the attribute. 1698 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1699 return parseSymbol<Attribute>( 1700 symbolData, state.context, state.symbols, [&](Parser &parser) { 1701 CustomDialectAsmParser customParser(symbolData, parser); 1702 return dialect->parseAttribute(customParser, attrType); 1703 }); 1704 } 1705 1706 // Otherwise, form a new opaque attribute. 1707 return OpaqueAttr::getChecked( 1708 Identifier::get(dialectName, state.context), symbolData, 1709 attrType ? attrType : NoneType::get(state.context), 1710 getEncodedSourceLocation(loc)); 1711 }); 1712 1713 // Ensure that the attribute has the same type as requested. 1714 if (attr && type && attr.getType() != type) { 1715 emitError("attribute type different than expected: expected ") 1716 << type << ", but got " << attr.getType(); 1717 return nullptr; 1718 } 1719 return attr; 1720 } 1721 1722 /// Parse a float attribute. 1723 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1724 auto val = getToken().getFloatingPointValue(); 1725 if (!val.hasValue()) 1726 return (emitError("floating point value too large for attribute"), nullptr); 1727 consumeToken(Token::floatliteral); 1728 if (!type) { 1729 // Default to F64 when no type is specified. 1730 if (!consumeIf(Token::colon)) 1731 type = builder.getF64Type(); 1732 else if (!(type = parseType())) 1733 return nullptr; 1734 } 1735 if (!type.isa<FloatType>()) 1736 return (emitError("floating point value not valid for specified type"), 1737 nullptr); 1738 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1739 } 1740 1741 /// Construct a float attribute bitwise equivalent to the integer literal. 1742 static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1743 uint64_t value) { 1744 // FIXME: bfloat is currently stored as a double internally because it doesn't 1745 // have valid APFloat semantics. 1746 if (type.isF64() || type.isBF16()) 1747 return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); 1748 1749 APInt apInt(type.getWidth(), value); 1750 if (apInt != value) { 1751 p->emitError("hexadecimal float constant out of range for type"); 1752 return llvm::None; 1753 } 1754 return APFloat(type.getFloatSemantics(), apInt); 1755 } 1756 1757 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1758 /// or a float attribute. 1759 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1760 auto val = getToken().getUInt64IntegerValue(); 1761 if (!val.hasValue()) 1762 return (emitError("integer constant out of range for attribute"), nullptr); 1763 1764 // Remember if the literal is hexadecimal. 1765 StringRef spelling = getToken().getSpelling(); 1766 auto loc = state.curToken.getLoc(); 1767 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1768 1769 consumeToken(Token::integer); 1770 if (!type) { 1771 // Default to i64 if not type is specified. 1772 if (!consumeIf(Token::colon)) 1773 type = builder.getIntegerType(64); 1774 else if (!(type = parseType())) 1775 return nullptr; 1776 } 1777 1778 if (auto floatType = type.dyn_cast<FloatType>()) { 1779 if (isNegative) 1780 return emitError( 1781 loc, 1782 "hexadecimal float literal should not have a leading minus"), 1783 nullptr; 1784 if (!isHex) { 1785 emitError(loc, "unexpected decimal integer literal for a float attribute") 1786 .attachNote() 1787 << "add a trailing dot to make the literal a float"; 1788 return nullptr; 1789 } 1790 1791 // Construct a float attribute bitwise equivalent to the integer literal. 1792 Optional<APFloat> apVal = 1793 buildHexadecimalFloatLiteral(this, floatType, *val); 1794 return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); 1795 } 1796 1797 if (!type.isa<IntegerType>() && !type.isa<IndexType>()) 1798 return emitError(loc, "integer literal not valid for specified type"), 1799 nullptr; 1800 1801 if (isNegative && type.isUnsignedInteger()) { 1802 emitError(loc, 1803 "negative integer literal not valid for unsigned integer type"); 1804 return nullptr; 1805 } 1806 1807 // Parse the integer literal. 1808 int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); 1809 APInt apInt(width, *val, isNegative); 1810 if (apInt != *val) 1811 return emitError(loc, "integer constant out of range for attribute"), 1812 nullptr; 1813 1814 // Otherwise construct an integer attribute. 1815 if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0) 1816 return emitError(loc, "integer constant out of range for attribute"), 1817 nullptr; 1818 1819 return builder.getIntegerAttr(type, isNegative ? -apInt : apInt); 1820 } 1821 1822 /// Parse elements values stored within a hex etring. On success, the values are 1823 /// stored into 'result'. 1824 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, 1825 std::string &result) { 1826 std::string val = tok.getStringValue(); 1827 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1828 return parser.emitError(tok.getLoc(), 1829 "elements hex string should start with '0x'"); 1830 1831 StringRef hexValues = StringRef(val).drop_front(2); 1832 if (!llvm::all_of(hexValues, llvm::isHexDigit)) 1833 return parser.emitError(tok.getLoc(), 1834 "elements hex string only contains hex digits"); 1835 1836 result = llvm::fromHex(hexValues); 1837 return success(); 1838 } 1839 1840 /// Parse an opaque elements attribute. 1841 Attribute Parser::parseOpaqueElementsAttr(Type attrType) { 1842 consumeToken(Token::kw_opaque); 1843 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1844 return nullptr; 1845 1846 if (getToken().isNot(Token::string)) 1847 return (emitError("expected dialect namespace"), nullptr); 1848 1849 auto name = getToken().getStringValue(); 1850 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1851 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1852 // attribute. Otherwise, it can't be roundtripped without having the dialect 1853 // registered. 1854 if (!dialect) 1855 return (emitError("no registered dialect with namespace '" + name + "'"), 1856 nullptr); 1857 consumeToken(Token::string); 1858 1859 if (parseToken(Token::comma, "expected ','")) 1860 return nullptr; 1861 1862 Token hexTok = getToken(); 1863 if (parseToken(Token::string, "elements hex string should start with '0x'") || 1864 parseToken(Token::greater, "expected '>'")) 1865 return nullptr; 1866 auto type = parseElementsLiteralType(attrType); 1867 if (!type) 1868 return nullptr; 1869 1870 std::string data; 1871 if (parseElementAttrHexValues(*this, hexTok, data)) 1872 return nullptr; 1873 return OpaqueElementsAttr::get(dialect, type, data); 1874 } 1875 1876 namespace { 1877 class TensorLiteralParser { 1878 public: 1879 TensorLiteralParser(Parser &p) : p(p) {} 1880 1881 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 1882 /// may also parse a tensor literal that is store as a hex string. 1883 ParseResult parse(bool allowHex); 1884 1885 /// Build a dense attribute instance with the parsed elements and the given 1886 /// shaped type. 1887 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1888 1889 ArrayRef<int64_t> getShape() const { return shape; } 1890 1891 private: 1892 enum class ElementKind { Boolean, Integer, Float }; 1893 1894 /// Return a string to represent the given element kind. 1895 const char *getElementKindStr(ElementKind kind) { 1896 switch (kind) { 1897 case ElementKind::Boolean: 1898 return "'boolean'"; 1899 case ElementKind::Integer: 1900 return "'integer'"; 1901 case ElementKind::Float: 1902 return "'float'"; 1903 } 1904 llvm_unreachable("unknown element kind"); 1905 } 1906 1907 /// Build a Dense Integer attribute for the given type. 1908 DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, 1909 IntegerType eltTy); 1910 1911 /// Build a Dense Float attribute for the given type. 1912 DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, 1913 FloatType eltTy); 1914 1915 /// Build a Dense attribute with hex data for the given type. 1916 DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); 1917 1918 /// Parse a single element, returning failure if it isn't a valid element 1919 /// literal. For example: 1920 /// parseElement(1) -> Success, 1 1921 /// parseElement([1]) -> Failure 1922 ParseResult parseElement(); 1923 1924 /// Parse a list of either lists or elements, returning the dimensions of the 1925 /// parsed sub-tensors in dims. For example: 1926 /// parseList([1, 2, 3]) -> Success, [3] 1927 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1928 /// parseList([[1, 2], 3]) -> Failure 1929 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1930 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 1931 1932 /// Parse a literal that was printed as a hex string. 1933 ParseResult parseHexElements(); 1934 1935 Parser &p; 1936 1937 /// The shape inferred from the parsed elements. 1938 SmallVector<int64_t, 4> shape; 1939 1940 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 1941 std::vector<std::pair<bool, Token>> storage; 1942 1943 /// A flag that indicates the type of elements that have been parsed. 1944 Optional<ElementKind> knownEltKind; 1945 1946 /// Storage used when parsing elements that were stored as hex values. 1947 Optional<Token> hexStorage; 1948 }; 1949 } // namespace 1950 1951 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 1952 /// may also parse a tensor literal that is store as a hex string. 1953 ParseResult TensorLiteralParser::parse(bool allowHex) { 1954 // If hex is allowed, check for a string literal. 1955 if (allowHex && p.getToken().is(Token::string)) { 1956 hexStorage = p.getToken(); 1957 p.consumeToken(Token::string); 1958 return success(); 1959 } 1960 // Otherwise, parse a list or an individual element. 1961 if (p.getToken().is(Token::l_square)) 1962 return parseList(shape); 1963 return parseElement(); 1964 } 1965 1966 /// Build a dense attribute instance with the parsed elements and the given 1967 /// shaped type. 1968 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 1969 ShapedType type) { 1970 // Check to see if we parsed the literal from a hex string. 1971 if (hexStorage.hasValue()) 1972 return getHexAttr(loc, type); 1973 1974 // Check that the parsed storage size has the same number of elements to the 1975 // type, or is a known splat. 1976 if (!shape.empty() && getShape() != type.getShape()) { 1977 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 1978 << "]) does not match type ([" << type.getShape() << "])"; 1979 return nullptr; 1980 } 1981 1982 // If the type is an integer, build a set of APInt values from the storage 1983 // with the correct bitwidth. 1984 if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) 1985 return getIntAttr(loc, type, intTy); 1986 1987 // Otherwise, this must be a floating point type. 1988 auto floatTy = type.getElementType().dyn_cast<FloatType>(); 1989 if (!floatTy) { 1990 p.emitError(loc) << "expected floating-point or integer element type, got " 1991 << type.getElementType(); 1992 return nullptr; 1993 } 1994 return getFloatAttr(loc, type, floatTy); 1995 } 1996 1997 /// Build a Dense Integer attribute for the given type. 1998 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, 1999 ShapedType type, 2000 IntegerType eltTy) { 2001 std::vector<APInt> intElements; 2002 intElements.reserve(storage.size()); 2003 auto isUintType = type.getElementType().isUnsignedInteger(); 2004 for (const auto &signAndToken : storage) { 2005 bool isNegative = signAndToken.first; 2006 const Token &token = signAndToken.second; 2007 auto tokenLoc = token.getLoc(); 2008 2009 if (isNegative && isUintType) { 2010 p.emitError(tokenLoc) 2011 << "expected unsigned integer elements, but parsed negative value"; 2012 return nullptr; 2013 } 2014 2015 // Check to see if floating point values were parsed. 2016 if (token.is(Token::floatliteral)) { 2017 p.emitError(tokenLoc) 2018 << "expected integer elements, but parsed floating-point"; 2019 return nullptr; 2020 } 2021 2022 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 2023 "unexpected token type"); 2024 if (token.isAny(Token::kw_true, Token::kw_false)) { 2025 if (!eltTy.isInteger(1)) 2026 p.emitError(tokenLoc) 2027 << "expected i1 type for 'true' or 'false' values"; 2028 APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), 2029 /*isSigned=*/false); 2030 intElements.push_back(apInt); 2031 continue; 2032 } 2033 2034 // Create APInt values for each element with the correct bitwidth. 2035 auto val = token.getUInt64IntegerValue(); 2036 if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 2037 : (int64_t)val.getValue() < 0)) { 2038 p.emitError(tokenLoc, "integer constant out of range for attribute"); 2039 return nullptr; 2040 } 2041 APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); 2042 if (apInt != val.getValue()) 2043 return (p.emitError(tokenLoc, "integer constant out of range for type"), 2044 nullptr); 2045 intElements.push_back(isNegative ? -apInt : apInt); 2046 } 2047 2048 return DenseElementsAttr::get(type, intElements); 2049 } 2050 2051 /// Build a Dense Float attribute for the given type. 2052 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, 2053 ShapedType type, 2054 FloatType eltTy) { 2055 std::vector<APFloat> floatValues; 2056 floatValues.reserve(storage.size()); 2057 for (const auto &signAndToken : storage) { 2058 bool isNegative = signAndToken.first; 2059 const Token &token = signAndToken.second; 2060 2061 // Handle hexadecimal float literals. 2062 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 2063 if (isNegative) { 2064 p.emitError(token.getLoc()) 2065 << "hexadecimal float literal should not have a leading minus"; 2066 return nullptr; 2067 } 2068 auto val = token.getUInt64IntegerValue(); 2069 if (!val.hasValue()) { 2070 p.emitError("hexadecimal float constant out of range for attribute"); 2071 return nullptr; 2072 } 2073 Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); 2074 if (!apVal) 2075 return nullptr; 2076 floatValues.push_back(*apVal); 2077 continue; 2078 } 2079 2080 // Check to see if any decimal integers or booleans were parsed. 2081 if (!token.is(Token::floatliteral)) { 2082 p.emitError() << "expected floating-point elements, but parsed integer"; 2083 return nullptr; 2084 } 2085 2086 // Build the float values from tokens. 2087 auto val = token.getFloatingPointValue(); 2088 if (!val.hasValue()) { 2089 p.emitError("floating point value too large for attribute"); 2090 return nullptr; 2091 } 2092 // Treat BF16 as double because it is not supported in LLVM's APFloat. 2093 APFloat apVal(isNegative ? -*val : *val); 2094 if (!eltTy.isBF16() && !eltTy.isF64()) { 2095 bool unused; 2096 apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, 2097 &unused); 2098 } 2099 floatValues.push_back(apVal); 2100 } 2101 2102 return DenseElementsAttr::get(type, floatValues); 2103 } 2104 2105 /// Build a Dense attribute with hex data for the given type. 2106 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, 2107 ShapedType type) { 2108 Type elementType = type.getElementType(); 2109 if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) { 2110 p.emitError(loc) << "expected floating-point or integer element type, got " 2111 << elementType; 2112 return nullptr; 2113 } 2114 2115 std::string data; 2116 if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) 2117 return nullptr; 2118 2119 // Check that the size of the hex data correpsonds to the size of the type, or 2120 // a splat of the type. 2121 // TODO: bf16 is currently stored as a double, this should be removed when 2122 // APFloat properly supports it. 2123 int64_t elementWidth = 2124 elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth(); 2125 if (static_cast<int64_t>(data.size() * CHAR_BIT) != 2126 (type.getNumElements() * elementWidth)) { 2127 p.emitError(loc) << "elements hex data size is invalid for provided type: " 2128 << type; 2129 return nullptr; 2130 } 2131 2132 return DenseElementsAttr::getFromRawBuffer( 2133 type, ArrayRef<char>(data.data(), data.size()), /*isSplatBuffer=*/false); 2134 } 2135 2136 ParseResult TensorLiteralParser::parseElement() { 2137 switch (p.getToken().getKind()) { 2138 // Parse a boolean element. 2139 case Token::kw_true: 2140 case Token::kw_false: 2141 case Token::floatliteral: 2142 case Token::integer: 2143 storage.emplace_back(/*isNegative=*/false, p.getToken()); 2144 p.consumeToken(); 2145 break; 2146 2147 // Parse a signed integer or a negative floating-point element. 2148 case Token::minus: 2149 p.consumeToken(Token::minus); 2150 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 2151 return p.emitError("expected integer or floating point literal"); 2152 storage.emplace_back(/*isNegative=*/true, p.getToken()); 2153 p.consumeToken(); 2154 break; 2155 2156 default: 2157 return p.emitError("expected element literal of primitive type"); 2158 } 2159 2160 return success(); 2161 } 2162 2163 /// Parse a list of either lists or elements, returning the dimensions of the 2164 /// parsed sub-tensors in dims. For example: 2165 /// parseList([1, 2, 3]) -> Success, [3] 2166 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 2167 /// parseList([[1, 2], 3]) -> Failure 2168 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 2169 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 2170 p.consumeToken(Token::l_square); 2171 2172 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 2173 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 2174 if (prevDims == newDims) 2175 return success(); 2176 return p.emitError("tensor literal is invalid; ranks are not consistent " 2177 "between elements"); 2178 }; 2179 2180 bool first = true; 2181 SmallVector<int64_t, 4> newDims; 2182 unsigned size = 0; 2183 auto parseCommaSeparatedList = [&]() -> ParseResult { 2184 SmallVector<int64_t, 4> thisDims; 2185 if (p.getToken().getKind() == Token::l_square) { 2186 if (parseList(thisDims)) 2187 return failure(); 2188 } else if (parseElement()) { 2189 return failure(); 2190 } 2191 ++size; 2192 if (!first) 2193 return checkDims(newDims, thisDims); 2194 newDims = thisDims; 2195 first = false; 2196 return success(); 2197 }; 2198 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 2199 return failure(); 2200 2201 // Return the sublists' dimensions with 'size' prepended. 2202 dims.clear(); 2203 dims.push_back(size); 2204 dims.append(newDims.begin(), newDims.end()); 2205 return success(); 2206 } 2207 2208 /// Parse a dense elements attribute. 2209 Attribute Parser::parseDenseElementsAttr(Type attrType) { 2210 consumeToken(Token::kw_dense); 2211 if (parseToken(Token::less, "expected '<' after 'dense'")) 2212 return nullptr; 2213 2214 // Parse the literal data. 2215 TensorLiteralParser literalParser(*this); 2216 if (literalParser.parse(/*allowHex=*/true)) 2217 return nullptr; 2218 2219 if (parseToken(Token::greater, "expected '>'")) 2220 return nullptr; 2221 2222 auto typeLoc = getToken().getLoc(); 2223 auto type = parseElementsLiteralType(attrType); 2224 if (!type) 2225 return nullptr; 2226 return literalParser.getAttr(typeLoc, type); 2227 } 2228 2229 /// Shaped type for elements attribute. 2230 /// 2231 /// elements-literal-type ::= vector-type | ranked-tensor-type 2232 /// 2233 /// This method also checks the type has static shape. 2234 ShapedType Parser::parseElementsLiteralType(Type type) { 2235 // If the user didn't provide a type, parse the colon type for the literal. 2236 if (!type) { 2237 if (parseToken(Token::colon, "expected ':'")) 2238 return nullptr; 2239 if (!(type = parseType())) 2240 return nullptr; 2241 } 2242 2243 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 2244 emitError("elements literal must be a ranked tensor or vector type"); 2245 return nullptr; 2246 } 2247 2248 auto sType = type.cast<ShapedType>(); 2249 if (!sType.hasStaticShape()) 2250 return (emitError("elements literal type must have static shape"), nullptr); 2251 2252 return sType; 2253 } 2254 2255 /// Parse a sparse elements attribute. 2256 Attribute Parser::parseSparseElementsAttr(Type attrType) { 2257 consumeToken(Token::kw_sparse); 2258 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 2259 return nullptr; 2260 2261 /// Parse the indices. We don't allow hex values here as we may need to use 2262 /// the inferred shape. 2263 auto indicesLoc = getToken().getLoc(); 2264 TensorLiteralParser indiceParser(*this); 2265 if (indiceParser.parse(/*allowHex=*/false)) 2266 return nullptr; 2267 2268 if (parseToken(Token::comma, "expected ','")) 2269 return nullptr; 2270 2271 /// Parse the values. 2272 auto valuesLoc = getToken().getLoc(); 2273 TensorLiteralParser valuesParser(*this); 2274 if (valuesParser.parse(/*allowHex=*/true)) 2275 return nullptr; 2276 2277 if (parseToken(Token::greater, "expected '>'")) 2278 return nullptr; 2279 2280 auto type = parseElementsLiteralType(attrType); 2281 if (!type) 2282 return nullptr; 2283 2284 // If the indices are a splat, i.e. the literal parser parsed an element and 2285 // not a list, we set the shape explicitly. The indices are represented by a 2286 // 2-dimensional shape where the second dimension is the rank of the type. 2287 // Given that the parsed indices is a splat, we know that we only have one 2288 // indice and thus one for the first dimension. 2289 auto indiceEltType = builder.getIntegerType(64); 2290 ShapedType indicesType; 2291 if (indiceParser.getShape().empty()) { 2292 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 2293 } else { 2294 // Otherwise, set the shape to the one parsed by the literal parser. 2295 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 2296 } 2297 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 2298 2299 // If the values are a splat, set the shape explicitly based on the number of 2300 // indices. The number of indices is encoded in the first dimension of the 2301 // indice shape type. 2302 auto valuesEltType = type.getElementType(); 2303 ShapedType valuesType = 2304 valuesParser.getShape().empty() 2305 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 2306 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 2307 auto values = valuesParser.getAttr(valuesLoc, valuesType); 2308 2309 /// Sanity check. 2310 if (valuesType.getRank() != 1) 2311 return (emitError("expected 1-d tensor for values"), nullptr); 2312 2313 auto sameShape = (indicesType.getRank() == 1) || 2314 (type.getRank() == indicesType.getDimSize(1)); 2315 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 2316 if (!sameShape || !sameElementNum) { 2317 emitError() << "expected shape ([" << type.getShape() 2318 << "]); inferred shape of indices literal ([" 2319 << indicesType.getShape() 2320 << "]); inferred shape of values literal ([" 2321 << valuesType.getShape() << "])"; 2322 return nullptr; 2323 } 2324 2325 // Build the sparse elements attribute by the indices and values. 2326 return SparseElementsAttr::get(type, indices, values); 2327 } 2328 2329 //===----------------------------------------------------------------------===// 2330 // Location parsing. 2331 //===----------------------------------------------------------------------===// 2332 2333 /// Parse a location. 2334 /// 2335 /// location ::= `loc` inline-location 2336 /// inline-location ::= '(' location-inst ')' 2337 /// 2338 ParseResult Parser::parseLocation(LocationAttr &loc) { 2339 // Check for 'loc' identifier. 2340 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 2341 return emitError(); 2342 2343 // Parse the inline-location. 2344 if (parseToken(Token::l_paren, "expected '(' in inline location") || 2345 parseLocationInstance(loc) || 2346 parseToken(Token::r_paren, "expected ')' in inline location")) 2347 return failure(); 2348 return success(); 2349 } 2350 2351 /// Specific location instances. 2352 /// 2353 /// location-inst ::= filelinecol-location | 2354 /// name-location | 2355 /// callsite-location | 2356 /// fused-location | 2357 /// unknown-location 2358 /// filelinecol-location ::= string-literal ':' integer-literal 2359 /// ':' integer-literal 2360 /// name-location ::= string-literal 2361 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 2362 /// fused-location ::= fused ('<' attribute-value '>')? 2363 /// '[' location-inst (location-inst ',')* ']' 2364 /// unknown-location ::= 'unknown' 2365 /// 2366 ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { 2367 consumeToken(Token::bare_identifier); 2368 2369 // Parse the '('. 2370 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 2371 return failure(); 2372 2373 // Parse the callee location. 2374 LocationAttr calleeLoc; 2375 if (parseLocationInstance(calleeLoc)) 2376 return failure(); 2377 2378 // Parse the 'at'. 2379 if (getToken().isNot(Token::bare_identifier) || 2380 getToken().getSpelling() != "at") 2381 return emitError("expected 'at' in callsite location"); 2382 consumeToken(Token::bare_identifier); 2383 2384 // Parse the caller location. 2385 LocationAttr callerLoc; 2386 if (parseLocationInstance(callerLoc)) 2387 return failure(); 2388 2389 // Parse the ')'. 2390 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 2391 return failure(); 2392 2393 // Return the callsite location. 2394 loc = CallSiteLoc::get(calleeLoc, callerLoc); 2395 return success(); 2396 } 2397 2398 ParseResult Parser::parseFusedLocation(LocationAttr &loc) { 2399 consumeToken(Token::bare_identifier); 2400 2401 // Try to parse the optional metadata. 2402 Attribute metadata; 2403 if (consumeIf(Token::less)) { 2404 metadata = parseAttribute(); 2405 if (!metadata) 2406 return emitError("expected valid attribute metadata"); 2407 // Parse the '>' token. 2408 if (parseToken(Token::greater, 2409 "expected '>' after fused location metadata")) 2410 return failure(); 2411 } 2412 2413 SmallVector<Location, 4> locations; 2414 auto parseElt = [&] { 2415 LocationAttr newLoc; 2416 if (parseLocationInstance(newLoc)) 2417 return failure(); 2418 locations.push_back(newLoc); 2419 return success(); 2420 }; 2421 2422 if (parseToken(Token::l_square, "expected '[' in fused location") || 2423 parseCommaSeparatedList(parseElt) || 2424 parseToken(Token::r_square, "expected ']' in fused location")) 2425 return failure(); 2426 2427 // Return the fused location. 2428 loc = FusedLoc::get(locations, metadata, getContext()); 2429 return success(); 2430 } 2431 2432 ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { 2433 auto *ctx = getContext(); 2434 auto str = getToken().getStringValue(); 2435 consumeToken(Token::string); 2436 2437 // If the next token is ':' this is a filelinecol location. 2438 if (consumeIf(Token::colon)) { 2439 // Parse the line number. 2440 if (getToken().isNot(Token::integer)) 2441 return emitError("expected integer line number in FileLineColLoc"); 2442 auto line = getToken().getUnsignedIntegerValue(); 2443 if (!line.hasValue()) 2444 return emitError("expected integer line number in FileLineColLoc"); 2445 consumeToken(Token::integer); 2446 2447 // Parse the ':'. 2448 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 2449 return failure(); 2450 2451 // Parse the column number. 2452 if (getToken().isNot(Token::integer)) 2453 return emitError("expected integer column number in FileLineColLoc"); 2454 auto column = getToken().getUnsignedIntegerValue(); 2455 if (!column.hasValue()) 2456 return emitError("expected integer column number in FileLineColLoc"); 2457 consumeToken(Token::integer); 2458 2459 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 2460 return success(); 2461 } 2462 2463 // Otherwise, this is a NameLoc. 2464 2465 // Check for a child location. 2466 if (consumeIf(Token::l_paren)) { 2467 auto childSourceLoc = getToken().getLoc(); 2468 2469 // Parse the child location. 2470 LocationAttr childLoc; 2471 if (parseLocationInstance(childLoc)) 2472 return failure(); 2473 2474 // The child must not be another NameLoc. 2475 if (childLoc.isa<NameLoc>()) 2476 return emitError(childSourceLoc, 2477 "child of NameLoc cannot be another NameLoc"); 2478 loc = NameLoc::get(Identifier::get(str, ctx), childLoc); 2479 2480 // Parse the closing ')'. 2481 if (parseToken(Token::r_paren, 2482 "expected ')' after child location of NameLoc")) 2483 return failure(); 2484 } else { 2485 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 2486 } 2487 2488 return success(); 2489 } 2490 2491 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 2492 // Handle either name or filelinecol locations. 2493 if (getToken().is(Token::string)) 2494 return parseNameOrFileLineColLocation(loc); 2495 2496 // Bare tokens required for other cases. 2497 if (!getToken().is(Token::bare_identifier)) 2498 return emitError("expected location instance"); 2499 2500 // Check for the 'callsite' signifying a callsite location. 2501 if (getToken().getSpelling() == "callsite") 2502 return parseCallSiteLocation(loc); 2503 2504 // If the token is 'fused', then this is a fused location. 2505 if (getToken().getSpelling() == "fused") 2506 return parseFusedLocation(loc); 2507 2508 // Check for a 'unknown' for an unknown location. 2509 if (getToken().getSpelling() == "unknown") { 2510 consumeToken(Token::bare_identifier); 2511 loc = UnknownLoc::get(getContext()); 2512 return success(); 2513 } 2514 2515 return emitError("expected location instance"); 2516 } 2517 2518 //===----------------------------------------------------------------------===// 2519 // Affine parsing. 2520 //===----------------------------------------------------------------------===// 2521 2522 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 2523 /// the boolean sense. 2524 enum AffineLowPrecOp { 2525 /// Null value. 2526 LNoOp, 2527 Add, 2528 Sub 2529 }; 2530 2531 /// Higher precedence ops - all at the same precedence level. HNoOp is false 2532 /// in the boolean sense. 2533 enum AffineHighPrecOp { 2534 /// Null value. 2535 HNoOp, 2536 Mul, 2537 FloorDiv, 2538 CeilDiv, 2539 Mod 2540 }; 2541 2542 namespace { 2543 /// This is a specialized parser for affine structures (affine maps, affine 2544 /// expressions, and integer sets), maintaining the state transient to their 2545 /// bodies. 2546 class AffineParser : public Parser { 2547 public: 2548 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 2549 function_ref<ParseResult(bool)> parseElement = nullptr) 2550 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 2551 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 2552 2553 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 2554 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 2555 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 2556 ParseResult parseAffineMapOfSSAIds(AffineMap &map, 2557 OpAsmParser::Delimiter delimiter); 2558 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 2559 unsigned &numDims); 2560 2561 private: 2562 // Binary affine op parsing. 2563 AffineLowPrecOp consumeIfLowPrecOp(); 2564 AffineHighPrecOp consumeIfHighPrecOp(); 2565 2566 // Identifier lists for polyhedral structures. 2567 ParseResult parseDimIdList(unsigned &numDims); 2568 ParseResult parseSymbolIdList(unsigned &numSymbols); 2569 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 2570 unsigned &numSymbols); 2571 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 2572 2573 AffineExpr parseAffineExpr(); 2574 AffineExpr parseParentheticalExpr(); 2575 AffineExpr parseNegateExpression(AffineExpr lhs); 2576 AffineExpr parseIntegerExpr(); 2577 AffineExpr parseBareIdExpr(); 2578 AffineExpr parseSSAIdExpr(bool isSymbol); 2579 AffineExpr parseSymbolSSAIdExpr(); 2580 2581 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 2582 AffineExpr rhs, SMLoc opLoc); 2583 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 2584 AffineExpr rhs); 2585 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 2586 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 2587 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 2588 SMLoc llhsOpLoc); 2589 AffineExpr parseAffineConstraint(bool *isEq); 2590 2591 private: 2592 bool allowParsingSSAIds; 2593 function_ref<ParseResult(bool)> parseElement; 2594 unsigned numDimOperands; 2595 unsigned numSymbolOperands; 2596 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 2597 }; 2598 } // end anonymous namespace 2599 2600 /// Create an affine binary high precedence op expression (mul's, div's, mod). 2601 /// opLoc is the location of the op token to be used to report errors 2602 /// for non-conforming expressions. 2603 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 2604 AffineExpr lhs, AffineExpr rhs, 2605 SMLoc opLoc) { 2606 // TODO: make the error location info accurate. 2607 switch (op) { 2608 case Mul: 2609 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 2610 emitError(opLoc, "non-affine expression: at least one of the multiply " 2611 "operands has to be either a constant or symbolic"); 2612 return nullptr; 2613 } 2614 return lhs * rhs; 2615 case FloorDiv: 2616 if (!rhs.isSymbolicOrConstant()) { 2617 emitError(opLoc, "non-affine expression: right operand of floordiv " 2618 "has to be either a constant or symbolic"); 2619 return nullptr; 2620 } 2621 return lhs.floorDiv(rhs); 2622 case CeilDiv: 2623 if (!rhs.isSymbolicOrConstant()) { 2624 emitError(opLoc, "non-affine expression: right operand of ceildiv " 2625 "has to be either a constant or symbolic"); 2626 return nullptr; 2627 } 2628 return lhs.ceilDiv(rhs); 2629 case Mod: 2630 if (!rhs.isSymbolicOrConstant()) { 2631 emitError(opLoc, "non-affine expression: right operand of mod " 2632 "has to be either a constant or symbolic"); 2633 return nullptr; 2634 } 2635 return lhs % rhs; 2636 case HNoOp: 2637 llvm_unreachable("can't create affine expression for null high prec op"); 2638 return nullptr; 2639 } 2640 llvm_unreachable("Unknown AffineHighPrecOp"); 2641 } 2642 2643 /// Create an affine binary low precedence op expression (add, sub). 2644 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 2645 AffineExpr lhs, AffineExpr rhs) { 2646 switch (op) { 2647 case AffineLowPrecOp::Add: 2648 return lhs + rhs; 2649 case AffineLowPrecOp::Sub: 2650 return lhs - rhs; 2651 case AffineLowPrecOp::LNoOp: 2652 llvm_unreachable("can't create affine expression for null low prec op"); 2653 return nullptr; 2654 } 2655 llvm_unreachable("Unknown AffineLowPrecOp"); 2656 } 2657 2658 /// Consume this token if it is a lower precedence affine op (there are only 2659 /// two precedence levels). 2660 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 2661 switch (getToken().getKind()) { 2662 case Token::plus: 2663 consumeToken(Token::plus); 2664 return AffineLowPrecOp::Add; 2665 case Token::minus: 2666 consumeToken(Token::minus); 2667 return AffineLowPrecOp::Sub; 2668 default: 2669 return AffineLowPrecOp::LNoOp; 2670 } 2671 } 2672 2673 /// Consume this token if it is a higher precedence affine op (there are only 2674 /// two precedence levels) 2675 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 2676 switch (getToken().getKind()) { 2677 case Token::star: 2678 consumeToken(Token::star); 2679 return Mul; 2680 case Token::kw_floordiv: 2681 consumeToken(Token::kw_floordiv); 2682 return FloorDiv; 2683 case Token::kw_ceildiv: 2684 consumeToken(Token::kw_ceildiv); 2685 return CeilDiv; 2686 case Token::kw_mod: 2687 consumeToken(Token::kw_mod); 2688 return Mod; 2689 default: 2690 return HNoOp; 2691 } 2692 } 2693 2694 /// Parse a high precedence op expression list: mul, div, and mod are high 2695 /// precedence binary ops, i.e., parse a 2696 /// expr_1 op_1 expr_2 op_2 ... expr_n 2697 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2698 /// All affine binary ops are left associative. 2699 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2700 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2701 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2702 /// report an error for non-conforming expressions. 2703 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2704 AffineHighPrecOp llhsOp, 2705 SMLoc llhsOpLoc) { 2706 AffineExpr lhs = parseAffineOperandExpr(llhs); 2707 if (!lhs) 2708 return nullptr; 2709 2710 // Found an LHS. Parse the remaining expression. 2711 auto opLoc = getToken().getLoc(); 2712 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2713 if (llhs) { 2714 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2715 if (!expr) 2716 return nullptr; 2717 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2718 } 2719 // No LLHS, get RHS 2720 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2721 } 2722 2723 // This is the last operand in this expression. 2724 if (llhs) 2725 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2726 2727 // No llhs, 'lhs' itself is the expression. 2728 return lhs; 2729 } 2730 2731 /// Parse an affine expression inside parentheses. 2732 /// 2733 /// affine-expr ::= `(` affine-expr `)` 2734 AffineExpr AffineParser::parseParentheticalExpr() { 2735 if (parseToken(Token::l_paren, "expected '('")) 2736 return nullptr; 2737 if (getToken().is(Token::r_paren)) 2738 return (emitError("no expression inside parentheses"), nullptr); 2739 2740 auto expr = parseAffineExpr(); 2741 if (!expr) 2742 return nullptr; 2743 if (parseToken(Token::r_paren, "expected ')'")) 2744 return nullptr; 2745 2746 return expr; 2747 } 2748 2749 /// Parse the negation expression. 2750 /// 2751 /// affine-expr ::= `-` affine-expr 2752 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2753 if (parseToken(Token::minus, "expected '-'")) 2754 return nullptr; 2755 2756 AffineExpr operand = parseAffineOperandExpr(lhs); 2757 // Since negation has the highest precedence of all ops (including high 2758 // precedence ops) but lower than parentheses, we are only going to use 2759 // parseAffineOperandExpr instead of parseAffineExpr here. 2760 if (!operand) 2761 // Extra error message although parseAffineOperandExpr would have 2762 // complained. Leads to a better diagnostic. 2763 return (emitError("missing operand of negation"), nullptr); 2764 return (-1) * operand; 2765 } 2766 2767 /// Parse a bare id that may appear in an affine expression. 2768 /// 2769 /// affine-expr ::= bare-id 2770 AffineExpr AffineParser::parseBareIdExpr() { 2771 if (getToken().isNot(Token::bare_identifier)) 2772 return (emitError("expected bare identifier"), nullptr); 2773 2774 StringRef sRef = getTokenSpelling(); 2775 for (auto entry : dimsAndSymbols) { 2776 if (entry.first == sRef) { 2777 consumeToken(Token::bare_identifier); 2778 return entry.second; 2779 } 2780 } 2781 2782 return (emitError("use of undeclared identifier"), nullptr); 2783 } 2784 2785 /// Parse an SSA id which may appear in an affine expression. 2786 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2787 if (!allowParsingSSAIds) 2788 return (emitError("unexpected ssa identifier"), nullptr); 2789 if (getToken().isNot(Token::percent_identifier)) 2790 return (emitError("expected ssa identifier"), nullptr); 2791 auto name = getTokenSpelling(); 2792 // Check if we already parsed this SSA id. 2793 for (auto entry : dimsAndSymbols) { 2794 if (entry.first == name) { 2795 consumeToken(Token::percent_identifier); 2796 return entry.second; 2797 } 2798 } 2799 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2800 if (parseElement(isSymbol)) 2801 return (emitError("failed to parse ssa identifier"), nullptr); 2802 auto idExpr = isSymbol 2803 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2804 : getAffineDimExpr(numDimOperands++, getContext()); 2805 dimsAndSymbols.push_back({name, idExpr}); 2806 return idExpr; 2807 } 2808 2809 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2810 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2811 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2812 return nullptr; 2813 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2814 if (!symbolExpr) 2815 return nullptr; 2816 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2817 return nullptr; 2818 return symbolExpr; 2819 } 2820 2821 /// Parse a positive integral constant appearing in an affine expression. 2822 /// 2823 /// affine-expr ::= integer-literal 2824 AffineExpr AffineParser::parseIntegerExpr() { 2825 auto val = getToken().getUInt64IntegerValue(); 2826 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2827 return (emitError("constant too large for index"), nullptr); 2828 2829 consumeToken(Token::integer); 2830 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2831 } 2832 2833 /// Parses an expression that can be a valid operand of an affine expression. 2834 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2835 /// operator, the rhs of which is being parsed. This is used to determine 2836 /// whether an error should be emitted for a missing right operand. 2837 // Eg: for an expression without parentheses (like i + j + k + l), each 2838 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2839 // operand expression, it's an op expression and will be parsed via 2840 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2841 // -l are valid operands that will be parsed by this function. 2842 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2843 switch (getToken().getKind()) { 2844 case Token::bare_identifier: 2845 return parseBareIdExpr(); 2846 case Token::kw_symbol: 2847 return parseSymbolSSAIdExpr(); 2848 case Token::percent_identifier: 2849 return parseSSAIdExpr(/*isSymbol=*/false); 2850 case Token::integer: 2851 return parseIntegerExpr(); 2852 case Token::l_paren: 2853 return parseParentheticalExpr(); 2854 case Token::minus: 2855 return parseNegateExpression(lhs); 2856 case Token::kw_ceildiv: 2857 case Token::kw_floordiv: 2858 case Token::kw_mod: 2859 case Token::plus: 2860 case Token::star: 2861 if (lhs) 2862 emitError("missing right operand of binary operator"); 2863 else 2864 emitError("missing left operand of binary operator"); 2865 return nullptr; 2866 default: 2867 if (lhs) 2868 emitError("missing right operand of binary operator"); 2869 else 2870 emitError("expected affine expression"); 2871 return nullptr; 2872 } 2873 } 2874 2875 /// Parse affine expressions that are bare-id's, integer constants, 2876 /// parenthetical affine expressions, and affine op expressions that are a 2877 /// composition of those. 2878 /// 2879 /// All binary op's associate from left to right. 2880 /// 2881 /// {add, sub} have lower precedence than {mul, div, and mod}. 2882 /// 2883 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2884 /// ceildiv, and mod are at the same higher precedence level. Negation has 2885 /// higher precedence than any binary op. 2886 /// 2887 /// llhs: the affine expression appearing on the left of the one being parsed. 2888 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2889 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2890 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2891 /// associativity. 2892 /// 2893 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2894 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2895 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2896 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2897 AffineLowPrecOp llhsOp) { 2898 AffineExpr lhs; 2899 if (!(lhs = parseAffineOperandExpr(llhs))) 2900 return nullptr; 2901 2902 // Found an LHS. Deal with the ops. 2903 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 2904 if (llhs) { 2905 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2906 return parseAffineLowPrecOpExpr(sum, lOp); 2907 } 2908 // No LLHS, get RHS and form the expression. 2909 return parseAffineLowPrecOpExpr(lhs, lOp); 2910 } 2911 auto opLoc = getToken().getLoc(); 2912 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 2913 // We have a higher precedence op here. Get the rhs operand for the llhs 2914 // through parseAffineHighPrecOpExpr. 2915 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 2916 if (!highRes) 2917 return nullptr; 2918 2919 // If llhs is null, the product forms the first operand of the yet to be 2920 // found expression. If non-null, the op to associate with llhs is llhsOp. 2921 AffineExpr expr = 2922 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 2923 2924 // Recurse for subsequent low prec op's after the affine high prec op 2925 // expression. 2926 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 2927 return parseAffineLowPrecOpExpr(expr, nextOp); 2928 return expr; 2929 } 2930 // Last operand in the expression list. 2931 if (llhs) 2932 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2933 // No llhs, 'lhs' itself is the expression. 2934 return lhs; 2935 } 2936 2937 /// Parse an affine expression. 2938 /// affine-expr ::= `(` affine-expr `)` 2939 /// | `-` affine-expr 2940 /// | affine-expr `+` affine-expr 2941 /// | affine-expr `-` affine-expr 2942 /// | affine-expr `*` affine-expr 2943 /// | affine-expr `floordiv` affine-expr 2944 /// | affine-expr `ceildiv` affine-expr 2945 /// | affine-expr `mod` affine-expr 2946 /// | bare-id 2947 /// | integer-literal 2948 /// 2949 /// Additional conditions are checked depending on the production. For eg., 2950 /// one of the operands for `*` has to be either constant/symbolic; the second 2951 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 2952 AffineExpr AffineParser::parseAffineExpr() { 2953 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 2954 } 2955 2956 /// Parse a dim or symbol from the lists appearing before the actual 2957 /// expressions of the affine map. Update our state to store the 2958 /// dimensional/symbolic identifier. 2959 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 2960 if (getToken().isNot(Token::bare_identifier)) 2961 return emitError("expected bare identifier"); 2962 2963 auto name = getTokenSpelling(); 2964 for (auto entry : dimsAndSymbols) { 2965 if (entry.first == name) 2966 return emitError("redefinition of identifier '" + name + "'"); 2967 } 2968 consumeToken(Token::bare_identifier); 2969 2970 dimsAndSymbols.push_back({name, idExpr}); 2971 return success(); 2972 } 2973 2974 /// Parse the list of dimensional identifiers to an affine map. 2975 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 2976 if (parseToken(Token::l_paren, 2977 "expected '(' at start of dimensional identifiers list")) { 2978 return failure(); 2979 } 2980 2981 auto parseElt = [&]() -> ParseResult { 2982 auto dimension = getAffineDimExpr(numDims++, getContext()); 2983 return parseIdentifierDefinition(dimension); 2984 }; 2985 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 2986 } 2987 2988 /// Parse the list of symbolic identifiers to an affine map. 2989 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 2990 consumeToken(Token::l_square); 2991 auto parseElt = [&]() -> ParseResult { 2992 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 2993 return parseIdentifierDefinition(symbol); 2994 }; 2995 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 2996 } 2997 2998 /// Parse the list of symbolic identifiers to an affine map. 2999 ParseResult 3000 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 3001 unsigned &numSymbols) { 3002 if (parseDimIdList(numDims)) { 3003 return failure(); 3004 } 3005 if (!getToken().is(Token::l_square)) { 3006 numSymbols = 0; 3007 return success(); 3008 } 3009 return parseSymbolIdList(numSymbols); 3010 } 3011 3012 /// Parses an ambiguous affine map or integer set definition inline. 3013 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 3014 IntegerSet &set) { 3015 unsigned numDims = 0, numSymbols = 0; 3016 3017 // List of dimensional and optional symbol identifiers. 3018 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 3019 return failure(); 3020 } 3021 3022 // This is needed for parsing attributes as we wouldn't know whether we would 3023 // be parsing an integer set attribute or an affine map attribute. 3024 bool isArrow = getToken().is(Token::arrow); 3025 bool isColon = getToken().is(Token::colon); 3026 if (!isArrow && !isColon) { 3027 return emitError("expected '->' or ':'"); 3028 } else if (isArrow) { 3029 parseToken(Token::arrow, "expected '->' or '['"); 3030 map = parseAffineMapRange(numDims, numSymbols); 3031 return map ? success() : failure(); 3032 } else if (parseToken(Token::colon, "expected ':' or '['")) { 3033 return failure(); 3034 } 3035 3036 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 3037 return success(); 3038 3039 return failure(); 3040 } 3041 3042 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 3043 ParseResult 3044 AffineParser::parseAffineMapOfSSAIds(AffineMap &map, 3045 OpAsmParser::Delimiter delimiter) { 3046 Token::Kind rightToken; 3047 switch (delimiter) { 3048 case OpAsmParser::Delimiter::Square: 3049 if (parseToken(Token::l_square, "expected '['")) 3050 return failure(); 3051 rightToken = Token::r_square; 3052 break; 3053 case OpAsmParser::Delimiter::Paren: 3054 if (parseToken(Token::l_paren, "expected '('")) 3055 return failure(); 3056 rightToken = Token::r_paren; 3057 break; 3058 default: 3059 return emitError("unexpected delimiter"); 3060 } 3061 3062 SmallVector<AffineExpr, 4> exprs; 3063 auto parseElt = [&]() -> ParseResult { 3064 auto elt = parseAffineExpr(); 3065 exprs.push_back(elt); 3066 return elt ? success() : failure(); 3067 }; 3068 3069 // Parse a multi-dimensional affine expression (a comma-separated list of 3070 // 1-d affine expressions); the list cannot be empty. Grammar: 3071 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 3072 if (parseCommaSeparatedListUntil(rightToken, parseElt, 3073 /*allowEmptyList=*/true)) 3074 return failure(); 3075 // Parsed a valid affine map. 3076 if (exprs.empty()) 3077 map = AffineMap::get(getContext()); 3078 else 3079 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 3080 exprs); 3081 return success(); 3082 } 3083 3084 /// Parse the range and sizes affine map definition inline. 3085 /// 3086 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 3087 /// 3088 /// multi-dim-affine-expr ::= `(` `)` 3089 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` 3090 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 3091 unsigned numSymbols) { 3092 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 3093 3094 SmallVector<AffineExpr, 4> exprs; 3095 auto parseElt = [&]() -> ParseResult { 3096 auto elt = parseAffineExpr(); 3097 ParseResult res = elt ? success() : failure(); 3098 exprs.push_back(elt); 3099 return res; 3100 }; 3101 3102 // Parse a multi-dimensional affine expression (a comma-separated list of 3103 // 1-d affine expressions); the list cannot be empty. Grammar: 3104 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 3105 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3106 return AffineMap(); 3107 3108 if (exprs.empty()) 3109 return AffineMap::get(getContext()); 3110 3111 // Parsed a valid affine map. 3112 return AffineMap::get(numDims, numSymbols, exprs); 3113 } 3114 3115 /// Parse an affine constraint. 3116 /// affine-constraint ::= affine-expr `>=` `0` 3117 /// | affine-expr `==` `0` 3118 /// 3119 /// isEq is set to true if the parsed constraint is an equality, false if it 3120 /// is an inequality (greater than or equal). 3121 /// 3122 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 3123 AffineExpr expr = parseAffineExpr(); 3124 if (!expr) 3125 return nullptr; 3126 3127 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 3128 getToken().is(Token::integer)) { 3129 auto dim = getToken().getUnsignedIntegerValue(); 3130 if (dim.hasValue() && dim.getValue() == 0) { 3131 consumeToken(Token::integer); 3132 *isEq = false; 3133 return expr; 3134 } 3135 return (emitError("expected '0' after '>='"), nullptr); 3136 } 3137 3138 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 3139 getToken().is(Token::integer)) { 3140 auto dim = getToken().getUnsignedIntegerValue(); 3141 if (dim.hasValue() && dim.getValue() == 0) { 3142 consumeToken(Token::integer); 3143 *isEq = true; 3144 return expr; 3145 } 3146 return (emitError("expected '0' after '=='"), nullptr); 3147 } 3148 3149 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 3150 nullptr); 3151 } 3152 3153 /// Parse the constraints that are part of an integer set definition. 3154 /// integer-set-inline 3155 /// ::= dim-and-symbol-id-lists `:` 3156 /// '(' affine-constraint-conjunction? ')' 3157 /// affine-constraint-conjunction ::= affine-constraint (`,` 3158 /// affine-constraint)* 3159 /// 3160 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 3161 unsigned numSymbols) { 3162 if (parseToken(Token::l_paren, 3163 "expected '(' at start of integer set constraint list")) 3164 return IntegerSet(); 3165 3166 SmallVector<AffineExpr, 4> constraints; 3167 SmallVector<bool, 4> isEqs; 3168 auto parseElt = [&]() -> ParseResult { 3169 bool isEq; 3170 auto elt = parseAffineConstraint(&isEq); 3171 ParseResult res = elt ? success() : failure(); 3172 if (elt) { 3173 constraints.push_back(elt); 3174 isEqs.push_back(isEq); 3175 } 3176 return res; 3177 }; 3178 3179 // Parse a list of affine constraints (comma-separated). 3180 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3181 return IntegerSet(); 3182 3183 // If no constraints were parsed, then treat this as a degenerate 'true' case. 3184 if (constraints.empty()) { 3185 /* 0 == 0 */ 3186 auto zero = getAffineConstantExpr(0, getContext()); 3187 return IntegerSet::get(numDims, numSymbols, zero, true); 3188 } 3189 3190 // Parsed a valid integer set. 3191 return IntegerSet::get(numDims, numSymbols, constraints, isEqs); 3192 } 3193 3194 /// Parse an ambiguous reference to either and affine map or an integer set. 3195 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 3196 IntegerSet &set) { 3197 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 3198 } 3199 ParseResult Parser::parseAffineMapReference(AffineMap &map) { 3200 llvm::SMLoc curLoc = getToken().getLoc(); 3201 IntegerSet set; 3202 if (parseAffineMapOrIntegerSetReference(map, set)) 3203 return failure(); 3204 if (set) 3205 return emitError(curLoc, "expected AffineMap, but got IntegerSet"); 3206 return success(); 3207 } 3208 ParseResult Parser::parseIntegerSetReference(IntegerSet &set) { 3209 llvm::SMLoc curLoc = getToken().getLoc(); 3210 AffineMap map; 3211 if (parseAffineMapOrIntegerSetReference(map, set)) 3212 return failure(); 3213 if (map) 3214 return emitError(curLoc, "expected IntegerSet, but got AffineMap"); 3215 return success(); 3216 } 3217 3218 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 3219 /// parse SSA value uses encountered while parsing affine expressions. 3220 ParseResult 3221 Parser::parseAffineMapOfSSAIds(AffineMap &map, 3222 function_ref<ParseResult(bool)> parseElement, 3223 OpAsmParser::Delimiter delimiter) { 3224 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 3225 .parseAffineMapOfSSAIds(map, delimiter); 3226 } 3227 3228 //===----------------------------------------------------------------------===// 3229 // OperationParser 3230 //===----------------------------------------------------------------------===// 3231 3232 namespace { 3233 /// This class provides support for parsing operations and regions of 3234 /// operations. 3235 class OperationParser : public Parser { 3236 public: 3237 OperationParser(ParserState &state, ModuleOp moduleOp) 3238 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 3239 } 3240 3241 ~OperationParser(); 3242 3243 /// After parsing is finished, this function must be called to see if there 3244 /// are any remaining issues. 3245 ParseResult finalize(); 3246 3247 //===--------------------------------------------------------------------===// 3248 // SSA Value Handling 3249 //===--------------------------------------------------------------------===// 3250 3251 /// This represents a use of an SSA value in the program. The first two 3252 /// entries in the tuple are the name and result number of a reference. The 3253 /// third is the location of the reference, which is used in case this ends 3254 /// up being a use of an undefined value. 3255 struct SSAUseInfo { 3256 StringRef name; // Value name, e.g. %42 or %abc 3257 unsigned number; // Number, specified with #12 3258 SMLoc loc; // Location of first definition or use. 3259 }; 3260 3261 /// Push a new SSA name scope to the parser. 3262 void pushSSANameScope(bool isIsolated); 3263 3264 /// Pop the last SSA name scope from the parser. 3265 ParseResult popSSANameScope(); 3266 3267 /// Register a definition of a value with the symbol table. 3268 ParseResult addDefinition(SSAUseInfo useInfo, Value value); 3269 3270 /// Parse an optional list of SSA uses into 'results'. 3271 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 3272 3273 /// Parse a single SSA use into 'result'. 3274 ParseResult parseSSAUse(SSAUseInfo &result); 3275 3276 /// Given a reference to an SSA value and its type, return a reference. This 3277 /// returns null on failure. 3278 Value resolveSSAUse(SSAUseInfo useInfo, Type type); 3279 3280 ParseResult parseSSADefOrUseAndType( 3281 const std::function<ParseResult(SSAUseInfo, Type)> &action); 3282 3283 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value> &results); 3284 3285 /// Return the location of the value identified by its name and number if it 3286 /// has been already reference. 3287 Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 3288 auto &values = isolatedNameScopes.back().values; 3289 if (!values.count(name) || number >= values[name].size()) 3290 return {}; 3291 if (values[name][number].first) 3292 return values[name][number].second; 3293 return {}; 3294 } 3295 3296 //===--------------------------------------------------------------------===// 3297 // Operation Parsing 3298 //===--------------------------------------------------------------------===// 3299 3300 /// Parse an operation instance. 3301 ParseResult parseOperation(); 3302 3303 /// Parse a single operation successor and its operand list. 3304 ParseResult parseSuccessorAndUseList(Block *&dest, 3305 SmallVectorImpl<Value> &operands); 3306 3307 /// Parse a comma-separated list of operation successors in brackets. 3308 ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations, 3309 SmallVectorImpl<SmallVector<Value, 4>> &operands); 3310 3311 /// Parse an operation instance that is in the generic form. 3312 Operation *parseGenericOperation(); 3313 3314 /// Parse an operation instance that is in the generic form and insert it at 3315 /// the provided insertion point. 3316 Operation *parseGenericOperation(Block *insertBlock, 3317 Block::iterator insertPt); 3318 3319 /// Parse an operation instance that is in the op-defined custom form. 3320 Operation *parseCustomOperation(); 3321 3322 //===--------------------------------------------------------------------===// 3323 // Region Parsing 3324 //===--------------------------------------------------------------------===// 3325 3326 /// Parse a region into 'region' with the provided entry block arguments. 3327 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 3328 /// isolated from those above. 3329 ParseResult parseRegion(Region ®ion, 3330 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 3331 bool isIsolatedNameScope = false); 3332 3333 /// Parse a region body into 'region'. 3334 ParseResult parseRegionBody(Region ®ion); 3335 3336 //===--------------------------------------------------------------------===// 3337 // Block Parsing 3338 //===--------------------------------------------------------------------===// 3339 3340 /// Parse a new block into 'block'. 3341 ParseResult parseBlock(Block *&block); 3342 3343 /// Parse a list of operations into 'block'. 3344 ParseResult parseBlockBody(Block *block); 3345 3346 /// Parse a (possibly empty) list of block arguments. 3347 ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results, 3348 Block *owner); 3349 3350 /// Get the block with the specified name, creating it if it doesn't 3351 /// already exist. The location specified is the point of use, which allows 3352 /// us to diagnose references to blocks that are not defined precisely. 3353 Block *getBlockNamed(StringRef name, SMLoc loc); 3354 3355 /// Define the block with the specified name. Returns the Block* or nullptr in 3356 /// the case of redefinition. 3357 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 3358 3359 private: 3360 /// Returns the info for a block at the current scope for the given name. 3361 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 3362 return blocksByName.back()[name]; 3363 } 3364 3365 /// Insert a new forward reference to the given block. 3366 void insertForwardRef(Block *block, SMLoc loc) { 3367 forwardRef.back().try_emplace(block, loc); 3368 } 3369 3370 /// Erase any forward reference to the given block. 3371 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 3372 3373 /// Record that a definition was added at the current scope. 3374 void recordDefinition(StringRef def); 3375 3376 /// Get the value entry for the given SSA name. 3377 SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name); 3378 3379 /// Create a forward reference placeholder value with the given location and 3380 /// result type. 3381 Value createForwardRefPlaceholder(SMLoc loc, Type type); 3382 3383 /// Return true if this is a forward reference. 3384 bool isForwardRefPlaceholder(Value value) { 3385 return forwardRefPlaceholders.count(value); 3386 } 3387 3388 /// This struct represents an isolated SSA name scope. This scope may contain 3389 /// other nested non-isolated scopes. These scopes are used for operations 3390 /// that are known to be isolated to allow for reusing names within their 3391 /// regions, even if those names are used above. 3392 struct IsolatedSSANameScope { 3393 /// Record that a definition was added at the current scope. 3394 void recordDefinition(StringRef def) { 3395 definitionsPerScope.back().insert(def); 3396 } 3397 3398 /// Push a nested name scope. 3399 void pushSSANameScope() { definitionsPerScope.push_back({}); } 3400 3401 /// Pop a nested name scope. 3402 void popSSANameScope() { 3403 for (auto &def : definitionsPerScope.pop_back_val()) 3404 values.erase(def.getKey()); 3405 } 3406 3407 /// This keeps track of all of the SSA values we are tracking for each name 3408 /// scope, indexed by their name. This has one entry per result number. 3409 llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values; 3410 3411 /// This keeps track of all of the values defined by a specific name scope. 3412 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 3413 }; 3414 3415 /// A list of isolated name scopes. 3416 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 3417 3418 /// This keeps track of the block names as well as the location of the first 3419 /// reference for each nested name scope. This is used to diagnose invalid 3420 /// block references and memorize them. 3421 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 3422 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 3423 3424 /// These are all of the placeholders we've made along with the location of 3425 /// their first reference, to allow checking for use of undefined values. 3426 DenseMap<Value, SMLoc> forwardRefPlaceholders; 3427 3428 /// The builder used when creating parsed operation instances. 3429 OpBuilder opBuilder; 3430 3431 /// The top level module operation. 3432 ModuleOp moduleOp; 3433 }; 3434 } // end anonymous namespace 3435 3436 OperationParser::~OperationParser() { 3437 for (auto &fwd : forwardRefPlaceholders) { 3438 // Drop all uses of undefined forward declared reference and destroy 3439 // defining operation. 3440 fwd.first.dropAllUses(); 3441 fwd.first.getDefiningOp()->destroy(); 3442 } 3443 } 3444 3445 /// After parsing is finished, this function must be called to see if there are 3446 /// any remaining issues. 3447 ParseResult OperationParser::finalize() { 3448 // Check for any forward references that are left. If we find any, error 3449 // out. 3450 if (!forwardRefPlaceholders.empty()) { 3451 SmallVector<std::pair<const char *, Value>, 4> errors; 3452 // Iteration over the map isn't deterministic, so sort by source location. 3453 for (auto entry : forwardRefPlaceholders) 3454 errors.push_back({entry.second.getPointer(), entry.first}); 3455 llvm::array_pod_sort(errors.begin(), errors.end()); 3456 3457 for (auto entry : errors) { 3458 auto loc = SMLoc::getFromPointer(entry.first); 3459 emitError(loc, "use of undeclared SSA value name"); 3460 } 3461 return failure(); 3462 } 3463 3464 return success(); 3465 } 3466 3467 //===----------------------------------------------------------------------===// 3468 // SSA Value Handling 3469 //===----------------------------------------------------------------------===// 3470 3471 void OperationParser::pushSSANameScope(bool isIsolated) { 3472 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 3473 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 3474 3475 // Push back a new name definition scope. 3476 if (isIsolated) 3477 isolatedNameScopes.push_back({}); 3478 isolatedNameScopes.back().pushSSANameScope(); 3479 } 3480 3481 ParseResult OperationParser::popSSANameScope() { 3482 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 3483 3484 // Verify that all referenced blocks were defined. 3485 if (!forwardRefInCurrentScope.empty()) { 3486 SmallVector<std::pair<const char *, Block *>, 4> errors; 3487 // Iteration over the map isn't deterministic, so sort by source location. 3488 for (auto entry : forwardRefInCurrentScope) { 3489 errors.push_back({entry.second.getPointer(), entry.first}); 3490 // Add this block to the top-level region to allow for automatic cleanup. 3491 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 3492 } 3493 llvm::array_pod_sort(errors.begin(), errors.end()); 3494 3495 for (auto entry : errors) { 3496 auto loc = SMLoc::getFromPointer(entry.first); 3497 emitError(loc, "reference to an undefined block"); 3498 } 3499 return failure(); 3500 } 3501 3502 // Pop the next nested namescope. If there is only one internal namescope, 3503 // just pop the isolated scope. 3504 auto ¤tNameScope = isolatedNameScopes.back(); 3505 if (currentNameScope.definitionsPerScope.size() == 1) 3506 isolatedNameScopes.pop_back(); 3507 else 3508 currentNameScope.popSSANameScope(); 3509 3510 blocksByName.pop_back(); 3511 return success(); 3512 } 3513 3514 /// Register a definition of a value with the symbol table. 3515 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) { 3516 auto &entries = getSSAValueEntry(useInfo.name); 3517 3518 // Make sure there is a slot for this value. 3519 if (entries.size() <= useInfo.number) 3520 entries.resize(useInfo.number + 1); 3521 3522 // If we already have an entry for this, check to see if it was a definition 3523 // or a forward reference. 3524 if (auto existing = entries[useInfo.number].first) { 3525 if (!isForwardRefPlaceholder(existing)) { 3526 return emitError(useInfo.loc) 3527 .append("redefinition of SSA value '", useInfo.name, "'") 3528 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3529 .append("previously defined here"); 3530 } 3531 3532 // If it was a forward reference, update everything that used it to use 3533 // the actual definition instead, delete the forward ref, and remove it 3534 // from our set of forward references we track. 3535 existing.replaceAllUsesWith(value); 3536 existing.getDefiningOp()->destroy(); 3537 forwardRefPlaceholders.erase(existing); 3538 } 3539 3540 /// Record this definition for the current scope. 3541 entries[useInfo.number] = {value, useInfo.loc}; 3542 recordDefinition(useInfo.name); 3543 return success(); 3544 } 3545 3546 /// Parse a (possibly empty) list of SSA operands. 3547 /// 3548 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 3549 /// ssa-use-list-opt ::= ssa-use-list? 3550 /// 3551 ParseResult 3552 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 3553 if (getToken().isNot(Token::percent_identifier)) 3554 return success(); 3555 return parseCommaSeparatedList([&]() -> ParseResult { 3556 SSAUseInfo result; 3557 if (parseSSAUse(result)) 3558 return failure(); 3559 results.push_back(result); 3560 return success(); 3561 }); 3562 } 3563 3564 /// Parse a SSA operand for an operation. 3565 /// 3566 /// ssa-use ::= ssa-id 3567 /// 3568 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 3569 result.name = getTokenSpelling(); 3570 result.number = 0; 3571 result.loc = getToken().getLoc(); 3572 if (parseToken(Token::percent_identifier, "expected SSA operand")) 3573 return failure(); 3574 3575 // If we have an attribute ID, it is a result number. 3576 if (getToken().is(Token::hash_identifier)) { 3577 if (auto value = getToken().getHashIdentifierNumber()) 3578 result.number = value.getValue(); 3579 else 3580 return emitError("invalid SSA value result number"); 3581 consumeToken(Token::hash_identifier); 3582 } 3583 3584 return success(); 3585 } 3586 3587 /// Given an unbound reference to an SSA value and its type, return the value 3588 /// it specifies. This returns null on failure. 3589 Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 3590 auto &entries = getSSAValueEntry(useInfo.name); 3591 3592 // If we have already seen a value of this name, return it. 3593 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 3594 auto result = entries[useInfo.number].first; 3595 // Check that the type matches the other uses. 3596 if (result.getType() == type) 3597 return result; 3598 3599 emitError(useInfo.loc, "use of value '") 3600 .append(useInfo.name, 3601 "' expects different type than prior uses: ", type, " vs ", 3602 result.getType()) 3603 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3604 .append("prior use here"); 3605 return nullptr; 3606 } 3607 3608 // Make sure we have enough slots for this. 3609 if (entries.size() <= useInfo.number) 3610 entries.resize(useInfo.number + 1); 3611 3612 // If the value has already been defined and this is an overly large result 3613 // number, diagnose that. 3614 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 3615 return (emitError(useInfo.loc, "reference to invalid result number"), 3616 nullptr); 3617 3618 // Otherwise, this is a forward reference. Create a placeholder and remember 3619 // that we did so. 3620 auto result = createForwardRefPlaceholder(useInfo.loc, type); 3621 entries[useInfo.number].first = result; 3622 entries[useInfo.number].second = useInfo.loc; 3623 return result; 3624 } 3625 3626 /// Parse an SSA use with an associated type. 3627 /// 3628 /// ssa-use-and-type ::= ssa-use `:` type 3629 ParseResult OperationParser::parseSSADefOrUseAndType( 3630 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 3631 SSAUseInfo useInfo; 3632 if (parseSSAUse(useInfo) || 3633 parseToken(Token::colon, "expected ':' and type for SSA operand")) 3634 return failure(); 3635 3636 auto type = parseType(); 3637 if (!type) 3638 return failure(); 3639 3640 return action(useInfo, type); 3641 } 3642 3643 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 3644 /// followed by a type list. 3645 /// 3646 /// ssa-use-and-type-list 3647 /// ::= ssa-use-list ':' type-list-no-parens 3648 /// 3649 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 3650 SmallVectorImpl<Value> &results) { 3651 SmallVector<SSAUseInfo, 4> valueIDs; 3652 if (parseOptionalSSAUseList(valueIDs)) 3653 return failure(); 3654 3655 // If there were no operands, then there is no colon or type lists. 3656 if (valueIDs.empty()) 3657 return success(); 3658 3659 SmallVector<Type, 4> types; 3660 if (parseToken(Token::colon, "expected ':' in operand list") || 3661 parseTypeListNoParens(types)) 3662 return failure(); 3663 3664 if (valueIDs.size() != types.size()) 3665 return emitError("expected ") 3666 << valueIDs.size() << " types to match operand list"; 3667 3668 results.reserve(valueIDs.size()); 3669 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 3670 if (auto value = resolveSSAUse(valueIDs[i], types[i])) 3671 results.push_back(value); 3672 else 3673 return failure(); 3674 } 3675 3676 return success(); 3677 } 3678 3679 /// Record that a definition was added at the current scope. 3680 void OperationParser::recordDefinition(StringRef def) { 3681 isolatedNameScopes.back().recordDefinition(def); 3682 } 3683 3684 /// Get the value entry for the given SSA name. 3685 SmallVectorImpl<std::pair<Value, SMLoc>> & 3686 OperationParser::getSSAValueEntry(StringRef name) { 3687 return isolatedNameScopes.back().values[name]; 3688 } 3689 3690 /// Create and remember a new placeholder for a forward reference. 3691 Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 3692 // Forward references are always created as operations, because we just need 3693 // something with a def/use chain. 3694 // 3695 // We create these placeholders as having an empty name, which we know 3696 // cannot be created through normal user input, allowing us to distinguish 3697 // them. 3698 auto name = OperationName("placeholder", getContext()); 3699 auto *op = Operation::create( 3700 getEncodedSourceLocation(loc), name, type, /*operands=*/{}, 3701 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, 3702 /*resizableOperandList=*/false); 3703 forwardRefPlaceholders[op->getResult(0)] = loc; 3704 return op->getResult(0); 3705 } 3706 3707 //===----------------------------------------------------------------------===// 3708 // Operation Parsing 3709 //===----------------------------------------------------------------------===// 3710 3711 /// Parse an operation. 3712 /// 3713 /// operation ::= op-result-list? 3714 /// (generic-operation | custom-operation) 3715 /// trailing-location? 3716 /// generic-operation ::= string-literal `(` ssa-use-list? `)` 3717 /// successor-list? (`(` region-list `)`)? 3718 /// attribute-dict? `:` function-type 3719 /// custom-operation ::= bare-id custom-operation-format 3720 /// op-result-list ::= op-result (`,` op-result)* `=` 3721 /// op-result ::= ssa-id (`:` integer-literal) 3722 /// 3723 ParseResult OperationParser::parseOperation() { 3724 auto loc = getToken().getLoc(); 3725 SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs; 3726 size_t numExpectedResults = 0; 3727 if (getToken().is(Token::percent_identifier)) { 3728 // Parse the group of result ids. 3729 auto parseNextResult = [&]() -> ParseResult { 3730 // Parse the next result id. 3731 if (!getToken().is(Token::percent_identifier)) 3732 return emitError("expected valid ssa identifier"); 3733 3734 Token nameTok = getToken(); 3735 consumeToken(Token::percent_identifier); 3736 3737 // If the next token is a ':', we parse the expected result count. 3738 size_t expectedSubResults = 1; 3739 if (consumeIf(Token::colon)) { 3740 // Check that the next token is an integer. 3741 if (!getToken().is(Token::integer)) 3742 return emitError("expected integer number of results"); 3743 3744 // Check that number of results is > 0. 3745 auto val = getToken().getUInt64IntegerValue(); 3746 if (!val.hasValue() || val.getValue() < 1) 3747 return emitError("expected named operation to have atleast 1 result"); 3748 consumeToken(Token::integer); 3749 expectedSubResults = *val; 3750 } 3751 3752 resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults, 3753 nameTok.getLoc()); 3754 numExpectedResults += expectedSubResults; 3755 return success(); 3756 }; 3757 if (parseCommaSeparatedList(parseNextResult)) 3758 return failure(); 3759 3760 if (parseToken(Token::equal, "expected '=' after SSA name")) 3761 return failure(); 3762 } 3763 3764 Operation *op; 3765 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3766 op = parseCustomOperation(); 3767 else if (getToken().is(Token::string)) 3768 op = parseGenericOperation(); 3769 else 3770 return emitError("expected operation name in quotes"); 3771 3772 // If parsing of the basic operation failed, then this whole thing fails. 3773 if (!op) 3774 return failure(); 3775 3776 // If the operation had a name, register it. 3777 if (!resultIDs.empty()) { 3778 if (op->getNumResults() == 0) 3779 return emitError(loc, "cannot name an operation with no results"); 3780 if (numExpectedResults != op->getNumResults()) 3781 return emitError(loc, "operation defines ") 3782 << op->getNumResults() << " results but was provided " 3783 << numExpectedResults << " to bind"; 3784 3785 // Add definitions for each of the result groups. 3786 unsigned opResI = 0; 3787 for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) { 3788 for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) { 3789 if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, 3790 op->getResult(opResI++))) 3791 return failure(); 3792 } 3793 } 3794 } 3795 3796 return success(); 3797 } 3798 3799 /// Parse a single operation successor and its operand list. 3800 /// 3801 /// successor ::= block-id branch-use-list? 3802 /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` 3803 /// 3804 ParseResult 3805 OperationParser::parseSuccessorAndUseList(Block *&dest, 3806 SmallVectorImpl<Value> &operands) { 3807 // Verify branch is identifier and get the matching block. 3808 if (!getToken().is(Token::caret_identifier)) 3809 return emitError("expected block name"); 3810 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3811 consumeToken(); 3812 3813 // Handle optional arguments. 3814 if (consumeIf(Token::l_paren) && 3815 (parseOptionalSSAUseAndTypeList(operands) || 3816 parseToken(Token::r_paren, "expected ')' to close argument list"))) { 3817 return failure(); 3818 } 3819 3820 return success(); 3821 } 3822 3823 /// Parse a comma-separated list of operation successors in brackets. 3824 /// 3825 /// successor-list ::= `[` successor (`,` successor )* `]` 3826 /// 3827 ParseResult OperationParser::parseSuccessors( 3828 SmallVectorImpl<Block *> &destinations, 3829 SmallVectorImpl<SmallVector<Value, 4>> &operands) { 3830 if (parseToken(Token::l_square, "expected '['")) 3831 return failure(); 3832 3833 auto parseElt = [this, &destinations, &operands]() { 3834 Block *dest; 3835 SmallVector<Value, 4> destOperands; 3836 auto res = parseSuccessorAndUseList(dest, destOperands); 3837 destinations.push_back(dest); 3838 operands.push_back(destOperands); 3839 return res; 3840 }; 3841 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3842 /*allowEmptyList=*/false); 3843 } 3844 3845 namespace { 3846 // RAII-style guard for cleaning up the regions in the operation state before 3847 // deleting them. Within the parser, regions may get deleted if parsing failed, 3848 // and other errors may be present, in particular undominated uses. This makes 3849 // sure such uses are deleted. 3850 struct CleanupOpStateRegions { 3851 ~CleanupOpStateRegions() { 3852 SmallVector<Region *, 4> regionsToClean; 3853 regionsToClean.reserve(state.regions.size()); 3854 for (auto ®ion : state.regions) 3855 if (region) 3856 for (auto &block : *region) 3857 block.dropAllDefinedValueUses(); 3858 } 3859 OperationState &state; 3860 }; 3861 } // namespace 3862 3863 Operation *OperationParser::parseGenericOperation() { 3864 // Get location information for the operation. 3865 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3866 3867 auto name = getToken().getStringValue(); 3868 if (name.empty()) 3869 return (emitError("empty operation name is invalid"), nullptr); 3870 if (name.find('\0') != StringRef::npos) 3871 return (emitError("null character not allowed in operation name"), nullptr); 3872 3873 consumeToken(Token::string); 3874 3875 OperationState result(srcLocation, name); 3876 3877 // Generic operations have a resizable operation list. 3878 result.setOperandListToResizable(); 3879 3880 // Parse the operand list. 3881 SmallVector<SSAUseInfo, 8> operandInfos; 3882 3883 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3884 parseOptionalSSAUseList(operandInfos) || 3885 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3886 return nullptr; 3887 } 3888 3889 // Parse the successor list but don't add successors to the result yet to 3890 // avoid messing up with the argument order. 3891 SmallVector<Block *, 2> successors; 3892 SmallVector<SmallVector<Value, 4>, 2> successorOperands; 3893 if (getToken().is(Token::l_square)) { 3894 // Check if the operation is a known terminator. 3895 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3896 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3897 return emitError("successors in non-terminator"), nullptr; 3898 if (parseSuccessors(successors, successorOperands)) 3899 return nullptr; 3900 } 3901 3902 // Parse the region list. 3903 CleanupOpStateRegions guard{result}; 3904 if (consumeIf(Token::l_paren)) { 3905 do { 3906 // Create temporary regions with the top level region as parent. 3907 result.regions.emplace_back(new Region(moduleOp)); 3908 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3909 return nullptr; 3910 } while (consumeIf(Token::comma)); 3911 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3912 return nullptr; 3913 } 3914 3915 if (getToken().is(Token::l_brace)) { 3916 if (parseAttributeDict(result.attributes)) 3917 return nullptr; 3918 } 3919 3920 if (parseToken(Token::colon, "expected ':' followed by operation type")) 3921 return nullptr; 3922 3923 auto typeLoc = getToken().getLoc(); 3924 auto type = parseType(); 3925 if (!type) 3926 return nullptr; 3927 auto fnType = type.dyn_cast<FunctionType>(); 3928 if (!fnType) 3929 return (emitError(typeLoc, "expected function type"), nullptr); 3930 3931 result.addTypes(fnType.getResults()); 3932 3933 // Check that we have the right number of types for the operands. 3934 auto operandTypes = fnType.getInputs(); 3935 if (operandTypes.size() != operandInfos.size()) { 3936 auto plural = "s"[operandInfos.size() == 1]; 3937 return (emitError(typeLoc, "expected ") 3938 << operandInfos.size() << " operand type" << plural 3939 << " but had " << operandTypes.size(), 3940 nullptr); 3941 } 3942 3943 // Resolve all of the operands. 3944 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 3945 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 3946 if (!result.operands.back()) 3947 return nullptr; 3948 } 3949 3950 // Add the successors, and their operands after the proper operands. 3951 for (auto succ : llvm::zip(successors, successorOperands)) { 3952 Block *successor = std::get<0>(succ); 3953 const SmallVector<Value, 4> &operands = std::get<1>(succ); 3954 result.addSuccessor(successor, operands); 3955 } 3956 3957 // Parse a location if one is present. 3958 if (parseOptionalTrailingLocation(result.location)) 3959 return nullptr; 3960 3961 return opBuilder.createOperation(result); 3962 } 3963 3964 Operation *OperationParser::parseGenericOperation(Block *insertBlock, 3965 Block::iterator insertPt) { 3966 OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); 3967 opBuilder.setInsertionPoint(insertBlock, insertPt); 3968 return parseGenericOperation(); 3969 } 3970 3971 namespace { 3972 class CustomOpAsmParser : public OpAsmParser { 3973 public: 3974 CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, 3975 OperationParser &parser) 3976 : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} 3977 3978 /// Parse an instance of the operation described by 'opDefinition' into the 3979 /// provided operation state. 3980 ParseResult parseOperation(OperationState &opState) { 3981 if (opDefinition->parseAssembly(*this, opState)) 3982 return failure(); 3983 return success(); 3984 } 3985 3986 Operation *parseGenericOperation(Block *insertBlock, 3987 Block::iterator insertPt) final { 3988 return parser.parseGenericOperation(insertBlock, insertPt); 3989 } 3990 3991 //===--------------------------------------------------------------------===// 3992 // Utilities 3993 //===--------------------------------------------------------------------===// 3994 3995 /// Return if any errors were emitted during parsing. 3996 bool didEmitError() const { return emittedError; } 3997 3998 /// Emit a diagnostic at the specified location and return failure. 3999 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 4000 emittedError = true; 4001 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 4002 message); 4003 } 4004 4005 llvm::SMLoc getCurrentLocation() override { 4006 return parser.getToken().getLoc(); 4007 } 4008 4009 Builder &getBuilder() const override { return parser.builder; } 4010 4011 llvm::SMLoc getNameLoc() const override { return nameLoc; } 4012 4013 //===--------------------------------------------------------------------===// 4014 // Token Parsing 4015 //===--------------------------------------------------------------------===// 4016 4017 /// Parse a `->` token. 4018 ParseResult parseArrow() override { 4019 return parser.parseToken(Token::arrow, "expected '->'"); 4020 } 4021 4022 /// Parses a `->` if present. 4023 ParseResult parseOptionalArrow() override { 4024 return success(parser.consumeIf(Token::arrow)); 4025 } 4026 4027 /// Parse a `:` token. 4028 ParseResult parseColon() override { 4029 return parser.parseToken(Token::colon, "expected ':'"); 4030 } 4031 4032 /// Parse a `:` token if present. 4033 ParseResult parseOptionalColon() override { 4034 return success(parser.consumeIf(Token::colon)); 4035 } 4036 4037 /// Parse a `,` token. 4038 ParseResult parseComma() override { 4039 return parser.parseToken(Token::comma, "expected ','"); 4040 } 4041 4042 /// Parse a `,` token if present. 4043 ParseResult parseOptionalComma() override { 4044 return success(parser.consumeIf(Token::comma)); 4045 } 4046 4047 /// Parses a `...` if present. 4048 ParseResult parseOptionalEllipsis() override { 4049 return success(parser.consumeIf(Token::ellipsis)); 4050 } 4051 4052 /// Parse a `=` token. 4053 ParseResult parseEqual() override { 4054 return parser.parseToken(Token::equal, "expected '='"); 4055 } 4056 4057 /// Parse a '<' token. 4058 ParseResult parseLess() override { 4059 return parser.parseToken(Token::less, "expected '<'"); 4060 } 4061 4062 /// Parse a '>' token. 4063 ParseResult parseGreater() override { 4064 return parser.parseToken(Token::greater, "expected '>'"); 4065 } 4066 4067 /// Parse a `(` token. 4068 ParseResult parseLParen() override { 4069 return parser.parseToken(Token::l_paren, "expected '('"); 4070 } 4071 4072 /// Parses a '(' if present. 4073 ParseResult parseOptionalLParen() override { 4074 return success(parser.consumeIf(Token::l_paren)); 4075 } 4076 4077 /// Parse a `)` token. 4078 ParseResult parseRParen() override { 4079 return parser.parseToken(Token::r_paren, "expected ')'"); 4080 } 4081 4082 /// Parses a ')' if present. 4083 ParseResult parseOptionalRParen() override { 4084 return success(parser.consumeIf(Token::r_paren)); 4085 } 4086 4087 /// Parse a `[` token. 4088 ParseResult parseLSquare() override { 4089 return parser.parseToken(Token::l_square, "expected '['"); 4090 } 4091 4092 /// Parses a '[' if present. 4093 ParseResult parseOptionalLSquare() override { 4094 return success(parser.consumeIf(Token::l_square)); 4095 } 4096 4097 /// Parse a `]` token. 4098 ParseResult parseRSquare() override { 4099 return parser.parseToken(Token::r_square, "expected ']'"); 4100 } 4101 4102 /// Parses a ']' if present. 4103 ParseResult parseOptionalRSquare() override { 4104 return success(parser.consumeIf(Token::r_square)); 4105 } 4106 4107 //===--------------------------------------------------------------------===// 4108 // Attribute Parsing 4109 //===--------------------------------------------------------------------===// 4110 4111 /// Parse an arbitrary attribute of a given type and return it in result. This 4112 /// also adds the attribute to the specified attribute list with the specified 4113 /// name. 4114 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 4115 SmallVectorImpl<NamedAttribute> &attrs) override { 4116 result = parser.parseAttribute(type); 4117 if (!result) 4118 return failure(); 4119 4120 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 4121 return success(); 4122 } 4123 4124 /// Parse a named dictionary into 'result' if it is present. 4125 ParseResult 4126 parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override { 4127 if (parser.getToken().isNot(Token::l_brace)) 4128 return success(); 4129 return parser.parseAttributeDict(result); 4130 } 4131 4132 /// Parse a named dictionary into 'result' if the `attributes` keyword is 4133 /// present. 4134 ParseResult parseOptionalAttrDictWithKeyword( 4135 SmallVectorImpl<NamedAttribute> &result) override { 4136 if (failed(parseOptionalKeyword("attributes"))) 4137 return success(); 4138 return parser.parseAttributeDict(result); 4139 } 4140 4141 /// Parse an affine map instance into 'map'. 4142 ParseResult parseAffineMap(AffineMap &map) override { 4143 return parser.parseAffineMapReference(map); 4144 } 4145 4146 /// Parse an integer set instance into 'set'. 4147 ParseResult printIntegerSet(IntegerSet &set) override { 4148 return parser.parseIntegerSetReference(set); 4149 } 4150 4151 //===--------------------------------------------------------------------===// 4152 // Identifier Parsing 4153 //===--------------------------------------------------------------------===// 4154 4155 /// Returns if the current token corresponds to a keyword. 4156 bool isCurrentTokenAKeyword() const { 4157 return parser.getToken().is(Token::bare_identifier) || 4158 parser.getToken().isKeyword(); 4159 } 4160 4161 /// Parse the given keyword if present. 4162 ParseResult parseOptionalKeyword(StringRef keyword) override { 4163 // Check that the current token has the same spelling. 4164 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 4165 return failure(); 4166 parser.consumeToken(); 4167 return success(); 4168 } 4169 4170 /// Parse a keyword, if present, into 'keyword'. 4171 ParseResult parseOptionalKeyword(StringRef *keyword) override { 4172 // Check that the current token is a keyword. 4173 if (!isCurrentTokenAKeyword()) 4174 return failure(); 4175 4176 *keyword = parser.getTokenSpelling(); 4177 parser.consumeToken(); 4178 return success(); 4179 } 4180 4181 /// Parse an optional @-identifier and store it (without the '@' symbol) in a 4182 /// string attribute named 'attrName'. 4183 ParseResult 4184 parseOptionalSymbolName(StringAttr &result, StringRef attrName, 4185 SmallVectorImpl<NamedAttribute> &attrs) override { 4186 Token atToken = parser.getToken(); 4187 if (atToken.isNot(Token::at_identifier)) 4188 return failure(); 4189 4190 result = getBuilder().getStringAttr(extractSymbolReference(atToken)); 4191 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 4192 parser.consumeToken(); 4193 return success(); 4194 } 4195 4196 //===--------------------------------------------------------------------===// 4197 // Operand Parsing 4198 //===--------------------------------------------------------------------===// 4199 4200 /// Parse a single operand. 4201 ParseResult parseOperand(OperandType &result) override { 4202 OperationParser::SSAUseInfo useInfo; 4203 if (parser.parseSSAUse(useInfo)) 4204 return failure(); 4205 4206 result = {useInfo.loc, useInfo.name, useInfo.number}; 4207 return success(); 4208 } 4209 4210 /// Parse zero or more SSA comma-separated operand references with a specified 4211 /// surrounding delimiter, and an optional required operand count. 4212 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 4213 int requiredOperandCount = -1, 4214 Delimiter delimiter = Delimiter::None) override { 4215 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 4216 requiredOperandCount, delimiter); 4217 } 4218 4219 /// Parse zero or more SSA comma-separated operand or region arguments with 4220 /// optional surrounding delimiter and required operand count. 4221 ParseResult 4222 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 4223 bool isOperandList, int requiredOperandCount = -1, 4224 Delimiter delimiter = Delimiter::None) { 4225 auto startLoc = parser.getToken().getLoc(); 4226 4227 // Handle delimiters. 4228 switch (delimiter) { 4229 case Delimiter::None: 4230 // Don't check for the absence of a delimiter if the number of operands 4231 // is unknown (and hence the operand list could be empty). 4232 if (requiredOperandCount == -1) 4233 break; 4234 // Token already matches an identifier and so can't be a delimiter. 4235 if (parser.getToken().is(Token::percent_identifier)) 4236 break; 4237 // Test against known delimiters. 4238 if (parser.getToken().is(Token::l_paren) || 4239 parser.getToken().is(Token::l_square)) 4240 return emitError(startLoc, "unexpected delimiter"); 4241 return emitError(startLoc, "invalid operand"); 4242 case Delimiter::OptionalParen: 4243 if (parser.getToken().isNot(Token::l_paren)) 4244 return success(); 4245 LLVM_FALLTHROUGH; 4246 case Delimiter::Paren: 4247 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 4248 return failure(); 4249 break; 4250 case Delimiter::OptionalSquare: 4251 if (parser.getToken().isNot(Token::l_square)) 4252 return success(); 4253 LLVM_FALLTHROUGH; 4254 case Delimiter::Square: 4255 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 4256 return failure(); 4257 break; 4258 } 4259 4260 // Check for zero operands. 4261 if (parser.getToken().is(Token::percent_identifier)) { 4262 do { 4263 OperandType operandOrArg; 4264 if (isOperandList ? parseOperand(operandOrArg) 4265 : parseRegionArgument(operandOrArg)) 4266 return failure(); 4267 result.push_back(operandOrArg); 4268 } while (parser.consumeIf(Token::comma)); 4269 } 4270 4271 // Handle delimiters. If we reach here, the optional delimiters were 4272 // present, so we need to parse their closing one. 4273 switch (delimiter) { 4274 case Delimiter::None: 4275 break; 4276 case Delimiter::OptionalParen: 4277 case Delimiter::Paren: 4278 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 4279 return failure(); 4280 break; 4281 case Delimiter::OptionalSquare: 4282 case Delimiter::Square: 4283 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 4284 return failure(); 4285 break; 4286 } 4287 4288 if (requiredOperandCount != -1 && 4289 result.size() != static_cast<size_t>(requiredOperandCount)) 4290 return emitError(startLoc, "expected ") 4291 << requiredOperandCount << " operands"; 4292 return success(); 4293 } 4294 4295 /// Parse zero or more trailing SSA comma-separated trailing operand 4296 /// references with a specified surrounding delimiter, and an optional 4297 /// required operand count. A leading comma is expected before the operands. 4298 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 4299 int requiredOperandCount, 4300 Delimiter delimiter) override { 4301 if (parser.getToken().is(Token::comma)) { 4302 parseComma(); 4303 return parseOperandList(result, requiredOperandCount, delimiter); 4304 } 4305 if (requiredOperandCount != -1) 4306 return emitError(parser.getToken().getLoc(), "expected ") 4307 << requiredOperandCount << " operands"; 4308 return success(); 4309 } 4310 4311 /// Resolve an operand to an SSA value, emitting an error on failure. 4312 ParseResult resolveOperand(const OperandType &operand, Type type, 4313 SmallVectorImpl<Value> &result) override { 4314 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4315 operand.location}; 4316 if (auto value = parser.resolveSSAUse(operandInfo, type)) { 4317 result.push_back(value); 4318 return success(); 4319 } 4320 return failure(); 4321 } 4322 4323 /// Parse an AffineMap of SSA ids. 4324 ParseResult parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 4325 Attribute &mapAttr, StringRef attrName, 4326 SmallVectorImpl<NamedAttribute> &attrs, 4327 Delimiter delimiter) override { 4328 SmallVector<OperandType, 2> dimOperands; 4329 SmallVector<OperandType, 1> symOperands; 4330 4331 auto parseElement = [&](bool isSymbol) -> ParseResult { 4332 OperandType operand; 4333 if (parseOperand(operand)) 4334 return failure(); 4335 if (isSymbol) 4336 symOperands.push_back(operand); 4337 else 4338 dimOperands.push_back(operand); 4339 return success(); 4340 }; 4341 4342 AffineMap map; 4343 if (parser.parseAffineMapOfSSAIds(map, parseElement, delimiter)) 4344 return failure(); 4345 // Add AffineMap attribute. 4346 if (map) { 4347 mapAttr = AffineMapAttr::get(map); 4348 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 4349 } 4350 4351 // Add dim operands before symbol operands in 'operands'. 4352 operands.assign(dimOperands.begin(), dimOperands.end()); 4353 operands.append(symOperands.begin(), symOperands.end()); 4354 return success(); 4355 } 4356 4357 //===--------------------------------------------------------------------===// 4358 // Region Parsing 4359 //===--------------------------------------------------------------------===// 4360 4361 /// Parse a region that takes `arguments` of `argTypes` types. This 4362 /// effectively defines the SSA values of `arguments` and assigns their type. 4363 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 4364 ArrayRef<Type> argTypes, 4365 bool enableNameShadowing) override { 4366 assert(arguments.size() == argTypes.size() && 4367 "mismatching number of arguments and types"); 4368 4369 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 4370 regionArguments; 4371 for (auto pair : llvm::zip(arguments, argTypes)) { 4372 const OperandType &operand = std::get<0>(pair); 4373 Type type = std::get<1>(pair); 4374 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4375 operand.location}; 4376 regionArguments.emplace_back(operandInfo, type); 4377 } 4378 4379 // Try to parse the region. 4380 assert((!enableNameShadowing || 4381 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 4382 "name shadowing is only allowed on isolated regions"); 4383 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 4384 return failure(); 4385 return success(); 4386 } 4387 4388 /// Parses a region if present. 4389 ParseResult parseOptionalRegion(Region ®ion, 4390 ArrayRef<OperandType> arguments, 4391 ArrayRef<Type> argTypes, 4392 bool enableNameShadowing) override { 4393 if (parser.getToken().isNot(Token::l_brace)) 4394 return success(); 4395 return parseRegion(region, arguments, argTypes, enableNameShadowing); 4396 } 4397 4398 /// Parse a region argument. The type of the argument will be resolved later 4399 /// by a call to `parseRegion`. 4400 ParseResult parseRegionArgument(OperandType &argument) override { 4401 return parseOperand(argument); 4402 } 4403 4404 /// Parse a region argument if present. 4405 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 4406 if (parser.getToken().isNot(Token::percent_identifier)) 4407 return success(); 4408 return parseRegionArgument(argument); 4409 } 4410 4411 ParseResult 4412 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 4413 int requiredOperandCount = -1, 4414 Delimiter delimiter = Delimiter::None) override { 4415 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 4416 requiredOperandCount, delimiter); 4417 } 4418 4419 //===--------------------------------------------------------------------===// 4420 // Successor Parsing 4421 //===--------------------------------------------------------------------===// 4422 4423 /// Parse a single operation successor and its operand list. 4424 ParseResult 4425 parseSuccessorAndUseList(Block *&dest, 4426 SmallVectorImpl<Value> &operands) override { 4427 return parser.parseSuccessorAndUseList(dest, operands); 4428 } 4429 4430 /// Parse an optional operation successor and its operand list. 4431 OptionalParseResult 4432 parseOptionalSuccessorAndUseList(Block *&dest, 4433 SmallVectorImpl<Value> &operands) override { 4434 if (parser.getToken().isNot(Token::caret_identifier)) 4435 return llvm::None; 4436 return parseSuccessorAndUseList(dest, operands); 4437 } 4438 4439 //===--------------------------------------------------------------------===// 4440 // Type Parsing 4441 //===--------------------------------------------------------------------===// 4442 4443 /// Parse a type. 4444 ParseResult parseType(Type &result) override { 4445 return failure(!(result = parser.parseType())); 4446 } 4447 4448 /// Parse an arrow followed by a type list. 4449 ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override { 4450 if (parseArrow() || parser.parseFunctionResultTypes(result)) 4451 return failure(); 4452 return success(); 4453 } 4454 4455 /// Parse an optional arrow followed by a type list. 4456 ParseResult 4457 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 4458 if (!parser.consumeIf(Token::arrow)) 4459 return success(); 4460 return parser.parseFunctionResultTypes(result); 4461 } 4462 4463 /// Parse a colon followed by a type. 4464 ParseResult parseColonType(Type &result) override { 4465 return failure(parser.parseToken(Token::colon, "expected ':'") || 4466 !(result = parser.parseType())); 4467 } 4468 4469 /// Parse a colon followed by a type list, which must have at least one type. 4470 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 4471 if (parser.parseToken(Token::colon, "expected ':'")) 4472 return failure(); 4473 return parser.parseTypeListNoParens(result); 4474 } 4475 4476 /// Parse an optional colon followed by a type list, which if present must 4477 /// have at least one type. 4478 ParseResult 4479 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 4480 if (!parser.consumeIf(Token::colon)) 4481 return success(); 4482 return parser.parseTypeListNoParens(result); 4483 } 4484 4485 /// Parse a list of assignments of the form 4486 /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). 4487 /// The list must contain at least one entry 4488 ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs, 4489 SmallVectorImpl<OperandType> &rhs) override { 4490 auto parseElt = [&]() -> ParseResult { 4491 OperandType regionArg, operand; 4492 if (parseRegionArgument(regionArg) || parseEqual() || 4493 parseOperand(operand)) 4494 return failure(); 4495 lhs.push_back(regionArg); 4496 rhs.push_back(operand); 4497 return success(); 4498 }; 4499 if (parseLParen()) 4500 return failure(); 4501 return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); 4502 } 4503 4504 private: 4505 /// The source location of the operation name. 4506 SMLoc nameLoc; 4507 4508 /// The abstract information of the operation. 4509 const AbstractOperation *opDefinition; 4510 4511 /// The main operation parser. 4512 OperationParser &parser; 4513 4514 /// A flag that indicates if any errors were emitted during parsing. 4515 bool emittedError = false; 4516 }; 4517 } // end anonymous namespace. 4518 4519 Operation *OperationParser::parseCustomOperation() { 4520 auto opLoc = getToken().getLoc(); 4521 auto opName = getTokenSpelling(); 4522 4523 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 4524 if (!opDefinition && !opName.contains('.')) { 4525 // If the operation name has no namespace prefix we treat it as a standard 4526 // operation and prefix it with "std". 4527 // TODO: Would it be better to just build a mapping of the registered 4528 // operations in the standard dialect? 4529 opDefinition = 4530 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 4531 } 4532 4533 if (!opDefinition) { 4534 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 4535 return nullptr; 4536 } 4537 4538 consumeToken(); 4539 4540 // If the custom op parser crashes, produce some indication to help 4541 // debugging. 4542 std::string opNameStr = opName.str(); 4543 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 4544 opNameStr.c_str()); 4545 4546 // Get location information for the operation. 4547 auto srcLocation = getEncodedSourceLocation(opLoc); 4548 4549 // Have the op implementation take a crack and parsing this. 4550 OperationState opState(srcLocation, opDefinition->name); 4551 CleanupOpStateRegions guard{opState}; 4552 CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); 4553 if (opAsmParser.parseOperation(opState)) 4554 return nullptr; 4555 4556 // If it emitted an error, we failed. 4557 if (opAsmParser.didEmitError()) 4558 return nullptr; 4559 4560 // Parse a location if one is present. 4561 if (parseOptionalTrailingLocation(opState.location)) 4562 return nullptr; 4563 4564 // Otherwise, we succeeded. Use the state it parsed as our op information. 4565 return opBuilder.createOperation(opState); 4566 } 4567 4568 //===----------------------------------------------------------------------===// 4569 // Region Parsing 4570 //===----------------------------------------------------------------------===// 4571 4572 /// Region. 4573 /// 4574 /// region ::= '{' region-body 4575 /// 4576 ParseResult OperationParser::parseRegion( 4577 Region ®ion, 4578 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 4579 bool isIsolatedNameScope) { 4580 // Parse the '{'. 4581 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 4582 return failure(); 4583 4584 // Check for an empty region. 4585 if (entryArguments.empty() && consumeIf(Token::r_brace)) 4586 return success(); 4587 auto currentPt = opBuilder.saveInsertionPoint(); 4588 4589 // Push a new named value scope. 4590 pushSSANameScope(isIsolatedNameScope); 4591 4592 // Parse the first block directly to allow for it to be unnamed. 4593 Block *block = new Block(); 4594 4595 // Add arguments to the entry block. 4596 if (!entryArguments.empty()) { 4597 for (auto &placeholderArgPair : entryArguments) { 4598 auto &argInfo = placeholderArgPair.first; 4599 // Ensure that the argument was not already defined. 4600 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 4601 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 4602 "' is already in use") 4603 .attachNote(getEncodedSourceLocation(*defLoc)) 4604 << "previously referenced here"; 4605 } 4606 if (addDefinition(placeholderArgPair.first, 4607 block->addArgument(placeholderArgPair.second))) { 4608 delete block; 4609 return failure(); 4610 } 4611 } 4612 4613 // If we had named arguments, then don't allow a block name. 4614 if (getToken().is(Token::caret_identifier)) 4615 return emitError("invalid block name in region with named arguments"); 4616 } 4617 4618 if (parseBlock(block)) { 4619 delete block; 4620 return failure(); 4621 } 4622 4623 // Verify that no other arguments were parsed. 4624 if (!entryArguments.empty() && 4625 block->getNumArguments() > entryArguments.size()) { 4626 delete block; 4627 return emitError("entry block arguments were already defined"); 4628 } 4629 4630 // Parse the rest of the region. 4631 region.push_back(block); 4632 if (parseRegionBody(region)) 4633 return failure(); 4634 4635 // Pop the SSA value scope for this region. 4636 if (popSSANameScope()) 4637 return failure(); 4638 4639 // Reset the original insertion point. 4640 opBuilder.restoreInsertionPoint(currentPt); 4641 return success(); 4642 } 4643 4644 /// Region. 4645 /// 4646 /// region-body ::= block* '}' 4647 /// 4648 ParseResult OperationParser::parseRegionBody(Region ®ion) { 4649 // Parse the list of blocks. 4650 while (!consumeIf(Token::r_brace)) { 4651 Block *newBlock = nullptr; 4652 if (parseBlock(newBlock)) 4653 return failure(); 4654 region.push_back(newBlock); 4655 } 4656 return success(); 4657 } 4658 4659 //===----------------------------------------------------------------------===// 4660 // Block Parsing 4661 //===----------------------------------------------------------------------===// 4662 4663 /// Block declaration. 4664 /// 4665 /// block ::= block-label? operation* 4666 /// block-label ::= block-id block-arg-list? `:` 4667 /// block-id ::= caret-id 4668 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 4669 /// 4670 ParseResult OperationParser::parseBlock(Block *&block) { 4671 // The first block of a region may already exist, if it does the caret 4672 // identifier is optional. 4673 if (block && getToken().isNot(Token::caret_identifier)) 4674 return parseBlockBody(block); 4675 4676 SMLoc nameLoc = getToken().getLoc(); 4677 auto name = getTokenSpelling(); 4678 if (parseToken(Token::caret_identifier, "expected block name")) 4679 return failure(); 4680 4681 block = defineBlockNamed(name, nameLoc, block); 4682 4683 // Fail if the block was already defined. 4684 if (!block) 4685 return emitError(nameLoc, "redefinition of block '") << name << "'"; 4686 4687 // If an argument list is present, parse it. 4688 if (consumeIf(Token::l_paren)) { 4689 SmallVector<BlockArgument, 8> bbArgs; 4690 if (parseOptionalBlockArgList(bbArgs, block) || 4691 parseToken(Token::r_paren, "expected ')' to end argument list")) 4692 return failure(); 4693 } 4694 4695 if (parseToken(Token::colon, "expected ':' after block name")) 4696 return failure(); 4697 4698 return parseBlockBody(block); 4699 } 4700 4701 ParseResult OperationParser::parseBlockBody(Block *block) { 4702 // Set the insertion point to the end of the block to parse. 4703 opBuilder.setInsertionPointToEnd(block); 4704 4705 // Parse the list of operations that make up the body of the block. 4706 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 4707 if (parseOperation()) 4708 return failure(); 4709 4710 return success(); 4711 } 4712 4713 /// Get the block with the specified name, creating it if it doesn't already 4714 /// exist. The location specified is the point of use, which allows 4715 /// us to diagnose references to blocks that are not defined precisely. 4716 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 4717 auto &blockAndLoc = getBlockInfoByName(name); 4718 if (!blockAndLoc.first) { 4719 blockAndLoc = {new Block(), loc}; 4720 insertForwardRef(blockAndLoc.first, loc); 4721 } 4722 4723 return blockAndLoc.first; 4724 } 4725 4726 /// Define the block with the specified name. Returns the Block* or nullptr in 4727 /// the case of redefinition. 4728 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 4729 Block *existing) { 4730 auto &blockAndLoc = getBlockInfoByName(name); 4731 if (!blockAndLoc.first) { 4732 // If the caller provided a block, use it. Otherwise create a new one. 4733 if (!existing) 4734 existing = new Block(); 4735 blockAndLoc.first = existing; 4736 blockAndLoc.second = loc; 4737 return blockAndLoc.first; 4738 } 4739 4740 // Forward declarations are removed once defined, so if we are defining a 4741 // existing block and it is not a forward declaration, then it is a 4742 // redeclaration. 4743 if (!eraseForwardRef(blockAndLoc.first)) 4744 return nullptr; 4745 return blockAndLoc.first; 4746 } 4747 4748 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 4749 /// 4750 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 4751 /// 4752 ParseResult OperationParser::parseOptionalBlockArgList( 4753 SmallVectorImpl<BlockArgument> &results, Block *owner) { 4754 if (getToken().is(Token::r_brace)) 4755 return success(); 4756 4757 // If the block already has arguments, then we're handling the entry block. 4758 // Parse and register the names for the arguments, but do not add them. 4759 bool definingExistingArgs = owner->getNumArguments() != 0; 4760 unsigned nextArgument = 0; 4761 4762 return parseCommaSeparatedList([&]() -> ParseResult { 4763 return parseSSADefOrUseAndType( 4764 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 4765 // If this block did not have existing arguments, define a new one. 4766 if (!definingExistingArgs) 4767 return addDefinition(useInfo, owner->addArgument(type)); 4768 4769 // Otherwise, ensure that this argument has already been created. 4770 if (nextArgument >= owner->getNumArguments()) 4771 return emitError("too many arguments specified in argument list"); 4772 4773 // Finally, make sure the existing argument has the correct type. 4774 auto arg = owner->getArgument(nextArgument++); 4775 if (arg.getType() != type) 4776 return emitError("argument and block argument type mismatch"); 4777 return addDefinition(useInfo, arg); 4778 }); 4779 }); 4780 } 4781 4782 //===----------------------------------------------------------------------===// 4783 // Top-level entity parsing. 4784 //===----------------------------------------------------------------------===// 4785 4786 namespace { 4787 /// This parser handles entities that are only valid at the top level of the 4788 /// file. 4789 class ModuleParser : public Parser { 4790 public: 4791 explicit ModuleParser(ParserState &state) : Parser(state) {} 4792 4793 ParseResult parseModule(ModuleOp module); 4794 4795 private: 4796 /// Parse an attribute alias declaration. 4797 ParseResult parseAttributeAliasDef(); 4798 4799 /// Parse an attribute alias declaration. 4800 ParseResult parseTypeAliasDef(); 4801 }; 4802 } // end anonymous namespace 4803 4804 /// Parses an attribute alias declaration. 4805 /// 4806 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 4807 /// 4808 ParseResult ModuleParser::parseAttributeAliasDef() { 4809 assert(getToken().is(Token::hash_identifier)); 4810 StringRef aliasName = getTokenSpelling().drop_front(); 4811 4812 // Check for redefinitions. 4813 if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) 4814 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 4815 4816 // Make sure this isn't invading the dialect attribute namespace. 4817 if (aliasName.contains('.')) 4818 return emitError("attribute names with a '.' are reserved for " 4819 "dialect-defined names"); 4820 4821 consumeToken(Token::hash_identifier); 4822 4823 // Parse the '='. 4824 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 4825 return failure(); 4826 4827 // Parse the attribute value. 4828 Attribute attr = parseAttribute(); 4829 if (!attr) 4830 return failure(); 4831 4832 getState().symbols.attributeAliasDefinitions[aliasName] = attr; 4833 return success(); 4834 } 4835 4836 /// Parse a type alias declaration. 4837 /// 4838 /// type-alias-def ::= '!' alias-name `=` 'type' type 4839 /// 4840 ParseResult ModuleParser::parseTypeAliasDef() { 4841 assert(getToken().is(Token::exclamation_identifier)); 4842 StringRef aliasName = getTokenSpelling().drop_front(); 4843 4844 // Check for redefinitions. 4845 if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) 4846 return emitError("redefinition of type alias id '" + aliasName + "'"); 4847 4848 // Make sure this isn't invading the dialect type namespace. 4849 if (aliasName.contains('.')) 4850 return emitError("type names with a '.' are reserved for " 4851 "dialect-defined names"); 4852 4853 consumeToken(Token::exclamation_identifier); 4854 4855 // Parse the '=' and 'type'. 4856 if (parseToken(Token::equal, "expected '=' in type alias definition") || 4857 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 4858 return failure(); 4859 4860 // Parse the type. 4861 Type aliasedType = parseType(); 4862 if (!aliasedType) 4863 return failure(); 4864 4865 // Register this alias with the parser state. 4866 getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); 4867 return success(); 4868 } 4869 4870 /// This is the top-level module parser. 4871 ParseResult ModuleParser::parseModule(ModuleOp module) { 4872 OperationParser opParser(getState(), module); 4873 4874 // Module itself is a name scope. 4875 opParser.pushSSANameScope(/*isIsolated=*/true); 4876 4877 while (true) { 4878 switch (getToken().getKind()) { 4879 default: 4880 // Parse a top-level operation. 4881 if (opParser.parseOperation()) 4882 return failure(); 4883 break; 4884 4885 // If we got to the end of the file, then we're done. 4886 case Token::eof: { 4887 if (opParser.finalize()) 4888 return failure(); 4889 4890 // Handle the case where the top level module was explicitly defined. 4891 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 4892 auto &operations = bodyBlocks.front().getOperations(); 4893 assert(!operations.empty() && "expected a valid module terminator"); 4894 4895 // Check that the first operation is a module, and it is the only 4896 // non-terminator operation. 4897 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 4898 if (nested && std::next(operations.begin(), 2) == operations.end()) { 4899 // Merge the data of the nested module operation into 'module'. 4900 module.setLoc(nested.getLoc()); 4901 module.setAttrs(nested.getOperation()->getAttrList()); 4902 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 4903 4904 // Erase the original module body. 4905 bodyBlocks.pop_front(); 4906 } 4907 4908 return opParser.popSSANameScope(); 4909 } 4910 4911 // If we got an error token, then the lexer already emitted an error, just 4912 // stop. Someday we could introduce error recovery if there was demand 4913 // for it. 4914 case Token::error: 4915 return failure(); 4916 4917 // Parse an attribute alias. 4918 case Token::hash_identifier: 4919 if (parseAttributeAliasDef()) 4920 return failure(); 4921 break; 4922 4923 // Parse a type alias. 4924 case Token::exclamation_identifier: 4925 if (parseTypeAliasDef()) 4926 return failure(); 4927 break; 4928 } 4929 } 4930 } 4931 4932 //===----------------------------------------------------------------------===// 4933 4934 /// This parses the file specified by the indicated SourceMgr and returns an 4935 /// MLIR module if it was valid. If not, it emits diagnostics and returns 4936 /// null. 4937 OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 4938 MLIRContext *context) { 4939 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 4940 4941 // This is the result module we are parsing into. 4942 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 4943 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 4944 4945 SymbolState aliasState; 4946 ParserState state(sourceMgr, context, aliasState); 4947 if (ModuleParser(state).parseModule(*module)) 4948 return nullptr; 4949 4950 // Make sure the parse module has no other structural problems detected by 4951 // the verifier. 4952 if (failed(verify(*module))) 4953 return nullptr; 4954 4955 return module; 4956 } 4957 4958 /// This parses the file specified by the indicated filename and returns an 4959 /// MLIR module if it was valid. If not, the error message is emitted through 4960 /// the error handler registered in the context, and a null pointer is returned. 4961 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4962 MLIRContext *context) { 4963 llvm::SourceMgr sourceMgr; 4964 return parseSourceFile(filename, sourceMgr, context); 4965 } 4966 4967 /// This parses the file specified by the indicated filename using the provided 4968 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 4969 /// message is emitted through the error handler registered in the context, and 4970 /// a null pointer is returned. 4971 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4972 llvm::SourceMgr &sourceMgr, 4973 MLIRContext *context) { 4974 if (sourceMgr.getNumBuffers() != 0) { 4975 // TODO(b/136086478): Extend to support multiple buffers. 4976 emitError(mlir::UnknownLoc::get(context), 4977 "only main buffer parsed at the moment"); 4978 return nullptr; 4979 } 4980 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 4981 if (std::error_code error = file_or_err.getError()) { 4982 emitError(mlir::UnknownLoc::get(context), 4983 "could not open input file " + filename); 4984 return nullptr; 4985 } 4986 4987 // Load the MLIR module. 4988 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 4989 return parseSourceFile(sourceMgr, context); 4990 } 4991 4992 /// This parses the program string to a MLIR module if it was valid. If not, 4993 /// it emits diagnostics and returns null. 4994 OwningModuleRef mlir::parseSourceString(StringRef moduleStr, 4995 MLIRContext *context) { 4996 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 4997 if (!memBuffer) 4998 return nullptr; 4999 5000 SourceMgr sourceMgr; 5001 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 5002 return parseSourceFile(sourceMgr, context); 5003 } 5004 5005 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 5006 /// parsing failed, nullptr is returned. The number of bytes read from the input 5007 /// string is returned in 'numRead'. 5008 template <typename T, typename ParserFn> 5009 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, 5010 ParserFn &&parserFn) { 5011 SymbolState aliasState; 5012 return parseSymbol<T>( 5013 inputStr, context, aliasState, 5014 [&](Parser &parser) { 5015 SourceMgrDiagnosticHandler handler( 5016 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), 5017 parser.getContext()); 5018 return parserFn(parser); 5019 }, 5020 &numRead); 5021 } 5022 5023 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { 5024 size_t numRead = 0; 5025 return parseAttribute(attrStr, context, numRead); 5026 } 5027 Attribute mlir::parseAttribute(StringRef attrStr, Type type) { 5028 size_t numRead = 0; 5029 return parseAttribute(attrStr, type, numRead); 5030 } 5031 5032 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, 5033 size_t &numRead) { 5034 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { 5035 return parser.parseAttribute(); 5036 }); 5037 } 5038 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { 5039 return parseSymbol<Attribute>( 5040 attrStr, type.getContext(), numRead, 5041 [type](Parser &parser) { return parser.parseAttribute(type); }); 5042 } 5043 5044 Type mlir::parseType(StringRef typeStr, MLIRContext *context) { 5045 size_t numRead = 0; 5046 return parseType(typeStr, context, numRead); 5047 } 5048 5049 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { 5050 return parseSymbol<Type>(typeStr, context, numRead, 5051 [](Parser &parser) { return parser.parseType(); }); 5052 } 5053