1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR textual form. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Parser.h" 14 #include "Lexer.h" 15 #include "mlir/Analysis/Verifier.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/IR/Location.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/Module.h" 26 #include "mlir/IR/OpImplementation.h" 27 #include "mlir/IR/StandardTypes.h" 28 #include "mlir/Support/STLExtras.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/ADT/bit.h" 33 #include "llvm/Support/PrettyStackTrace.h" 34 #include "llvm/Support/SMLoc.h" 35 #include "llvm/Support/SourceMgr.h" 36 #include <algorithm> 37 using namespace mlir; 38 using llvm::MemoryBuffer; 39 using llvm::SMLoc; 40 using llvm::SourceMgr; 41 42 namespace { 43 class Parser; 44 45 //===----------------------------------------------------------------------===// 46 // SymbolState 47 //===----------------------------------------------------------------------===// 48 49 /// This class contains record of any parsed top-level symbols. 50 struct SymbolState { 51 // A map from attribute alias identifier to Attribute. 52 llvm::StringMap<Attribute> attributeAliasDefinitions; 53 54 // A map from type alias identifier to Type. 55 llvm::StringMap<Type> typeAliasDefinitions; 56 57 /// A set of locations into the main parser memory buffer for each of the 58 /// active nested parsers. Given that some nested parsers, i.e. custom dialect 59 /// parsers, operate on a temporary memory buffer, this provides an anchor 60 /// point for emitting diagnostics. 61 SmallVector<llvm::SMLoc, 1> nestedParserLocs; 62 63 /// The top-level lexer that contains the original memory buffer provided by 64 /// the user. This is used by nested parsers to get a properly encoded source 65 /// location. 66 Lexer *topLevelLexer = nullptr; 67 }; 68 69 //===----------------------------------------------------------------------===// 70 // ParserState 71 //===----------------------------------------------------------------------===// 72 73 /// This class refers to all of the state maintained globally by the parser, 74 /// such as the current lexer position etc. 75 struct ParserState { 76 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, 77 SymbolState &symbols) 78 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), 79 symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { 80 // Set the top level lexer for the symbol state if one doesn't exist. 81 if (!symbols.topLevelLexer) 82 symbols.topLevelLexer = &lex; 83 } 84 ~ParserState() { 85 // Reset the top level lexer if it refers the lexer in our state. 86 if (symbols.topLevelLexer == &lex) 87 symbols.topLevelLexer = nullptr; 88 } 89 ParserState(const ParserState &) = delete; 90 void operator=(const ParserState &) = delete; 91 92 /// The context we're parsing into. 93 MLIRContext *const context; 94 95 /// The lexer for the source file we're parsing. 96 Lexer lex; 97 98 /// This is the next token that hasn't been consumed yet. 99 Token curToken; 100 101 /// The current state for symbol parsing. 102 SymbolState &symbols; 103 104 /// The depth of this parser in the nested parsing stack. 105 size_t parserDepth; 106 }; 107 108 //===----------------------------------------------------------------------===// 109 // Parser 110 //===----------------------------------------------------------------------===// 111 112 /// This class implement support for parsing global entities like types and 113 /// shared entities like SSA names. It is intended to be subclassed by 114 /// specialized subparsers that include state, e.g. when a local symbol table. 115 class Parser { 116 public: 117 Builder builder; 118 119 Parser(ParserState &state) : builder(state.context), state(state) {} 120 121 // Helper methods to get stuff from the parser-global state. 122 ParserState &getState() const { return state; } 123 MLIRContext *getContext() const { return state.context; } 124 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 125 126 /// Parse a comma-separated list of elements up until the specified end token. 127 ParseResult 128 parseCommaSeparatedListUntil(Token::Kind rightToken, 129 const std::function<ParseResult()> &parseElement, 130 bool allowEmptyList = true); 131 132 /// Parse a comma separated list of elements that must have at least one entry 133 /// in it. 134 ParseResult 135 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 136 137 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 138 139 // We have two forms of parsing methods - those that return a non-null 140 // pointer on success, and those that return a ParseResult to indicate whether 141 // they returned a failure. The second class fills in by-reference arguments 142 // as the results of their action. 143 144 //===--------------------------------------------------------------------===// 145 // Error Handling 146 //===--------------------------------------------------------------------===// 147 148 /// Emit an error and return failure. 149 InFlightDiagnostic emitError(const Twine &message = {}) { 150 return emitError(state.curToken.getLoc(), message); 151 } 152 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 153 154 /// Encode the specified source location information into an attribute for 155 /// attachment to the IR. 156 Location getEncodedSourceLocation(llvm::SMLoc loc) { 157 // If there are no active nested parsers, we can get the encoded source 158 // location directly. 159 if (state.parserDepth == 0) 160 return state.lex.getEncodedSourceLocation(loc); 161 // Otherwise, we need to re-encode it to point to the top level buffer. 162 return state.symbols.topLevelLexer->getEncodedSourceLocation( 163 remapLocationToTopLevelBuffer(loc)); 164 } 165 166 /// Remaps the given SMLoc to the top level lexer of the parser. This is used 167 /// to adjust locations of potentially nested parsers to ensure that they can 168 /// be emitted properly as diagnostics. 169 llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { 170 // If there are no active nested parsers, we can return location directly. 171 SymbolState &symbols = state.symbols; 172 if (state.parserDepth == 0) 173 return loc; 174 assert(symbols.topLevelLexer && "expected valid top-level lexer"); 175 176 // Otherwise, we need to remap the location to the main parser. This is 177 // simply offseting the location onto the location of the last nested 178 // parser. 179 size_t offset = loc.getPointer() - state.lex.getBufferBegin(); 180 auto *rawLoc = 181 symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; 182 return llvm::SMLoc::getFromPointer(rawLoc); 183 } 184 185 //===--------------------------------------------------------------------===// 186 // Token Parsing 187 //===--------------------------------------------------------------------===// 188 189 /// Return the current token the parser is inspecting. 190 const Token &getToken() const { return state.curToken; } 191 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 192 193 /// If the current token has the specified kind, consume it and return true. 194 /// If not, return false. 195 bool consumeIf(Token::Kind kind) { 196 if (state.curToken.isNot(kind)) 197 return false; 198 consumeToken(kind); 199 return true; 200 } 201 202 /// Advance the current lexer onto the next token. 203 void consumeToken() { 204 assert(state.curToken.isNot(Token::eof, Token::error) && 205 "shouldn't advance past EOF or errors"); 206 state.curToken = state.lex.lexToken(); 207 } 208 209 /// Advance the current lexer onto the next token, asserting what the expected 210 /// current token is. This is preferred to the above method because it leads 211 /// to more self-documenting code with better checking. 212 void consumeToken(Token::Kind kind) { 213 assert(state.curToken.is(kind) && "consumed an unexpected token"); 214 consumeToken(); 215 } 216 217 /// Consume the specified token if present and return success. On failure, 218 /// output a diagnostic and return failure. 219 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 220 221 //===--------------------------------------------------------------------===// 222 // Type Parsing 223 //===--------------------------------------------------------------------===// 224 225 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 226 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 227 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 228 229 /// Parse an arbitrary type. 230 Type parseType(); 231 232 /// Parse a complex type. 233 Type parseComplexType(); 234 235 /// Parse an extended type. 236 Type parseExtendedType(); 237 238 /// Parse a function type. 239 Type parseFunctionType(); 240 241 /// Parse a memref type. 242 Type parseMemRefType(); 243 244 /// Parse a non function type. 245 Type parseNonFunctionType(); 246 247 /// Parse a tensor type. 248 Type parseTensorType(); 249 250 /// Parse a tuple type. 251 Type parseTupleType(); 252 253 /// Parse a vector type. 254 VectorType parseVectorType(); 255 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 256 bool allowDynamic = true); 257 ParseResult parseXInDimensionList(); 258 259 /// Parse strided layout specification. 260 ParseResult parseStridedLayout(int64_t &offset, 261 SmallVectorImpl<int64_t> &strides); 262 263 // Parse a brace-delimiter list of comma-separated integers with `?` as an 264 // unknown marker. 265 ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions); 266 267 //===--------------------------------------------------------------------===// 268 // Attribute Parsing 269 //===--------------------------------------------------------------------===// 270 271 /// Parse an arbitrary attribute with an optional type. 272 Attribute parseAttribute(Type type = {}); 273 274 /// Parse an attribute dictionary. 275 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 276 277 /// Parse an extended attribute. 278 Attribute parseExtendedAttr(Type type); 279 280 /// Parse a float attribute. 281 Attribute parseFloatAttr(Type type, bool isNegative); 282 283 /// Parse a decimal or a hexadecimal literal, which can be either an integer 284 /// or a float attribute. 285 Attribute parseDecOrHexAttr(Type type, bool isNegative); 286 287 /// Parse an opaque elements attribute. 288 Attribute parseOpaqueElementsAttr(Type attrType); 289 290 /// Parse a dense elements attribute. 291 Attribute parseDenseElementsAttr(Type attrType); 292 ShapedType parseElementsLiteralType(Type type); 293 294 /// Parse a sparse elements attribute. 295 Attribute parseSparseElementsAttr(Type attrType); 296 297 //===--------------------------------------------------------------------===// 298 // Location Parsing 299 //===--------------------------------------------------------------------===// 300 301 /// Parse an inline location. 302 ParseResult parseLocation(LocationAttr &loc); 303 304 /// Parse a raw location instance. 305 ParseResult parseLocationInstance(LocationAttr &loc); 306 307 /// Parse a callsite location instance. 308 ParseResult parseCallSiteLocation(LocationAttr &loc); 309 310 /// Parse a fused location instance. 311 ParseResult parseFusedLocation(LocationAttr &loc); 312 313 /// Parse a name or FileLineCol location instance. 314 ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); 315 316 /// Parse an optional trailing location. 317 /// 318 /// trailing-location ::= (`loc` `(` location `)`)? 319 /// 320 ParseResult parseOptionalTrailingLocation(Location &loc) { 321 // If there is a 'loc' we parse a trailing location. 322 if (!getToken().is(Token::kw_loc)) 323 return success(); 324 325 // Parse the location. 326 LocationAttr directLoc; 327 if (parseLocation(directLoc)) 328 return failure(); 329 loc = directLoc; 330 return success(); 331 } 332 333 //===--------------------------------------------------------------------===// 334 // Affine Parsing 335 //===--------------------------------------------------------------------===// 336 337 /// Parse a reference to either an affine map, or an integer set. 338 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 339 IntegerSet &set); 340 ParseResult parseAffineMapReference(AffineMap &map); 341 ParseResult parseIntegerSetReference(IntegerSet &set); 342 343 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 344 ParseResult 345 parseAffineMapOfSSAIds(AffineMap &map, 346 function_ref<ParseResult(bool)> parseElement, 347 OpAsmParser::Delimiter delimiter); 348 349 private: 350 /// The Parser is subclassed and reinstantiated. Do not add additional 351 /// non-trivial state here, add it to the ParserState class. 352 ParserState &state; 353 }; 354 } // end anonymous namespace 355 356 //===----------------------------------------------------------------------===// 357 // Helper methods. 358 //===----------------------------------------------------------------------===// 359 360 /// Parse a comma separated list of elements that must have at least one entry 361 /// in it. 362 ParseResult Parser::parseCommaSeparatedList( 363 const std::function<ParseResult()> &parseElement) { 364 // Non-empty case starts with an element. 365 if (parseElement()) 366 return failure(); 367 368 // Otherwise we have a list of comma separated elements. 369 while (consumeIf(Token::comma)) { 370 if (parseElement()) 371 return failure(); 372 } 373 return success(); 374 } 375 376 /// Parse a comma-separated list of elements, terminated with an arbitrary 377 /// token. This allows empty lists if allowEmptyList is true. 378 /// 379 /// abstract-list ::= rightToken // if allowEmptyList == true 380 /// abstract-list ::= element (',' element)* rightToken 381 /// 382 ParseResult Parser::parseCommaSeparatedListUntil( 383 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 384 bool allowEmptyList) { 385 // Handle the empty case. 386 if (getToken().is(rightToken)) { 387 if (!allowEmptyList) 388 return emitError("expected list element"); 389 consumeToken(rightToken); 390 return success(); 391 } 392 393 if (parseCommaSeparatedList(parseElement) || 394 parseToken(rightToken, "expected ',' or '" + 395 Token::getTokenSpelling(rightToken) + "'")) 396 return failure(); 397 398 return success(); 399 } 400 401 //===----------------------------------------------------------------------===// 402 // DialectAsmParser 403 //===----------------------------------------------------------------------===// 404 405 namespace { 406 /// This class provides the main implementation of the DialectAsmParser that 407 /// allows for dialects to parse attributes and types. This allows for dialect 408 /// hooking into the main MLIR parsing logic. 409 class CustomDialectAsmParser : public DialectAsmParser { 410 public: 411 CustomDialectAsmParser(StringRef fullSpec, Parser &parser) 412 : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), 413 parser(parser) {} 414 ~CustomDialectAsmParser() override {} 415 416 /// Emit a diagnostic at the specified location and return failure. 417 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 418 return parser.emitError(loc, message); 419 } 420 421 /// Return a builder which provides useful access to MLIRContext, global 422 /// objects like types and attributes. 423 Builder &getBuilder() const override { return parser.builder; } 424 425 /// Get the location of the next token and store it into the argument. This 426 /// always succeeds. 427 llvm::SMLoc getCurrentLocation() override { 428 return parser.getToken().getLoc(); 429 } 430 431 /// Return the location of the original name token. 432 llvm::SMLoc getNameLoc() const override { return nameLoc; } 433 434 /// Re-encode the given source location as an MLIR location and return it. 435 Location getEncodedSourceLoc(llvm::SMLoc loc) override { 436 return parser.getEncodedSourceLocation(loc); 437 } 438 439 /// Returns the full specification of the symbol being parsed. This allows 440 /// for using a separate parser if necessary. 441 StringRef getFullSymbolSpec() const override { return fullSpec; } 442 443 /// Parse a floating point value from the stream. 444 ParseResult parseFloat(double &result) override { 445 bool negative = parser.consumeIf(Token::minus); 446 Token curTok = parser.getToken(); 447 448 // Check for a floating point value. 449 if (curTok.is(Token::floatliteral)) { 450 auto val = curTok.getFloatingPointValue(); 451 if (!val.hasValue()) 452 return emitError(curTok.getLoc(), "floating point value too large"); 453 parser.consumeToken(Token::floatliteral); 454 result = negative ? -*val : *val; 455 return success(); 456 } 457 458 // TODO(riverriddle) support hex floating point values. 459 return emitError(getCurrentLocation(), "expected floating point literal"); 460 } 461 462 /// Parse an optional integer value from the stream. 463 OptionalParseResult parseOptionalInteger(uint64_t &result) override { 464 Token curToken = parser.getToken(); 465 if (curToken.isNot(Token::integer, Token::minus)) 466 return llvm::None; 467 468 bool negative = parser.consumeIf(Token::minus); 469 Token curTok = parser.getToken(); 470 if (parser.parseToken(Token::integer, "expected integer value")) 471 return failure(); 472 473 auto val = curTok.getUInt64IntegerValue(); 474 if (!val) 475 return emitError(curTok.getLoc(), "integer value too large"); 476 result = negative ? -*val : *val; 477 return success(); 478 } 479 480 //===--------------------------------------------------------------------===// 481 // Token Parsing 482 //===--------------------------------------------------------------------===// 483 484 /// Parse a `->` token. 485 ParseResult parseArrow() override { 486 return parser.parseToken(Token::arrow, "expected '->'"); 487 } 488 489 /// Parses a `->` if present. 490 ParseResult parseOptionalArrow() override { 491 return success(parser.consumeIf(Token::arrow)); 492 } 493 494 /// Parse a '{' token. 495 ParseResult parseLBrace() override { 496 return parser.parseToken(Token::l_brace, "expected '{'"); 497 } 498 499 /// Parse a '{' token if present 500 ParseResult parseOptionalLBrace() override { 501 return success(parser.consumeIf(Token::l_brace)); 502 } 503 504 /// Parse a `}` token. 505 ParseResult parseRBrace() override { 506 return parser.parseToken(Token::r_brace, "expected '}'"); 507 } 508 509 /// Parse a `}` token if present 510 ParseResult parseOptionalRBrace() override { 511 return success(parser.consumeIf(Token::r_brace)); 512 } 513 514 /// Parse a `:` token. 515 ParseResult parseColon() override { 516 return parser.parseToken(Token::colon, "expected ':'"); 517 } 518 519 /// Parse a `:` token if present. 520 ParseResult parseOptionalColon() override { 521 return success(parser.consumeIf(Token::colon)); 522 } 523 524 /// Parse a `,` token. 525 ParseResult parseComma() override { 526 return parser.parseToken(Token::comma, "expected ','"); 527 } 528 529 /// Parse a `,` token if present. 530 ParseResult parseOptionalComma() override { 531 return success(parser.consumeIf(Token::comma)); 532 } 533 534 /// Parses a `...` if present. 535 ParseResult parseOptionalEllipsis() override { 536 return success(parser.consumeIf(Token::ellipsis)); 537 } 538 539 /// Parse a `=` token. 540 ParseResult parseEqual() override { 541 return parser.parseToken(Token::equal, "expected '='"); 542 } 543 544 /// Parse a '<' token. 545 ParseResult parseLess() override { 546 return parser.parseToken(Token::less, "expected '<'"); 547 } 548 549 /// Parse a `<` token if present. 550 ParseResult parseOptionalLess() override { 551 return success(parser.consumeIf(Token::less)); 552 } 553 554 /// Parse a '>' token. 555 ParseResult parseGreater() override { 556 return parser.parseToken(Token::greater, "expected '>'"); 557 } 558 559 /// Parse a `>` token if present. 560 ParseResult parseOptionalGreater() override { 561 return success(parser.consumeIf(Token::greater)); 562 } 563 564 /// Parse a `(` token. 565 ParseResult parseLParen() override { 566 return parser.parseToken(Token::l_paren, "expected '('"); 567 } 568 569 /// Parses a '(' if present. 570 ParseResult parseOptionalLParen() override { 571 return success(parser.consumeIf(Token::l_paren)); 572 } 573 574 /// Parse a `)` token. 575 ParseResult parseRParen() override { 576 return parser.parseToken(Token::r_paren, "expected ')'"); 577 } 578 579 /// Parses a ')' if present. 580 ParseResult parseOptionalRParen() override { 581 return success(parser.consumeIf(Token::r_paren)); 582 } 583 584 /// Parse a `[` token. 585 ParseResult parseLSquare() override { 586 return parser.parseToken(Token::l_square, "expected '['"); 587 } 588 589 /// Parses a '[' if present. 590 ParseResult parseOptionalLSquare() override { 591 return success(parser.consumeIf(Token::l_square)); 592 } 593 594 /// Parse a `]` token. 595 ParseResult parseRSquare() override { 596 return parser.parseToken(Token::r_square, "expected ']'"); 597 } 598 599 /// Parses a ']' if present. 600 ParseResult parseOptionalRSquare() override { 601 return success(parser.consumeIf(Token::r_square)); 602 } 603 604 /// Parses a '?' if present. 605 ParseResult parseOptionalQuestion() override { 606 return success(parser.consumeIf(Token::question)); 607 } 608 609 /// Parses a '*' if present. 610 ParseResult parseOptionalStar() override { 611 return success(parser.consumeIf(Token::star)); 612 } 613 614 /// Returns if the current token corresponds to a keyword. 615 bool isCurrentTokenAKeyword() const { 616 return parser.getToken().is(Token::bare_identifier) || 617 parser.getToken().isKeyword(); 618 } 619 620 /// Parse the given keyword if present. 621 ParseResult parseOptionalKeyword(StringRef keyword) override { 622 // Check that the current token has the same spelling. 623 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 624 return failure(); 625 parser.consumeToken(); 626 return success(); 627 } 628 629 /// Parse a keyword, if present, into 'keyword'. 630 ParseResult parseOptionalKeyword(StringRef *keyword) override { 631 // Check that the current token is a keyword. 632 if (!isCurrentTokenAKeyword()) 633 return failure(); 634 635 *keyword = parser.getTokenSpelling(); 636 parser.consumeToken(); 637 return success(); 638 } 639 640 //===--------------------------------------------------------------------===// 641 // Attribute Parsing 642 //===--------------------------------------------------------------------===// 643 644 /// Parse an arbitrary attribute and return it in result. 645 ParseResult parseAttribute(Attribute &result, Type type) override { 646 result = parser.parseAttribute(type); 647 return success(static_cast<bool>(result)); 648 } 649 650 /// Parse an affine map instance into 'map'. 651 ParseResult parseAffineMap(AffineMap &map) override { 652 return parser.parseAffineMapReference(map); 653 } 654 655 /// Parse an integer set instance into 'set'. 656 ParseResult printIntegerSet(IntegerSet &set) override { 657 return parser.parseIntegerSetReference(set); 658 } 659 660 //===--------------------------------------------------------------------===// 661 // Type Parsing 662 //===--------------------------------------------------------------------===// 663 664 ParseResult parseType(Type &result) override { 665 result = parser.parseType(); 666 return success(static_cast<bool>(result)); 667 } 668 669 ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, 670 bool allowDynamic) override { 671 return parser.parseDimensionListRanked(dimensions, allowDynamic); 672 } 673 674 private: 675 /// The full symbol specification. 676 StringRef fullSpec; 677 678 /// The source location of the dialect symbol. 679 SMLoc nameLoc; 680 681 /// The main parser. 682 Parser &parser; 683 }; 684 } // namespace 685 686 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 687 /// and may be recursive. Return with the 'prettyName' StringRef encompassing 688 /// the entire pretty name. 689 /// 690 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 691 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 692 /// | '(' pretty-dialect-sym-contents+ ')' 693 /// | '[' pretty-dialect-sym-contents+ ']' 694 /// | '{' pretty-dialect-sym-contents+ '}' 695 /// | '[^[<({>\])}\0]+' 696 /// 697 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 698 // Pretty symbol names are a relatively unstructured format that contains a 699 // series of properly nested punctuation, with anything else in the middle. 700 // Scan ahead to find it and consume it if successful, otherwise emit an 701 // error. 702 auto *curPtr = getTokenSpelling().data(); 703 704 SmallVector<char, 8> nestedPunctuation; 705 706 // Scan over the nested punctuation, bailing out on error and consuming until 707 // we find the end. We know that we're currently looking at the '<', so we 708 // can go until we find the matching '>' character. 709 assert(*curPtr == '<'); 710 do { 711 char c = *curPtr++; 712 switch (c) { 713 case '\0': 714 // This also handles the EOF case. 715 return emitError("unexpected nul or EOF in pretty dialect name"); 716 case '<': 717 case '[': 718 case '(': 719 case '{': 720 nestedPunctuation.push_back(c); 721 continue; 722 723 case '-': 724 // The sequence `->` is treated as special token. 725 if (*curPtr == '>') 726 ++curPtr; 727 continue; 728 729 case '>': 730 if (nestedPunctuation.pop_back_val() != '<') 731 return emitError("unbalanced '>' character in pretty dialect name"); 732 break; 733 case ']': 734 if (nestedPunctuation.pop_back_val() != '[') 735 return emitError("unbalanced ']' character in pretty dialect name"); 736 break; 737 case ')': 738 if (nestedPunctuation.pop_back_val() != '(') 739 return emitError("unbalanced ')' character in pretty dialect name"); 740 break; 741 case '}': 742 if (nestedPunctuation.pop_back_val() != '{') 743 return emitError("unbalanced '}' character in pretty dialect name"); 744 break; 745 746 default: 747 continue; 748 } 749 } while (!nestedPunctuation.empty()); 750 751 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 752 // consuming all this stuff, and return. 753 state.lex.resetPointer(curPtr); 754 755 unsigned length = curPtr - prettyName.begin(); 756 prettyName = StringRef(prettyName.begin(), length); 757 consumeToken(); 758 return success(); 759 } 760 761 /// Parse an extended dialect symbol. 762 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 763 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 764 SymbolAliasMap &aliases, 765 CreateFn &&createSymbol) { 766 // Parse the dialect namespace. 767 StringRef identifier = p.getTokenSpelling().drop_front(); 768 auto loc = p.getToken().getLoc(); 769 p.consumeToken(identifierTok); 770 771 // If there is no '<' token following this, and if the typename contains no 772 // dot, then we are parsing a symbol alias. 773 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 774 // Check for an alias for this type. 775 auto aliasIt = aliases.find(identifier); 776 if (aliasIt == aliases.end()) 777 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 778 nullptr); 779 return aliasIt->second; 780 } 781 782 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 783 // a dot, then this is the "pretty" form. If not, it is the verbose form that 784 // looks like <"...">. 785 std::string symbolData; 786 auto dialectName = identifier; 787 788 // Handle the verbose form, where "identifier" is a simple dialect name. 789 if (!identifier.contains('.')) { 790 // Consume the '<'. 791 if (p.parseToken(Token::less, "expected '<' in dialect type")) 792 return nullptr; 793 794 // Parse the symbol specific data. 795 if (p.getToken().isNot(Token::string)) 796 return (p.emitError("expected string literal data in dialect symbol"), 797 nullptr); 798 symbolData = p.getToken().getStringValue(); 799 loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); 800 p.consumeToken(Token::string); 801 802 // Consume the '>'. 803 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 804 return nullptr; 805 } else { 806 // Ok, the dialect name is the part of the identifier before the dot, the 807 // part after the dot is the dialect's symbol, or the start thereof. 808 auto dotHalves = identifier.split('.'); 809 dialectName = dotHalves.first; 810 auto prettyName = dotHalves.second; 811 loc = llvm::SMLoc::getFromPointer(prettyName.data()); 812 813 // If the dialect's symbol is followed immediately by a <, then lex the body 814 // of it into prettyName. 815 if (p.getToken().is(Token::less) && 816 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 817 if (p.parsePrettyDialectSymbolName(prettyName)) 818 return nullptr; 819 } 820 821 symbolData = prettyName.str(); 822 } 823 824 // Record the name location of the type remapped to the top level buffer. 825 llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); 826 p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); 827 828 // Call into the provided symbol construction function. 829 Symbol sym = createSymbol(dialectName, symbolData, loc); 830 831 // Pop the last parser location. 832 p.getState().symbols.nestedParserLocs.pop_back(); 833 return sym; 834 } 835 836 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 837 /// parsing failed, nullptr is returned. The number of bytes read from the input 838 /// string is returned in 'numRead'. 839 template <typename T, typename ParserFn> 840 static T parseSymbol(StringRef inputStr, MLIRContext *context, 841 SymbolState &symbolState, ParserFn &&parserFn, 842 size_t *numRead = nullptr) { 843 SourceMgr sourceMgr; 844 auto memBuffer = MemoryBuffer::getMemBuffer( 845 inputStr, /*BufferName=*/"<mlir_parser_buffer>", 846 /*RequiresNullTerminator=*/false); 847 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 848 ParserState state(sourceMgr, context, symbolState); 849 Parser parser(state); 850 851 Token startTok = parser.getToken(); 852 T symbol = parserFn(parser); 853 if (!symbol) 854 return T(); 855 856 // If 'numRead' is valid, then provide the number of bytes that were read. 857 Token endTok = parser.getToken(); 858 if (numRead) { 859 *numRead = static_cast<size_t>(endTok.getLoc().getPointer() - 860 startTok.getLoc().getPointer()); 861 862 // Otherwise, ensure that all of the tokens were parsed. 863 } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { 864 parser.emitError(endTok.getLoc(), "encountered unexpected token"); 865 return T(); 866 } 867 return symbol; 868 } 869 870 //===----------------------------------------------------------------------===// 871 // Error Handling 872 //===----------------------------------------------------------------------===// 873 874 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 875 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 876 877 // If we hit a parse error in response to a lexer error, then the lexer 878 // already reported the error. 879 if (getToken().is(Token::error)) 880 diag.abandon(); 881 return diag; 882 } 883 884 //===----------------------------------------------------------------------===// 885 // Token Parsing 886 //===----------------------------------------------------------------------===// 887 888 /// Consume the specified token if present and return success. On failure, 889 /// output a diagnostic and return failure. 890 ParseResult Parser::parseToken(Token::Kind expectedToken, 891 const Twine &message) { 892 if (consumeIf(expectedToken)) 893 return success(); 894 return emitError(message); 895 } 896 897 //===----------------------------------------------------------------------===// 898 // Type Parsing 899 //===----------------------------------------------------------------------===// 900 901 /// Parse an arbitrary type. 902 /// 903 /// type ::= function-type 904 /// | non-function-type 905 /// 906 Type Parser::parseType() { 907 if (getToken().is(Token::l_paren)) 908 return parseFunctionType(); 909 return parseNonFunctionType(); 910 } 911 912 /// Parse a function result type. 913 /// 914 /// function-result-type ::= type-list-parens 915 /// | non-function-type 916 /// 917 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 918 if (getToken().is(Token::l_paren)) 919 return parseTypeListParens(elements); 920 921 Type t = parseNonFunctionType(); 922 if (!t) 923 return failure(); 924 elements.push_back(t); 925 return success(); 926 } 927 928 /// Parse a list of types without an enclosing parenthesis. The list must have 929 /// at least one member. 930 /// 931 /// type-list-no-parens ::= type (`,` type)* 932 /// 933 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 934 auto parseElt = [&]() -> ParseResult { 935 auto elt = parseType(); 936 elements.push_back(elt); 937 return elt ? success() : failure(); 938 }; 939 940 return parseCommaSeparatedList(parseElt); 941 } 942 943 /// Parse a parenthesized list of types. 944 /// 945 /// type-list-parens ::= `(` `)` 946 /// | `(` type-list-no-parens `)` 947 /// 948 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 949 if (parseToken(Token::l_paren, "expected '('")) 950 return failure(); 951 952 // Handle empty lists. 953 if (getToken().is(Token::r_paren)) 954 return consumeToken(), success(); 955 956 if (parseTypeListNoParens(elements) || 957 parseToken(Token::r_paren, "expected ')'")) 958 return failure(); 959 return success(); 960 } 961 962 /// Parse a complex type. 963 /// 964 /// complex-type ::= `complex` `<` type `>` 965 /// 966 Type Parser::parseComplexType() { 967 consumeToken(Token::kw_complex); 968 969 // Parse the '<'. 970 if (parseToken(Token::less, "expected '<' in complex type")) 971 return nullptr; 972 973 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 974 auto elementType = parseType(); 975 if (!elementType || 976 parseToken(Token::greater, "expected '>' in complex type")) 977 return nullptr; 978 979 return ComplexType::getChecked(elementType, typeLocation); 980 } 981 982 /// Parse an extended type. 983 /// 984 /// extended-type ::= (dialect-type | type-alias) 985 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 986 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 987 /// type-alias ::= `!` alias-name 988 /// 989 Type Parser::parseExtendedType() { 990 return parseExtendedSymbol<Type>( 991 *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, 992 [&](StringRef dialectName, StringRef symbolData, 993 llvm::SMLoc loc) -> Type { 994 // If we found a registered dialect, then ask it to parse the type. 995 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 996 return parseSymbol<Type>( 997 symbolData, state.context, state.symbols, [&](Parser &parser) { 998 CustomDialectAsmParser customParser(symbolData, parser); 999 return dialect->parseType(customParser); 1000 }); 1001 } 1002 1003 // Otherwise, form a new opaque type. 1004 return OpaqueType::getChecked( 1005 Identifier::get(dialectName, state.context), symbolData, 1006 state.context, getEncodedSourceLocation(loc)); 1007 }); 1008 } 1009 1010 /// Parse a function type. 1011 /// 1012 /// function-type ::= type-list-parens `->` function-result-type 1013 /// 1014 Type Parser::parseFunctionType() { 1015 assert(getToken().is(Token::l_paren)); 1016 1017 SmallVector<Type, 4> arguments, results; 1018 if (parseTypeListParens(arguments) || 1019 parseToken(Token::arrow, "expected '->' in function type") || 1020 parseFunctionResultTypes(results)) 1021 return nullptr; 1022 1023 return builder.getFunctionType(arguments, results); 1024 } 1025 1026 /// Parse the offset and strides from a strided layout specification. 1027 /// 1028 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 1029 /// 1030 ParseResult Parser::parseStridedLayout(int64_t &offset, 1031 SmallVectorImpl<int64_t> &strides) { 1032 // Parse offset. 1033 consumeToken(Token::kw_offset); 1034 if (!consumeIf(Token::colon)) 1035 return emitError("expected colon after `offset` keyword"); 1036 auto maybeOffset = getToken().getUnsignedIntegerValue(); 1037 bool question = getToken().is(Token::question); 1038 if (!maybeOffset && !question) 1039 return emitError("invalid offset"); 1040 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue()) 1041 : MemRefType::getDynamicStrideOrOffset(); 1042 consumeToken(); 1043 1044 if (!consumeIf(Token::comma)) 1045 return emitError("expected comma after offset value"); 1046 1047 // Parse stride list. 1048 if (!consumeIf(Token::kw_strides)) 1049 return emitError("expected `strides` keyword after offset specification"); 1050 if (!consumeIf(Token::colon)) 1051 return emitError("expected colon after `strides` keyword"); 1052 if (failed(parseStrideList(strides))) 1053 return emitError("invalid braces-enclosed stride list"); 1054 if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) 1055 return emitError("invalid memref stride"); 1056 1057 return success(); 1058 } 1059 1060 /// Parse a memref type. 1061 /// 1062 /// memref-type ::= ranked-memref-type | unranked-memref-type 1063 /// 1064 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type 1065 /// (`,` semi-affine-map-composition)? (`,` 1066 /// memory-space)? `>` 1067 /// 1068 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` 1069 /// 1070 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 1071 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 1072 /// 1073 Type Parser::parseMemRefType() { 1074 consumeToken(Token::kw_memref); 1075 1076 if (parseToken(Token::less, "expected '<' in memref type")) 1077 return nullptr; 1078 1079 bool isUnranked; 1080 SmallVector<int64_t, 4> dimensions; 1081 1082 if (consumeIf(Token::star)) { 1083 // This is an unranked memref type. 1084 isUnranked = true; 1085 if (parseXInDimensionList()) 1086 return nullptr; 1087 1088 } else { 1089 isUnranked = false; 1090 if (parseDimensionListRanked(dimensions)) 1091 return nullptr; 1092 } 1093 1094 // Parse the element type. 1095 auto typeLoc = getToken().getLoc(); 1096 auto elementType = parseType(); 1097 if (!elementType) 1098 return nullptr; 1099 1100 // Parse semi-affine-map-composition. 1101 SmallVector<AffineMap, 2> affineMapComposition; 1102 unsigned memorySpace = 0; 1103 bool parsedMemorySpace = false; 1104 1105 auto parseElt = [&]() -> ParseResult { 1106 if (getToken().is(Token::integer)) { 1107 // Parse memory space. 1108 if (parsedMemorySpace) 1109 return emitError("multiple memory spaces specified in memref type"); 1110 auto v = getToken().getUnsignedIntegerValue(); 1111 if (!v.hasValue()) 1112 return emitError("invalid memory space in memref type"); 1113 memorySpace = v.getValue(); 1114 consumeToken(Token::integer); 1115 parsedMemorySpace = true; 1116 } else { 1117 if (isUnranked) 1118 return emitError("cannot have affine map for unranked memref type"); 1119 if (parsedMemorySpace) 1120 return emitError("expected memory space to be last in memref type"); 1121 if (getToken().is(Token::kw_offset)) { 1122 int64_t offset; 1123 SmallVector<int64_t, 4> strides; 1124 if (failed(parseStridedLayout(offset, strides))) 1125 return failure(); 1126 // Construct strided affine map. 1127 auto map = makeStridedLinearLayoutMap(strides, offset, 1128 elementType.getContext()); 1129 affineMapComposition.push_back(map); 1130 } else { 1131 // Parse affine map. 1132 auto affineMap = parseAttribute(); 1133 if (!affineMap) 1134 return failure(); 1135 // Verify that the parsed attribute is an affine map. 1136 if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>()) 1137 affineMapComposition.push_back(affineMapAttr.getValue()); 1138 else 1139 return emitError("expected affine map in memref type"); 1140 } 1141 } 1142 return success(); 1143 }; 1144 1145 // Parse a list of mappings and address space if present. 1146 if (consumeIf(Token::comma)) { 1147 // Parse comma separated list of affine maps, followed by memory space. 1148 if (parseCommaSeparatedListUntil(Token::greater, parseElt, 1149 /*allowEmptyList=*/false)) { 1150 return nullptr; 1151 } 1152 } else { 1153 if (parseToken(Token::greater, "expected ',' or '>' in memref type")) 1154 return nullptr; 1155 } 1156 1157 if (isUnranked) 1158 return UnrankedMemRefType::getChecked(elementType, memorySpace, 1159 getEncodedSourceLocation(typeLoc)); 1160 1161 return MemRefType::getChecked(dimensions, elementType, affineMapComposition, 1162 memorySpace, getEncodedSourceLocation(typeLoc)); 1163 } 1164 1165 /// Parse any type except the function type. 1166 /// 1167 /// non-function-type ::= integer-type 1168 /// | index-type 1169 /// | float-type 1170 /// | extended-type 1171 /// | vector-type 1172 /// | tensor-type 1173 /// | memref-type 1174 /// | complex-type 1175 /// | tuple-type 1176 /// | none-type 1177 /// 1178 /// index-type ::= `index` 1179 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 1180 /// none-type ::= `none` 1181 /// 1182 Type Parser::parseNonFunctionType() { 1183 switch (getToken().getKind()) { 1184 default: 1185 return (emitError("expected non-function type"), nullptr); 1186 case Token::kw_memref: 1187 return parseMemRefType(); 1188 case Token::kw_tensor: 1189 return parseTensorType(); 1190 case Token::kw_complex: 1191 return parseComplexType(); 1192 case Token::kw_tuple: 1193 return parseTupleType(); 1194 case Token::kw_vector: 1195 return parseVectorType(); 1196 // integer-type 1197 case Token::inttype: { 1198 auto width = getToken().getIntTypeBitwidth(); 1199 if (!width.hasValue()) 1200 return (emitError("invalid integer width"), nullptr); 1201 auto loc = getEncodedSourceLocation(getToken().getLoc()); 1202 consumeToken(Token::inttype); 1203 return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); 1204 } 1205 1206 // float-type 1207 case Token::kw_bf16: 1208 consumeToken(Token::kw_bf16); 1209 return builder.getBF16Type(); 1210 case Token::kw_f16: 1211 consumeToken(Token::kw_f16); 1212 return builder.getF16Type(); 1213 case Token::kw_f32: 1214 consumeToken(Token::kw_f32); 1215 return builder.getF32Type(); 1216 case Token::kw_f64: 1217 consumeToken(Token::kw_f64); 1218 return builder.getF64Type(); 1219 1220 // index-type 1221 case Token::kw_index: 1222 consumeToken(Token::kw_index); 1223 return builder.getIndexType(); 1224 1225 // none-type 1226 case Token::kw_none: 1227 consumeToken(Token::kw_none); 1228 return builder.getNoneType(); 1229 1230 // extended type 1231 case Token::exclamation_identifier: 1232 return parseExtendedType(); 1233 } 1234 } 1235 1236 /// Parse a tensor type. 1237 /// 1238 /// tensor-type ::= `tensor` `<` dimension-list type `>` 1239 /// dimension-list ::= dimension-list-ranked | `*x` 1240 /// 1241 Type Parser::parseTensorType() { 1242 consumeToken(Token::kw_tensor); 1243 1244 if (parseToken(Token::less, "expected '<' in tensor type")) 1245 return nullptr; 1246 1247 bool isUnranked; 1248 SmallVector<int64_t, 4> dimensions; 1249 1250 if (consumeIf(Token::star)) { 1251 // This is an unranked tensor type. 1252 isUnranked = true; 1253 1254 if (parseXInDimensionList()) 1255 return nullptr; 1256 1257 } else { 1258 isUnranked = false; 1259 if (parseDimensionListRanked(dimensions)) 1260 return nullptr; 1261 } 1262 1263 // Parse the element type. 1264 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 1265 auto elementType = parseType(); 1266 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 1267 return nullptr; 1268 1269 if (isUnranked) 1270 return UnrankedTensorType::getChecked(elementType, typeLocation); 1271 return RankedTensorType::getChecked(dimensions, elementType, typeLocation); 1272 } 1273 1274 /// Parse a tuple type. 1275 /// 1276 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 1277 /// 1278 Type Parser::parseTupleType() { 1279 consumeToken(Token::kw_tuple); 1280 1281 // Parse the '<'. 1282 if (parseToken(Token::less, "expected '<' in tuple type")) 1283 return nullptr; 1284 1285 // Check for an empty tuple by directly parsing '>'. 1286 if (consumeIf(Token::greater)) 1287 return TupleType::get(getContext()); 1288 1289 // Parse the element types and the '>'. 1290 SmallVector<Type, 4> types; 1291 if (parseTypeListNoParens(types) || 1292 parseToken(Token::greater, "expected '>' in tuple type")) 1293 return nullptr; 1294 1295 return TupleType::get(types, getContext()); 1296 } 1297 1298 /// Parse a vector type. 1299 /// 1300 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` 1301 /// non-empty-static-dimension-list ::= decimal-literal `x` 1302 /// static-dimension-list 1303 /// static-dimension-list ::= (decimal-literal `x`)* 1304 /// 1305 VectorType Parser::parseVectorType() { 1306 consumeToken(Token::kw_vector); 1307 1308 if (parseToken(Token::less, "expected '<' in vector type")) 1309 return nullptr; 1310 1311 SmallVector<int64_t, 4> dimensions; 1312 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 1313 return nullptr; 1314 if (dimensions.empty()) 1315 return (emitError("expected dimension size in vector type"), nullptr); 1316 1317 // Parse the element type. 1318 auto typeLoc = getToken().getLoc(); 1319 auto elementType = parseType(); 1320 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 1321 return nullptr; 1322 1323 return VectorType::getChecked(dimensions, elementType, 1324 getEncodedSourceLocation(typeLoc)); 1325 } 1326 1327 /// Parse a dimension list of a tensor or memref type. This populates the 1328 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 1329 /// errors out on `?` otherwise. 1330 /// 1331 /// dimension-list-ranked ::= (dimension `x`)* 1332 /// dimension ::= `?` | decimal-literal 1333 /// 1334 /// When `allowDynamic` is not set, this is used to parse: 1335 /// 1336 /// static-dimension-list ::= (decimal-literal `x`)* 1337 ParseResult 1338 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 1339 bool allowDynamic) { 1340 while (getToken().isAny(Token::integer, Token::question)) { 1341 if (consumeIf(Token::question)) { 1342 if (!allowDynamic) 1343 return emitError("expected static shape"); 1344 dimensions.push_back(-1); 1345 } else { 1346 // Hexadecimal integer literals (starting with `0x`) are not allowed in 1347 // aggregate type declarations. Therefore, `0xf32` should be processed as 1348 // a sequence of separate elements `0`, `x`, `f32`. 1349 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 1350 // We can get here only if the token is an integer literal. Hexadecimal 1351 // integer literals can only start with `0x` (`1x` wouldn't lex as a 1352 // literal, just `1` would, at which point we don't get into this 1353 // branch). 1354 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 1355 dimensions.push_back(0); 1356 state.lex.resetPointer(getTokenSpelling().data() + 1); 1357 consumeToken(); 1358 } else { 1359 // Make sure this integer value is in bound and valid. 1360 auto dimension = getToken().getUnsignedIntegerValue(); 1361 if (!dimension.hasValue()) 1362 return emitError("invalid dimension"); 1363 dimensions.push_back((int64_t)dimension.getValue()); 1364 consumeToken(Token::integer); 1365 } 1366 } 1367 1368 // Make sure we have an 'x' or something like 'xbf32'. 1369 if (parseXInDimensionList()) 1370 return failure(); 1371 } 1372 1373 return success(); 1374 } 1375 1376 /// Parse an 'x' token in a dimension list, handling the case where the x is 1377 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 1378 /// token. 1379 ParseResult Parser::parseXInDimensionList() { 1380 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 1381 return emitError("expected 'x' in dimension list"); 1382 1383 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 1384 if (getTokenSpelling().size() != 1) 1385 state.lex.resetPointer(getTokenSpelling().data() + 1); 1386 1387 // Consume the 'x'. 1388 consumeToken(Token::bare_identifier); 1389 1390 return success(); 1391 } 1392 1393 // Parse a comma-separated list of dimensions, possibly empty: 1394 // stride-list ::= `[` (dimension (`,` dimension)*)? `]` 1395 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { 1396 if (!consumeIf(Token::l_square)) 1397 return failure(); 1398 // Empty list early exit. 1399 if (consumeIf(Token::r_square)) 1400 return success(); 1401 while (true) { 1402 if (consumeIf(Token::question)) { 1403 dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); 1404 } else { 1405 // This must be an integer value. 1406 int64_t val; 1407 if (getToken().getSpelling().getAsInteger(10, val)) 1408 return emitError("invalid integer value: ") << getToken().getSpelling(); 1409 // Make sure it is not the one value for `?`. 1410 if (ShapedType::isDynamic(val)) 1411 return emitError("invalid integer value: ") 1412 << getToken().getSpelling() 1413 << ", use `?` to specify a dynamic dimension"; 1414 dimensions.push_back(val); 1415 consumeToken(Token::integer); 1416 } 1417 if (!consumeIf(Token::comma)) 1418 break; 1419 } 1420 if (!consumeIf(Token::r_square)) 1421 return failure(); 1422 return success(); 1423 } 1424 1425 //===----------------------------------------------------------------------===// 1426 // Attribute parsing. 1427 //===----------------------------------------------------------------------===// 1428 1429 /// Return the symbol reference referred to by the given token, that is known to 1430 /// be an @-identifier. 1431 static std::string extractSymbolReference(Token tok) { 1432 assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); 1433 StringRef nameStr = tok.getSpelling().drop_front(); 1434 1435 // Check to see if the reference is a string literal, or a bare identifier. 1436 if (nameStr.front() == '"') 1437 return tok.getStringValue(); 1438 return std::string(nameStr); 1439 } 1440 1441 /// Parse an arbitrary attribute. 1442 /// 1443 /// attribute-value ::= `unit` 1444 /// | bool-literal 1445 /// | integer-literal (`:` (index-type | integer-type))? 1446 /// | float-literal (`:` float-type)? 1447 /// | string-literal (`:` type)? 1448 /// | type 1449 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 1450 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 1451 /// | symbol-ref-id (`::` symbol-ref-id)* 1452 /// | `dense` `<` attribute-value `>` `:` 1453 /// (tensor-type | vector-type) 1454 /// | `sparse` `<` attribute-value `,` attribute-value `>` 1455 /// `:` (tensor-type | vector-type) 1456 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 1457 /// `>` `:` (tensor-type | vector-type) 1458 /// | extended-attribute 1459 /// 1460 Attribute Parser::parseAttribute(Type type) { 1461 switch (getToken().getKind()) { 1462 // Parse an AffineMap or IntegerSet attribute. 1463 case Token::kw_affine_map: { 1464 consumeToken(Token::kw_affine_map); 1465 1466 AffineMap map; 1467 if (parseToken(Token::less, "expected '<' in affine map") || 1468 parseAffineMapReference(map) || 1469 parseToken(Token::greater, "expected '>' in affine map")) 1470 return Attribute(); 1471 return AffineMapAttr::get(map); 1472 } 1473 case Token::kw_affine_set: { 1474 consumeToken(Token::kw_affine_set); 1475 1476 IntegerSet set; 1477 if (parseToken(Token::less, "expected '<' in integer set") || 1478 parseIntegerSetReference(set) || 1479 parseToken(Token::greater, "expected '>' in integer set")) 1480 return Attribute(); 1481 return IntegerSetAttr::get(set); 1482 } 1483 1484 // Parse an array attribute. 1485 case Token::l_square: { 1486 consumeToken(Token::l_square); 1487 1488 SmallVector<Attribute, 4> elements; 1489 auto parseElt = [&]() -> ParseResult { 1490 elements.push_back(parseAttribute()); 1491 return elements.back() ? success() : failure(); 1492 }; 1493 1494 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 1495 return nullptr; 1496 return builder.getArrayAttr(elements); 1497 } 1498 1499 // Parse a boolean attribute. 1500 case Token::kw_false: 1501 consumeToken(Token::kw_false); 1502 return builder.getBoolAttr(false); 1503 case Token::kw_true: 1504 consumeToken(Token::kw_true); 1505 return builder.getBoolAttr(true); 1506 1507 // Parse a dense elements attribute. 1508 case Token::kw_dense: 1509 return parseDenseElementsAttr(type); 1510 1511 // Parse a dictionary attribute. 1512 case Token::l_brace: { 1513 SmallVector<NamedAttribute, 4> elements; 1514 if (parseAttributeDict(elements)) 1515 return nullptr; 1516 return builder.getDictionaryAttr(elements); 1517 } 1518 1519 // Parse an extended attribute, i.e. alias or dialect attribute. 1520 case Token::hash_identifier: 1521 return parseExtendedAttr(type); 1522 1523 // Parse floating point and integer attributes. 1524 case Token::floatliteral: 1525 return parseFloatAttr(type, /*isNegative=*/false); 1526 case Token::integer: 1527 return parseDecOrHexAttr(type, /*isNegative=*/false); 1528 case Token::minus: { 1529 consumeToken(Token::minus); 1530 if (getToken().is(Token::integer)) 1531 return parseDecOrHexAttr(type, /*isNegative=*/true); 1532 if (getToken().is(Token::floatliteral)) 1533 return parseFloatAttr(type, /*isNegative=*/true); 1534 1535 return (emitError("expected constant integer or floating point value"), 1536 nullptr); 1537 } 1538 1539 // Parse a location attribute. 1540 case Token::kw_loc: { 1541 LocationAttr attr; 1542 return failed(parseLocation(attr)) ? Attribute() : attr; 1543 } 1544 1545 // Parse an opaque elements attribute. 1546 case Token::kw_opaque: 1547 return parseOpaqueElementsAttr(type); 1548 1549 // Parse a sparse elements attribute. 1550 case Token::kw_sparse: 1551 return parseSparseElementsAttr(type); 1552 1553 // Parse a string attribute. 1554 case Token::string: { 1555 auto val = getToken().getStringValue(); 1556 consumeToken(Token::string); 1557 // Parse the optional trailing colon type if one wasn't explicitly provided. 1558 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1559 return Attribute(); 1560 1561 return type ? StringAttr::get(val, type) 1562 : StringAttr::get(val, getContext()); 1563 } 1564 1565 // Parse a symbol reference attribute. 1566 case Token::at_identifier: { 1567 std::string nameStr = extractSymbolReference(getToken()); 1568 consumeToken(Token::at_identifier); 1569 1570 // Parse any nested references. 1571 std::vector<FlatSymbolRefAttr> nestedRefs; 1572 while (getToken().is(Token::colon)) { 1573 // Check for the '::' prefix. 1574 const char *curPointer = getToken().getLoc().getPointer(); 1575 consumeToken(Token::colon); 1576 if (!consumeIf(Token::colon)) { 1577 state.lex.resetPointer(curPointer); 1578 consumeToken(); 1579 break; 1580 } 1581 // Parse the reference itself. 1582 auto curLoc = getToken().getLoc(); 1583 if (getToken().isNot(Token::at_identifier)) { 1584 emitError(curLoc, "expected nested symbol reference identifier"); 1585 return Attribute(); 1586 } 1587 1588 std::string nameStr = extractSymbolReference(getToken()); 1589 consumeToken(Token::at_identifier); 1590 nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); 1591 } 1592 1593 return builder.getSymbolRefAttr(nameStr, nestedRefs); 1594 } 1595 1596 // Parse a 'unit' attribute. 1597 case Token::kw_unit: 1598 consumeToken(Token::kw_unit); 1599 return builder.getUnitAttr(); 1600 1601 default: 1602 // Parse a type attribute. 1603 if (Type type = parseType()) 1604 return TypeAttr::get(type); 1605 return nullptr; 1606 } 1607 } 1608 1609 /// Attribute dictionary. 1610 /// 1611 /// attribute-dict ::= `{` `}` 1612 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1613 /// attribute-entry ::= bare-id `=` attribute-value 1614 /// 1615 ParseResult 1616 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1617 if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) 1618 return failure(); 1619 1620 auto parseElt = [&]() -> ParseResult { 1621 // We allow keywords as attribute names. 1622 if (getToken().isNot(Token::bare_identifier, Token::inttype) && 1623 !getToken().isKeyword()) 1624 return emitError("expected attribute name"); 1625 Identifier nameId = builder.getIdentifier(getTokenSpelling()); 1626 consumeToken(); 1627 1628 // Try to parse the '=' for the attribute value. 1629 if (!consumeIf(Token::equal)) { 1630 // If there is no '=', we treat this as a unit attribute. 1631 attributes.push_back({nameId, builder.getUnitAttr()}); 1632 return success(); 1633 } 1634 1635 auto attr = parseAttribute(); 1636 if (!attr) 1637 return failure(); 1638 1639 attributes.push_back({nameId, attr}); 1640 return success(); 1641 }; 1642 1643 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1644 return failure(); 1645 1646 return success(); 1647 } 1648 1649 /// Parse an extended attribute. 1650 /// 1651 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1652 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1653 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1654 /// attribute-alias ::= `#` alias-name 1655 /// 1656 Attribute Parser::parseExtendedAttr(Type type) { 1657 Attribute attr = parseExtendedSymbol<Attribute>( 1658 *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, 1659 [&](StringRef dialectName, StringRef symbolData, 1660 llvm::SMLoc loc) -> Attribute { 1661 // Parse an optional trailing colon type. 1662 Type attrType = type; 1663 if (consumeIf(Token::colon) && !(attrType = parseType())) 1664 return Attribute(); 1665 1666 // If we found a registered dialect, then ask it to parse the attribute. 1667 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1668 return parseSymbol<Attribute>( 1669 symbolData, state.context, state.symbols, [&](Parser &parser) { 1670 CustomDialectAsmParser customParser(symbolData, parser); 1671 return dialect->parseAttribute(customParser, attrType); 1672 }); 1673 } 1674 1675 // Otherwise, form a new opaque attribute. 1676 return OpaqueAttr::getChecked( 1677 Identifier::get(dialectName, state.context), symbolData, 1678 attrType ? attrType : NoneType::get(state.context), 1679 getEncodedSourceLocation(loc)); 1680 }); 1681 1682 // Ensure that the attribute has the same type as requested. 1683 if (attr && type && attr.getType() != type) { 1684 emitError("attribute type different than expected: expected ") 1685 << type << ", but got " << attr.getType(); 1686 return nullptr; 1687 } 1688 return attr; 1689 } 1690 1691 /// Parse a float attribute. 1692 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1693 auto val = getToken().getFloatingPointValue(); 1694 if (!val.hasValue()) 1695 return (emitError("floating point value too large for attribute"), nullptr); 1696 consumeToken(Token::floatliteral); 1697 if (!type) { 1698 // Default to F64 when no type is specified. 1699 if (!consumeIf(Token::colon)) 1700 type = builder.getF64Type(); 1701 else if (!(type = parseType())) 1702 return nullptr; 1703 } 1704 if (!type.isa<FloatType>()) 1705 return (emitError("floating point value not valid for specified type"), 1706 nullptr); 1707 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1708 } 1709 1710 /// Construct a float attribute bitwise equivalent to the integer literal. 1711 static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1712 uint64_t value) { 1713 // FIXME: bfloat is currently stored as a double internally because it doesn't 1714 // have valid APFloat semantics. 1715 if (type.isF64() || type.isBF16()) { 1716 APFloat apFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); 1717 return p->builder.getFloatAttr(type, apFloat); 1718 } 1719 1720 APInt apInt(type.getWidth(), value); 1721 if (apInt != value) { 1722 p->emitError("hexadecimal float constant out of range for type"); 1723 return nullptr; 1724 } 1725 APFloat apFloat(type.getFloatSemantics(), apInt); 1726 return p->builder.getFloatAttr(type, apFloat); 1727 } 1728 1729 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1730 /// or a float attribute. 1731 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1732 auto val = getToken().getUInt64IntegerValue(); 1733 if (!val.hasValue()) 1734 return (emitError("integer constant out of range for attribute"), nullptr); 1735 1736 // Remember if the literal is hexadecimal. 1737 StringRef spelling = getToken().getSpelling(); 1738 auto loc = state.curToken.getLoc(); 1739 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1740 1741 consumeToken(Token::integer); 1742 if (!type) { 1743 // Default to i64 if not type is specified. 1744 if (!consumeIf(Token::colon)) 1745 type = builder.getIntegerType(64); 1746 else if (!(type = parseType())) 1747 return nullptr; 1748 } 1749 1750 if (auto floatType = type.dyn_cast<FloatType>()) { 1751 if (isNegative) 1752 return emitError( 1753 loc, 1754 "hexadecimal float literal should not have a leading minus"), 1755 nullptr; 1756 if (!isHex) { 1757 emitError(loc, "unexpected decimal integer literal for a float attribute") 1758 .attachNote() 1759 << "add a trailing dot to make the literal a float"; 1760 return nullptr; 1761 } 1762 1763 // Construct a float attribute bitwise equivalent to the integer literal. 1764 return buildHexadecimalFloatLiteral(this, floatType, *val); 1765 } 1766 1767 if (!type.isIntOrIndex()) 1768 return emitError(loc, "integer literal not valid for specified type"), 1769 nullptr; 1770 1771 // Parse the integer literal. 1772 int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); 1773 APInt apInt(width, *val, isNegative); 1774 if (apInt != *val) 1775 return emitError(loc, "integer constant out of range for attribute"), 1776 nullptr; 1777 1778 // Otherwise construct an integer attribute. 1779 if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0) 1780 return emitError(loc, "integer constant out of range for attribute"), 1781 nullptr; 1782 1783 return builder.getIntegerAttr(type, isNegative ? -apInt : apInt); 1784 } 1785 1786 /// Parse an opaque elements attribute. 1787 Attribute Parser::parseOpaqueElementsAttr(Type attrType) { 1788 consumeToken(Token::kw_opaque); 1789 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1790 return nullptr; 1791 1792 if (getToken().isNot(Token::string)) 1793 return (emitError("expected dialect namespace"), nullptr); 1794 1795 auto name = getToken().getStringValue(); 1796 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1797 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1798 // attribute. Otherwise, it can't be roundtripped without having the dialect 1799 // registered. 1800 if (!dialect) 1801 return (emitError("no registered dialect with namespace '" + name + "'"), 1802 nullptr); 1803 1804 consumeToken(Token::string); 1805 if (parseToken(Token::comma, "expected ','")) 1806 return nullptr; 1807 1808 if (getToken().getKind() != Token::string) 1809 return (emitError("opaque string should start with '0x'"), nullptr); 1810 1811 auto val = getToken().getStringValue(); 1812 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1813 return (emitError("opaque string should start with '0x'"), nullptr); 1814 1815 val = val.substr(2); 1816 if (!llvm::all_of(val, llvm::isHexDigit)) 1817 return (emitError("opaque string only contains hex digits"), nullptr); 1818 1819 consumeToken(Token::string); 1820 if (parseToken(Token::greater, "expected '>'")) 1821 return nullptr; 1822 1823 auto type = parseElementsLiteralType(attrType); 1824 if (!type) 1825 return nullptr; 1826 1827 return OpaqueElementsAttr::get(dialect, type, llvm::fromHex(val)); 1828 } 1829 1830 namespace { 1831 class TensorLiteralParser { 1832 public: 1833 TensorLiteralParser(Parser &p) : p(p) {} 1834 1835 ParseResult parse() { 1836 if (p.getToken().is(Token::l_square)) 1837 return parseList(shape); 1838 return parseElement(); 1839 } 1840 1841 /// Build a dense attribute instance with the parsed elements and the given 1842 /// shaped type. 1843 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1844 1845 ArrayRef<int64_t> getShape() const { return shape; } 1846 1847 private: 1848 enum class ElementKind { Boolean, Integer, Float }; 1849 1850 /// Return a string to represent the given element kind. 1851 const char *getElementKindStr(ElementKind kind) { 1852 switch (kind) { 1853 case ElementKind::Boolean: 1854 return "'boolean'"; 1855 case ElementKind::Integer: 1856 return "'integer'"; 1857 case ElementKind::Float: 1858 return "'float'"; 1859 } 1860 llvm_unreachable("unknown element kind"); 1861 } 1862 1863 /// Build a Dense Integer attribute for the given type. 1864 DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, 1865 IntegerType eltTy); 1866 1867 /// Build a Dense Float attribute for the given type. 1868 DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, 1869 FloatType eltTy); 1870 1871 /// Parse a single element, returning failure if it isn't a valid element 1872 /// literal. For example: 1873 /// parseElement(1) -> Success, 1 1874 /// parseElement([1]) -> Failure 1875 ParseResult parseElement(); 1876 1877 /// Parse a list of either lists or elements, returning the dimensions of the 1878 /// parsed sub-tensors in dims. For example: 1879 /// parseList([1, 2, 3]) -> Success, [3] 1880 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1881 /// parseList([[1, 2], 3]) -> Failure 1882 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1883 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 1884 1885 Parser &p; 1886 1887 /// The shape inferred from the parsed elements. 1888 SmallVector<int64_t, 4> shape; 1889 1890 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 1891 std::vector<std::pair<bool, Token>> storage; 1892 1893 /// A flag that indicates the type of elements that have been parsed. 1894 Optional<ElementKind> knownEltKind; 1895 }; 1896 } // namespace 1897 1898 /// Build a dense attribute instance with the parsed elements and the given 1899 /// shaped type. 1900 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 1901 ShapedType type) { 1902 // Check that the parsed storage size has the same number of elements to the 1903 // type, or is a known splat. 1904 if (!shape.empty() && getShape() != type.getShape()) { 1905 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 1906 << "]) does not match type ([" << type.getShape() << "])"; 1907 return nullptr; 1908 } 1909 1910 // If the type is an integer, build a set of APInt values from the storage 1911 // with the correct bitwidth. 1912 if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) 1913 return getIntAttr(loc, type, intTy); 1914 1915 // Otherwise, this must be a floating point type. 1916 auto floatTy = type.getElementType().dyn_cast<FloatType>(); 1917 if (!floatTy) { 1918 p.emitError(loc) << "expected floating-point or integer element type, got " 1919 << type.getElementType(); 1920 return nullptr; 1921 } 1922 return getFloatAttr(loc, type, floatTy); 1923 } 1924 1925 /// Build a Dense Integer attribute for the given type. 1926 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, 1927 ShapedType type, 1928 IntegerType eltTy) { 1929 std::vector<APInt> intElements; 1930 intElements.reserve(storage.size()); 1931 for (const auto &signAndToken : storage) { 1932 bool isNegative = signAndToken.first; 1933 const Token &token = signAndToken.second; 1934 1935 // Check to see if floating point values were parsed. 1936 if (token.is(Token::floatliteral)) { 1937 p.emitError() << "expected integer elements, but parsed floating-point"; 1938 return nullptr; 1939 } 1940 1941 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 1942 "unexpected token type"); 1943 if (token.isAny(Token::kw_true, Token::kw_false)) { 1944 if (!eltTy.isInteger(1)) 1945 p.emitError() << "expected i1 type for 'true' or 'false' values"; 1946 APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), 1947 /*isSigned=*/false); 1948 intElements.push_back(apInt); 1949 continue; 1950 } 1951 1952 // Create APInt values for each element with the correct bitwidth. 1953 auto val = token.getUInt64IntegerValue(); 1954 if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 1955 : (int64_t)val.getValue() < 0)) { 1956 p.emitError(token.getLoc(), 1957 "integer constant out of range for attribute"); 1958 return nullptr; 1959 } 1960 APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); 1961 if (apInt != val.getValue()) 1962 return (p.emitError("integer constant out of range for type"), nullptr); 1963 intElements.push_back(isNegative ? -apInt : apInt); 1964 } 1965 1966 return DenseElementsAttr::get(type, intElements); 1967 } 1968 1969 /// Build a Dense Float attribute for the given type. 1970 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, 1971 ShapedType type, 1972 FloatType eltTy) { 1973 std::vector<Attribute> floatValues; 1974 floatValues.reserve(storage.size()); 1975 for (const auto &signAndToken : storage) { 1976 bool isNegative = signAndToken.first; 1977 const Token &token = signAndToken.second; 1978 1979 // Handle hexadecimal float literals. 1980 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 1981 if (isNegative) { 1982 p.emitError(token.getLoc()) 1983 << "hexadecimal float literal should not have a leading minus"; 1984 return nullptr; 1985 } 1986 auto val = token.getUInt64IntegerValue(); 1987 if (!val.hasValue()) { 1988 p.emitError("hexadecimal float constant out of range for attribute"); 1989 return nullptr; 1990 } 1991 FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); 1992 if (!attr) 1993 return nullptr; 1994 floatValues.push_back(attr); 1995 continue; 1996 } 1997 1998 // Check to see if any decimal integers or booleans were parsed. 1999 if (!token.is(Token::floatliteral)) { 2000 p.emitError() << "expected floating-point elements, but parsed integer"; 2001 return nullptr; 2002 } 2003 2004 // Build the float values from tokens. 2005 auto val = token.getFloatingPointValue(); 2006 if (!val.hasValue()) { 2007 p.emitError("floating point value too large for attribute"); 2008 return nullptr; 2009 } 2010 floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); 2011 } 2012 2013 return DenseElementsAttr::get(type, floatValues); 2014 } 2015 2016 ParseResult TensorLiteralParser::parseElement() { 2017 switch (p.getToken().getKind()) { 2018 // Parse a boolean element. 2019 case Token::kw_true: 2020 case Token::kw_false: 2021 case Token::floatliteral: 2022 case Token::integer: 2023 storage.emplace_back(/*isNegative=*/false, p.getToken()); 2024 p.consumeToken(); 2025 break; 2026 2027 // Parse a signed integer or a negative floating-point element. 2028 case Token::minus: 2029 p.consumeToken(Token::minus); 2030 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 2031 return p.emitError("expected integer or floating point literal"); 2032 storage.emplace_back(/*isNegative=*/true, p.getToken()); 2033 p.consumeToken(); 2034 break; 2035 2036 default: 2037 return p.emitError("expected element literal of primitive type"); 2038 } 2039 2040 return success(); 2041 } 2042 2043 /// Parse a list of either lists or elements, returning the dimensions of the 2044 /// parsed sub-tensors in dims. For example: 2045 /// parseList([1, 2, 3]) -> Success, [3] 2046 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 2047 /// parseList([[1, 2], 3]) -> Failure 2048 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 2049 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 2050 p.consumeToken(Token::l_square); 2051 2052 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 2053 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 2054 if (prevDims == newDims) 2055 return success(); 2056 return p.emitError("tensor literal is invalid; ranks are not consistent " 2057 "between elements"); 2058 }; 2059 2060 bool first = true; 2061 SmallVector<int64_t, 4> newDims; 2062 unsigned size = 0; 2063 auto parseCommaSeparatedList = [&]() -> ParseResult { 2064 SmallVector<int64_t, 4> thisDims; 2065 if (p.getToken().getKind() == Token::l_square) { 2066 if (parseList(thisDims)) 2067 return failure(); 2068 } else if (parseElement()) { 2069 return failure(); 2070 } 2071 ++size; 2072 if (!first) 2073 return checkDims(newDims, thisDims); 2074 newDims = thisDims; 2075 first = false; 2076 return success(); 2077 }; 2078 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 2079 return failure(); 2080 2081 // Return the sublists' dimensions with 'size' prepended. 2082 dims.clear(); 2083 dims.push_back(size); 2084 dims.append(newDims.begin(), newDims.end()); 2085 return success(); 2086 } 2087 2088 /// Parse a dense elements attribute. 2089 Attribute Parser::parseDenseElementsAttr(Type attrType) { 2090 consumeToken(Token::kw_dense); 2091 if (parseToken(Token::less, "expected '<' after 'dense'")) 2092 return nullptr; 2093 2094 // Parse the literal data. 2095 TensorLiteralParser literalParser(*this); 2096 if (literalParser.parse()) 2097 return nullptr; 2098 2099 if (parseToken(Token::greater, "expected '>'")) 2100 return nullptr; 2101 2102 auto typeLoc = getToken().getLoc(); 2103 auto type = parseElementsLiteralType(attrType); 2104 if (!type) 2105 return nullptr; 2106 return literalParser.getAttr(typeLoc, type); 2107 } 2108 2109 /// Shaped type for elements attribute. 2110 /// 2111 /// elements-literal-type ::= vector-type | ranked-tensor-type 2112 /// 2113 /// This method also checks the type has static shape. 2114 ShapedType Parser::parseElementsLiteralType(Type type) { 2115 // If the user didn't provide a type, parse the colon type for the literal. 2116 if (!type) { 2117 if (parseToken(Token::colon, "expected ':'")) 2118 return nullptr; 2119 if (!(type = parseType())) 2120 return nullptr; 2121 } 2122 2123 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 2124 emitError("elements literal must be a ranked tensor or vector type"); 2125 return nullptr; 2126 } 2127 2128 auto sType = type.cast<ShapedType>(); 2129 if (!sType.hasStaticShape()) 2130 return (emitError("elements literal type must have static shape"), nullptr); 2131 2132 return sType; 2133 } 2134 2135 /// Parse a sparse elements attribute. 2136 Attribute Parser::parseSparseElementsAttr(Type attrType) { 2137 consumeToken(Token::kw_sparse); 2138 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 2139 return nullptr; 2140 2141 /// Parse indices 2142 auto indicesLoc = getToken().getLoc(); 2143 TensorLiteralParser indiceParser(*this); 2144 if (indiceParser.parse()) 2145 return nullptr; 2146 2147 if (parseToken(Token::comma, "expected ','")) 2148 return nullptr; 2149 2150 /// Parse values. 2151 auto valuesLoc = getToken().getLoc(); 2152 TensorLiteralParser valuesParser(*this); 2153 if (valuesParser.parse()) 2154 return nullptr; 2155 2156 if (parseToken(Token::greater, "expected '>'")) 2157 return nullptr; 2158 2159 auto type = parseElementsLiteralType(attrType); 2160 if (!type) 2161 return nullptr; 2162 2163 // If the indices are a splat, i.e. the literal parser parsed an element and 2164 // not a list, we set the shape explicitly. The indices are represented by a 2165 // 2-dimensional shape where the second dimension is the rank of the type. 2166 // Given that the parsed indices is a splat, we know that we only have one 2167 // indice and thus one for the first dimension. 2168 auto indiceEltType = builder.getIntegerType(64); 2169 ShapedType indicesType; 2170 if (indiceParser.getShape().empty()) { 2171 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 2172 } else { 2173 // Otherwise, set the shape to the one parsed by the literal parser. 2174 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 2175 } 2176 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 2177 2178 // If the values are a splat, set the shape explicitly based on the number of 2179 // indices. The number of indices is encoded in the first dimension of the 2180 // indice shape type. 2181 auto valuesEltType = type.getElementType(); 2182 ShapedType valuesType = 2183 valuesParser.getShape().empty() 2184 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 2185 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 2186 auto values = valuesParser.getAttr(valuesLoc, valuesType); 2187 2188 /// Sanity check. 2189 if (valuesType.getRank() != 1) 2190 return (emitError("expected 1-d tensor for values"), nullptr); 2191 2192 auto sameShape = (indicesType.getRank() == 1) || 2193 (type.getRank() == indicesType.getDimSize(1)); 2194 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 2195 if (!sameShape || !sameElementNum) { 2196 emitError() << "expected shape ([" << type.getShape() 2197 << "]); inferred shape of indices literal ([" 2198 << indicesType.getShape() 2199 << "]); inferred shape of values literal ([" 2200 << valuesType.getShape() << "])"; 2201 return nullptr; 2202 } 2203 2204 // Build the sparse elements attribute by the indices and values. 2205 return SparseElementsAttr::get(type, indices, values); 2206 } 2207 2208 //===----------------------------------------------------------------------===// 2209 // Location parsing. 2210 //===----------------------------------------------------------------------===// 2211 2212 /// Parse a location. 2213 /// 2214 /// location ::= `loc` inline-location 2215 /// inline-location ::= '(' location-inst ')' 2216 /// 2217 ParseResult Parser::parseLocation(LocationAttr &loc) { 2218 // Check for 'loc' identifier. 2219 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 2220 return emitError(); 2221 2222 // Parse the inline-location. 2223 if (parseToken(Token::l_paren, "expected '(' in inline location") || 2224 parseLocationInstance(loc) || 2225 parseToken(Token::r_paren, "expected ')' in inline location")) 2226 return failure(); 2227 return success(); 2228 } 2229 2230 /// Specific location instances. 2231 /// 2232 /// location-inst ::= filelinecol-location | 2233 /// name-location | 2234 /// callsite-location | 2235 /// fused-location | 2236 /// unknown-location 2237 /// filelinecol-location ::= string-literal ':' integer-literal 2238 /// ':' integer-literal 2239 /// name-location ::= string-literal 2240 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 2241 /// fused-location ::= fused ('<' attribute-value '>')? 2242 /// '[' location-inst (location-inst ',')* ']' 2243 /// unknown-location ::= 'unknown' 2244 /// 2245 ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { 2246 consumeToken(Token::bare_identifier); 2247 2248 // Parse the '('. 2249 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 2250 return failure(); 2251 2252 // Parse the callee location. 2253 LocationAttr calleeLoc; 2254 if (parseLocationInstance(calleeLoc)) 2255 return failure(); 2256 2257 // Parse the 'at'. 2258 if (getToken().isNot(Token::bare_identifier) || 2259 getToken().getSpelling() != "at") 2260 return emitError("expected 'at' in callsite location"); 2261 consumeToken(Token::bare_identifier); 2262 2263 // Parse the caller location. 2264 LocationAttr callerLoc; 2265 if (parseLocationInstance(callerLoc)) 2266 return failure(); 2267 2268 // Parse the ')'. 2269 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 2270 return failure(); 2271 2272 // Return the callsite location. 2273 loc = CallSiteLoc::get(calleeLoc, callerLoc); 2274 return success(); 2275 } 2276 2277 ParseResult Parser::parseFusedLocation(LocationAttr &loc) { 2278 consumeToken(Token::bare_identifier); 2279 2280 // Try to parse the optional metadata. 2281 Attribute metadata; 2282 if (consumeIf(Token::less)) { 2283 metadata = parseAttribute(); 2284 if (!metadata) 2285 return emitError("expected valid attribute metadata"); 2286 // Parse the '>' token. 2287 if (parseToken(Token::greater, 2288 "expected '>' after fused location metadata")) 2289 return failure(); 2290 } 2291 2292 SmallVector<Location, 4> locations; 2293 auto parseElt = [&] { 2294 LocationAttr newLoc; 2295 if (parseLocationInstance(newLoc)) 2296 return failure(); 2297 locations.push_back(newLoc); 2298 return success(); 2299 }; 2300 2301 if (parseToken(Token::l_square, "expected '[' in fused location") || 2302 parseCommaSeparatedList(parseElt) || 2303 parseToken(Token::r_square, "expected ']' in fused location")) 2304 return failure(); 2305 2306 // Return the fused location. 2307 loc = FusedLoc::get(locations, metadata, getContext()); 2308 return success(); 2309 } 2310 2311 ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { 2312 auto *ctx = getContext(); 2313 auto str = getToken().getStringValue(); 2314 consumeToken(Token::string); 2315 2316 // If the next token is ':' this is a filelinecol location. 2317 if (consumeIf(Token::colon)) { 2318 // Parse the line number. 2319 if (getToken().isNot(Token::integer)) 2320 return emitError("expected integer line number in FileLineColLoc"); 2321 auto line = getToken().getUnsignedIntegerValue(); 2322 if (!line.hasValue()) 2323 return emitError("expected integer line number in FileLineColLoc"); 2324 consumeToken(Token::integer); 2325 2326 // Parse the ':'. 2327 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 2328 return failure(); 2329 2330 // Parse the column number. 2331 if (getToken().isNot(Token::integer)) 2332 return emitError("expected integer column number in FileLineColLoc"); 2333 auto column = getToken().getUnsignedIntegerValue(); 2334 if (!column.hasValue()) 2335 return emitError("expected integer column number in FileLineColLoc"); 2336 consumeToken(Token::integer); 2337 2338 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 2339 return success(); 2340 } 2341 2342 // Otherwise, this is a NameLoc. 2343 2344 // Check for a child location. 2345 if (consumeIf(Token::l_paren)) { 2346 auto childSourceLoc = getToken().getLoc(); 2347 2348 // Parse the child location. 2349 LocationAttr childLoc; 2350 if (parseLocationInstance(childLoc)) 2351 return failure(); 2352 2353 // The child must not be another NameLoc. 2354 if (childLoc.isa<NameLoc>()) 2355 return emitError(childSourceLoc, 2356 "child of NameLoc cannot be another NameLoc"); 2357 loc = NameLoc::get(Identifier::get(str, ctx), childLoc); 2358 2359 // Parse the closing ')'. 2360 if (parseToken(Token::r_paren, 2361 "expected ')' after child location of NameLoc")) 2362 return failure(); 2363 } else { 2364 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 2365 } 2366 2367 return success(); 2368 } 2369 2370 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 2371 // Handle either name or filelinecol locations. 2372 if (getToken().is(Token::string)) 2373 return parseNameOrFileLineColLocation(loc); 2374 2375 // Bare tokens required for other cases. 2376 if (!getToken().is(Token::bare_identifier)) 2377 return emitError("expected location instance"); 2378 2379 // Check for the 'callsite' signifying a callsite location. 2380 if (getToken().getSpelling() == "callsite") 2381 return parseCallSiteLocation(loc); 2382 2383 // If the token is 'fused', then this is a fused location. 2384 if (getToken().getSpelling() == "fused") 2385 return parseFusedLocation(loc); 2386 2387 // Check for a 'unknown' for an unknown location. 2388 if (getToken().getSpelling() == "unknown") { 2389 consumeToken(Token::bare_identifier); 2390 loc = UnknownLoc::get(getContext()); 2391 return success(); 2392 } 2393 2394 return emitError("expected location instance"); 2395 } 2396 2397 //===----------------------------------------------------------------------===// 2398 // Affine parsing. 2399 //===----------------------------------------------------------------------===// 2400 2401 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 2402 /// the boolean sense. 2403 enum AffineLowPrecOp { 2404 /// Null value. 2405 LNoOp, 2406 Add, 2407 Sub 2408 }; 2409 2410 /// Higher precedence ops - all at the same precedence level. HNoOp is false 2411 /// in the boolean sense. 2412 enum AffineHighPrecOp { 2413 /// Null value. 2414 HNoOp, 2415 Mul, 2416 FloorDiv, 2417 CeilDiv, 2418 Mod 2419 }; 2420 2421 namespace { 2422 /// This is a specialized parser for affine structures (affine maps, affine 2423 /// expressions, and integer sets), maintaining the state transient to their 2424 /// bodies. 2425 class AffineParser : public Parser { 2426 public: 2427 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 2428 function_ref<ParseResult(bool)> parseElement = nullptr) 2429 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 2430 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 2431 2432 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 2433 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 2434 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 2435 ParseResult parseAffineMapOfSSAIds(AffineMap &map, 2436 OpAsmParser::Delimiter delimiter); 2437 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 2438 unsigned &numDims); 2439 2440 private: 2441 // Binary affine op parsing. 2442 AffineLowPrecOp consumeIfLowPrecOp(); 2443 AffineHighPrecOp consumeIfHighPrecOp(); 2444 2445 // Identifier lists for polyhedral structures. 2446 ParseResult parseDimIdList(unsigned &numDims); 2447 ParseResult parseSymbolIdList(unsigned &numSymbols); 2448 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 2449 unsigned &numSymbols); 2450 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 2451 2452 AffineExpr parseAffineExpr(); 2453 AffineExpr parseParentheticalExpr(); 2454 AffineExpr parseNegateExpression(AffineExpr lhs); 2455 AffineExpr parseIntegerExpr(); 2456 AffineExpr parseBareIdExpr(); 2457 AffineExpr parseSSAIdExpr(bool isSymbol); 2458 AffineExpr parseSymbolSSAIdExpr(); 2459 2460 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 2461 AffineExpr rhs, SMLoc opLoc); 2462 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 2463 AffineExpr rhs); 2464 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 2465 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 2466 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 2467 SMLoc llhsOpLoc); 2468 AffineExpr parseAffineConstraint(bool *isEq); 2469 2470 private: 2471 bool allowParsingSSAIds; 2472 function_ref<ParseResult(bool)> parseElement; 2473 unsigned numDimOperands; 2474 unsigned numSymbolOperands; 2475 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 2476 }; 2477 } // end anonymous namespace 2478 2479 /// Create an affine binary high precedence op expression (mul's, div's, mod). 2480 /// opLoc is the location of the op token to be used to report errors 2481 /// for non-conforming expressions. 2482 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 2483 AffineExpr lhs, AffineExpr rhs, 2484 SMLoc opLoc) { 2485 // TODO: make the error location info accurate. 2486 switch (op) { 2487 case Mul: 2488 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 2489 emitError(opLoc, "non-affine expression: at least one of the multiply " 2490 "operands has to be either a constant or symbolic"); 2491 return nullptr; 2492 } 2493 return lhs * rhs; 2494 case FloorDiv: 2495 if (!rhs.isSymbolicOrConstant()) { 2496 emitError(opLoc, "non-affine expression: right operand of floordiv " 2497 "has to be either a constant or symbolic"); 2498 return nullptr; 2499 } 2500 return lhs.floorDiv(rhs); 2501 case CeilDiv: 2502 if (!rhs.isSymbolicOrConstant()) { 2503 emitError(opLoc, "non-affine expression: right operand of ceildiv " 2504 "has to be either a constant or symbolic"); 2505 return nullptr; 2506 } 2507 return lhs.ceilDiv(rhs); 2508 case Mod: 2509 if (!rhs.isSymbolicOrConstant()) { 2510 emitError(opLoc, "non-affine expression: right operand of mod " 2511 "has to be either a constant or symbolic"); 2512 return nullptr; 2513 } 2514 return lhs % rhs; 2515 case HNoOp: 2516 llvm_unreachable("can't create affine expression for null high prec op"); 2517 return nullptr; 2518 } 2519 llvm_unreachable("Unknown AffineHighPrecOp"); 2520 } 2521 2522 /// Create an affine binary low precedence op expression (add, sub). 2523 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 2524 AffineExpr lhs, AffineExpr rhs) { 2525 switch (op) { 2526 case AffineLowPrecOp::Add: 2527 return lhs + rhs; 2528 case AffineLowPrecOp::Sub: 2529 return lhs - rhs; 2530 case AffineLowPrecOp::LNoOp: 2531 llvm_unreachable("can't create affine expression for null low prec op"); 2532 return nullptr; 2533 } 2534 llvm_unreachable("Unknown AffineLowPrecOp"); 2535 } 2536 2537 /// Consume this token if it is a lower precedence affine op (there are only 2538 /// two precedence levels). 2539 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 2540 switch (getToken().getKind()) { 2541 case Token::plus: 2542 consumeToken(Token::plus); 2543 return AffineLowPrecOp::Add; 2544 case Token::minus: 2545 consumeToken(Token::minus); 2546 return AffineLowPrecOp::Sub; 2547 default: 2548 return AffineLowPrecOp::LNoOp; 2549 } 2550 } 2551 2552 /// Consume this token if it is a higher precedence affine op (there are only 2553 /// two precedence levels) 2554 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 2555 switch (getToken().getKind()) { 2556 case Token::star: 2557 consumeToken(Token::star); 2558 return Mul; 2559 case Token::kw_floordiv: 2560 consumeToken(Token::kw_floordiv); 2561 return FloorDiv; 2562 case Token::kw_ceildiv: 2563 consumeToken(Token::kw_ceildiv); 2564 return CeilDiv; 2565 case Token::kw_mod: 2566 consumeToken(Token::kw_mod); 2567 return Mod; 2568 default: 2569 return HNoOp; 2570 } 2571 } 2572 2573 /// Parse a high precedence op expression list: mul, div, and mod are high 2574 /// precedence binary ops, i.e., parse a 2575 /// expr_1 op_1 expr_2 op_2 ... expr_n 2576 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2577 /// All affine binary ops are left associative. 2578 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2579 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2580 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2581 /// report an error for non-conforming expressions. 2582 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2583 AffineHighPrecOp llhsOp, 2584 SMLoc llhsOpLoc) { 2585 AffineExpr lhs = parseAffineOperandExpr(llhs); 2586 if (!lhs) 2587 return nullptr; 2588 2589 // Found an LHS. Parse the remaining expression. 2590 auto opLoc = getToken().getLoc(); 2591 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2592 if (llhs) { 2593 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2594 if (!expr) 2595 return nullptr; 2596 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2597 } 2598 // No LLHS, get RHS 2599 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2600 } 2601 2602 // This is the last operand in this expression. 2603 if (llhs) 2604 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2605 2606 // No llhs, 'lhs' itself is the expression. 2607 return lhs; 2608 } 2609 2610 /// Parse an affine expression inside parentheses. 2611 /// 2612 /// affine-expr ::= `(` affine-expr `)` 2613 AffineExpr AffineParser::parseParentheticalExpr() { 2614 if (parseToken(Token::l_paren, "expected '('")) 2615 return nullptr; 2616 if (getToken().is(Token::r_paren)) 2617 return (emitError("no expression inside parentheses"), nullptr); 2618 2619 auto expr = parseAffineExpr(); 2620 if (!expr) 2621 return nullptr; 2622 if (parseToken(Token::r_paren, "expected ')'")) 2623 return nullptr; 2624 2625 return expr; 2626 } 2627 2628 /// Parse the negation expression. 2629 /// 2630 /// affine-expr ::= `-` affine-expr 2631 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2632 if (parseToken(Token::minus, "expected '-'")) 2633 return nullptr; 2634 2635 AffineExpr operand = parseAffineOperandExpr(lhs); 2636 // Since negation has the highest precedence of all ops (including high 2637 // precedence ops) but lower than parentheses, we are only going to use 2638 // parseAffineOperandExpr instead of parseAffineExpr here. 2639 if (!operand) 2640 // Extra error message although parseAffineOperandExpr would have 2641 // complained. Leads to a better diagnostic. 2642 return (emitError("missing operand of negation"), nullptr); 2643 return (-1) * operand; 2644 } 2645 2646 /// Parse a bare id that may appear in an affine expression. 2647 /// 2648 /// affine-expr ::= bare-id 2649 AffineExpr AffineParser::parseBareIdExpr() { 2650 if (getToken().isNot(Token::bare_identifier)) 2651 return (emitError("expected bare identifier"), nullptr); 2652 2653 StringRef sRef = getTokenSpelling(); 2654 for (auto entry : dimsAndSymbols) { 2655 if (entry.first == sRef) { 2656 consumeToken(Token::bare_identifier); 2657 return entry.second; 2658 } 2659 } 2660 2661 return (emitError("use of undeclared identifier"), nullptr); 2662 } 2663 2664 /// Parse an SSA id which may appear in an affine expression. 2665 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2666 if (!allowParsingSSAIds) 2667 return (emitError("unexpected ssa identifier"), nullptr); 2668 if (getToken().isNot(Token::percent_identifier)) 2669 return (emitError("expected ssa identifier"), nullptr); 2670 auto name = getTokenSpelling(); 2671 // Check if we already parsed this SSA id. 2672 for (auto entry : dimsAndSymbols) { 2673 if (entry.first == name) { 2674 consumeToken(Token::percent_identifier); 2675 return entry.second; 2676 } 2677 } 2678 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2679 if (parseElement(isSymbol)) 2680 return (emitError("failed to parse ssa identifier"), nullptr); 2681 auto idExpr = isSymbol 2682 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2683 : getAffineDimExpr(numDimOperands++, getContext()); 2684 dimsAndSymbols.push_back({name, idExpr}); 2685 return idExpr; 2686 } 2687 2688 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2689 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2690 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2691 return nullptr; 2692 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2693 if (!symbolExpr) 2694 return nullptr; 2695 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2696 return nullptr; 2697 return symbolExpr; 2698 } 2699 2700 /// Parse a positive integral constant appearing in an affine expression. 2701 /// 2702 /// affine-expr ::= integer-literal 2703 AffineExpr AffineParser::parseIntegerExpr() { 2704 auto val = getToken().getUInt64IntegerValue(); 2705 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2706 return (emitError("constant too large for index"), nullptr); 2707 2708 consumeToken(Token::integer); 2709 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2710 } 2711 2712 /// Parses an expression that can be a valid operand of an affine expression. 2713 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2714 /// operator, the rhs of which is being parsed. This is used to determine 2715 /// whether an error should be emitted for a missing right operand. 2716 // Eg: for an expression without parentheses (like i + j + k + l), each 2717 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2718 // operand expression, it's an op expression and will be parsed via 2719 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2720 // -l are valid operands that will be parsed by this function. 2721 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2722 switch (getToken().getKind()) { 2723 case Token::bare_identifier: 2724 return parseBareIdExpr(); 2725 case Token::kw_symbol: 2726 return parseSymbolSSAIdExpr(); 2727 case Token::percent_identifier: 2728 return parseSSAIdExpr(/*isSymbol=*/false); 2729 case Token::integer: 2730 return parseIntegerExpr(); 2731 case Token::l_paren: 2732 return parseParentheticalExpr(); 2733 case Token::minus: 2734 return parseNegateExpression(lhs); 2735 case Token::kw_ceildiv: 2736 case Token::kw_floordiv: 2737 case Token::kw_mod: 2738 case Token::plus: 2739 case Token::star: 2740 if (lhs) 2741 emitError("missing right operand of binary operator"); 2742 else 2743 emitError("missing left operand of binary operator"); 2744 return nullptr; 2745 default: 2746 if (lhs) 2747 emitError("missing right operand of binary operator"); 2748 else 2749 emitError("expected affine expression"); 2750 return nullptr; 2751 } 2752 } 2753 2754 /// Parse affine expressions that are bare-id's, integer constants, 2755 /// parenthetical affine expressions, and affine op expressions that are a 2756 /// composition of those. 2757 /// 2758 /// All binary op's associate from left to right. 2759 /// 2760 /// {add, sub} have lower precedence than {mul, div, and mod}. 2761 /// 2762 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2763 /// ceildiv, and mod are at the same higher precedence level. Negation has 2764 /// higher precedence than any binary op. 2765 /// 2766 /// llhs: the affine expression appearing on the left of the one being parsed. 2767 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2768 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2769 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2770 /// associativity. 2771 /// 2772 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2773 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2774 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2775 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2776 AffineLowPrecOp llhsOp) { 2777 AffineExpr lhs; 2778 if (!(lhs = parseAffineOperandExpr(llhs))) 2779 return nullptr; 2780 2781 // Found an LHS. Deal with the ops. 2782 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 2783 if (llhs) { 2784 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2785 return parseAffineLowPrecOpExpr(sum, lOp); 2786 } 2787 // No LLHS, get RHS and form the expression. 2788 return parseAffineLowPrecOpExpr(lhs, lOp); 2789 } 2790 auto opLoc = getToken().getLoc(); 2791 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 2792 // We have a higher precedence op here. Get the rhs operand for the llhs 2793 // through parseAffineHighPrecOpExpr. 2794 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 2795 if (!highRes) 2796 return nullptr; 2797 2798 // If llhs is null, the product forms the first operand of the yet to be 2799 // found expression. If non-null, the op to associate with llhs is llhsOp. 2800 AffineExpr expr = 2801 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 2802 2803 // Recurse for subsequent low prec op's after the affine high prec op 2804 // expression. 2805 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 2806 return parseAffineLowPrecOpExpr(expr, nextOp); 2807 return expr; 2808 } 2809 // Last operand in the expression list. 2810 if (llhs) 2811 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2812 // No llhs, 'lhs' itself is the expression. 2813 return lhs; 2814 } 2815 2816 /// Parse an affine expression. 2817 /// affine-expr ::= `(` affine-expr `)` 2818 /// | `-` affine-expr 2819 /// | affine-expr `+` affine-expr 2820 /// | affine-expr `-` affine-expr 2821 /// | affine-expr `*` affine-expr 2822 /// | affine-expr `floordiv` affine-expr 2823 /// | affine-expr `ceildiv` affine-expr 2824 /// | affine-expr `mod` affine-expr 2825 /// | bare-id 2826 /// | integer-literal 2827 /// 2828 /// Additional conditions are checked depending on the production. For eg., 2829 /// one of the operands for `*` has to be either constant/symbolic; the second 2830 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 2831 AffineExpr AffineParser::parseAffineExpr() { 2832 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 2833 } 2834 2835 /// Parse a dim or symbol from the lists appearing before the actual 2836 /// expressions of the affine map. Update our state to store the 2837 /// dimensional/symbolic identifier. 2838 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 2839 if (getToken().isNot(Token::bare_identifier)) 2840 return emitError("expected bare identifier"); 2841 2842 auto name = getTokenSpelling(); 2843 for (auto entry : dimsAndSymbols) { 2844 if (entry.first == name) 2845 return emitError("redefinition of identifier '" + name + "'"); 2846 } 2847 consumeToken(Token::bare_identifier); 2848 2849 dimsAndSymbols.push_back({name, idExpr}); 2850 return success(); 2851 } 2852 2853 /// Parse the list of dimensional identifiers to an affine map. 2854 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 2855 if (parseToken(Token::l_paren, 2856 "expected '(' at start of dimensional identifiers list")) { 2857 return failure(); 2858 } 2859 2860 auto parseElt = [&]() -> ParseResult { 2861 auto dimension = getAffineDimExpr(numDims++, getContext()); 2862 return parseIdentifierDefinition(dimension); 2863 }; 2864 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 2865 } 2866 2867 /// Parse the list of symbolic identifiers to an affine map. 2868 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 2869 consumeToken(Token::l_square); 2870 auto parseElt = [&]() -> ParseResult { 2871 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 2872 return parseIdentifierDefinition(symbol); 2873 }; 2874 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 2875 } 2876 2877 /// Parse the list of symbolic identifiers to an affine map. 2878 ParseResult 2879 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 2880 unsigned &numSymbols) { 2881 if (parseDimIdList(numDims)) { 2882 return failure(); 2883 } 2884 if (!getToken().is(Token::l_square)) { 2885 numSymbols = 0; 2886 return success(); 2887 } 2888 return parseSymbolIdList(numSymbols); 2889 } 2890 2891 /// Parses an ambiguous affine map or integer set definition inline. 2892 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 2893 IntegerSet &set) { 2894 unsigned numDims = 0, numSymbols = 0; 2895 2896 // List of dimensional and optional symbol identifiers. 2897 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 2898 return failure(); 2899 } 2900 2901 // This is needed for parsing attributes as we wouldn't know whether we would 2902 // be parsing an integer set attribute or an affine map attribute. 2903 bool isArrow = getToken().is(Token::arrow); 2904 bool isColon = getToken().is(Token::colon); 2905 if (!isArrow && !isColon) { 2906 return emitError("expected '->' or ':'"); 2907 } else if (isArrow) { 2908 parseToken(Token::arrow, "expected '->' or '['"); 2909 map = parseAffineMapRange(numDims, numSymbols); 2910 return map ? success() : failure(); 2911 } else if (parseToken(Token::colon, "expected ':' or '['")) { 2912 return failure(); 2913 } 2914 2915 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 2916 return success(); 2917 2918 return failure(); 2919 } 2920 2921 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 2922 ParseResult 2923 AffineParser::parseAffineMapOfSSAIds(AffineMap &map, 2924 OpAsmParser::Delimiter delimiter) { 2925 Token::Kind rightToken; 2926 switch (delimiter) { 2927 case OpAsmParser::Delimiter::Square: 2928 if (parseToken(Token::l_square, "expected '['")) 2929 return failure(); 2930 rightToken = Token::r_square; 2931 break; 2932 case OpAsmParser::Delimiter::Paren: 2933 if (parseToken(Token::l_paren, "expected '('")) 2934 return failure(); 2935 rightToken = Token::r_paren; 2936 break; 2937 default: 2938 return emitError("unexpected delimiter"); 2939 } 2940 2941 SmallVector<AffineExpr, 4> exprs; 2942 auto parseElt = [&]() -> ParseResult { 2943 auto elt = parseAffineExpr(); 2944 exprs.push_back(elt); 2945 return elt ? success() : failure(); 2946 }; 2947 2948 // Parse a multi-dimensional affine expression (a comma-separated list of 2949 // 1-d affine expressions); the list cannot be empty. Grammar: 2950 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2951 if (parseCommaSeparatedListUntil(rightToken, parseElt, 2952 /*allowEmptyList=*/true)) 2953 return failure(); 2954 // Parsed a valid affine map. 2955 if (exprs.empty()) 2956 map = AffineMap::get(getContext()); 2957 else 2958 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 2959 exprs); 2960 return success(); 2961 } 2962 2963 /// Parse the range and sizes affine map definition inline. 2964 /// 2965 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 2966 /// 2967 /// multi-dim-affine-expr ::= `(` `)` 2968 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` 2969 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 2970 unsigned numSymbols) { 2971 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 2972 2973 SmallVector<AffineExpr, 4> exprs; 2974 auto parseElt = [&]() -> ParseResult { 2975 auto elt = parseAffineExpr(); 2976 ParseResult res = elt ? success() : failure(); 2977 exprs.push_back(elt); 2978 return res; 2979 }; 2980 2981 // Parse a multi-dimensional affine expression (a comma-separated list of 2982 // 1-d affine expressions); the list cannot be empty. Grammar: 2983 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2984 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 2985 return AffineMap(); 2986 2987 if (exprs.empty()) 2988 return AffineMap::get(getContext()); 2989 2990 // Parsed a valid affine map. 2991 return AffineMap::get(numDims, numSymbols, exprs); 2992 } 2993 2994 /// Parse an affine constraint. 2995 /// affine-constraint ::= affine-expr `>=` `0` 2996 /// | affine-expr `==` `0` 2997 /// 2998 /// isEq is set to true if the parsed constraint is an equality, false if it 2999 /// is an inequality (greater than or equal). 3000 /// 3001 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 3002 AffineExpr expr = parseAffineExpr(); 3003 if (!expr) 3004 return nullptr; 3005 3006 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 3007 getToken().is(Token::integer)) { 3008 auto dim = getToken().getUnsignedIntegerValue(); 3009 if (dim.hasValue() && dim.getValue() == 0) { 3010 consumeToken(Token::integer); 3011 *isEq = false; 3012 return expr; 3013 } 3014 return (emitError("expected '0' after '>='"), nullptr); 3015 } 3016 3017 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 3018 getToken().is(Token::integer)) { 3019 auto dim = getToken().getUnsignedIntegerValue(); 3020 if (dim.hasValue() && dim.getValue() == 0) { 3021 consumeToken(Token::integer); 3022 *isEq = true; 3023 return expr; 3024 } 3025 return (emitError("expected '0' after '=='"), nullptr); 3026 } 3027 3028 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 3029 nullptr); 3030 } 3031 3032 /// Parse the constraints that are part of an integer set definition. 3033 /// integer-set-inline 3034 /// ::= dim-and-symbol-id-lists `:` 3035 /// '(' affine-constraint-conjunction? ')' 3036 /// affine-constraint-conjunction ::= affine-constraint (`,` 3037 /// affine-constraint)* 3038 /// 3039 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 3040 unsigned numSymbols) { 3041 if (parseToken(Token::l_paren, 3042 "expected '(' at start of integer set constraint list")) 3043 return IntegerSet(); 3044 3045 SmallVector<AffineExpr, 4> constraints; 3046 SmallVector<bool, 4> isEqs; 3047 auto parseElt = [&]() -> ParseResult { 3048 bool isEq; 3049 auto elt = parseAffineConstraint(&isEq); 3050 ParseResult res = elt ? success() : failure(); 3051 if (elt) { 3052 constraints.push_back(elt); 3053 isEqs.push_back(isEq); 3054 } 3055 return res; 3056 }; 3057 3058 // Parse a list of affine constraints (comma-separated). 3059 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3060 return IntegerSet(); 3061 3062 // If no constraints were parsed, then treat this as a degenerate 'true' case. 3063 if (constraints.empty()) { 3064 /* 0 == 0 */ 3065 auto zero = getAffineConstantExpr(0, getContext()); 3066 return IntegerSet::get(numDims, numSymbols, zero, true); 3067 } 3068 3069 // Parsed a valid integer set. 3070 return IntegerSet::get(numDims, numSymbols, constraints, isEqs); 3071 } 3072 3073 /// Parse an ambiguous reference to either and affine map or an integer set. 3074 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 3075 IntegerSet &set) { 3076 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 3077 } 3078 ParseResult Parser::parseAffineMapReference(AffineMap &map) { 3079 llvm::SMLoc curLoc = getToken().getLoc(); 3080 IntegerSet set; 3081 if (parseAffineMapOrIntegerSetReference(map, set)) 3082 return failure(); 3083 if (set) 3084 return emitError(curLoc, "expected AffineMap, but got IntegerSet"); 3085 return success(); 3086 } 3087 ParseResult Parser::parseIntegerSetReference(IntegerSet &set) { 3088 llvm::SMLoc curLoc = getToken().getLoc(); 3089 AffineMap map; 3090 if (parseAffineMapOrIntegerSetReference(map, set)) 3091 return failure(); 3092 if (map) 3093 return emitError(curLoc, "expected IntegerSet, but got AffineMap"); 3094 return success(); 3095 } 3096 3097 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 3098 /// parse SSA value uses encountered while parsing affine expressions. 3099 ParseResult 3100 Parser::parseAffineMapOfSSAIds(AffineMap &map, 3101 function_ref<ParseResult(bool)> parseElement, 3102 OpAsmParser::Delimiter delimiter) { 3103 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 3104 .parseAffineMapOfSSAIds(map, delimiter); 3105 } 3106 3107 //===----------------------------------------------------------------------===// 3108 // OperationParser 3109 //===----------------------------------------------------------------------===// 3110 3111 namespace { 3112 /// This class provides support for parsing operations and regions of 3113 /// operations. 3114 class OperationParser : public Parser { 3115 public: 3116 OperationParser(ParserState &state, ModuleOp moduleOp) 3117 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 3118 } 3119 3120 ~OperationParser(); 3121 3122 /// After parsing is finished, this function must be called to see if there 3123 /// are any remaining issues. 3124 ParseResult finalize(); 3125 3126 //===--------------------------------------------------------------------===// 3127 // SSA Value Handling 3128 //===--------------------------------------------------------------------===// 3129 3130 /// This represents a use of an SSA value in the program. The first two 3131 /// entries in the tuple are the name and result number of a reference. The 3132 /// third is the location of the reference, which is used in case this ends 3133 /// up being a use of an undefined value. 3134 struct SSAUseInfo { 3135 StringRef name; // Value name, e.g. %42 or %abc 3136 unsigned number; // Number, specified with #12 3137 SMLoc loc; // Location of first definition or use. 3138 }; 3139 3140 /// Push a new SSA name scope to the parser. 3141 void pushSSANameScope(bool isIsolated); 3142 3143 /// Pop the last SSA name scope from the parser. 3144 ParseResult popSSANameScope(); 3145 3146 /// Register a definition of a value with the symbol table. 3147 ParseResult addDefinition(SSAUseInfo useInfo, Value value); 3148 3149 /// Parse an optional list of SSA uses into 'results'. 3150 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 3151 3152 /// Parse a single SSA use into 'result'. 3153 ParseResult parseSSAUse(SSAUseInfo &result); 3154 3155 /// Given a reference to an SSA value and its type, return a reference. This 3156 /// returns null on failure. 3157 Value resolveSSAUse(SSAUseInfo useInfo, Type type); 3158 3159 ParseResult parseSSADefOrUseAndType( 3160 const std::function<ParseResult(SSAUseInfo, Type)> &action); 3161 3162 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value> &results); 3163 3164 /// Return the location of the value identified by its name and number if it 3165 /// has been already reference. 3166 Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 3167 auto &values = isolatedNameScopes.back().values; 3168 if (!values.count(name) || number >= values[name].size()) 3169 return {}; 3170 if (values[name][number].first) 3171 return values[name][number].second; 3172 return {}; 3173 } 3174 3175 //===--------------------------------------------------------------------===// 3176 // Operation Parsing 3177 //===--------------------------------------------------------------------===// 3178 3179 /// Parse an operation instance. 3180 ParseResult parseOperation(); 3181 3182 /// Parse a single operation successor and its operand list. 3183 ParseResult parseSuccessorAndUseList(Block *&dest, 3184 SmallVectorImpl<Value> &operands); 3185 3186 /// Parse a comma-separated list of operation successors in brackets. 3187 ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations, 3188 SmallVectorImpl<SmallVector<Value, 4>> &operands); 3189 3190 /// Parse an operation instance that is in the generic form. 3191 Operation *parseGenericOperation(); 3192 3193 /// Parse an operation instance that is in the generic form and insert it at 3194 /// the provided insertion point. 3195 Operation *parseGenericOperation(Block *insertBlock, 3196 Block::iterator insertPt); 3197 3198 /// Parse an operation instance that is in the op-defined custom form. 3199 Operation *parseCustomOperation(); 3200 3201 //===--------------------------------------------------------------------===// 3202 // Region Parsing 3203 //===--------------------------------------------------------------------===// 3204 3205 /// Parse a region into 'region' with the provided entry block arguments. 3206 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 3207 /// isolated from those above. 3208 ParseResult parseRegion(Region ®ion, 3209 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 3210 bool isIsolatedNameScope = false); 3211 3212 /// Parse a region body into 'region'. 3213 ParseResult parseRegionBody(Region ®ion); 3214 3215 //===--------------------------------------------------------------------===// 3216 // Block Parsing 3217 //===--------------------------------------------------------------------===// 3218 3219 /// Parse a new block into 'block'. 3220 ParseResult parseBlock(Block *&block); 3221 3222 /// Parse a list of operations into 'block'. 3223 ParseResult parseBlockBody(Block *block); 3224 3225 /// Parse a (possibly empty) list of block arguments. 3226 ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results, 3227 Block *owner); 3228 3229 /// Get the block with the specified name, creating it if it doesn't 3230 /// already exist. The location specified is the point of use, which allows 3231 /// us to diagnose references to blocks that are not defined precisely. 3232 Block *getBlockNamed(StringRef name, SMLoc loc); 3233 3234 /// Define the block with the specified name. Returns the Block* or nullptr in 3235 /// the case of redefinition. 3236 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 3237 3238 private: 3239 /// Returns the info for a block at the current scope for the given name. 3240 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 3241 return blocksByName.back()[name]; 3242 } 3243 3244 /// Insert a new forward reference to the given block. 3245 void insertForwardRef(Block *block, SMLoc loc) { 3246 forwardRef.back().try_emplace(block, loc); 3247 } 3248 3249 /// Erase any forward reference to the given block. 3250 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 3251 3252 /// Record that a definition was added at the current scope. 3253 void recordDefinition(StringRef def); 3254 3255 /// Get the value entry for the given SSA name. 3256 SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name); 3257 3258 /// Create a forward reference placeholder value with the given location and 3259 /// result type. 3260 Value createForwardRefPlaceholder(SMLoc loc, Type type); 3261 3262 /// Return true if this is a forward reference. 3263 bool isForwardRefPlaceholder(Value value) { 3264 return forwardRefPlaceholders.count(value); 3265 } 3266 3267 /// This struct represents an isolated SSA name scope. This scope may contain 3268 /// other nested non-isolated scopes. These scopes are used for operations 3269 /// that are known to be isolated to allow for reusing names within their 3270 /// regions, even if those names are used above. 3271 struct IsolatedSSANameScope { 3272 /// Record that a definition was added at the current scope. 3273 void recordDefinition(StringRef def) { 3274 definitionsPerScope.back().insert(def); 3275 } 3276 3277 /// Push a nested name scope. 3278 void pushSSANameScope() { definitionsPerScope.push_back({}); } 3279 3280 /// Pop a nested name scope. 3281 void popSSANameScope() { 3282 for (auto &def : definitionsPerScope.pop_back_val()) 3283 values.erase(def.getKey()); 3284 } 3285 3286 /// This keeps track of all of the SSA values we are tracking for each name 3287 /// scope, indexed by their name. This has one entry per result number. 3288 llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values; 3289 3290 /// This keeps track of all of the values defined by a specific name scope. 3291 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 3292 }; 3293 3294 /// A list of isolated name scopes. 3295 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 3296 3297 /// This keeps track of the block names as well as the location of the first 3298 /// reference for each nested name scope. This is used to diagnose invalid 3299 /// block references and memorize them. 3300 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 3301 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 3302 3303 /// These are all of the placeholders we've made along with the location of 3304 /// their first reference, to allow checking for use of undefined values. 3305 DenseMap<Value, SMLoc> forwardRefPlaceholders; 3306 3307 /// The builder used when creating parsed operation instances. 3308 OpBuilder opBuilder; 3309 3310 /// The top level module operation. 3311 ModuleOp moduleOp; 3312 }; 3313 } // end anonymous namespace 3314 3315 OperationParser::~OperationParser() { 3316 for (auto &fwd : forwardRefPlaceholders) { 3317 // Drop all uses of undefined forward declared reference and destroy 3318 // defining operation. 3319 fwd.first.dropAllUses(); 3320 fwd.first.getDefiningOp()->destroy(); 3321 } 3322 } 3323 3324 /// After parsing is finished, this function must be called to see if there are 3325 /// any remaining issues. 3326 ParseResult OperationParser::finalize() { 3327 // Check for any forward references that are left. If we find any, error 3328 // out. 3329 if (!forwardRefPlaceholders.empty()) { 3330 SmallVector<std::pair<const char *, Value>, 4> errors; 3331 // Iteration over the map isn't deterministic, so sort by source location. 3332 for (auto entry : forwardRefPlaceholders) 3333 errors.push_back({entry.second.getPointer(), entry.first}); 3334 llvm::array_pod_sort(errors.begin(), errors.end()); 3335 3336 for (auto entry : errors) { 3337 auto loc = SMLoc::getFromPointer(entry.first); 3338 emitError(loc, "use of undeclared SSA value name"); 3339 } 3340 return failure(); 3341 } 3342 3343 return success(); 3344 } 3345 3346 //===----------------------------------------------------------------------===// 3347 // SSA Value Handling 3348 //===----------------------------------------------------------------------===// 3349 3350 void OperationParser::pushSSANameScope(bool isIsolated) { 3351 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 3352 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 3353 3354 // Push back a new name definition scope. 3355 if (isIsolated) 3356 isolatedNameScopes.push_back({}); 3357 isolatedNameScopes.back().pushSSANameScope(); 3358 } 3359 3360 ParseResult OperationParser::popSSANameScope() { 3361 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 3362 3363 // Verify that all referenced blocks were defined. 3364 if (!forwardRefInCurrentScope.empty()) { 3365 SmallVector<std::pair<const char *, Block *>, 4> errors; 3366 // Iteration over the map isn't deterministic, so sort by source location. 3367 for (auto entry : forwardRefInCurrentScope) { 3368 errors.push_back({entry.second.getPointer(), entry.first}); 3369 // Add this block to the top-level region to allow for automatic cleanup. 3370 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 3371 } 3372 llvm::array_pod_sort(errors.begin(), errors.end()); 3373 3374 for (auto entry : errors) { 3375 auto loc = SMLoc::getFromPointer(entry.first); 3376 emitError(loc, "reference to an undefined block"); 3377 } 3378 return failure(); 3379 } 3380 3381 // Pop the next nested namescope. If there is only one internal namescope, 3382 // just pop the isolated scope. 3383 auto ¤tNameScope = isolatedNameScopes.back(); 3384 if (currentNameScope.definitionsPerScope.size() == 1) 3385 isolatedNameScopes.pop_back(); 3386 else 3387 currentNameScope.popSSANameScope(); 3388 3389 blocksByName.pop_back(); 3390 return success(); 3391 } 3392 3393 /// Register a definition of a value with the symbol table. 3394 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) { 3395 auto &entries = getSSAValueEntry(useInfo.name); 3396 3397 // Make sure there is a slot for this value. 3398 if (entries.size() <= useInfo.number) 3399 entries.resize(useInfo.number + 1); 3400 3401 // If we already have an entry for this, check to see if it was a definition 3402 // or a forward reference. 3403 if (auto existing = entries[useInfo.number].first) { 3404 if (!isForwardRefPlaceholder(existing)) { 3405 return emitError(useInfo.loc) 3406 .append("redefinition of SSA value '", useInfo.name, "'") 3407 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3408 .append("previously defined here"); 3409 } 3410 3411 // If it was a forward reference, update everything that used it to use 3412 // the actual definition instead, delete the forward ref, and remove it 3413 // from our set of forward references we track. 3414 existing.replaceAllUsesWith(value); 3415 existing.getDefiningOp()->destroy(); 3416 forwardRefPlaceholders.erase(existing); 3417 } 3418 3419 /// Record this definition for the current scope. 3420 entries[useInfo.number] = {value, useInfo.loc}; 3421 recordDefinition(useInfo.name); 3422 return success(); 3423 } 3424 3425 /// Parse a (possibly empty) list of SSA operands. 3426 /// 3427 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 3428 /// ssa-use-list-opt ::= ssa-use-list? 3429 /// 3430 ParseResult 3431 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 3432 if (getToken().isNot(Token::percent_identifier)) 3433 return success(); 3434 return parseCommaSeparatedList([&]() -> ParseResult { 3435 SSAUseInfo result; 3436 if (parseSSAUse(result)) 3437 return failure(); 3438 results.push_back(result); 3439 return success(); 3440 }); 3441 } 3442 3443 /// Parse a SSA operand for an operation. 3444 /// 3445 /// ssa-use ::= ssa-id 3446 /// 3447 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 3448 result.name = getTokenSpelling(); 3449 result.number = 0; 3450 result.loc = getToken().getLoc(); 3451 if (parseToken(Token::percent_identifier, "expected SSA operand")) 3452 return failure(); 3453 3454 // If we have an attribute ID, it is a result number. 3455 if (getToken().is(Token::hash_identifier)) { 3456 if (auto value = getToken().getHashIdentifierNumber()) 3457 result.number = value.getValue(); 3458 else 3459 return emitError("invalid SSA value result number"); 3460 consumeToken(Token::hash_identifier); 3461 } 3462 3463 return success(); 3464 } 3465 3466 /// Given an unbound reference to an SSA value and its type, return the value 3467 /// it specifies. This returns null on failure. 3468 Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 3469 auto &entries = getSSAValueEntry(useInfo.name); 3470 3471 // If we have already seen a value of this name, return it. 3472 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 3473 auto result = entries[useInfo.number].first; 3474 // Check that the type matches the other uses. 3475 if (result.getType() == type) 3476 return result; 3477 3478 emitError(useInfo.loc, "use of value '") 3479 .append(useInfo.name, 3480 "' expects different type than prior uses: ", type, " vs ", 3481 result.getType()) 3482 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3483 .append("prior use here"); 3484 return nullptr; 3485 } 3486 3487 // Make sure we have enough slots for this. 3488 if (entries.size() <= useInfo.number) 3489 entries.resize(useInfo.number + 1); 3490 3491 // If the value has already been defined and this is an overly large result 3492 // number, diagnose that. 3493 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 3494 return (emitError(useInfo.loc, "reference to invalid result number"), 3495 nullptr); 3496 3497 // Otherwise, this is a forward reference. Create a placeholder and remember 3498 // that we did so. 3499 auto result = createForwardRefPlaceholder(useInfo.loc, type); 3500 entries[useInfo.number].first = result; 3501 entries[useInfo.number].second = useInfo.loc; 3502 return result; 3503 } 3504 3505 /// Parse an SSA use with an associated type. 3506 /// 3507 /// ssa-use-and-type ::= ssa-use `:` type 3508 ParseResult OperationParser::parseSSADefOrUseAndType( 3509 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 3510 SSAUseInfo useInfo; 3511 if (parseSSAUse(useInfo) || 3512 parseToken(Token::colon, "expected ':' and type for SSA operand")) 3513 return failure(); 3514 3515 auto type = parseType(); 3516 if (!type) 3517 return failure(); 3518 3519 return action(useInfo, type); 3520 } 3521 3522 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 3523 /// followed by a type list. 3524 /// 3525 /// ssa-use-and-type-list 3526 /// ::= ssa-use-list ':' type-list-no-parens 3527 /// 3528 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 3529 SmallVectorImpl<Value> &results) { 3530 SmallVector<SSAUseInfo, 4> valueIDs; 3531 if (parseOptionalSSAUseList(valueIDs)) 3532 return failure(); 3533 3534 // If there were no operands, then there is no colon or type lists. 3535 if (valueIDs.empty()) 3536 return success(); 3537 3538 SmallVector<Type, 4> types; 3539 if (parseToken(Token::colon, "expected ':' in operand list") || 3540 parseTypeListNoParens(types)) 3541 return failure(); 3542 3543 if (valueIDs.size() != types.size()) 3544 return emitError("expected ") 3545 << valueIDs.size() << " types to match operand list"; 3546 3547 results.reserve(valueIDs.size()); 3548 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 3549 if (auto value = resolveSSAUse(valueIDs[i], types[i])) 3550 results.push_back(value); 3551 else 3552 return failure(); 3553 } 3554 3555 return success(); 3556 } 3557 3558 /// Record that a definition was added at the current scope. 3559 void OperationParser::recordDefinition(StringRef def) { 3560 isolatedNameScopes.back().recordDefinition(def); 3561 } 3562 3563 /// Get the value entry for the given SSA name. 3564 SmallVectorImpl<std::pair<Value, SMLoc>> & 3565 OperationParser::getSSAValueEntry(StringRef name) { 3566 return isolatedNameScopes.back().values[name]; 3567 } 3568 3569 /// Create and remember a new placeholder for a forward reference. 3570 Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 3571 // Forward references are always created as operations, because we just need 3572 // something with a def/use chain. 3573 // 3574 // We create these placeholders as having an empty name, which we know 3575 // cannot be created through normal user input, allowing us to distinguish 3576 // them. 3577 auto name = OperationName("placeholder", getContext()); 3578 auto *op = Operation::create( 3579 getEncodedSourceLocation(loc), name, type, /*operands=*/{}, 3580 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, 3581 /*resizableOperandList=*/false); 3582 forwardRefPlaceholders[op->getResult(0)] = loc; 3583 return op->getResult(0); 3584 } 3585 3586 //===----------------------------------------------------------------------===// 3587 // Operation Parsing 3588 //===----------------------------------------------------------------------===// 3589 3590 /// Parse an operation. 3591 /// 3592 /// operation ::= op-result-list? 3593 /// (generic-operation | custom-operation) 3594 /// trailing-location? 3595 /// generic-operation ::= string-literal `(` ssa-use-list? `)` 3596 /// successor-list? (`(` region-list `)`)? 3597 /// attribute-dict? `:` function-type 3598 /// custom-operation ::= bare-id custom-operation-format 3599 /// op-result-list ::= op-result (`,` op-result)* `=` 3600 /// op-result ::= ssa-id (`:` integer-literal) 3601 /// 3602 ParseResult OperationParser::parseOperation() { 3603 auto loc = getToken().getLoc(); 3604 SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs; 3605 size_t numExpectedResults = 0; 3606 if (getToken().is(Token::percent_identifier)) { 3607 // Parse the group of result ids. 3608 auto parseNextResult = [&]() -> ParseResult { 3609 // Parse the next result id. 3610 if (!getToken().is(Token::percent_identifier)) 3611 return emitError("expected valid ssa identifier"); 3612 3613 Token nameTok = getToken(); 3614 consumeToken(Token::percent_identifier); 3615 3616 // If the next token is a ':', we parse the expected result count. 3617 size_t expectedSubResults = 1; 3618 if (consumeIf(Token::colon)) { 3619 // Check that the next token is an integer. 3620 if (!getToken().is(Token::integer)) 3621 return emitError("expected integer number of results"); 3622 3623 // Check that number of results is > 0. 3624 auto val = getToken().getUInt64IntegerValue(); 3625 if (!val.hasValue() || val.getValue() < 1) 3626 return emitError("expected named operation to have atleast 1 result"); 3627 consumeToken(Token::integer); 3628 expectedSubResults = *val; 3629 } 3630 3631 resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults, 3632 nameTok.getLoc()); 3633 numExpectedResults += expectedSubResults; 3634 return success(); 3635 }; 3636 if (parseCommaSeparatedList(parseNextResult)) 3637 return failure(); 3638 3639 if (parseToken(Token::equal, "expected '=' after SSA name")) 3640 return failure(); 3641 } 3642 3643 Operation *op; 3644 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3645 op = parseCustomOperation(); 3646 else if (getToken().is(Token::string)) 3647 op = parseGenericOperation(); 3648 else 3649 return emitError("expected operation name in quotes"); 3650 3651 // If parsing of the basic operation failed, then this whole thing fails. 3652 if (!op) 3653 return failure(); 3654 3655 // If the operation had a name, register it. 3656 if (!resultIDs.empty()) { 3657 if (op->getNumResults() == 0) 3658 return emitError(loc, "cannot name an operation with no results"); 3659 if (numExpectedResults != op->getNumResults()) 3660 return emitError(loc, "operation defines ") 3661 << op->getNumResults() << " results but was provided " 3662 << numExpectedResults << " to bind"; 3663 3664 // Add definitions for each of the result groups. 3665 unsigned opResI = 0; 3666 for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) { 3667 for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) { 3668 if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, 3669 op->getResult(opResI++))) 3670 return failure(); 3671 } 3672 } 3673 } 3674 3675 return success(); 3676 } 3677 3678 /// Parse a single operation successor and its operand list. 3679 /// 3680 /// successor ::= block-id branch-use-list? 3681 /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` 3682 /// 3683 ParseResult 3684 OperationParser::parseSuccessorAndUseList(Block *&dest, 3685 SmallVectorImpl<Value> &operands) { 3686 // Verify branch is identifier and get the matching block. 3687 if (!getToken().is(Token::caret_identifier)) 3688 return emitError("expected block name"); 3689 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3690 consumeToken(); 3691 3692 // Handle optional arguments. 3693 if (consumeIf(Token::l_paren) && 3694 (parseOptionalSSAUseAndTypeList(operands) || 3695 parseToken(Token::r_paren, "expected ')' to close argument list"))) { 3696 return failure(); 3697 } 3698 3699 return success(); 3700 } 3701 3702 /// Parse a comma-separated list of operation successors in brackets. 3703 /// 3704 /// successor-list ::= `[` successor (`,` successor )* `]` 3705 /// 3706 ParseResult OperationParser::parseSuccessors( 3707 SmallVectorImpl<Block *> &destinations, 3708 SmallVectorImpl<SmallVector<Value, 4>> &operands) { 3709 if (parseToken(Token::l_square, "expected '['")) 3710 return failure(); 3711 3712 auto parseElt = [this, &destinations, &operands]() { 3713 Block *dest; 3714 SmallVector<Value, 4> destOperands; 3715 auto res = parseSuccessorAndUseList(dest, destOperands); 3716 destinations.push_back(dest); 3717 operands.push_back(destOperands); 3718 return res; 3719 }; 3720 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3721 /*allowEmptyList=*/false); 3722 } 3723 3724 namespace { 3725 // RAII-style guard for cleaning up the regions in the operation state before 3726 // deleting them. Within the parser, regions may get deleted if parsing failed, 3727 // and other errors may be present, in particular undominated uses. This makes 3728 // sure such uses are deleted. 3729 struct CleanupOpStateRegions { 3730 ~CleanupOpStateRegions() { 3731 SmallVector<Region *, 4> regionsToClean; 3732 regionsToClean.reserve(state.regions.size()); 3733 for (auto ®ion : state.regions) 3734 if (region) 3735 for (auto &block : *region) 3736 block.dropAllDefinedValueUses(); 3737 } 3738 OperationState &state; 3739 }; 3740 } // namespace 3741 3742 Operation *OperationParser::parseGenericOperation() { 3743 // Get location information for the operation. 3744 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3745 3746 auto name = getToken().getStringValue(); 3747 if (name.empty()) 3748 return (emitError("empty operation name is invalid"), nullptr); 3749 if (name.find('\0') != StringRef::npos) 3750 return (emitError("null character not allowed in operation name"), nullptr); 3751 3752 consumeToken(Token::string); 3753 3754 OperationState result(srcLocation, name); 3755 3756 // Generic operations have a resizable operation list. 3757 result.setOperandListToResizable(); 3758 3759 // Parse the operand list. 3760 SmallVector<SSAUseInfo, 8> operandInfos; 3761 3762 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3763 parseOptionalSSAUseList(operandInfos) || 3764 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3765 return nullptr; 3766 } 3767 3768 // Parse the successor list but don't add successors to the result yet to 3769 // avoid messing up with the argument order. 3770 SmallVector<Block *, 2> successors; 3771 SmallVector<SmallVector<Value, 4>, 2> successorOperands; 3772 if (getToken().is(Token::l_square)) { 3773 // Check if the operation is a known terminator. 3774 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3775 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3776 return emitError("successors in non-terminator"), nullptr; 3777 if (parseSuccessors(successors, successorOperands)) 3778 return nullptr; 3779 } 3780 3781 // Parse the region list. 3782 CleanupOpStateRegions guard{result}; 3783 if (consumeIf(Token::l_paren)) { 3784 do { 3785 // Create temporary regions with the top level region as parent. 3786 result.regions.emplace_back(new Region(moduleOp)); 3787 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3788 return nullptr; 3789 } while (consumeIf(Token::comma)); 3790 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3791 return nullptr; 3792 } 3793 3794 if (getToken().is(Token::l_brace)) { 3795 if (parseAttributeDict(result.attributes)) 3796 return nullptr; 3797 } 3798 3799 if (parseToken(Token::colon, "expected ':' followed by operation type")) 3800 return nullptr; 3801 3802 auto typeLoc = getToken().getLoc(); 3803 auto type = parseType(); 3804 if (!type) 3805 return nullptr; 3806 auto fnType = type.dyn_cast<FunctionType>(); 3807 if (!fnType) 3808 return (emitError(typeLoc, "expected function type"), nullptr); 3809 3810 result.addTypes(fnType.getResults()); 3811 3812 // Check that we have the right number of types for the operands. 3813 auto operandTypes = fnType.getInputs(); 3814 if (operandTypes.size() != operandInfos.size()) { 3815 auto plural = "s"[operandInfos.size() == 1]; 3816 return (emitError(typeLoc, "expected ") 3817 << operandInfos.size() << " operand type" << plural 3818 << " but had " << operandTypes.size(), 3819 nullptr); 3820 } 3821 3822 // Resolve all of the operands. 3823 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 3824 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 3825 if (!result.operands.back()) 3826 return nullptr; 3827 } 3828 3829 // Add the successors, and their operands after the proper operands. 3830 for (auto succ : llvm::zip(successors, successorOperands)) { 3831 Block *successor = std::get<0>(succ); 3832 const SmallVector<Value, 4> &operands = std::get<1>(succ); 3833 result.addSuccessor(successor, operands); 3834 } 3835 3836 // Parse a location if one is present. 3837 if (parseOptionalTrailingLocation(result.location)) 3838 return nullptr; 3839 3840 return opBuilder.createOperation(result); 3841 } 3842 3843 Operation *OperationParser::parseGenericOperation(Block *insertBlock, 3844 Block::iterator insertPt) { 3845 OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); 3846 opBuilder.setInsertionPoint(insertBlock, insertPt); 3847 return parseGenericOperation(); 3848 } 3849 3850 namespace { 3851 class CustomOpAsmParser : public OpAsmParser { 3852 public: 3853 CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, 3854 OperationParser &parser) 3855 : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} 3856 3857 /// Parse an instance of the operation described by 'opDefinition' into the 3858 /// provided operation state. 3859 ParseResult parseOperation(OperationState &opState) { 3860 if (opDefinition->parseAssembly(*this, opState)) 3861 return failure(); 3862 return success(); 3863 } 3864 3865 Operation *parseGenericOperation(Block *insertBlock, 3866 Block::iterator insertPt) final { 3867 return parser.parseGenericOperation(insertBlock, insertPt); 3868 } 3869 3870 //===--------------------------------------------------------------------===// 3871 // Utilities 3872 //===--------------------------------------------------------------------===// 3873 3874 /// Return if any errors were emitted during parsing. 3875 bool didEmitError() const { return emittedError; } 3876 3877 /// Emit a diagnostic at the specified location and return failure. 3878 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 3879 emittedError = true; 3880 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 3881 message); 3882 } 3883 3884 llvm::SMLoc getCurrentLocation() override { 3885 return parser.getToken().getLoc(); 3886 } 3887 3888 Builder &getBuilder() const override { return parser.builder; } 3889 3890 llvm::SMLoc getNameLoc() const override { return nameLoc; } 3891 3892 //===--------------------------------------------------------------------===// 3893 // Token Parsing 3894 //===--------------------------------------------------------------------===// 3895 3896 /// Parse a `->` token. 3897 ParseResult parseArrow() override { 3898 return parser.parseToken(Token::arrow, "expected '->'"); 3899 } 3900 3901 /// Parses a `->` if present. 3902 ParseResult parseOptionalArrow() override { 3903 return success(parser.consumeIf(Token::arrow)); 3904 } 3905 3906 /// Parse a `:` token. 3907 ParseResult parseColon() override { 3908 return parser.parseToken(Token::colon, "expected ':'"); 3909 } 3910 3911 /// Parse a `:` token if present. 3912 ParseResult parseOptionalColon() override { 3913 return success(parser.consumeIf(Token::colon)); 3914 } 3915 3916 /// Parse a `,` token. 3917 ParseResult parseComma() override { 3918 return parser.parseToken(Token::comma, "expected ','"); 3919 } 3920 3921 /// Parse a `,` token if present. 3922 ParseResult parseOptionalComma() override { 3923 return success(parser.consumeIf(Token::comma)); 3924 } 3925 3926 /// Parses a `...` if present. 3927 ParseResult parseOptionalEllipsis() override { 3928 return success(parser.consumeIf(Token::ellipsis)); 3929 } 3930 3931 /// Parse a `=` token. 3932 ParseResult parseEqual() override { 3933 return parser.parseToken(Token::equal, "expected '='"); 3934 } 3935 3936 /// Parse a '<' token. 3937 ParseResult parseLess() override { 3938 return parser.parseToken(Token::less, "expected '<'"); 3939 } 3940 3941 /// Parse a '>' token. 3942 ParseResult parseGreater() override { 3943 return parser.parseToken(Token::greater, "expected '>'"); 3944 } 3945 3946 /// Parse a `(` token. 3947 ParseResult parseLParen() override { 3948 return parser.parseToken(Token::l_paren, "expected '('"); 3949 } 3950 3951 /// Parses a '(' if present. 3952 ParseResult parseOptionalLParen() override { 3953 return success(parser.consumeIf(Token::l_paren)); 3954 } 3955 3956 /// Parse a `)` token. 3957 ParseResult parseRParen() override { 3958 return parser.parseToken(Token::r_paren, "expected ')'"); 3959 } 3960 3961 /// Parses a ')' if present. 3962 ParseResult parseOptionalRParen() override { 3963 return success(parser.consumeIf(Token::r_paren)); 3964 } 3965 3966 /// Parse a `[` token. 3967 ParseResult parseLSquare() override { 3968 return parser.parseToken(Token::l_square, "expected '['"); 3969 } 3970 3971 /// Parses a '[' if present. 3972 ParseResult parseOptionalLSquare() override { 3973 return success(parser.consumeIf(Token::l_square)); 3974 } 3975 3976 /// Parse a `]` token. 3977 ParseResult parseRSquare() override { 3978 return parser.parseToken(Token::r_square, "expected ']'"); 3979 } 3980 3981 /// Parses a ']' if present. 3982 ParseResult parseOptionalRSquare() override { 3983 return success(parser.consumeIf(Token::r_square)); 3984 } 3985 3986 //===--------------------------------------------------------------------===// 3987 // Attribute Parsing 3988 //===--------------------------------------------------------------------===// 3989 3990 /// Parse an arbitrary attribute of a given type and return it in result. This 3991 /// also adds the attribute to the specified attribute list with the specified 3992 /// name. 3993 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 3994 SmallVectorImpl<NamedAttribute> &attrs) override { 3995 result = parser.parseAttribute(type); 3996 if (!result) 3997 return failure(); 3998 3999 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 4000 return success(); 4001 } 4002 4003 /// Parse a named dictionary into 'result' if it is present. 4004 ParseResult 4005 parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override { 4006 if (parser.getToken().isNot(Token::l_brace)) 4007 return success(); 4008 return parser.parseAttributeDict(result); 4009 } 4010 4011 /// Parse a named dictionary into 'result' if the `attributes` keyword is 4012 /// present. 4013 ParseResult parseOptionalAttrDictWithKeyword( 4014 SmallVectorImpl<NamedAttribute> &result) override { 4015 if (failed(parseOptionalKeyword("attributes"))) 4016 return success(); 4017 return parser.parseAttributeDict(result); 4018 } 4019 4020 /// Parse an affine map instance into 'map'. 4021 ParseResult parseAffineMap(AffineMap &map) override { 4022 return parser.parseAffineMapReference(map); 4023 } 4024 4025 /// Parse an integer set instance into 'set'. 4026 ParseResult printIntegerSet(IntegerSet &set) override { 4027 return parser.parseIntegerSetReference(set); 4028 } 4029 4030 //===--------------------------------------------------------------------===// 4031 // Identifier Parsing 4032 //===--------------------------------------------------------------------===// 4033 4034 /// Returns if the current token corresponds to a keyword. 4035 bool isCurrentTokenAKeyword() const { 4036 return parser.getToken().is(Token::bare_identifier) || 4037 parser.getToken().isKeyword(); 4038 } 4039 4040 /// Parse the given keyword if present. 4041 ParseResult parseOptionalKeyword(StringRef keyword) override { 4042 // Check that the current token has the same spelling. 4043 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 4044 return failure(); 4045 parser.consumeToken(); 4046 return success(); 4047 } 4048 4049 /// Parse a keyword, if present, into 'keyword'. 4050 ParseResult parseOptionalKeyword(StringRef *keyword) override { 4051 // Check that the current token is a keyword. 4052 if (!isCurrentTokenAKeyword()) 4053 return failure(); 4054 4055 *keyword = parser.getTokenSpelling(); 4056 parser.consumeToken(); 4057 return success(); 4058 } 4059 4060 /// Parse an optional @-identifier and store it (without the '@' symbol) in a 4061 /// string attribute named 'attrName'. 4062 ParseResult 4063 parseOptionalSymbolName(StringAttr &result, StringRef attrName, 4064 SmallVectorImpl<NamedAttribute> &attrs) override { 4065 Token atToken = parser.getToken(); 4066 if (atToken.isNot(Token::at_identifier)) 4067 return failure(); 4068 4069 result = getBuilder().getStringAttr(extractSymbolReference(atToken)); 4070 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 4071 parser.consumeToken(); 4072 return success(); 4073 } 4074 4075 //===--------------------------------------------------------------------===// 4076 // Operand Parsing 4077 //===--------------------------------------------------------------------===// 4078 4079 /// Parse a single operand. 4080 ParseResult parseOperand(OperandType &result) override { 4081 OperationParser::SSAUseInfo useInfo; 4082 if (parser.parseSSAUse(useInfo)) 4083 return failure(); 4084 4085 result = {useInfo.loc, useInfo.name, useInfo.number}; 4086 return success(); 4087 } 4088 4089 /// Parse zero or more SSA comma-separated operand references with a specified 4090 /// surrounding delimiter, and an optional required operand count. 4091 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 4092 int requiredOperandCount = -1, 4093 Delimiter delimiter = Delimiter::None) override { 4094 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 4095 requiredOperandCount, delimiter); 4096 } 4097 4098 /// Parse zero or more SSA comma-separated operand or region arguments with 4099 /// optional surrounding delimiter and required operand count. 4100 ParseResult 4101 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 4102 bool isOperandList, int requiredOperandCount = -1, 4103 Delimiter delimiter = Delimiter::None) { 4104 auto startLoc = parser.getToken().getLoc(); 4105 4106 // Handle delimiters. 4107 switch (delimiter) { 4108 case Delimiter::None: 4109 // Don't check for the absence of a delimiter if the number of operands 4110 // is unknown (and hence the operand list could be empty). 4111 if (requiredOperandCount == -1) 4112 break; 4113 // Token already matches an identifier and so can't be a delimiter. 4114 if (parser.getToken().is(Token::percent_identifier)) 4115 break; 4116 // Test against known delimiters. 4117 if (parser.getToken().is(Token::l_paren) || 4118 parser.getToken().is(Token::l_square)) 4119 return emitError(startLoc, "unexpected delimiter"); 4120 return emitError(startLoc, "invalid operand"); 4121 case Delimiter::OptionalParen: 4122 if (parser.getToken().isNot(Token::l_paren)) 4123 return success(); 4124 LLVM_FALLTHROUGH; 4125 case Delimiter::Paren: 4126 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 4127 return failure(); 4128 break; 4129 case Delimiter::OptionalSquare: 4130 if (parser.getToken().isNot(Token::l_square)) 4131 return success(); 4132 LLVM_FALLTHROUGH; 4133 case Delimiter::Square: 4134 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 4135 return failure(); 4136 break; 4137 } 4138 4139 // Check for zero operands. 4140 if (parser.getToken().is(Token::percent_identifier)) { 4141 do { 4142 OperandType operandOrArg; 4143 if (isOperandList ? parseOperand(operandOrArg) 4144 : parseRegionArgument(operandOrArg)) 4145 return failure(); 4146 result.push_back(operandOrArg); 4147 } while (parser.consumeIf(Token::comma)); 4148 } 4149 4150 // Handle delimiters. If we reach here, the optional delimiters were 4151 // present, so we need to parse their closing one. 4152 switch (delimiter) { 4153 case Delimiter::None: 4154 break; 4155 case Delimiter::OptionalParen: 4156 case Delimiter::Paren: 4157 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 4158 return failure(); 4159 break; 4160 case Delimiter::OptionalSquare: 4161 case Delimiter::Square: 4162 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 4163 return failure(); 4164 break; 4165 } 4166 4167 if (requiredOperandCount != -1 && 4168 result.size() != static_cast<size_t>(requiredOperandCount)) 4169 return emitError(startLoc, "expected ") 4170 << requiredOperandCount << " operands"; 4171 return success(); 4172 } 4173 4174 /// Parse zero or more trailing SSA comma-separated trailing operand 4175 /// references with a specified surrounding delimiter, and an optional 4176 /// required operand count. A leading comma is expected before the operands. 4177 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 4178 int requiredOperandCount, 4179 Delimiter delimiter) override { 4180 if (parser.getToken().is(Token::comma)) { 4181 parseComma(); 4182 return parseOperandList(result, requiredOperandCount, delimiter); 4183 } 4184 if (requiredOperandCount != -1) 4185 return emitError(parser.getToken().getLoc(), "expected ") 4186 << requiredOperandCount << " operands"; 4187 return success(); 4188 } 4189 4190 /// Resolve an operand to an SSA value, emitting an error on failure. 4191 ParseResult resolveOperand(const OperandType &operand, Type type, 4192 SmallVectorImpl<Value> &result) override { 4193 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4194 operand.location}; 4195 if (auto value = parser.resolveSSAUse(operandInfo, type)) { 4196 result.push_back(value); 4197 return success(); 4198 } 4199 return failure(); 4200 } 4201 4202 /// Parse an AffineMap of SSA ids. 4203 ParseResult parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 4204 Attribute &mapAttr, StringRef attrName, 4205 SmallVectorImpl<NamedAttribute> &attrs, 4206 Delimiter delimiter) override { 4207 SmallVector<OperandType, 2> dimOperands; 4208 SmallVector<OperandType, 1> symOperands; 4209 4210 auto parseElement = [&](bool isSymbol) -> ParseResult { 4211 OperandType operand; 4212 if (parseOperand(operand)) 4213 return failure(); 4214 if (isSymbol) 4215 symOperands.push_back(operand); 4216 else 4217 dimOperands.push_back(operand); 4218 return success(); 4219 }; 4220 4221 AffineMap map; 4222 if (parser.parseAffineMapOfSSAIds(map, parseElement, delimiter)) 4223 return failure(); 4224 // Add AffineMap attribute. 4225 if (map) { 4226 mapAttr = AffineMapAttr::get(map); 4227 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 4228 } 4229 4230 // Add dim operands before symbol operands in 'operands'. 4231 operands.assign(dimOperands.begin(), dimOperands.end()); 4232 operands.append(symOperands.begin(), symOperands.end()); 4233 return success(); 4234 } 4235 4236 //===--------------------------------------------------------------------===// 4237 // Region Parsing 4238 //===--------------------------------------------------------------------===// 4239 4240 /// Parse a region that takes `arguments` of `argTypes` types. This 4241 /// effectively defines the SSA values of `arguments` and assigns their type. 4242 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 4243 ArrayRef<Type> argTypes, 4244 bool enableNameShadowing) override { 4245 assert(arguments.size() == argTypes.size() && 4246 "mismatching number of arguments and types"); 4247 4248 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 4249 regionArguments; 4250 for (auto pair : llvm::zip(arguments, argTypes)) { 4251 const OperandType &operand = std::get<0>(pair); 4252 Type type = std::get<1>(pair); 4253 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4254 operand.location}; 4255 regionArguments.emplace_back(operandInfo, type); 4256 } 4257 4258 // Try to parse the region. 4259 assert((!enableNameShadowing || 4260 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 4261 "name shadowing is only allowed on isolated regions"); 4262 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 4263 return failure(); 4264 return success(); 4265 } 4266 4267 /// Parses a region if present. 4268 ParseResult parseOptionalRegion(Region ®ion, 4269 ArrayRef<OperandType> arguments, 4270 ArrayRef<Type> argTypes, 4271 bool enableNameShadowing) override { 4272 if (parser.getToken().isNot(Token::l_brace)) 4273 return success(); 4274 return parseRegion(region, arguments, argTypes, enableNameShadowing); 4275 } 4276 4277 /// Parse a region argument. The type of the argument will be resolved later 4278 /// by a call to `parseRegion`. 4279 ParseResult parseRegionArgument(OperandType &argument) override { 4280 return parseOperand(argument); 4281 } 4282 4283 /// Parse a region argument if present. 4284 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 4285 if (parser.getToken().isNot(Token::percent_identifier)) 4286 return success(); 4287 return parseRegionArgument(argument); 4288 } 4289 4290 ParseResult 4291 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 4292 int requiredOperandCount = -1, 4293 Delimiter delimiter = Delimiter::None) override { 4294 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 4295 requiredOperandCount, delimiter); 4296 } 4297 4298 //===--------------------------------------------------------------------===// 4299 // Successor Parsing 4300 //===--------------------------------------------------------------------===// 4301 4302 /// Parse a single operation successor and its operand list. 4303 ParseResult 4304 parseSuccessorAndUseList(Block *&dest, 4305 SmallVectorImpl<Value> &operands) override { 4306 return parser.parseSuccessorAndUseList(dest, operands); 4307 } 4308 4309 //===--------------------------------------------------------------------===// 4310 // Type Parsing 4311 //===--------------------------------------------------------------------===// 4312 4313 /// Parse a type. 4314 ParseResult parseType(Type &result) override { 4315 return failure(!(result = parser.parseType())); 4316 } 4317 4318 /// Parse an optional arrow followed by a type list. 4319 ParseResult 4320 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 4321 if (!parser.consumeIf(Token::arrow)) 4322 return success(); 4323 return parser.parseFunctionResultTypes(result); 4324 } 4325 4326 /// Parse a colon followed by a type. 4327 ParseResult parseColonType(Type &result) override { 4328 return failure(parser.parseToken(Token::colon, "expected ':'") || 4329 !(result = parser.parseType())); 4330 } 4331 4332 /// Parse a colon followed by a type list, which must have at least one type. 4333 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 4334 if (parser.parseToken(Token::colon, "expected ':'")) 4335 return failure(); 4336 return parser.parseTypeListNoParens(result); 4337 } 4338 4339 /// Parse an optional colon followed by a type list, which if present must 4340 /// have at least one type. 4341 ParseResult 4342 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 4343 if (!parser.consumeIf(Token::colon)) 4344 return success(); 4345 return parser.parseTypeListNoParens(result); 4346 } 4347 4348 private: 4349 /// The source location of the operation name. 4350 SMLoc nameLoc; 4351 4352 /// The abstract information of the operation. 4353 const AbstractOperation *opDefinition; 4354 4355 /// The main operation parser. 4356 OperationParser &parser; 4357 4358 /// A flag that indicates if any errors were emitted during parsing. 4359 bool emittedError = false; 4360 }; 4361 } // end anonymous namespace. 4362 4363 Operation *OperationParser::parseCustomOperation() { 4364 auto opLoc = getToken().getLoc(); 4365 auto opName = getTokenSpelling(); 4366 4367 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 4368 if (!opDefinition && !opName.contains('.')) { 4369 // If the operation name has no namespace prefix we treat it as a standard 4370 // operation and prefix it with "std". 4371 // TODO: Would it be better to just build a mapping of the registered 4372 // operations in the standard dialect? 4373 opDefinition = 4374 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 4375 } 4376 4377 if (!opDefinition) { 4378 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 4379 return nullptr; 4380 } 4381 4382 consumeToken(); 4383 4384 // If the custom op parser crashes, produce some indication to help 4385 // debugging. 4386 std::string opNameStr = opName.str(); 4387 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 4388 opNameStr.c_str()); 4389 4390 // Get location information for the operation. 4391 auto srcLocation = getEncodedSourceLocation(opLoc); 4392 4393 // Have the op implementation take a crack and parsing this. 4394 OperationState opState(srcLocation, opDefinition->name); 4395 CleanupOpStateRegions guard{opState}; 4396 CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); 4397 if (opAsmParser.parseOperation(opState)) 4398 return nullptr; 4399 4400 // If it emitted an error, we failed. 4401 if (opAsmParser.didEmitError()) 4402 return nullptr; 4403 4404 // Parse a location if one is present. 4405 if (parseOptionalTrailingLocation(opState.location)) 4406 return nullptr; 4407 4408 // Otherwise, we succeeded. Use the state it parsed as our op information. 4409 return opBuilder.createOperation(opState); 4410 } 4411 4412 //===----------------------------------------------------------------------===// 4413 // Region Parsing 4414 //===----------------------------------------------------------------------===// 4415 4416 /// Region. 4417 /// 4418 /// region ::= '{' region-body 4419 /// 4420 ParseResult OperationParser::parseRegion( 4421 Region ®ion, 4422 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 4423 bool isIsolatedNameScope) { 4424 // Parse the '{'. 4425 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 4426 return failure(); 4427 4428 // Check for an empty region. 4429 if (entryArguments.empty() && consumeIf(Token::r_brace)) 4430 return success(); 4431 auto currentPt = opBuilder.saveInsertionPoint(); 4432 4433 // Push a new named value scope. 4434 pushSSANameScope(isIsolatedNameScope); 4435 4436 // Parse the first block directly to allow for it to be unnamed. 4437 Block *block = new Block(); 4438 4439 // Add arguments to the entry block. 4440 if (!entryArguments.empty()) { 4441 for (auto &placeholderArgPair : entryArguments) { 4442 auto &argInfo = placeholderArgPair.first; 4443 // Ensure that the argument was not already defined. 4444 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 4445 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 4446 "' is already in use") 4447 .attachNote(getEncodedSourceLocation(*defLoc)) 4448 << "previously referenced here"; 4449 } 4450 if (addDefinition(placeholderArgPair.first, 4451 block->addArgument(placeholderArgPair.second))) { 4452 delete block; 4453 return failure(); 4454 } 4455 } 4456 4457 // If we had named arguments, then don't allow a block name. 4458 if (getToken().is(Token::caret_identifier)) 4459 return emitError("invalid block name in region with named arguments"); 4460 } 4461 4462 if (parseBlock(block)) { 4463 delete block; 4464 return failure(); 4465 } 4466 4467 // Verify that no other arguments were parsed. 4468 if (!entryArguments.empty() && 4469 block->getNumArguments() > entryArguments.size()) { 4470 delete block; 4471 return emitError("entry block arguments were already defined"); 4472 } 4473 4474 // Parse the rest of the region. 4475 region.push_back(block); 4476 if (parseRegionBody(region)) 4477 return failure(); 4478 4479 // Pop the SSA value scope for this region. 4480 if (popSSANameScope()) 4481 return failure(); 4482 4483 // Reset the original insertion point. 4484 opBuilder.restoreInsertionPoint(currentPt); 4485 return success(); 4486 } 4487 4488 /// Region. 4489 /// 4490 /// region-body ::= block* '}' 4491 /// 4492 ParseResult OperationParser::parseRegionBody(Region ®ion) { 4493 // Parse the list of blocks. 4494 while (!consumeIf(Token::r_brace)) { 4495 Block *newBlock = nullptr; 4496 if (parseBlock(newBlock)) 4497 return failure(); 4498 region.push_back(newBlock); 4499 } 4500 return success(); 4501 } 4502 4503 //===----------------------------------------------------------------------===// 4504 // Block Parsing 4505 //===----------------------------------------------------------------------===// 4506 4507 /// Block declaration. 4508 /// 4509 /// block ::= block-label? operation* 4510 /// block-label ::= block-id block-arg-list? `:` 4511 /// block-id ::= caret-id 4512 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 4513 /// 4514 ParseResult OperationParser::parseBlock(Block *&block) { 4515 // The first block of a region may already exist, if it does the caret 4516 // identifier is optional. 4517 if (block && getToken().isNot(Token::caret_identifier)) 4518 return parseBlockBody(block); 4519 4520 SMLoc nameLoc = getToken().getLoc(); 4521 auto name = getTokenSpelling(); 4522 if (parseToken(Token::caret_identifier, "expected block name")) 4523 return failure(); 4524 4525 block = defineBlockNamed(name, nameLoc, block); 4526 4527 // Fail if the block was already defined. 4528 if (!block) 4529 return emitError(nameLoc, "redefinition of block '") << name << "'"; 4530 4531 // If an argument list is present, parse it. 4532 if (consumeIf(Token::l_paren)) { 4533 SmallVector<BlockArgument, 8> bbArgs; 4534 if (parseOptionalBlockArgList(bbArgs, block) || 4535 parseToken(Token::r_paren, "expected ')' to end argument list")) 4536 return failure(); 4537 } 4538 4539 if (parseToken(Token::colon, "expected ':' after block name")) 4540 return failure(); 4541 4542 return parseBlockBody(block); 4543 } 4544 4545 ParseResult OperationParser::parseBlockBody(Block *block) { 4546 // Set the insertion point to the end of the block to parse. 4547 opBuilder.setInsertionPointToEnd(block); 4548 4549 // Parse the list of operations that make up the body of the block. 4550 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 4551 if (parseOperation()) 4552 return failure(); 4553 4554 return success(); 4555 } 4556 4557 /// Get the block with the specified name, creating it if it doesn't already 4558 /// exist. The location specified is the point of use, which allows 4559 /// us to diagnose references to blocks that are not defined precisely. 4560 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 4561 auto &blockAndLoc = getBlockInfoByName(name); 4562 if (!blockAndLoc.first) { 4563 blockAndLoc = {new Block(), loc}; 4564 insertForwardRef(blockAndLoc.first, loc); 4565 } 4566 4567 return blockAndLoc.first; 4568 } 4569 4570 /// Define the block with the specified name. Returns the Block* or nullptr in 4571 /// the case of redefinition. 4572 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 4573 Block *existing) { 4574 auto &blockAndLoc = getBlockInfoByName(name); 4575 if (!blockAndLoc.first) { 4576 // If the caller provided a block, use it. Otherwise create a new one. 4577 if (!existing) 4578 existing = new Block(); 4579 blockAndLoc.first = existing; 4580 blockAndLoc.second = loc; 4581 return blockAndLoc.first; 4582 } 4583 4584 // Forward declarations are removed once defined, so if we are defining a 4585 // existing block and it is not a forward declaration, then it is a 4586 // redeclaration. 4587 if (!eraseForwardRef(blockAndLoc.first)) 4588 return nullptr; 4589 return blockAndLoc.first; 4590 } 4591 4592 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 4593 /// 4594 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 4595 /// 4596 ParseResult OperationParser::parseOptionalBlockArgList( 4597 SmallVectorImpl<BlockArgument> &results, Block *owner) { 4598 if (getToken().is(Token::r_brace)) 4599 return success(); 4600 4601 // If the block already has arguments, then we're handling the entry block. 4602 // Parse and register the names for the arguments, but do not add them. 4603 bool definingExistingArgs = owner->getNumArguments() != 0; 4604 unsigned nextArgument = 0; 4605 4606 return parseCommaSeparatedList([&]() -> ParseResult { 4607 return parseSSADefOrUseAndType( 4608 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 4609 // If this block did not have existing arguments, define a new one. 4610 if (!definingExistingArgs) 4611 return addDefinition(useInfo, owner->addArgument(type)); 4612 4613 // Otherwise, ensure that this argument has already been created. 4614 if (nextArgument >= owner->getNumArguments()) 4615 return emitError("too many arguments specified in argument list"); 4616 4617 // Finally, make sure the existing argument has the correct type. 4618 auto arg = owner->getArgument(nextArgument++); 4619 if (arg.getType() != type) 4620 return emitError("argument and block argument type mismatch"); 4621 return addDefinition(useInfo, arg); 4622 }); 4623 }); 4624 } 4625 4626 //===----------------------------------------------------------------------===// 4627 // Top-level entity parsing. 4628 //===----------------------------------------------------------------------===// 4629 4630 namespace { 4631 /// This parser handles entities that are only valid at the top level of the 4632 /// file. 4633 class ModuleParser : public Parser { 4634 public: 4635 explicit ModuleParser(ParserState &state) : Parser(state) {} 4636 4637 ParseResult parseModule(ModuleOp module); 4638 4639 private: 4640 /// Parse an attribute alias declaration. 4641 ParseResult parseAttributeAliasDef(); 4642 4643 /// Parse an attribute alias declaration. 4644 ParseResult parseTypeAliasDef(); 4645 }; 4646 } // end anonymous namespace 4647 4648 /// Parses an attribute alias declaration. 4649 /// 4650 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 4651 /// 4652 ParseResult ModuleParser::parseAttributeAliasDef() { 4653 assert(getToken().is(Token::hash_identifier)); 4654 StringRef aliasName = getTokenSpelling().drop_front(); 4655 4656 // Check for redefinitions. 4657 if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) 4658 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 4659 4660 // Make sure this isn't invading the dialect attribute namespace. 4661 if (aliasName.contains('.')) 4662 return emitError("attribute names with a '.' are reserved for " 4663 "dialect-defined names"); 4664 4665 consumeToken(Token::hash_identifier); 4666 4667 // Parse the '='. 4668 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 4669 return failure(); 4670 4671 // Parse the attribute value. 4672 Attribute attr = parseAttribute(); 4673 if (!attr) 4674 return failure(); 4675 4676 getState().symbols.attributeAliasDefinitions[aliasName] = attr; 4677 return success(); 4678 } 4679 4680 /// Parse a type alias declaration. 4681 /// 4682 /// type-alias-def ::= '!' alias-name `=` 'type' type 4683 /// 4684 ParseResult ModuleParser::parseTypeAliasDef() { 4685 assert(getToken().is(Token::exclamation_identifier)); 4686 StringRef aliasName = getTokenSpelling().drop_front(); 4687 4688 // Check for redefinitions. 4689 if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) 4690 return emitError("redefinition of type alias id '" + aliasName + "'"); 4691 4692 // Make sure this isn't invading the dialect type namespace. 4693 if (aliasName.contains('.')) 4694 return emitError("type names with a '.' are reserved for " 4695 "dialect-defined names"); 4696 4697 consumeToken(Token::exclamation_identifier); 4698 4699 // Parse the '=' and 'type'. 4700 if (parseToken(Token::equal, "expected '=' in type alias definition") || 4701 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 4702 return failure(); 4703 4704 // Parse the type. 4705 Type aliasedType = parseType(); 4706 if (!aliasedType) 4707 return failure(); 4708 4709 // Register this alias with the parser state. 4710 getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); 4711 return success(); 4712 } 4713 4714 /// This is the top-level module parser. 4715 ParseResult ModuleParser::parseModule(ModuleOp module) { 4716 OperationParser opParser(getState(), module); 4717 4718 // Module itself is a name scope. 4719 opParser.pushSSANameScope(/*isIsolated=*/true); 4720 4721 while (true) { 4722 switch (getToken().getKind()) { 4723 default: 4724 // Parse a top-level operation. 4725 if (opParser.parseOperation()) 4726 return failure(); 4727 break; 4728 4729 // If we got to the end of the file, then we're done. 4730 case Token::eof: { 4731 if (opParser.finalize()) 4732 return failure(); 4733 4734 // Handle the case where the top level module was explicitly defined. 4735 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 4736 auto &operations = bodyBlocks.front().getOperations(); 4737 assert(!operations.empty() && "expected a valid module terminator"); 4738 4739 // Check that the first operation is a module, and it is the only 4740 // non-terminator operation. 4741 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 4742 if (nested && std::next(operations.begin(), 2) == operations.end()) { 4743 // Merge the data of the nested module operation into 'module'. 4744 module.setLoc(nested.getLoc()); 4745 module.setAttrs(nested.getOperation()->getAttrList()); 4746 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 4747 4748 // Erase the original module body. 4749 bodyBlocks.pop_front(); 4750 } 4751 4752 return opParser.popSSANameScope(); 4753 } 4754 4755 // If we got an error token, then the lexer already emitted an error, just 4756 // stop. Someday we could introduce error recovery if there was demand 4757 // for it. 4758 case Token::error: 4759 return failure(); 4760 4761 // Parse an attribute alias. 4762 case Token::hash_identifier: 4763 if (parseAttributeAliasDef()) 4764 return failure(); 4765 break; 4766 4767 // Parse a type alias. 4768 case Token::exclamation_identifier: 4769 if (parseTypeAliasDef()) 4770 return failure(); 4771 break; 4772 } 4773 } 4774 } 4775 4776 //===----------------------------------------------------------------------===// 4777 4778 /// This parses the file specified by the indicated SourceMgr and returns an 4779 /// MLIR module if it was valid. If not, it emits diagnostics and returns 4780 /// null. 4781 OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 4782 MLIRContext *context) { 4783 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 4784 4785 // This is the result module we are parsing into. 4786 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 4787 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 4788 4789 SymbolState aliasState; 4790 ParserState state(sourceMgr, context, aliasState); 4791 if (ModuleParser(state).parseModule(*module)) 4792 return nullptr; 4793 4794 // Make sure the parse module has no other structural problems detected by 4795 // the verifier. 4796 if (failed(verify(*module))) 4797 return nullptr; 4798 4799 return module; 4800 } 4801 4802 /// This parses the file specified by the indicated filename and returns an 4803 /// MLIR module if it was valid. If not, the error message is emitted through 4804 /// the error handler registered in the context, and a null pointer is returned. 4805 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4806 MLIRContext *context) { 4807 llvm::SourceMgr sourceMgr; 4808 return parseSourceFile(filename, sourceMgr, context); 4809 } 4810 4811 /// This parses the file specified by the indicated filename using the provided 4812 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 4813 /// message is emitted through the error handler registered in the context, and 4814 /// a null pointer is returned. 4815 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4816 llvm::SourceMgr &sourceMgr, 4817 MLIRContext *context) { 4818 if (sourceMgr.getNumBuffers() != 0) { 4819 // TODO(b/136086478): Extend to support multiple buffers. 4820 emitError(mlir::UnknownLoc::get(context), 4821 "only main buffer parsed at the moment"); 4822 return nullptr; 4823 } 4824 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 4825 if (std::error_code error = file_or_err.getError()) { 4826 emitError(mlir::UnknownLoc::get(context), 4827 "could not open input file " + filename); 4828 return nullptr; 4829 } 4830 4831 // Load the MLIR module. 4832 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 4833 return parseSourceFile(sourceMgr, context); 4834 } 4835 4836 /// This parses the program string to a MLIR module if it was valid. If not, 4837 /// it emits diagnostics and returns null. 4838 OwningModuleRef mlir::parseSourceString(StringRef moduleStr, 4839 MLIRContext *context) { 4840 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 4841 if (!memBuffer) 4842 return nullptr; 4843 4844 SourceMgr sourceMgr; 4845 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 4846 return parseSourceFile(sourceMgr, context); 4847 } 4848 4849 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 4850 /// parsing failed, nullptr is returned. The number of bytes read from the input 4851 /// string is returned in 'numRead'. 4852 template <typename T, typename ParserFn> 4853 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, 4854 ParserFn &&parserFn) { 4855 SymbolState aliasState; 4856 return parseSymbol<T>( 4857 inputStr, context, aliasState, 4858 [&](Parser &parser) { 4859 SourceMgrDiagnosticHandler handler( 4860 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), 4861 parser.getContext()); 4862 return parserFn(parser); 4863 }, 4864 &numRead); 4865 } 4866 4867 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { 4868 size_t numRead = 0; 4869 return parseAttribute(attrStr, context, numRead); 4870 } 4871 Attribute mlir::parseAttribute(StringRef attrStr, Type type) { 4872 size_t numRead = 0; 4873 return parseAttribute(attrStr, type, numRead); 4874 } 4875 4876 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, 4877 size_t &numRead) { 4878 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { 4879 return parser.parseAttribute(); 4880 }); 4881 } 4882 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { 4883 return parseSymbol<Attribute>( 4884 attrStr, type.getContext(), numRead, 4885 [type](Parser &parser) { return parser.parseAttribute(type); }); 4886 } 4887 4888 Type mlir::parseType(StringRef typeStr, MLIRContext *context) { 4889 size_t numRead = 0; 4890 return parseType(typeStr, context, numRead); 4891 } 4892 4893 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { 4894 return parseSymbol<Type>(typeStr, context, numRead, 4895 [](Parser &parser) { return parser.parseType(); }); 4896 } 4897