1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR textual form. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Parser.h" 14 #include "Lexer.h" 15 #include "mlir/Analysis/Verifier.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/IR/Location.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/Module.h" 26 #include "mlir/IR/OpImplementation.h" 27 #include "mlir/IR/StandardTypes.h" 28 #include "mlir/Support/STLExtras.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/ADT/bit.h" 33 #include "llvm/Support/PrettyStackTrace.h" 34 #include "llvm/Support/SMLoc.h" 35 #include "llvm/Support/SourceMgr.h" 36 #include <algorithm> 37 using namespace mlir; 38 using llvm::MemoryBuffer; 39 using llvm::SMLoc; 40 using llvm::SourceMgr; 41 42 namespace { 43 class Parser; 44 45 //===----------------------------------------------------------------------===// 46 // SymbolState 47 //===----------------------------------------------------------------------===// 48 49 /// This class contains record of any parsed top-level symbols. 50 struct SymbolState { 51 // A map from attribute alias identifier to Attribute. 52 llvm::StringMap<Attribute> attributeAliasDefinitions; 53 54 // A map from type alias identifier to Type. 55 llvm::StringMap<Type> typeAliasDefinitions; 56 57 /// A set of locations into the main parser memory buffer for each of the 58 /// active nested parsers. Given that some nested parsers, i.e. custom dialect 59 /// parsers, operate on a temporary memory buffer, this provides an anchor 60 /// point for emitting diagnostics. 61 SmallVector<llvm::SMLoc, 1> nestedParserLocs; 62 63 /// The top-level lexer that contains the original memory buffer provided by 64 /// the user. This is used by nested parsers to get a properly encoded source 65 /// location. 66 Lexer *topLevelLexer = nullptr; 67 }; 68 69 //===----------------------------------------------------------------------===// 70 // ParserState 71 //===----------------------------------------------------------------------===// 72 73 /// This class refers to all of the state maintained globally by the parser, 74 /// such as the current lexer position etc. 75 struct ParserState { 76 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, 77 SymbolState &symbols) 78 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), 79 symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { 80 // Set the top level lexer for the symbol state if one doesn't exist. 81 if (!symbols.topLevelLexer) 82 symbols.topLevelLexer = &lex; 83 } 84 ~ParserState() { 85 // Reset the top level lexer if it refers the lexer in our state. 86 if (symbols.topLevelLexer == &lex) 87 symbols.topLevelLexer = nullptr; 88 } 89 ParserState(const ParserState &) = delete; 90 void operator=(const ParserState &) = delete; 91 92 /// The context we're parsing into. 93 MLIRContext *const context; 94 95 /// The lexer for the source file we're parsing. 96 Lexer lex; 97 98 /// This is the next token that hasn't been consumed yet. 99 Token curToken; 100 101 /// The current state for symbol parsing. 102 SymbolState &symbols; 103 104 /// The depth of this parser in the nested parsing stack. 105 size_t parserDepth; 106 }; 107 108 //===----------------------------------------------------------------------===// 109 // Parser 110 //===----------------------------------------------------------------------===// 111 112 /// This class implement support for parsing global entities like types and 113 /// shared entities like SSA names. It is intended to be subclassed by 114 /// specialized subparsers that include state, e.g. when a local symbol table. 115 class Parser { 116 public: 117 Builder builder; 118 119 Parser(ParserState &state) : builder(state.context), state(state) {} 120 121 // Helper methods to get stuff from the parser-global state. 122 ParserState &getState() const { return state; } 123 MLIRContext *getContext() const { return state.context; } 124 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 125 126 /// Parse a comma-separated list of elements up until the specified end token. 127 ParseResult 128 parseCommaSeparatedListUntil(Token::Kind rightToken, 129 const std::function<ParseResult()> &parseElement, 130 bool allowEmptyList = true); 131 132 /// Parse a comma separated list of elements that must have at least one entry 133 /// in it. 134 ParseResult 135 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 136 137 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 138 139 // We have two forms of parsing methods - those that return a non-null 140 // pointer on success, and those that return a ParseResult to indicate whether 141 // they returned a failure. The second class fills in by-reference arguments 142 // as the results of their action. 143 144 //===--------------------------------------------------------------------===// 145 // Error Handling 146 //===--------------------------------------------------------------------===// 147 148 /// Emit an error and return failure. 149 InFlightDiagnostic emitError(const Twine &message = {}) { 150 return emitError(state.curToken.getLoc(), message); 151 } 152 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 153 154 /// Encode the specified source location information into an attribute for 155 /// attachment to the IR. 156 Location getEncodedSourceLocation(llvm::SMLoc loc) { 157 // If there are no active nested parsers, we can get the encoded source 158 // location directly. 159 if (state.parserDepth == 0) 160 return state.lex.getEncodedSourceLocation(loc); 161 // Otherwise, we need to re-encode it to point to the top level buffer. 162 return state.symbols.topLevelLexer->getEncodedSourceLocation( 163 remapLocationToTopLevelBuffer(loc)); 164 } 165 166 /// Remaps the given SMLoc to the top level lexer of the parser. This is used 167 /// to adjust locations of potentially nested parsers to ensure that they can 168 /// be emitted properly as diagnostics. 169 llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { 170 // If there are no active nested parsers, we can return location directly. 171 SymbolState &symbols = state.symbols; 172 if (state.parserDepth == 0) 173 return loc; 174 assert(symbols.topLevelLexer && "expected valid top-level lexer"); 175 176 // Otherwise, we need to remap the location to the main parser. This is 177 // simply offseting the location onto the location of the last nested 178 // parser. 179 size_t offset = loc.getPointer() - state.lex.getBufferBegin(); 180 auto *rawLoc = 181 symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; 182 return llvm::SMLoc::getFromPointer(rawLoc); 183 } 184 185 //===--------------------------------------------------------------------===// 186 // Token Parsing 187 //===--------------------------------------------------------------------===// 188 189 /// Return the current token the parser is inspecting. 190 const Token &getToken() const { return state.curToken; } 191 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 192 193 /// If the current token has the specified kind, consume it and return true. 194 /// If not, return false. 195 bool consumeIf(Token::Kind kind) { 196 if (state.curToken.isNot(kind)) 197 return false; 198 consumeToken(kind); 199 return true; 200 } 201 202 /// Advance the current lexer onto the next token. 203 void consumeToken() { 204 assert(state.curToken.isNot(Token::eof, Token::error) && 205 "shouldn't advance past EOF or errors"); 206 state.curToken = state.lex.lexToken(); 207 } 208 209 /// Advance the current lexer onto the next token, asserting what the expected 210 /// current token is. This is preferred to the above method because it leads 211 /// to more self-documenting code with better checking. 212 void consumeToken(Token::Kind kind) { 213 assert(state.curToken.is(kind) && "consumed an unexpected token"); 214 consumeToken(); 215 } 216 217 /// Consume the specified token if present and return success. On failure, 218 /// output a diagnostic and return failure. 219 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 220 221 //===--------------------------------------------------------------------===// 222 // Type Parsing 223 //===--------------------------------------------------------------------===// 224 225 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 226 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 227 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 228 229 /// Parse an arbitrary type. 230 Type parseType(); 231 232 /// Parse a complex type. 233 Type parseComplexType(); 234 235 /// Parse an extended type. 236 Type parseExtendedType(); 237 238 /// Parse a function type. 239 Type parseFunctionType(); 240 241 /// Parse a memref type. 242 Type parseMemRefType(); 243 244 /// Parse a non function type. 245 Type parseNonFunctionType(); 246 247 /// Parse a tensor type. 248 Type parseTensorType(); 249 250 /// Parse a tuple type. 251 Type parseTupleType(); 252 253 /// Parse a vector type. 254 VectorType parseVectorType(); 255 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 256 bool allowDynamic = true); 257 ParseResult parseXInDimensionList(); 258 259 /// Parse strided layout specification. 260 ParseResult parseStridedLayout(int64_t &offset, 261 SmallVectorImpl<int64_t> &strides); 262 263 // Parse a brace-delimiter list of comma-separated integers with `?` as an 264 // unknown marker. 265 ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions); 266 267 //===--------------------------------------------------------------------===// 268 // Attribute Parsing 269 //===--------------------------------------------------------------------===// 270 271 /// Parse an arbitrary attribute with an optional type. 272 Attribute parseAttribute(Type type = {}); 273 274 /// Parse an attribute dictionary. 275 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 276 277 /// Parse an extended attribute. 278 Attribute parseExtendedAttr(Type type); 279 280 /// Parse a float attribute. 281 Attribute parseFloatAttr(Type type, bool isNegative); 282 283 /// Parse a decimal or a hexadecimal literal, which can be either an integer 284 /// or a float attribute. 285 Attribute parseDecOrHexAttr(Type type, bool isNegative); 286 287 /// Parse an opaque elements attribute. 288 Attribute parseOpaqueElementsAttr(Type attrType); 289 290 /// Parse a dense elements attribute. 291 Attribute parseDenseElementsAttr(Type attrType); 292 ShapedType parseElementsLiteralType(Type type); 293 294 /// Parse a sparse elements attribute. 295 Attribute parseSparseElementsAttr(Type attrType); 296 297 //===--------------------------------------------------------------------===// 298 // Location Parsing 299 //===--------------------------------------------------------------------===// 300 301 /// Parse an inline location. 302 ParseResult parseLocation(LocationAttr &loc); 303 304 /// Parse a raw location instance. 305 ParseResult parseLocationInstance(LocationAttr &loc); 306 307 /// Parse a callsite location instance. 308 ParseResult parseCallSiteLocation(LocationAttr &loc); 309 310 /// Parse a fused location instance. 311 ParseResult parseFusedLocation(LocationAttr &loc); 312 313 /// Parse a name or FileLineCol location instance. 314 ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); 315 316 /// Parse an optional trailing location. 317 /// 318 /// trailing-location ::= (`loc` `(` location `)`)? 319 /// 320 ParseResult parseOptionalTrailingLocation(Location &loc) { 321 // If there is a 'loc' we parse a trailing location. 322 if (!getToken().is(Token::kw_loc)) 323 return success(); 324 325 // Parse the location. 326 LocationAttr directLoc; 327 if (parseLocation(directLoc)) 328 return failure(); 329 loc = directLoc; 330 return success(); 331 } 332 333 //===--------------------------------------------------------------------===// 334 // Affine Parsing 335 //===--------------------------------------------------------------------===// 336 337 /// Parse a reference to either an affine map, or an integer set. 338 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 339 IntegerSet &set); 340 ParseResult parseAffineMapReference(AffineMap &map); 341 ParseResult parseIntegerSetReference(IntegerSet &set); 342 343 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 344 ParseResult 345 parseAffineMapOfSSAIds(AffineMap &map, 346 function_ref<ParseResult(bool)> parseElement); 347 348 private: 349 /// The Parser is subclassed and reinstantiated. Do not add additional 350 /// non-trivial state here, add it to the ParserState class. 351 ParserState &state; 352 }; 353 } // end anonymous namespace 354 355 //===----------------------------------------------------------------------===// 356 // Helper methods. 357 //===----------------------------------------------------------------------===// 358 359 /// Parse a comma separated list of elements that must have at least one entry 360 /// in it. 361 ParseResult Parser::parseCommaSeparatedList( 362 const std::function<ParseResult()> &parseElement) { 363 // Non-empty case starts with an element. 364 if (parseElement()) 365 return failure(); 366 367 // Otherwise we have a list of comma separated elements. 368 while (consumeIf(Token::comma)) { 369 if (parseElement()) 370 return failure(); 371 } 372 return success(); 373 } 374 375 /// Parse a comma-separated list of elements, terminated with an arbitrary 376 /// token. This allows empty lists if allowEmptyList is true. 377 /// 378 /// abstract-list ::= rightToken // if allowEmptyList == true 379 /// abstract-list ::= element (',' element)* rightToken 380 /// 381 ParseResult Parser::parseCommaSeparatedListUntil( 382 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 383 bool allowEmptyList) { 384 // Handle the empty case. 385 if (getToken().is(rightToken)) { 386 if (!allowEmptyList) 387 return emitError("expected list element"); 388 consumeToken(rightToken); 389 return success(); 390 } 391 392 if (parseCommaSeparatedList(parseElement) || 393 parseToken(rightToken, "expected ',' or '" + 394 Token::getTokenSpelling(rightToken) + "'")) 395 return failure(); 396 397 return success(); 398 } 399 400 //===----------------------------------------------------------------------===// 401 // DialectAsmParser 402 //===----------------------------------------------------------------------===// 403 404 namespace { 405 /// This class provides the main implementation of the DialectAsmParser that 406 /// allows for dialects to parse attributes and types. This allows for dialect 407 /// hooking into the main MLIR parsing logic. 408 class CustomDialectAsmParser : public DialectAsmParser { 409 public: 410 CustomDialectAsmParser(StringRef fullSpec, Parser &parser) 411 : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), 412 parser(parser) {} 413 ~CustomDialectAsmParser() override {} 414 415 /// Emit a diagnostic at the specified location and return failure. 416 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 417 return parser.emitError(loc, message); 418 } 419 420 /// Return a builder which provides useful access to MLIRContext, global 421 /// objects like types and attributes. 422 Builder &getBuilder() const override { return parser.builder; } 423 424 /// Get the location of the next token and store it into the argument. This 425 /// always succeeds. 426 llvm::SMLoc getCurrentLocation() override { 427 return parser.getToken().getLoc(); 428 } 429 430 /// Return the location of the original name token. 431 llvm::SMLoc getNameLoc() const override { return nameLoc; } 432 433 /// Re-encode the given source location as an MLIR location and return it. 434 Location getEncodedSourceLoc(llvm::SMLoc loc) override { 435 return parser.getEncodedSourceLocation(loc); 436 } 437 438 /// Returns the full specification of the symbol being parsed. This allows 439 /// for using a separate parser if necessary. 440 StringRef getFullSymbolSpec() const override { return fullSpec; } 441 442 /// Parse a floating point value from the stream. 443 ParseResult parseFloat(double &result) override { 444 bool negative = parser.consumeIf(Token::minus); 445 Token curTok = parser.getToken(); 446 447 // Check for a floating point value. 448 if (curTok.is(Token::floatliteral)) { 449 auto val = curTok.getFloatingPointValue(); 450 if (!val.hasValue()) 451 return emitError(curTok.getLoc(), "floating point value too large"); 452 parser.consumeToken(Token::floatliteral); 453 result = negative ? -*val : *val; 454 return success(); 455 } 456 457 // TODO(riverriddle) support hex floating point values. 458 return emitError(getCurrentLocation(), "expected floating point literal"); 459 } 460 461 /// Parse an optional integer value from the stream. 462 OptionalParseResult parseOptionalInteger(uint64_t &result) override { 463 Token curToken = parser.getToken(); 464 if (curToken.isNot(Token::integer, Token::minus)) 465 return llvm::None; 466 467 bool negative = parser.consumeIf(Token::minus); 468 Token curTok = parser.getToken(); 469 if (parser.parseToken(Token::integer, "expected integer value")) 470 return failure(); 471 472 auto val = curTok.getUInt64IntegerValue(); 473 if (!val) 474 return emitError(curTok.getLoc(), "integer value too large"); 475 result = negative ? -*val : *val; 476 return success(); 477 } 478 479 //===--------------------------------------------------------------------===// 480 // Token Parsing 481 //===--------------------------------------------------------------------===// 482 483 /// Parse a `->` token. 484 ParseResult parseArrow() override { 485 return parser.parseToken(Token::arrow, "expected '->'"); 486 } 487 488 /// Parses a `->` if present. 489 ParseResult parseOptionalArrow() override { 490 return success(parser.consumeIf(Token::arrow)); 491 } 492 493 /// Parse a '{' token. 494 ParseResult parseLBrace() override { 495 return parser.parseToken(Token::l_brace, "expected '{'"); 496 } 497 498 /// Parse a '{' token if present 499 ParseResult parseOptionalLBrace() override { 500 return success(parser.consumeIf(Token::l_brace)); 501 } 502 503 /// Parse a `}` token. 504 ParseResult parseRBrace() override { 505 return parser.parseToken(Token::r_brace, "expected '}'"); 506 } 507 508 /// Parse a `}` token if present 509 ParseResult parseOptionalRBrace() override { 510 return success(parser.consumeIf(Token::r_brace)); 511 } 512 513 /// Parse a `:` token. 514 ParseResult parseColon() override { 515 return parser.parseToken(Token::colon, "expected ':'"); 516 } 517 518 /// Parse a `:` token if present. 519 ParseResult parseOptionalColon() override { 520 return success(parser.consumeIf(Token::colon)); 521 } 522 523 /// Parse a `,` token. 524 ParseResult parseComma() override { 525 return parser.parseToken(Token::comma, "expected ','"); 526 } 527 528 /// Parse a `,` token if present. 529 ParseResult parseOptionalComma() override { 530 return success(parser.consumeIf(Token::comma)); 531 } 532 533 /// Parses a `...` if present. 534 ParseResult parseOptionalEllipsis() override { 535 return success(parser.consumeIf(Token::ellipsis)); 536 } 537 538 /// Parse a `=` token. 539 ParseResult parseEqual() override { 540 return parser.parseToken(Token::equal, "expected '='"); 541 } 542 543 /// Parse a '<' token. 544 ParseResult parseLess() override { 545 return parser.parseToken(Token::less, "expected '<'"); 546 } 547 548 /// Parse a `<` token if present. 549 ParseResult parseOptionalLess() override { 550 return success(parser.consumeIf(Token::less)); 551 } 552 553 /// Parse a '>' token. 554 ParseResult parseGreater() override { 555 return parser.parseToken(Token::greater, "expected '>'"); 556 } 557 558 /// Parse a `>` token if present. 559 ParseResult parseOptionalGreater() override { 560 return success(parser.consumeIf(Token::greater)); 561 } 562 563 /// Parse a `(` token. 564 ParseResult parseLParen() override { 565 return parser.parseToken(Token::l_paren, "expected '('"); 566 } 567 568 /// Parses a '(' if present. 569 ParseResult parseOptionalLParen() override { 570 return success(parser.consumeIf(Token::l_paren)); 571 } 572 573 /// Parse a `)` token. 574 ParseResult parseRParen() override { 575 return parser.parseToken(Token::r_paren, "expected ')'"); 576 } 577 578 /// Parses a ')' if present. 579 ParseResult parseOptionalRParen() override { 580 return success(parser.consumeIf(Token::r_paren)); 581 } 582 583 /// Parse a `[` token. 584 ParseResult parseLSquare() override { 585 return parser.parseToken(Token::l_square, "expected '['"); 586 } 587 588 /// Parses a '[' if present. 589 ParseResult parseOptionalLSquare() override { 590 return success(parser.consumeIf(Token::l_square)); 591 } 592 593 /// Parse a `]` token. 594 ParseResult parseRSquare() override { 595 return parser.parseToken(Token::r_square, "expected ']'"); 596 } 597 598 /// Parses a ']' if present. 599 ParseResult parseOptionalRSquare() override { 600 return success(parser.consumeIf(Token::r_square)); 601 } 602 603 /// Parses a '?' if present. 604 ParseResult parseOptionalQuestion() override { 605 return success(parser.consumeIf(Token::question)); 606 } 607 608 /// Parses a '*' if present. 609 ParseResult parseOptionalStar() override { 610 return success(parser.consumeIf(Token::star)); 611 } 612 613 /// Returns if the current token corresponds to a keyword. 614 bool isCurrentTokenAKeyword() const { 615 return parser.getToken().is(Token::bare_identifier) || 616 parser.getToken().isKeyword(); 617 } 618 619 /// Parse the given keyword if present. 620 ParseResult parseOptionalKeyword(StringRef keyword) override { 621 // Check that the current token has the same spelling. 622 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 623 return failure(); 624 parser.consumeToken(); 625 return success(); 626 } 627 628 /// Parse a keyword, if present, into 'keyword'. 629 ParseResult parseOptionalKeyword(StringRef *keyword) override { 630 // Check that the current token is a keyword. 631 if (!isCurrentTokenAKeyword()) 632 return failure(); 633 634 *keyword = parser.getTokenSpelling(); 635 parser.consumeToken(); 636 return success(); 637 } 638 639 //===--------------------------------------------------------------------===// 640 // Attribute Parsing 641 //===--------------------------------------------------------------------===// 642 643 /// Parse an arbitrary attribute and return it in result. 644 ParseResult parseAttribute(Attribute &result, Type type) override { 645 result = parser.parseAttribute(type); 646 return success(static_cast<bool>(result)); 647 } 648 649 /// Parse an affine map instance into 'map'. 650 ParseResult parseAffineMap(AffineMap &map) override { 651 return parser.parseAffineMapReference(map); 652 } 653 654 /// Parse an integer set instance into 'set'. 655 ParseResult printIntegerSet(IntegerSet &set) override { 656 return parser.parseIntegerSetReference(set); 657 } 658 659 //===--------------------------------------------------------------------===// 660 // Type Parsing 661 //===--------------------------------------------------------------------===// 662 663 ParseResult parseType(Type &result) override { 664 result = parser.parseType(); 665 return success(static_cast<bool>(result)); 666 } 667 668 ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, 669 bool allowDynamic) override { 670 return parser.parseDimensionListRanked(dimensions, allowDynamic); 671 } 672 673 private: 674 /// The full symbol specification. 675 StringRef fullSpec; 676 677 /// The source location of the dialect symbol. 678 SMLoc nameLoc; 679 680 /// The main parser. 681 Parser &parser; 682 }; 683 } // namespace 684 685 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 686 /// and may be recursive. Return with the 'prettyName' StringRef encompassing 687 /// the entire pretty name. 688 /// 689 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 690 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 691 /// | '(' pretty-dialect-sym-contents+ ')' 692 /// | '[' pretty-dialect-sym-contents+ ']' 693 /// | '{' pretty-dialect-sym-contents+ '}' 694 /// | '[^[<({>\])}\0]+' 695 /// 696 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 697 // Pretty symbol names are a relatively unstructured format that contains a 698 // series of properly nested punctuation, with anything else in the middle. 699 // Scan ahead to find it and consume it if successful, otherwise emit an 700 // error. 701 auto *curPtr = getTokenSpelling().data(); 702 703 SmallVector<char, 8> nestedPunctuation; 704 705 // Scan over the nested punctuation, bailing out on error and consuming until 706 // we find the end. We know that we're currently looking at the '<', so we 707 // can go until we find the matching '>' character. 708 assert(*curPtr == '<'); 709 do { 710 char c = *curPtr++; 711 switch (c) { 712 case '\0': 713 // This also handles the EOF case. 714 return emitError("unexpected nul or EOF in pretty dialect name"); 715 case '<': 716 case '[': 717 case '(': 718 case '{': 719 nestedPunctuation.push_back(c); 720 continue; 721 722 case '-': 723 // The sequence `->` is treated as special token. 724 if (*curPtr == '>') 725 ++curPtr; 726 continue; 727 728 case '>': 729 if (nestedPunctuation.pop_back_val() != '<') 730 return emitError("unbalanced '>' character in pretty dialect name"); 731 break; 732 case ']': 733 if (nestedPunctuation.pop_back_val() != '[') 734 return emitError("unbalanced ']' character in pretty dialect name"); 735 break; 736 case ')': 737 if (nestedPunctuation.pop_back_val() != '(') 738 return emitError("unbalanced ')' character in pretty dialect name"); 739 break; 740 case '}': 741 if (nestedPunctuation.pop_back_val() != '{') 742 return emitError("unbalanced '}' character in pretty dialect name"); 743 break; 744 745 default: 746 continue; 747 } 748 } while (!nestedPunctuation.empty()); 749 750 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 751 // consuming all this stuff, and return. 752 state.lex.resetPointer(curPtr); 753 754 unsigned length = curPtr - prettyName.begin(); 755 prettyName = StringRef(prettyName.begin(), length); 756 consumeToken(); 757 return success(); 758 } 759 760 /// Parse an extended dialect symbol. 761 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 762 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 763 SymbolAliasMap &aliases, 764 CreateFn &&createSymbol) { 765 // Parse the dialect namespace. 766 StringRef identifier = p.getTokenSpelling().drop_front(); 767 auto loc = p.getToken().getLoc(); 768 p.consumeToken(identifierTok); 769 770 // If there is no '<' token following this, and if the typename contains no 771 // dot, then we are parsing a symbol alias. 772 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 773 // Check for an alias for this type. 774 auto aliasIt = aliases.find(identifier); 775 if (aliasIt == aliases.end()) 776 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 777 nullptr); 778 return aliasIt->second; 779 } 780 781 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 782 // a dot, then this is the "pretty" form. If not, it is the verbose form that 783 // looks like <"...">. 784 std::string symbolData; 785 auto dialectName = identifier; 786 787 // Handle the verbose form, where "identifier" is a simple dialect name. 788 if (!identifier.contains('.')) { 789 // Consume the '<'. 790 if (p.parseToken(Token::less, "expected '<' in dialect type")) 791 return nullptr; 792 793 // Parse the symbol specific data. 794 if (p.getToken().isNot(Token::string)) 795 return (p.emitError("expected string literal data in dialect symbol"), 796 nullptr); 797 symbolData = p.getToken().getStringValue(); 798 loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); 799 p.consumeToken(Token::string); 800 801 // Consume the '>'. 802 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 803 return nullptr; 804 } else { 805 // Ok, the dialect name is the part of the identifier before the dot, the 806 // part after the dot is the dialect's symbol, or the start thereof. 807 auto dotHalves = identifier.split('.'); 808 dialectName = dotHalves.first; 809 auto prettyName = dotHalves.second; 810 loc = llvm::SMLoc::getFromPointer(prettyName.data()); 811 812 // If the dialect's symbol is followed immediately by a <, then lex the body 813 // of it into prettyName. 814 if (p.getToken().is(Token::less) && 815 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 816 if (p.parsePrettyDialectSymbolName(prettyName)) 817 return nullptr; 818 } 819 820 symbolData = prettyName.str(); 821 } 822 823 // Record the name location of the type remapped to the top level buffer. 824 llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); 825 p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); 826 827 // Call into the provided symbol construction function. 828 Symbol sym = createSymbol(dialectName, symbolData, loc); 829 830 // Pop the last parser location. 831 p.getState().symbols.nestedParserLocs.pop_back(); 832 return sym; 833 } 834 835 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 836 /// parsing failed, nullptr is returned. The number of bytes read from the input 837 /// string is returned in 'numRead'. 838 template <typename T, typename ParserFn> 839 static T parseSymbol(StringRef inputStr, MLIRContext *context, 840 SymbolState &symbolState, ParserFn &&parserFn, 841 size_t *numRead = nullptr) { 842 SourceMgr sourceMgr; 843 auto memBuffer = MemoryBuffer::getMemBuffer( 844 inputStr, /*BufferName=*/"<mlir_parser_buffer>", 845 /*RequiresNullTerminator=*/false); 846 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 847 ParserState state(sourceMgr, context, symbolState); 848 Parser parser(state); 849 850 Token startTok = parser.getToken(); 851 T symbol = parserFn(parser); 852 if (!symbol) 853 return T(); 854 855 // If 'numRead' is valid, then provide the number of bytes that were read. 856 Token endTok = parser.getToken(); 857 if (numRead) { 858 *numRead = static_cast<size_t>(endTok.getLoc().getPointer() - 859 startTok.getLoc().getPointer()); 860 861 // Otherwise, ensure that all of the tokens were parsed. 862 } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { 863 parser.emitError(endTok.getLoc(), "encountered unexpected token"); 864 return T(); 865 } 866 return symbol; 867 } 868 869 //===----------------------------------------------------------------------===// 870 // Error Handling 871 //===----------------------------------------------------------------------===// 872 873 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 874 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 875 876 // If we hit a parse error in response to a lexer error, then the lexer 877 // already reported the error. 878 if (getToken().is(Token::error)) 879 diag.abandon(); 880 return diag; 881 } 882 883 //===----------------------------------------------------------------------===// 884 // Token Parsing 885 //===----------------------------------------------------------------------===// 886 887 /// Consume the specified token if present and return success. On failure, 888 /// output a diagnostic and return failure. 889 ParseResult Parser::parseToken(Token::Kind expectedToken, 890 const Twine &message) { 891 if (consumeIf(expectedToken)) 892 return success(); 893 return emitError(message); 894 } 895 896 //===----------------------------------------------------------------------===// 897 // Type Parsing 898 //===----------------------------------------------------------------------===// 899 900 /// Parse an arbitrary type. 901 /// 902 /// type ::= function-type 903 /// | non-function-type 904 /// 905 Type Parser::parseType() { 906 if (getToken().is(Token::l_paren)) 907 return parseFunctionType(); 908 return parseNonFunctionType(); 909 } 910 911 /// Parse a function result type. 912 /// 913 /// function-result-type ::= type-list-parens 914 /// | non-function-type 915 /// 916 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 917 if (getToken().is(Token::l_paren)) 918 return parseTypeListParens(elements); 919 920 Type t = parseNonFunctionType(); 921 if (!t) 922 return failure(); 923 elements.push_back(t); 924 return success(); 925 } 926 927 /// Parse a list of types without an enclosing parenthesis. The list must have 928 /// at least one member. 929 /// 930 /// type-list-no-parens ::= type (`,` type)* 931 /// 932 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 933 auto parseElt = [&]() -> ParseResult { 934 auto elt = parseType(); 935 elements.push_back(elt); 936 return elt ? success() : failure(); 937 }; 938 939 return parseCommaSeparatedList(parseElt); 940 } 941 942 /// Parse a parenthesized list of types. 943 /// 944 /// type-list-parens ::= `(` `)` 945 /// | `(` type-list-no-parens `)` 946 /// 947 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 948 if (parseToken(Token::l_paren, "expected '('")) 949 return failure(); 950 951 // Handle empty lists. 952 if (getToken().is(Token::r_paren)) 953 return consumeToken(), success(); 954 955 if (parseTypeListNoParens(elements) || 956 parseToken(Token::r_paren, "expected ')'")) 957 return failure(); 958 return success(); 959 } 960 961 /// Parse a complex type. 962 /// 963 /// complex-type ::= `complex` `<` type `>` 964 /// 965 Type Parser::parseComplexType() { 966 consumeToken(Token::kw_complex); 967 968 // Parse the '<'. 969 if (parseToken(Token::less, "expected '<' in complex type")) 970 return nullptr; 971 972 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 973 auto elementType = parseType(); 974 if (!elementType || 975 parseToken(Token::greater, "expected '>' in complex type")) 976 return nullptr; 977 978 return ComplexType::getChecked(elementType, typeLocation); 979 } 980 981 /// Parse an extended type. 982 /// 983 /// extended-type ::= (dialect-type | type-alias) 984 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 985 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 986 /// type-alias ::= `!` alias-name 987 /// 988 Type Parser::parseExtendedType() { 989 return parseExtendedSymbol<Type>( 990 *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, 991 [&](StringRef dialectName, StringRef symbolData, 992 llvm::SMLoc loc) -> Type { 993 // If we found a registered dialect, then ask it to parse the type. 994 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 995 return parseSymbol<Type>( 996 symbolData, state.context, state.symbols, [&](Parser &parser) { 997 CustomDialectAsmParser customParser(symbolData, parser); 998 return dialect->parseType(customParser); 999 }); 1000 } 1001 1002 // Otherwise, form a new opaque type. 1003 return OpaqueType::getChecked( 1004 Identifier::get(dialectName, state.context), symbolData, 1005 state.context, getEncodedSourceLocation(loc)); 1006 }); 1007 } 1008 1009 /// Parse a function type. 1010 /// 1011 /// function-type ::= type-list-parens `->` function-result-type 1012 /// 1013 Type Parser::parseFunctionType() { 1014 assert(getToken().is(Token::l_paren)); 1015 1016 SmallVector<Type, 4> arguments, results; 1017 if (parseTypeListParens(arguments) || 1018 parseToken(Token::arrow, "expected '->' in function type") || 1019 parseFunctionResultTypes(results)) 1020 return nullptr; 1021 1022 return builder.getFunctionType(arguments, results); 1023 } 1024 1025 /// Parse the offset and strides from a strided layout specification. 1026 /// 1027 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 1028 /// 1029 ParseResult Parser::parseStridedLayout(int64_t &offset, 1030 SmallVectorImpl<int64_t> &strides) { 1031 // Parse offset. 1032 consumeToken(Token::kw_offset); 1033 if (!consumeIf(Token::colon)) 1034 return emitError("expected colon after `offset` keyword"); 1035 auto maybeOffset = getToken().getUnsignedIntegerValue(); 1036 bool question = getToken().is(Token::question); 1037 if (!maybeOffset && !question) 1038 return emitError("invalid offset"); 1039 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue()) 1040 : MemRefType::getDynamicStrideOrOffset(); 1041 consumeToken(); 1042 1043 if (!consumeIf(Token::comma)) 1044 return emitError("expected comma after offset value"); 1045 1046 // Parse stride list. 1047 if (!consumeIf(Token::kw_strides)) 1048 return emitError("expected `strides` keyword after offset specification"); 1049 if (!consumeIf(Token::colon)) 1050 return emitError("expected colon after `strides` keyword"); 1051 if (failed(parseStrideList(strides))) 1052 return emitError("invalid braces-enclosed stride list"); 1053 if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) 1054 return emitError("invalid memref stride"); 1055 1056 return success(); 1057 } 1058 1059 /// Parse a memref type. 1060 /// 1061 /// memref-type ::= ranked-memref-type | unranked-memref-type 1062 /// 1063 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type 1064 /// (`,` semi-affine-map-composition)? (`,` 1065 /// memory-space)? `>` 1066 /// 1067 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` 1068 /// 1069 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 1070 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 1071 /// 1072 Type Parser::parseMemRefType() { 1073 consumeToken(Token::kw_memref); 1074 1075 if (parseToken(Token::less, "expected '<' in memref type")) 1076 return nullptr; 1077 1078 bool isUnranked; 1079 SmallVector<int64_t, 4> dimensions; 1080 1081 if (consumeIf(Token::star)) { 1082 // This is an unranked memref type. 1083 isUnranked = true; 1084 if (parseXInDimensionList()) 1085 return nullptr; 1086 1087 } else { 1088 isUnranked = false; 1089 if (parseDimensionListRanked(dimensions)) 1090 return nullptr; 1091 } 1092 1093 // Parse the element type. 1094 auto typeLoc = getToken().getLoc(); 1095 auto elementType = parseType(); 1096 if (!elementType) 1097 return nullptr; 1098 1099 // Parse semi-affine-map-composition. 1100 SmallVector<AffineMap, 2> affineMapComposition; 1101 unsigned memorySpace = 0; 1102 bool parsedMemorySpace = false; 1103 1104 auto parseElt = [&]() -> ParseResult { 1105 if (getToken().is(Token::integer)) { 1106 // Parse memory space. 1107 if (parsedMemorySpace) 1108 return emitError("multiple memory spaces specified in memref type"); 1109 auto v = getToken().getUnsignedIntegerValue(); 1110 if (!v.hasValue()) 1111 return emitError("invalid memory space in memref type"); 1112 memorySpace = v.getValue(); 1113 consumeToken(Token::integer); 1114 parsedMemorySpace = true; 1115 } else { 1116 if (isUnranked) 1117 return emitError("cannot have affine map for unranked memref type"); 1118 if (parsedMemorySpace) 1119 return emitError("expected memory space to be last in memref type"); 1120 if (getToken().is(Token::kw_offset)) { 1121 int64_t offset; 1122 SmallVector<int64_t, 4> strides; 1123 if (failed(parseStridedLayout(offset, strides))) 1124 return failure(); 1125 // Construct strided affine map. 1126 auto map = makeStridedLinearLayoutMap(strides, offset, 1127 elementType.getContext()); 1128 affineMapComposition.push_back(map); 1129 } else { 1130 // Parse affine map. 1131 auto affineMap = parseAttribute(); 1132 if (!affineMap) 1133 return failure(); 1134 // Verify that the parsed attribute is an affine map. 1135 if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>()) 1136 affineMapComposition.push_back(affineMapAttr.getValue()); 1137 else 1138 return emitError("expected affine map in memref type"); 1139 } 1140 } 1141 return success(); 1142 }; 1143 1144 // Parse a list of mappings and address space if present. 1145 if (consumeIf(Token::comma)) { 1146 // Parse comma separated list of affine maps, followed by memory space. 1147 if (parseCommaSeparatedListUntil(Token::greater, parseElt, 1148 /*allowEmptyList=*/false)) { 1149 return nullptr; 1150 } 1151 } else { 1152 if (parseToken(Token::greater, "expected ',' or '>' in memref type")) 1153 return nullptr; 1154 } 1155 1156 if (isUnranked) 1157 return UnrankedMemRefType::getChecked(elementType, memorySpace, 1158 getEncodedSourceLocation(typeLoc)); 1159 1160 return MemRefType::getChecked(dimensions, elementType, affineMapComposition, 1161 memorySpace, getEncodedSourceLocation(typeLoc)); 1162 } 1163 1164 /// Parse any type except the function type. 1165 /// 1166 /// non-function-type ::= integer-type 1167 /// | index-type 1168 /// | float-type 1169 /// | extended-type 1170 /// | vector-type 1171 /// | tensor-type 1172 /// | memref-type 1173 /// | complex-type 1174 /// | tuple-type 1175 /// | none-type 1176 /// 1177 /// index-type ::= `index` 1178 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 1179 /// none-type ::= `none` 1180 /// 1181 Type Parser::parseNonFunctionType() { 1182 switch (getToken().getKind()) { 1183 default: 1184 return (emitError("expected non-function type"), nullptr); 1185 case Token::kw_memref: 1186 return parseMemRefType(); 1187 case Token::kw_tensor: 1188 return parseTensorType(); 1189 case Token::kw_complex: 1190 return parseComplexType(); 1191 case Token::kw_tuple: 1192 return parseTupleType(); 1193 case Token::kw_vector: 1194 return parseVectorType(); 1195 // integer-type 1196 case Token::inttype: { 1197 auto width = getToken().getIntTypeBitwidth(); 1198 if (!width.hasValue()) 1199 return (emitError("invalid integer width"), nullptr); 1200 auto loc = getEncodedSourceLocation(getToken().getLoc()); 1201 consumeToken(Token::inttype); 1202 return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); 1203 } 1204 1205 // float-type 1206 case Token::kw_bf16: 1207 consumeToken(Token::kw_bf16); 1208 return builder.getBF16Type(); 1209 case Token::kw_f16: 1210 consumeToken(Token::kw_f16); 1211 return builder.getF16Type(); 1212 case Token::kw_f32: 1213 consumeToken(Token::kw_f32); 1214 return builder.getF32Type(); 1215 case Token::kw_f64: 1216 consumeToken(Token::kw_f64); 1217 return builder.getF64Type(); 1218 1219 // index-type 1220 case Token::kw_index: 1221 consumeToken(Token::kw_index); 1222 return builder.getIndexType(); 1223 1224 // none-type 1225 case Token::kw_none: 1226 consumeToken(Token::kw_none); 1227 return builder.getNoneType(); 1228 1229 // extended type 1230 case Token::exclamation_identifier: 1231 return parseExtendedType(); 1232 } 1233 } 1234 1235 /// Parse a tensor type. 1236 /// 1237 /// tensor-type ::= `tensor` `<` dimension-list type `>` 1238 /// dimension-list ::= dimension-list-ranked | `*x` 1239 /// 1240 Type Parser::parseTensorType() { 1241 consumeToken(Token::kw_tensor); 1242 1243 if (parseToken(Token::less, "expected '<' in tensor type")) 1244 return nullptr; 1245 1246 bool isUnranked; 1247 SmallVector<int64_t, 4> dimensions; 1248 1249 if (consumeIf(Token::star)) { 1250 // This is an unranked tensor type. 1251 isUnranked = true; 1252 1253 if (parseXInDimensionList()) 1254 return nullptr; 1255 1256 } else { 1257 isUnranked = false; 1258 if (parseDimensionListRanked(dimensions)) 1259 return nullptr; 1260 } 1261 1262 // Parse the element type. 1263 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 1264 auto elementType = parseType(); 1265 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 1266 return nullptr; 1267 1268 if (isUnranked) 1269 return UnrankedTensorType::getChecked(elementType, typeLocation); 1270 return RankedTensorType::getChecked(dimensions, elementType, typeLocation); 1271 } 1272 1273 /// Parse a tuple type. 1274 /// 1275 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 1276 /// 1277 Type Parser::parseTupleType() { 1278 consumeToken(Token::kw_tuple); 1279 1280 // Parse the '<'. 1281 if (parseToken(Token::less, "expected '<' in tuple type")) 1282 return nullptr; 1283 1284 // Check for an empty tuple by directly parsing '>'. 1285 if (consumeIf(Token::greater)) 1286 return TupleType::get(getContext()); 1287 1288 // Parse the element types and the '>'. 1289 SmallVector<Type, 4> types; 1290 if (parseTypeListNoParens(types) || 1291 parseToken(Token::greater, "expected '>' in tuple type")) 1292 return nullptr; 1293 1294 return TupleType::get(types, getContext()); 1295 } 1296 1297 /// Parse a vector type. 1298 /// 1299 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` 1300 /// non-empty-static-dimension-list ::= decimal-literal `x` 1301 /// static-dimension-list 1302 /// static-dimension-list ::= (decimal-literal `x`)* 1303 /// 1304 VectorType Parser::parseVectorType() { 1305 consumeToken(Token::kw_vector); 1306 1307 if (parseToken(Token::less, "expected '<' in vector type")) 1308 return nullptr; 1309 1310 SmallVector<int64_t, 4> dimensions; 1311 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 1312 return nullptr; 1313 if (dimensions.empty()) 1314 return (emitError("expected dimension size in vector type"), nullptr); 1315 1316 // Parse the element type. 1317 auto typeLoc = getToken().getLoc(); 1318 auto elementType = parseType(); 1319 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 1320 return nullptr; 1321 1322 return VectorType::getChecked(dimensions, elementType, 1323 getEncodedSourceLocation(typeLoc)); 1324 } 1325 1326 /// Parse a dimension list of a tensor or memref type. This populates the 1327 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 1328 /// errors out on `?` otherwise. 1329 /// 1330 /// dimension-list-ranked ::= (dimension `x`)* 1331 /// dimension ::= `?` | decimal-literal 1332 /// 1333 /// When `allowDynamic` is not set, this is used to parse: 1334 /// 1335 /// static-dimension-list ::= (decimal-literal `x`)* 1336 ParseResult 1337 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 1338 bool allowDynamic) { 1339 while (getToken().isAny(Token::integer, Token::question)) { 1340 if (consumeIf(Token::question)) { 1341 if (!allowDynamic) 1342 return emitError("expected static shape"); 1343 dimensions.push_back(-1); 1344 } else { 1345 // Hexadecimal integer literals (starting with `0x`) are not allowed in 1346 // aggregate type declarations. Therefore, `0xf32` should be processed as 1347 // a sequence of separate elements `0`, `x`, `f32`. 1348 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 1349 // We can get here only if the token is an integer literal. Hexadecimal 1350 // integer literals can only start with `0x` (`1x` wouldn't lex as a 1351 // literal, just `1` would, at which point we don't get into this 1352 // branch). 1353 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 1354 dimensions.push_back(0); 1355 state.lex.resetPointer(getTokenSpelling().data() + 1); 1356 consumeToken(); 1357 } else { 1358 // Make sure this integer value is in bound and valid. 1359 auto dimension = getToken().getUnsignedIntegerValue(); 1360 if (!dimension.hasValue()) 1361 return emitError("invalid dimension"); 1362 dimensions.push_back((int64_t)dimension.getValue()); 1363 consumeToken(Token::integer); 1364 } 1365 } 1366 1367 // Make sure we have an 'x' or something like 'xbf32'. 1368 if (parseXInDimensionList()) 1369 return failure(); 1370 } 1371 1372 return success(); 1373 } 1374 1375 /// Parse an 'x' token in a dimension list, handling the case where the x is 1376 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 1377 /// token. 1378 ParseResult Parser::parseXInDimensionList() { 1379 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 1380 return emitError("expected 'x' in dimension list"); 1381 1382 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 1383 if (getTokenSpelling().size() != 1) 1384 state.lex.resetPointer(getTokenSpelling().data() + 1); 1385 1386 // Consume the 'x'. 1387 consumeToken(Token::bare_identifier); 1388 1389 return success(); 1390 } 1391 1392 // Parse a comma-separated list of dimensions, possibly empty: 1393 // stride-list ::= `[` (dimension (`,` dimension)*)? `]` 1394 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { 1395 if (!consumeIf(Token::l_square)) 1396 return failure(); 1397 // Empty list early exit. 1398 if (consumeIf(Token::r_square)) 1399 return success(); 1400 while (true) { 1401 if (consumeIf(Token::question)) { 1402 dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); 1403 } else { 1404 // This must be an integer value. 1405 int64_t val; 1406 if (getToken().getSpelling().getAsInteger(10, val)) 1407 return emitError("invalid integer value: ") << getToken().getSpelling(); 1408 // Make sure it is not the one value for `?`. 1409 if (ShapedType::isDynamic(val)) 1410 return emitError("invalid integer value: ") 1411 << getToken().getSpelling() 1412 << ", use `?` to specify a dynamic dimension"; 1413 dimensions.push_back(val); 1414 consumeToken(Token::integer); 1415 } 1416 if (!consumeIf(Token::comma)) 1417 break; 1418 } 1419 if (!consumeIf(Token::r_square)) 1420 return failure(); 1421 return success(); 1422 } 1423 1424 //===----------------------------------------------------------------------===// 1425 // Attribute parsing. 1426 //===----------------------------------------------------------------------===// 1427 1428 /// Return the symbol reference referred to by the given token, that is known to 1429 /// be an @-identifier. 1430 static std::string extractSymbolReference(Token tok) { 1431 assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); 1432 StringRef nameStr = tok.getSpelling().drop_front(); 1433 1434 // Check to see if the reference is a string literal, or a bare identifier. 1435 if (nameStr.front() == '"') 1436 return tok.getStringValue(); 1437 return std::string(nameStr); 1438 } 1439 1440 /// Parse an arbitrary attribute. 1441 /// 1442 /// attribute-value ::= `unit` 1443 /// | bool-literal 1444 /// | integer-literal (`:` (index-type | integer-type))? 1445 /// | float-literal (`:` float-type)? 1446 /// | string-literal (`:` type)? 1447 /// | type 1448 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 1449 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 1450 /// | symbol-ref-id (`::` symbol-ref-id)* 1451 /// | `dense` `<` attribute-value `>` `:` 1452 /// (tensor-type | vector-type) 1453 /// | `sparse` `<` attribute-value `,` attribute-value `>` 1454 /// `:` (tensor-type | vector-type) 1455 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 1456 /// `>` `:` (tensor-type | vector-type) 1457 /// | extended-attribute 1458 /// 1459 Attribute Parser::parseAttribute(Type type) { 1460 switch (getToken().getKind()) { 1461 // Parse an AffineMap or IntegerSet attribute. 1462 case Token::kw_affine_map: { 1463 consumeToken(Token::kw_affine_map); 1464 1465 AffineMap map; 1466 if (parseToken(Token::less, "expected '<' in affine map") || 1467 parseAffineMapReference(map) || 1468 parseToken(Token::greater, "expected '>' in affine map")) 1469 return Attribute(); 1470 return AffineMapAttr::get(map); 1471 } 1472 case Token::kw_affine_set: { 1473 consumeToken(Token::kw_affine_set); 1474 1475 IntegerSet set; 1476 if (parseToken(Token::less, "expected '<' in integer set") || 1477 parseIntegerSetReference(set) || 1478 parseToken(Token::greater, "expected '>' in integer set")) 1479 return Attribute(); 1480 return IntegerSetAttr::get(set); 1481 } 1482 1483 // Parse an array attribute. 1484 case Token::l_square: { 1485 consumeToken(Token::l_square); 1486 1487 SmallVector<Attribute, 4> elements; 1488 auto parseElt = [&]() -> ParseResult { 1489 elements.push_back(parseAttribute()); 1490 return elements.back() ? success() : failure(); 1491 }; 1492 1493 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 1494 return nullptr; 1495 return builder.getArrayAttr(elements); 1496 } 1497 1498 // Parse a boolean attribute. 1499 case Token::kw_false: 1500 consumeToken(Token::kw_false); 1501 return builder.getBoolAttr(false); 1502 case Token::kw_true: 1503 consumeToken(Token::kw_true); 1504 return builder.getBoolAttr(true); 1505 1506 // Parse a dense elements attribute. 1507 case Token::kw_dense: 1508 return parseDenseElementsAttr(type); 1509 1510 // Parse a dictionary attribute. 1511 case Token::l_brace: { 1512 SmallVector<NamedAttribute, 4> elements; 1513 if (parseAttributeDict(elements)) 1514 return nullptr; 1515 return builder.getDictionaryAttr(elements); 1516 } 1517 1518 // Parse an extended attribute, i.e. alias or dialect attribute. 1519 case Token::hash_identifier: 1520 return parseExtendedAttr(type); 1521 1522 // Parse floating point and integer attributes. 1523 case Token::floatliteral: 1524 return parseFloatAttr(type, /*isNegative=*/false); 1525 case Token::integer: 1526 return parseDecOrHexAttr(type, /*isNegative=*/false); 1527 case Token::minus: { 1528 consumeToken(Token::minus); 1529 if (getToken().is(Token::integer)) 1530 return parseDecOrHexAttr(type, /*isNegative=*/true); 1531 if (getToken().is(Token::floatliteral)) 1532 return parseFloatAttr(type, /*isNegative=*/true); 1533 1534 return (emitError("expected constant integer or floating point value"), 1535 nullptr); 1536 } 1537 1538 // Parse a location attribute. 1539 case Token::kw_loc: { 1540 LocationAttr attr; 1541 return failed(parseLocation(attr)) ? Attribute() : attr; 1542 } 1543 1544 // Parse an opaque elements attribute. 1545 case Token::kw_opaque: 1546 return parseOpaqueElementsAttr(type); 1547 1548 // Parse a sparse elements attribute. 1549 case Token::kw_sparse: 1550 return parseSparseElementsAttr(type); 1551 1552 // Parse a string attribute. 1553 case Token::string: { 1554 auto val = getToken().getStringValue(); 1555 consumeToken(Token::string); 1556 // Parse the optional trailing colon type if one wasn't explicitly provided. 1557 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1558 return Attribute(); 1559 1560 return type ? StringAttr::get(val, type) 1561 : StringAttr::get(val, getContext()); 1562 } 1563 1564 // Parse a symbol reference attribute. 1565 case Token::at_identifier: { 1566 std::string nameStr = extractSymbolReference(getToken()); 1567 consumeToken(Token::at_identifier); 1568 1569 // Parse any nested references. 1570 std::vector<FlatSymbolRefAttr> nestedRefs; 1571 while (getToken().is(Token::colon)) { 1572 // Check for the '::' prefix. 1573 const char *curPointer = getToken().getLoc().getPointer(); 1574 consumeToken(Token::colon); 1575 if (!consumeIf(Token::colon)) { 1576 state.lex.resetPointer(curPointer); 1577 consumeToken(); 1578 break; 1579 } 1580 // Parse the reference itself. 1581 auto curLoc = getToken().getLoc(); 1582 if (getToken().isNot(Token::at_identifier)) { 1583 emitError(curLoc, "expected nested symbol reference identifier"); 1584 return Attribute(); 1585 } 1586 1587 std::string nameStr = extractSymbolReference(getToken()); 1588 consumeToken(Token::at_identifier); 1589 nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); 1590 } 1591 1592 return builder.getSymbolRefAttr(nameStr, nestedRefs); 1593 } 1594 1595 // Parse a 'unit' attribute. 1596 case Token::kw_unit: 1597 consumeToken(Token::kw_unit); 1598 return builder.getUnitAttr(); 1599 1600 default: 1601 // Parse a type attribute. 1602 if (Type type = parseType()) 1603 return TypeAttr::get(type); 1604 return nullptr; 1605 } 1606 } 1607 1608 /// Attribute dictionary. 1609 /// 1610 /// attribute-dict ::= `{` `}` 1611 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1612 /// attribute-entry ::= bare-id `=` attribute-value 1613 /// 1614 ParseResult 1615 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1616 if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) 1617 return failure(); 1618 1619 auto parseElt = [&]() -> ParseResult { 1620 // We allow keywords as attribute names. 1621 if (getToken().isNot(Token::bare_identifier, Token::inttype) && 1622 !getToken().isKeyword()) 1623 return emitError("expected attribute name"); 1624 Identifier nameId = builder.getIdentifier(getTokenSpelling()); 1625 consumeToken(); 1626 1627 // Try to parse the '=' for the attribute value. 1628 if (!consumeIf(Token::equal)) { 1629 // If there is no '=', we treat this as a unit attribute. 1630 attributes.push_back({nameId, builder.getUnitAttr()}); 1631 return success(); 1632 } 1633 1634 auto attr = parseAttribute(); 1635 if (!attr) 1636 return failure(); 1637 1638 attributes.push_back({nameId, attr}); 1639 return success(); 1640 }; 1641 1642 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1643 return failure(); 1644 1645 return success(); 1646 } 1647 1648 /// Parse an extended attribute. 1649 /// 1650 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1651 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1652 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1653 /// attribute-alias ::= `#` alias-name 1654 /// 1655 Attribute Parser::parseExtendedAttr(Type type) { 1656 Attribute attr = parseExtendedSymbol<Attribute>( 1657 *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, 1658 [&](StringRef dialectName, StringRef symbolData, 1659 llvm::SMLoc loc) -> Attribute { 1660 // Parse an optional trailing colon type. 1661 Type attrType = type; 1662 if (consumeIf(Token::colon) && !(attrType = parseType())) 1663 return Attribute(); 1664 1665 // If we found a registered dialect, then ask it to parse the attribute. 1666 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1667 return parseSymbol<Attribute>( 1668 symbolData, state.context, state.symbols, [&](Parser &parser) { 1669 CustomDialectAsmParser customParser(symbolData, parser); 1670 return dialect->parseAttribute(customParser, attrType); 1671 }); 1672 } 1673 1674 // Otherwise, form a new opaque attribute. 1675 return OpaqueAttr::getChecked( 1676 Identifier::get(dialectName, state.context), symbolData, 1677 attrType ? attrType : NoneType::get(state.context), 1678 getEncodedSourceLocation(loc)); 1679 }); 1680 1681 // Ensure that the attribute has the same type as requested. 1682 if (attr && type && attr.getType() != type) { 1683 emitError("attribute type different than expected: expected ") 1684 << type << ", but got " << attr.getType(); 1685 return nullptr; 1686 } 1687 return attr; 1688 } 1689 1690 /// Parse a float attribute. 1691 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1692 auto val = getToken().getFloatingPointValue(); 1693 if (!val.hasValue()) 1694 return (emitError("floating point value too large for attribute"), nullptr); 1695 consumeToken(Token::floatliteral); 1696 if (!type) { 1697 // Default to F64 when no type is specified. 1698 if (!consumeIf(Token::colon)) 1699 type = builder.getF64Type(); 1700 else if (!(type = parseType())) 1701 return nullptr; 1702 } 1703 if (!type.isa<FloatType>()) 1704 return (emitError("floating point value not valid for specified type"), 1705 nullptr); 1706 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1707 } 1708 1709 /// Construct a float attribute bitwise equivalent to the integer literal. 1710 static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1711 uint64_t value) { 1712 // FIXME: bfloat is currently stored as a double internally because it doesn't 1713 // have valid APFloat semantics. 1714 if (type.isF64() || type.isBF16()) { 1715 APFloat apFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); 1716 return p->builder.getFloatAttr(type, apFloat); 1717 } 1718 1719 APInt apInt(type.getWidth(), value); 1720 if (apInt != value) { 1721 p->emitError("hexadecimal float constant out of range for type"); 1722 return nullptr; 1723 } 1724 APFloat apFloat(type.getFloatSemantics(), apInt); 1725 return p->builder.getFloatAttr(type, apFloat); 1726 } 1727 1728 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1729 /// or a float attribute. 1730 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1731 auto val = getToken().getUInt64IntegerValue(); 1732 if (!val.hasValue()) 1733 return (emitError("integer constant out of range for attribute"), nullptr); 1734 1735 // Remember if the literal is hexadecimal. 1736 StringRef spelling = getToken().getSpelling(); 1737 auto loc = state.curToken.getLoc(); 1738 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1739 1740 consumeToken(Token::integer); 1741 if (!type) { 1742 // Default to i64 if not type is specified. 1743 if (!consumeIf(Token::colon)) 1744 type = builder.getIntegerType(64); 1745 else if (!(type = parseType())) 1746 return nullptr; 1747 } 1748 1749 if (auto floatType = type.dyn_cast<FloatType>()) { 1750 if (isNegative) 1751 return emitError( 1752 loc, 1753 "hexadecimal float literal should not have a leading minus"), 1754 nullptr; 1755 if (!isHex) { 1756 emitError(loc, "unexpected decimal integer literal for a float attribute") 1757 .attachNote() 1758 << "add a trailing dot to make the literal a float"; 1759 return nullptr; 1760 } 1761 1762 // Construct a float attribute bitwise equivalent to the integer literal. 1763 return buildHexadecimalFloatLiteral(this, floatType, *val); 1764 } 1765 1766 if (!type.isIntOrIndex()) 1767 return emitError(loc, "integer literal not valid for specified type"), 1768 nullptr; 1769 1770 // Parse the integer literal. 1771 int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); 1772 APInt apInt(width, *val, isNegative); 1773 if (apInt != *val) 1774 return emitError(loc, "integer constant out of range for attribute"), 1775 nullptr; 1776 1777 // Otherwise construct an integer attribute. 1778 if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0) 1779 return emitError(loc, "integer constant out of range for attribute"), 1780 nullptr; 1781 1782 return builder.getIntegerAttr(type, isNegative ? -apInt : apInt); 1783 } 1784 1785 /// Parse an opaque elements attribute. 1786 Attribute Parser::parseOpaqueElementsAttr(Type attrType) { 1787 consumeToken(Token::kw_opaque); 1788 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1789 return nullptr; 1790 1791 if (getToken().isNot(Token::string)) 1792 return (emitError("expected dialect namespace"), nullptr); 1793 1794 auto name = getToken().getStringValue(); 1795 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1796 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1797 // attribute. Otherwise, it can't be roundtripped without having the dialect 1798 // registered. 1799 if (!dialect) 1800 return (emitError("no registered dialect with namespace '" + name + "'"), 1801 nullptr); 1802 1803 consumeToken(Token::string); 1804 if (parseToken(Token::comma, "expected ','")) 1805 return nullptr; 1806 1807 if (getToken().getKind() != Token::string) 1808 return (emitError("opaque string should start with '0x'"), nullptr); 1809 1810 auto val = getToken().getStringValue(); 1811 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1812 return (emitError("opaque string should start with '0x'"), nullptr); 1813 1814 val = val.substr(2); 1815 if (!llvm::all_of(val, llvm::isHexDigit)) 1816 return (emitError("opaque string only contains hex digits"), nullptr); 1817 1818 consumeToken(Token::string); 1819 if (parseToken(Token::greater, "expected '>'")) 1820 return nullptr; 1821 1822 auto type = parseElementsLiteralType(attrType); 1823 if (!type) 1824 return nullptr; 1825 1826 return OpaqueElementsAttr::get(dialect, type, llvm::fromHex(val)); 1827 } 1828 1829 namespace { 1830 class TensorLiteralParser { 1831 public: 1832 TensorLiteralParser(Parser &p) : p(p) {} 1833 1834 ParseResult parse() { 1835 if (p.getToken().is(Token::l_square)) 1836 return parseList(shape); 1837 return parseElement(); 1838 } 1839 1840 /// Build a dense attribute instance with the parsed elements and the given 1841 /// shaped type. 1842 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1843 1844 ArrayRef<int64_t> getShape() const { return shape; } 1845 1846 private: 1847 enum class ElementKind { Boolean, Integer, Float }; 1848 1849 /// Return a string to represent the given element kind. 1850 const char *getElementKindStr(ElementKind kind) { 1851 switch (kind) { 1852 case ElementKind::Boolean: 1853 return "'boolean'"; 1854 case ElementKind::Integer: 1855 return "'integer'"; 1856 case ElementKind::Float: 1857 return "'float'"; 1858 } 1859 llvm_unreachable("unknown element kind"); 1860 } 1861 1862 /// Build a Dense Integer attribute for the given type. 1863 DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, 1864 IntegerType eltTy); 1865 1866 /// Build a Dense Float attribute for the given type. 1867 DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, 1868 FloatType eltTy); 1869 1870 /// Parse a single element, returning failure if it isn't a valid element 1871 /// literal. For example: 1872 /// parseElement(1) -> Success, 1 1873 /// parseElement([1]) -> Failure 1874 ParseResult parseElement(); 1875 1876 /// Parse a list of either lists or elements, returning the dimensions of the 1877 /// parsed sub-tensors in dims. For example: 1878 /// parseList([1, 2, 3]) -> Success, [3] 1879 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1880 /// parseList([[1, 2], 3]) -> Failure 1881 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1882 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 1883 1884 Parser &p; 1885 1886 /// The shape inferred from the parsed elements. 1887 SmallVector<int64_t, 4> shape; 1888 1889 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 1890 std::vector<std::pair<bool, Token>> storage; 1891 1892 /// A flag that indicates the type of elements that have been parsed. 1893 Optional<ElementKind> knownEltKind; 1894 }; 1895 } // namespace 1896 1897 /// Build a dense attribute instance with the parsed elements and the given 1898 /// shaped type. 1899 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 1900 ShapedType type) { 1901 // Check that the parsed storage size has the same number of elements to the 1902 // type, or is a known splat. 1903 if (!shape.empty() && getShape() != type.getShape()) { 1904 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 1905 << "]) does not match type ([" << type.getShape() << "])"; 1906 return nullptr; 1907 } 1908 1909 // If the type is an integer, build a set of APInt values from the storage 1910 // with the correct bitwidth. 1911 if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) 1912 return getIntAttr(loc, type, intTy); 1913 1914 // Otherwise, this must be a floating point type. 1915 auto floatTy = type.getElementType().dyn_cast<FloatType>(); 1916 if (!floatTy) { 1917 p.emitError(loc) << "expected floating-point or integer element type, got " 1918 << type.getElementType(); 1919 return nullptr; 1920 } 1921 return getFloatAttr(loc, type, floatTy); 1922 } 1923 1924 /// Build a Dense Integer attribute for the given type. 1925 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, 1926 ShapedType type, 1927 IntegerType eltTy) { 1928 std::vector<APInt> intElements; 1929 intElements.reserve(storage.size()); 1930 for (const auto &signAndToken : storage) { 1931 bool isNegative = signAndToken.first; 1932 const Token &token = signAndToken.second; 1933 1934 // Check to see if floating point values were parsed. 1935 if (token.is(Token::floatliteral)) { 1936 p.emitError() << "expected integer elements, but parsed floating-point"; 1937 return nullptr; 1938 } 1939 1940 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 1941 "unexpected token type"); 1942 if (token.isAny(Token::kw_true, Token::kw_false)) { 1943 if (!eltTy.isInteger(1)) 1944 p.emitError() << "expected i1 type for 'true' or 'false' values"; 1945 APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), 1946 /*isSigned=*/false); 1947 intElements.push_back(apInt); 1948 continue; 1949 } 1950 1951 // Create APInt values for each element with the correct bitwidth. 1952 auto val = token.getUInt64IntegerValue(); 1953 if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 1954 : (int64_t)val.getValue() < 0)) { 1955 p.emitError(token.getLoc(), 1956 "integer constant out of range for attribute"); 1957 return nullptr; 1958 } 1959 APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); 1960 if (apInt != val.getValue()) 1961 return (p.emitError("integer constant out of range for type"), nullptr); 1962 intElements.push_back(isNegative ? -apInt : apInt); 1963 } 1964 1965 return DenseElementsAttr::get(type, intElements); 1966 } 1967 1968 /// Build a Dense Float attribute for the given type. 1969 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, 1970 ShapedType type, 1971 FloatType eltTy) { 1972 std::vector<Attribute> floatValues; 1973 floatValues.reserve(storage.size()); 1974 for (const auto &signAndToken : storage) { 1975 bool isNegative = signAndToken.first; 1976 const Token &token = signAndToken.second; 1977 1978 // Handle hexadecimal float literals. 1979 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 1980 if (isNegative) { 1981 p.emitError(token.getLoc()) 1982 << "hexadecimal float literal should not have a leading minus"; 1983 return nullptr; 1984 } 1985 auto val = token.getUInt64IntegerValue(); 1986 if (!val.hasValue()) { 1987 p.emitError("hexadecimal float constant out of range for attribute"); 1988 return nullptr; 1989 } 1990 FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); 1991 if (!attr) 1992 return nullptr; 1993 floatValues.push_back(attr); 1994 continue; 1995 } 1996 1997 // Check to see if any decimal integers or booleans were parsed. 1998 if (!token.is(Token::floatliteral)) { 1999 p.emitError() << "expected floating-point elements, but parsed integer"; 2000 return nullptr; 2001 } 2002 2003 // Build the float values from tokens. 2004 auto val = token.getFloatingPointValue(); 2005 if (!val.hasValue()) { 2006 p.emitError("floating point value too large for attribute"); 2007 return nullptr; 2008 } 2009 floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); 2010 } 2011 2012 return DenseElementsAttr::get(type, floatValues); 2013 } 2014 2015 ParseResult TensorLiteralParser::parseElement() { 2016 switch (p.getToken().getKind()) { 2017 // Parse a boolean element. 2018 case Token::kw_true: 2019 case Token::kw_false: 2020 case Token::floatliteral: 2021 case Token::integer: 2022 storage.emplace_back(/*isNegative=*/false, p.getToken()); 2023 p.consumeToken(); 2024 break; 2025 2026 // Parse a signed integer or a negative floating-point element. 2027 case Token::minus: 2028 p.consumeToken(Token::minus); 2029 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 2030 return p.emitError("expected integer or floating point literal"); 2031 storage.emplace_back(/*isNegative=*/true, p.getToken()); 2032 p.consumeToken(); 2033 break; 2034 2035 default: 2036 return p.emitError("expected element literal of primitive type"); 2037 } 2038 2039 return success(); 2040 } 2041 2042 /// Parse a list of either lists or elements, returning the dimensions of the 2043 /// parsed sub-tensors in dims. For example: 2044 /// parseList([1, 2, 3]) -> Success, [3] 2045 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 2046 /// parseList([[1, 2], 3]) -> Failure 2047 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 2048 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 2049 p.consumeToken(Token::l_square); 2050 2051 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 2052 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 2053 if (prevDims == newDims) 2054 return success(); 2055 return p.emitError("tensor literal is invalid; ranks are not consistent " 2056 "between elements"); 2057 }; 2058 2059 bool first = true; 2060 SmallVector<int64_t, 4> newDims; 2061 unsigned size = 0; 2062 auto parseCommaSeparatedList = [&]() -> ParseResult { 2063 SmallVector<int64_t, 4> thisDims; 2064 if (p.getToken().getKind() == Token::l_square) { 2065 if (parseList(thisDims)) 2066 return failure(); 2067 } else if (parseElement()) { 2068 return failure(); 2069 } 2070 ++size; 2071 if (!first) 2072 return checkDims(newDims, thisDims); 2073 newDims = thisDims; 2074 first = false; 2075 return success(); 2076 }; 2077 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 2078 return failure(); 2079 2080 // Return the sublists' dimensions with 'size' prepended. 2081 dims.clear(); 2082 dims.push_back(size); 2083 dims.append(newDims.begin(), newDims.end()); 2084 return success(); 2085 } 2086 2087 /// Parse a dense elements attribute. 2088 Attribute Parser::parseDenseElementsAttr(Type attrType) { 2089 consumeToken(Token::kw_dense); 2090 if (parseToken(Token::less, "expected '<' after 'dense'")) 2091 return nullptr; 2092 2093 // Parse the literal data. 2094 TensorLiteralParser literalParser(*this); 2095 if (literalParser.parse()) 2096 return nullptr; 2097 2098 if (parseToken(Token::greater, "expected '>'")) 2099 return nullptr; 2100 2101 auto typeLoc = getToken().getLoc(); 2102 auto type = parseElementsLiteralType(attrType); 2103 if (!type) 2104 return nullptr; 2105 return literalParser.getAttr(typeLoc, type); 2106 } 2107 2108 /// Shaped type for elements attribute. 2109 /// 2110 /// elements-literal-type ::= vector-type | ranked-tensor-type 2111 /// 2112 /// This method also checks the type has static shape. 2113 ShapedType Parser::parseElementsLiteralType(Type type) { 2114 // If the user didn't provide a type, parse the colon type for the literal. 2115 if (!type) { 2116 if (parseToken(Token::colon, "expected ':'")) 2117 return nullptr; 2118 if (!(type = parseType())) 2119 return nullptr; 2120 } 2121 2122 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 2123 emitError("elements literal must be a ranked tensor or vector type"); 2124 return nullptr; 2125 } 2126 2127 auto sType = type.cast<ShapedType>(); 2128 if (!sType.hasStaticShape()) 2129 return (emitError("elements literal type must have static shape"), nullptr); 2130 2131 return sType; 2132 } 2133 2134 /// Parse a sparse elements attribute. 2135 Attribute Parser::parseSparseElementsAttr(Type attrType) { 2136 consumeToken(Token::kw_sparse); 2137 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 2138 return nullptr; 2139 2140 /// Parse indices 2141 auto indicesLoc = getToken().getLoc(); 2142 TensorLiteralParser indiceParser(*this); 2143 if (indiceParser.parse()) 2144 return nullptr; 2145 2146 if (parseToken(Token::comma, "expected ','")) 2147 return nullptr; 2148 2149 /// Parse values. 2150 auto valuesLoc = getToken().getLoc(); 2151 TensorLiteralParser valuesParser(*this); 2152 if (valuesParser.parse()) 2153 return nullptr; 2154 2155 if (parseToken(Token::greater, "expected '>'")) 2156 return nullptr; 2157 2158 auto type = parseElementsLiteralType(attrType); 2159 if (!type) 2160 return nullptr; 2161 2162 // If the indices are a splat, i.e. the literal parser parsed an element and 2163 // not a list, we set the shape explicitly. The indices are represented by a 2164 // 2-dimensional shape where the second dimension is the rank of the type. 2165 // Given that the parsed indices is a splat, we know that we only have one 2166 // indice and thus one for the first dimension. 2167 auto indiceEltType = builder.getIntegerType(64); 2168 ShapedType indicesType; 2169 if (indiceParser.getShape().empty()) { 2170 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 2171 } else { 2172 // Otherwise, set the shape to the one parsed by the literal parser. 2173 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 2174 } 2175 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 2176 2177 // If the values are a splat, set the shape explicitly based on the number of 2178 // indices. The number of indices is encoded in the first dimension of the 2179 // indice shape type. 2180 auto valuesEltType = type.getElementType(); 2181 ShapedType valuesType = 2182 valuesParser.getShape().empty() 2183 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 2184 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 2185 auto values = valuesParser.getAttr(valuesLoc, valuesType); 2186 2187 /// Sanity check. 2188 if (valuesType.getRank() != 1) 2189 return (emitError("expected 1-d tensor for values"), nullptr); 2190 2191 auto sameShape = (indicesType.getRank() == 1) || 2192 (type.getRank() == indicesType.getDimSize(1)); 2193 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 2194 if (!sameShape || !sameElementNum) { 2195 emitError() << "expected shape ([" << type.getShape() 2196 << "]); inferred shape of indices literal ([" 2197 << indicesType.getShape() 2198 << "]); inferred shape of values literal ([" 2199 << valuesType.getShape() << "])"; 2200 return nullptr; 2201 } 2202 2203 // Build the sparse elements attribute by the indices and values. 2204 return SparseElementsAttr::get(type, indices, values); 2205 } 2206 2207 //===----------------------------------------------------------------------===// 2208 // Location parsing. 2209 //===----------------------------------------------------------------------===// 2210 2211 /// Parse a location. 2212 /// 2213 /// location ::= `loc` inline-location 2214 /// inline-location ::= '(' location-inst ')' 2215 /// 2216 ParseResult Parser::parseLocation(LocationAttr &loc) { 2217 // Check for 'loc' identifier. 2218 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 2219 return emitError(); 2220 2221 // Parse the inline-location. 2222 if (parseToken(Token::l_paren, "expected '(' in inline location") || 2223 parseLocationInstance(loc) || 2224 parseToken(Token::r_paren, "expected ')' in inline location")) 2225 return failure(); 2226 return success(); 2227 } 2228 2229 /// Specific location instances. 2230 /// 2231 /// location-inst ::= filelinecol-location | 2232 /// name-location | 2233 /// callsite-location | 2234 /// fused-location | 2235 /// unknown-location 2236 /// filelinecol-location ::= string-literal ':' integer-literal 2237 /// ':' integer-literal 2238 /// name-location ::= string-literal 2239 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 2240 /// fused-location ::= fused ('<' attribute-value '>')? 2241 /// '[' location-inst (location-inst ',')* ']' 2242 /// unknown-location ::= 'unknown' 2243 /// 2244 ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { 2245 consumeToken(Token::bare_identifier); 2246 2247 // Parse the '('. 2248 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 2249 return failure(); 2250 2251 // Parse the callee location. 2252 LocationAttr calleeLoc; 2253 if (parseLocationInstance(calleeLoc)) 2254 return failure(); 2255 2256 // Parse the 'at'. 2257 if (getToken().isNot(Token::bare_identifier) || 2258 getToken().getSpelling() != "at") 2259 return emitError("expected 'at' in callsite location"); 2260 consumeToken(Token::bare_identifier); 2261 2262 // Parse the caller location. 2263 LocationAttr callerLoc; 2264 if (parseLocationInstance(callerLoc)) 2265 return failure(); 2266 2267 // Parse the ')'. 2268 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 2269 return failure(); 2270 2271 // Return the callsite location. 2272 loc = CallSiteLoc::get(calleeLoc, callerLoc); 2273 return success(); 2274 } 2275 2276 ParseResult Parser::parseFusedLocation(LocationAttr &loc) { 2277 consumeToken(Token::bare_identifier); 2278 2279 // Try to parse the optional metadata. 2280 Attribute metadata; 2281 if (consumeIf(Token::less)) { 2282 metadata = parseAttribute(); 2283 if (!metadata) 2284 return emitError("expected valid attribute metadata"); 2285 // Parse the '>' token. 2286 if (parseToken(Token::greater, 2287 "expected '>' after fused location metadata")) 2288 return failure(); 2289 } 2290 2291 SmallVector<Location, 4> locations; 2292 auto parseElt = [&] { 2293 LocationAttr newLoc; 2294 if (parseLocationInstance(newLoc)) 2295 return failure(); 2296 locations.push_back(newLoc); 2297 return success(); 2298 }; 2299 2300 if (parseToken(Token::l_square, "expected '[' in fused location") || 2301 parseCommaSeparatedList(parseElt) || 2302 parseToken(Token::r_square, "expected ']' in fused location")) 2303 return failure(); 2304 2305 // Return the fused location. 2306 loc = FusedLoc::get(locations, metadata, getContext()); 2307 return success(); 2308 } 2309 2310 ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { 2311 auto *ctx = getContext(); 2312 auto str = getToken().getStringValue(); 2313 consumeToken(Token::string); 2314 2315 // If the next token is ':' this is a filelinecol location. 2316 if (consumeIf(Token::colon)) { 2317 // Parse the line number. 2318 if (getToken().isNot(Token::integer)) 2319 return emitError("expected integer line number in FileLineColLoc"); 2320 auto line = getToken().getUnsignedIntegerValue(); 2321 if (!line.hasValue()) 2322 return emitError("expected integer line number in FileLineColLoc"); 2323 consumeToken(Token::integer); 2324 2325 // Parse the ':'. 2326 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 2327 return failure(); 2328 2329 // Parse the column number. 2330 if (getToken().isNot(Token::integer)) 2331 return emitError("expected integer column number in FileLineColLoc"); 2332 auto column = getToken().getUnsignedIntegerValue(); 2333 if (!column.hasValue()) 2334 return emitError("expected integer column number in FileLineColLoc"); 2335 consumeToken(Token::integer); 2336 2337 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 2338 return success(); 2339 } 2340 2341 // Otherwise, this is a NameLoc. 2342 2343 // Check for a child location. 2344 if (consumeIf(Token::l_paren)) { 2345 auto childSourceLoc = getToken().getLoc(); 2346 2347 // Parse the child location. 2348 LocationAttr childLoc; 2349 if (parseLocationInstance(childLoc)) 2350 return failure(); 2351 2352 // The child must not be another NameLoc. 2353 if (childLoc.isa<NameLoc>()) 2354 return emitError(childSourceLoc, 2355 "child of NameLoc cannot be another NameLoc"); 2356 loc = NameLoc::get(Identifier::get(str, ctx), childLoc); 2357 2358 // Parse the closing ')'. 2359 if (parseToken(Token::r_paren, 2360 "expected ')' after child location of NameLoc")) 2361 return failure(); 2362 } else { 2363 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 2364 } 2365 2366 return success(); 2367 } 2368 2369 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 2370 // Handle either name or filelinecol locations. 2371 if (getToken().is(Token::string)) 2372 return parseNameOrFileLineColLocation(loc); 2373 2374 // Bare tokens required for other cases. 2375 if (!getToken().is(Token::bare_identifier)) 2376 return emitError("expected location instance"); 2377 2378 // Check for the 'callsite' signifying a callsite location. 2379 if (getToken().getSpelling() == "callsite") 2380 return parseCallSiteLocation(loc); 2381 2382 // If the token is 'fused', then this is a fused location. 2383 if (getToken().getSpelling() == "fused") 2384 return parseFusedLocation(loc); 2385 2386 // Check for a 'unknown' for an unknown location. 2387 if (getToken().getSpelling() == "unknown") { 2388 consumeToken(Token::bare_identifier); 2389 loc = UnknownLoc::get(getContext()); 2390 return success(); 2391 } 2392 2393 return emitError("expected location instance"); 2394 } 2395 2396 //===----------------------------------------------------------------------===// 2397 // Affine parsing. 2398 //===----------------------------------------------------------------------===// 2399 2400 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 2401 /// the boolean sense. 2402 enum AffineLowPrecOp { 2403 /// Null value. 2404 LNoOp, 2405 Add, 2406 Sub 2407 }; 2408 2409 /// Higher precedence ops - all at the same precedence level. HNoOp is false 2410 /// in the boolean sense. 2411 enum AffineHighPrecOp { 2412 /// Null value. 2413 HNoOp, 2414 Mul, 2415 FloorDiv, 2416 CeilDiv, 2417 Mod 2418 }; 2419 2420 namespace { 2421 /// This is a specialized parser for affine structures (affine maps, affine 2422 /// expressions, and integer sets), maintaining the state transient to their 2423 /// bodies. 2424 class AffineParser : public Parser { 2425 public: 2426 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 2427 function_ref<ParseResult(bool)> parseElement = nullptr) 2428 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 2429 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 2430 2431 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 2432 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 2433 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 2434 ParseResult parseAffineMapOfSSAIds(AffineMap &map); 2435 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 2436 unsigned &numDims); 2437 2438 private: 2439 // Binary affine op parsing. 2440 AffineLowPrecOp consumeIfLowPrecOp(); 2441 AffineHighPrecOp consumeIfHighPrecOp(); 2442 2443 // Identifier lists for polyhedral structures. 2444 ParseResult parseDimIdList(unsigned &numDims); 2445 ParseResult parseSymbolIdList(unsigned &numSymbols); 2446 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 2447 unsigned &numSymbols); 2448 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 2449 2450 AffineExpr parseAffineExpr(); 2451 AffineExpr parseParentheticalExpr(); 2452 AffineExpr parseNegateExpression(AffineExpr lhs); 2453 AffineExpr parseIntegerExpr(); 2454 AffineExpr parseBareIdExpr(); 2455 AffineExpr parseSSAIdExpr(bool isSymbol); 2456 AffineExpr parseSymbolSSAIdExpr(); 2457 2458 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 2459 AffineExpr rhs, SMLoc opLoc); 2460 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 2461 AffineExpr rhs); 2462 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 2463 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 2464 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 2465 SMLoc llhsOpLoc); 2466 AffineExpr parseAffineConstraint(bool *isEq); 2467 2468 private: 2469 bool allowParsingSSAIds; 2470 function_ref<ParseResult(bool)> parseElement; 2471 unsigned numDimOperands; 2472 unsigned numSymbolOperands; 2473 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 2474 }; 2475 } // end anonymous namespace 2476 2477 /// Create an affine binary high precedence op expression (mul's, div's, mod). 2478 /// opLoc is the location of the op token to be used to report errors 2479 /// for non-conforming expressions. 2480 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 2481 AffineExpr lhs, AffineExpr rhs, 2482 SMLoc opLoc) { 2483 // TODO: make the error location info accurate. 2484 switch (op) { 2485 case Mul: 2486 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 2487 emitError(opLoc, "non-affine expression: at least one of the multiply " 2488 "operands has to be either a constant or symbolic"); 2489 return nullptr; 2490 } 2491 return lhs * rhs; 2492 case FloorDiv: 2493 if (!rhs.isSymbolicOrConstant()) { 2494 emitError(opLoc, "non-affine expression: right operand of floordiv " 2495 "has to be either a constant or symbolic"); 2496 return nullptr; 2497 } 2498 return lhs.floorDiv(rhs); 2499 case CeilDiv: 2500 if (!rhs.isSymbolicOrConstant()) { 2501 emitError(opLoc, "non-affine expression: right operand of ceildiv " 2502 "has to be either a constant or symbolic"); 2503 return nullptr; 2504 } 2505 return lhs.ceilDiv(rhs); 2506 case Mod: 2507 if (!rhs.isSymbolicOrConstant()) { 2508 emitError(opLoc, "non-affine expression: right operand of mod " 2509 "has to be either a constant or symbolic"); 2510 return nullptr; 2511 } 2512 return lhs % rhs; 2513 case HNoOp: 2514 llvm_unreachable("can't create affine expression for null high prec op"); 2515 return nullptr; 2516 } 2517 llvm_unreachable("Unknown AffineHighPrecOp"); 2518 } 2519 2520 /// Create an affine binary low precedence op expression (add, sub). 2521 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 2522 AffineExpr lhs, AffineExpr rhs) { 2523 switch (op) { 2524 case AffineLowPrecOp::Add: 2525 return lhs + rhs; 2526 case AffineLowPrecOp::Sub: 2527 return lhs - rhs; 2528 case AffineLowPrecOp::LNoOp: 2529 llvm_unreachable("can't create affine expression for null low prec op"); 2530 return nullptr; 2531 } 2532 llvm_unreachable("Unknown AffineLowPrecOp"); 2533 } 2534 2535 /// Consume this token if it is a lower precedence affine op (there are only 2536 /// two precedence levels). 2537 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 2538 switch (getToken().getKind()) { 2539 case Token::plus: 2540 consumeToken(Token::plus); 2541 return AffineLowPrecOp::Add; 2542 case Token::minus: 2543 consumeToken(Token::minus); 2544 return AffineLowPrecOp::Sub; 2545 default: 2546 return AffineLowPrecOp::LNoOp; 2547 } 2548 } 2549 2550 /// Consume this token if it is a higher precedence affine op (there are only 2551 /// two precedence levels) 2552 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 2553 switch (getToken().getKind()) { 2554 case Token::star: 2555 consumeToken(Token::star); 2556 return Mul; 2557 case Token::kw_floordiv: 2558 consumeToken(Token::kw_floordiv); 2559 return FloorDiv; 2560 case Token::kw_ceildiv: 2561 consumeToken(Token::kw_ceildiv); 2562 return CeilDiv; 2563 case Token::kw_mod: 2564 consumeToken(Token::kw_mod); 2565 return Mod; 2566 default: 2567 return HNoOp; 2568 } 2569 } 2570 2571 /// Parse a high precedence op expression list: mul, div, and mod are high 2572 /// precedence binary ops, i.e., parse a 2573 /// expr_1 op_1 expr_2 op_2 ... expr_n 2574 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2575 /// All affine binary ops are left associative. 2576 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2577 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2578 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2579 /// report an error for non-conforming expressions. 2580 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2581 AffineHighPrecOp llhsOp, 2582 SMLoc llhsOpLoc) { 2583 AffineExpr lhs = parseAffineOperandExpr(llhs); 2584 if (!lhs) 2585 return nullptr; 2586 2587 // Found an LHS. Parse the remaining expression. 2588 auto opLoc = getToken().getLoc(); 2589 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2590 if (llhs) { 2591 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2592 if (!expr) 2593 return nullptr; 2594 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2595 } 2596 // No LLHS, get RHS 2597 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2598 } 2599 2600 // This is the last operand in this expression. 2601 if (llhs) 2602 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2603 2604 // No llhs, 'lhs' itself is the expression. 2605 return lhs; 2606 } 2607 2608 /// Parse an affine expression inside parentheses. 2609 /// 2610 /// affine-expr ::= `(` affine-expr `)` 2611 AffineExpr AffineParser::parseParentheticalExpr() { 2612 if (parseToken(Token::l_paren, "expected '('")) 2613 return nullptr; 2614 if (getToken().is(Token::r_paren)) 2615 return (emitError("no expression inside parentheses"), nullptr); 2616 2617 auto expr = parseAffineExpr(); 2618 if (!expr) 2619 return nullptr; 2620 if (parseToken(Token::r_paren, "expected ')'")) 2621 return nullptr; 2622 2623 return expr; 2624 } 2625 2626 /// Parse the negation expression. 2627 /// 2628 /// affine-expr ::= `-` affine-expr 2629 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2630 if (parseToken(Token::minus, "expected '-'")) 2631 return nullptr; 2632 2633 AffineExpr operand = parseAffineOperandExpr(lhs); 2634 // Since negation has the highest precedence of all ops (including high 2635 // precedence ops) but lower than parentheses, we are only going to use 2636 // parseAffineOperandExpr instead of parseAffineExpr here. 2637 if (!operand) 2638 // Extra error message although parseAffineOperandExpr would have 2639 // complained. Leads to a better diagnostic. 2640 return (emitError("missing operand of negation"), nullptr); 2641 return (-1) * operand; 2642 } 2643 2644 /// Parse a bare id that may appear in an affine expression. 2645 /// 2646 /// affine-expr ::= bare-id 2647 AffineExpr AffineParser::parseBareIdExpr() { 2648 if (getToken().isNot(Token::bare_identifier)) 2649 return (emitError("expected bare identifier"), nullptr); 2650 2651 StringRef sRef = getTokenSpelling(); 2652 for (auto entry : dimsAndSymbols) { 2653 if (entry.first == sRef) { 2654 consumeToken(Token::bare_identifier); 2655 return entry.second; 2656 } 2657 } 2658 2659 return (emitError("use of undeclared identifier"), nullptr); 2660 } 2661 2662 /// Parse an SSA id which may appear in an affine expression. 2663 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2664 if (!allowParsingSSAIds) 2665 return (emitError("unexpected ssa identifier"), nullptr); 2666 if (getToken().isNot(Token::percent_identifier)) 2667 return (emitError("expected ssa identifier"), nullptr); 2668 auto name = getTokenSpelling(); 2669 // Check if we already parsed this SSA id. 2670 for (auto entry : dimsAndSymbols) { 2671 if (entry.first == name) { 2672 consumeToken(Token::percent_identifier); 2673 return entry.second; 2674 } 2675 } 2676 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2677 if (parseElement(isSymbol)) 2678 return (emitError("failed to parse ssa identifier"), nullptr); 2679 auto idExpr = isSymbol 2680 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2681 : getAffineDimExpr(numDimOperands++, getContext()); 2682 dimsAndSymbols.push_back({name, idExpr}); 2683 return idExpr; 2684 } 2685 2686 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2687 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2688 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2689 return nullptr; 2690 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2691 if (!symbolExpr) 2692 return nullptr; 2693 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2694 return nullptr; 2695 return symbolExpr; 2696 } 2697 2698 /// Parse a positive integral constant appearing in an affine expression. 2699 /// 2700 /// affine-expr ::= integer-literal 2701 AffineExpr AffineParser::parseIntegerExpr() { 2702 auto val = getToken().getUInt64IntegerValue(); 2703 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2704 return (emitError("constant too large for index"), nullptr); 2705 2706 consumeToken(Token::integer); 2707 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2708 } 2709 2710 /// Parses an expression that can be a valid operand of an affine expression. 2711 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2712 /// operator, the rhs of which is being parsed. This is used to determine 2713 /// whether an error should be emitted for a missing right operand. 2714 // Eg: for an expression without parentheses (like i + j + k + l), each 2715 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2716 // operand expression, it's an op expression and will be parsed via 2717 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2718 // -l are valid operands that will be parsed by this function. 2719 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2720 switch (getToken().getKind()) { 2721 case Token::bare_identifier: 2722 return parseBareIdExpr(); 2723 case Token::kw_symbol: 2724 return parseSymbolSSAIdExpr(); 2725 case Token::percent_identifier: 2726 return parseSSAIdExpr(/*isSymbol=*/false); 2727 case Token::integer: 2728 return parseIntegerExpr(); 2729 case Token::l_paren: 2730 return parseParentheticalExpr(); 2731 case Token::minus: 2732 return parseNegateExpression(lhs); 2733 case Token::kw_ceildiv: 2734 case Token::kw_floordiv: 2735 case Token::kw_mod: 2736 case Token::plus: 2737 case Token::star: 2738 if (lhs) 2739 emitError("missing right operand of binary operator"); 2740 else 2741 emitError("missing left operand of binary operator"); 2742 return nullptr; 2743 default: 2744 if (lhs) 2745 emitError("missing right operand of binary operator"); 2746 else 2747 emitError("expected affine expression"); 2748 return nullptr; 2749 } 2750 } 2751 2752 /// Parse affine expressions that are bare-id's, integer constants, 2753 /// parenthetical affine expressions, and affine op expressions that are a 2754 /// composition of those. 2755 /// 2756 /// All binary op's associate from left to right. 2757 /// 2758 /// {add, sub} have lower precedence than {mul, div, and mod}. 2759 /// 2760 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2761 /// ceildiv, and mod are at the same higher precedence level. Negation has 2762 /// higher precedence than any binary op. 2763 /// 2764 /// llhs: the affine expression appearing on the left of the one being parsed. 2765 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2766 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2767 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2768 /// associativity. 2769 /// 2770 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2771 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2772 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2773 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2774 AffineLowPrecOp llhsOp) { 2775 AffineExpr lhs; 2776 if (!(lhs = parseAffineOperandExpr(llhs))) 2777 return nullptr; 2778 2779 // Found an LHS. Deal with the ops. 2780 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 2781 if (llhs) { 2782 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2783 return parseAffineLowPrecOpExpr(sum, lOp); 2784 } 2785 // No LLHS, get RHS and form the expression. 2786 return parseAffineLowPrecOpExpr(lhs, lOp); 2787 } 2788 auto opLoc = getToken().getLoc(); 2789 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 2790 // We have a higher precedence op here. Get the rhs operand for the llhs 2791 // through parseAffineHighPrecOpExpr. 2792 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 2793 if (!highRes) 2794 return nullptr; 2795 2796 // If llhs is null, the product forms the first operand of the yet to be 2797 // found expression. If non-null, the op to associate with llhs is llhsOp. 2798 AffineExpr expr = 2799 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 2800 2801 // Recurse for subsequent low prec op's after the affine high prec op 2802 // expression. 2803 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 2804 return parseAffineLowPrecOpExpr(expr, nextOp); 2805 return expr; 2806 } 2807 // Last operand in the expression list. 2808 if (llhs) 2809 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2810 // No llhs, 'lhs' itself is the expression. 2811 return lhs; 2812 } 2813 2814 /// Parse an affine expression. 2815 /// affine-expr ::= `(` affine-expr `)` 2816 /// | `-` affine-expr 2817 /// | affine-expr `+` affine-expr 2818 /// | affine-expr `-` affine-expr 2819 /// | affine-expr `*` affine-expr 2820 /// | affine-expr `floordiv` affine-expr 2821 /// | affine-expr `ceildiv` affine-expr 2822 /// | affine-expr `mod` affine-expr 2823 /// | bare-id 2824 /// | integer-literal 2825 /// 2826 /// Additional conditions are checked depending on the production. For eg., 2827 /// one of the operands for `*` has to be either constant/symbolic; the second 2828 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 2829 AffineExpr AffineParser::parseAffineExpr() { 2830 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 2831 } 2832 2833 /// Parse a dim or symbol from the lists appearing before the actual 2834 /// expressions of the affine map. Update our state to store the 2835 /// dimensional/symbolic identifier. 2836 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 2837 if (getToken().isNot(Token::bare_identifier)) 2838 return emitError("expected bare identifier"); 2839 2840 auto name = getTokenSpelling(); 2841 for (auto entry : dimsAndSymbols) { 2842 if (entry.first == name) 2843 return emitError("redefinition of identifier '" + name + "'"); 2844 } 2845 consumeToken(Token::bare_identifier); 2846 2847 dimsAndSymbols.push_back({name, idExpr}); 2848 return success(); 2849 } 2850 2851 /// Parse the list of dimensional identifiers to an affine map. 2852 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 2853 if (parseToken(Token::l_paren, 2854 "expected '(' at start of dimensional identifiers list")) { 2855 return failure(); 2856 } 2857 2858 auto parseElt = [&]() -> ParseResult { 2859 auto dimension = getAffineDimExpr(numDims++, getContext()); 2860 return parseIdentifierDefinition(dimension); 2861 }; 2862 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 2863 } 2864 2865 /// Parse the list of symbolic identifiers to an affine map. 2866 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 2867 consumeToken(Token::l_square); 2868 auto parseElt = [&]() -> ParseResult { 2869 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 2870 return parseIdentifierDefinition(symbol); 2871 }; 2872 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 2873 } 2874 2875 /// Parse the list of symbolic identifiers to an affine map. 2876 ParseResult 2877 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 2878 unsigned &numSymbols) { 2879 if (parseDimIdList(numDims)) { 2880 return failure(); 2881 } 2882 if (!getToken().is(Token::l_square)) { 2883 numSymbols = 0; 2884 return success(); 2885 } 2886 return parseSymbolIdList(numSymbols); 2887 } 2888 2889 /// Parses an ambiguous affine map or integer set definition inline. 2890 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 2891 IntegerSet &set) { 2892 unsigned numDims = 0, numSymbols = 0; 2893 2894 // List of dimensional and optional symbol identifiers. 2895 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 2896 return failure(); 2897 } 2898 2899 // This is needed for parsing attributes as we wouldn't know whether we would 2900 // be parsing an integer set attribute or an affine map attribute. 2901 bool isArrow = getToken().is(Token::arrow); 2902 bool isColon = getToken().is(Token::colon); 2903 if (!isArrow && !isColon) { 2904 return emitError("expected '->' or ':'"); 2905 } else if (isArrow) { 2906 parseToken(Token::arrow, "expected '->' or '['"); 2907 map = parseAffineMapRange(numDims, numSymbols); 2908 return map ? success() : failure(); 2909 } else if (parseToken(Token::colon, "expected ':' or '['")) { 2910 return failure(); 2911 } 2912 2913 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 2914 return success(); 2915 2916 return failure(); 2917 } 2918 2919 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 2920 ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) { 2921 if (parseToken(Token::l_square, "expected '['")) 2922 return failure(); 2923 2924 SmallVector<AffineExpr, 4> exprs; 2925 auto parseElt = [&]() -> ParseResult { 2926 auto elt = parseAffineExpr(); 2927 exprs.push_back(elt); 2928 return elt ? success() : failure(); 2929 }; 2930 2931 // Parse a multi-dimensional affine expression (a comma-separated list of 2932 // 1-d affine expressions); the list cannot be empty. Grammar: 2933 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2934 if (parseCommaSeparatedListUntil(Token::r_square, parseElt, 2935 /*allowEmptyList=*/true)) 2936 return failure(); 2937 // Parsed a valid affine map. 2938 if (exprs.empty()) 2939 map = AffineMap::get(getContext()); 2940 else 2941 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 2942 exprs); 2943 return success(); 2944 } 2945 2946 /// Parse the range and sizes affine map definition inline. 2947 /// 2948 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 2949 /// 2950 /// multi-dim-affine-expr ::= `(` `)` 2951 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` 2952 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 2953 unsigned numSymbols) { 2954 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 2955 2956 SmallVector<AffineExpr, 4> exprs; 2957 auto parseElt = [&]() -> ParseResult { 2958 auto elt = parseAffineExpr(); 2959 ParseResult res = elt ? success() : failure(); 2960 exprs.push_back(elt); 2961 return res; 2962 }; 2963 2964 // Parse a multi-dimensional affine expression (a comma-separated list of 2965 // 1-d affine expressions); the list cannot be empty. Grammar: 2966 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2967 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 2968 return AffineMap(); 2969 2970 if (exprs.empty()) 2971 return AffineMap::get(getContext()); 2972 2973 // Parsed a valid affine map. 2974 return AffineMap::get(numDims, numSymbols, exprs); 2975 } 2976 2977 /// Parse an affine constraint. 2978 /// affine-constraint ::= affine-expr `>=` `0` 2979 /// | affine-expr `==` `0` 2980 /// 2981 /// isEq is set to true if the parsed constraint is an equality, false if it 2982 /// is an inequality (greater than or equal). 2983 /// 2984 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 2985 AffineExpr expr = parseAffineExpr(); 2986 if (!expr) 2987 return nullptr; 2988 2989 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 2990 getToken().is(Token::integer)) { 2991 auto dim = getToken().getUnsignedIntegerValue(); 2992 if (dim.hasValue() && dim.getValue() == 0) { 2993 consumeToken(Token::integer); 2994 *isEq = false; 2995 return expr; 2996 } 2997 return (emitError("expected '0' after '>='"), nullptr); 2998 } 2999 3000 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 3001 getToken().is(Token::integer)) { 3002 auto dim = getToken().getUnsignedIntegerValue(); 3003 if (dim.hasValue() && dim.getValue() == 0) { 3004 consumeToken(Token::integer); 3005 *isEq = true; 3006 return expr; 3007 } 3008 return (emitError("expected '0' after '=='"), nullptr); 3009 } 3010 3011 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 3012 nullptr); 3013 } 3014 3015 /// Parse the constraints that are part of an integer set definition. 3016 /// integer-set-inline 3017 /// ::= dim-and-symbol-id-lists `:` 3018 /// '(' affine-constraint-conjunction? ')' 3019 /// affine-constraint-conjunction ::= affine-constraint (`,` 3020 /// affine-constraint)* 3021 /// 3022 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 3023 unsigned numSymbols) { 3024 if (parseToken(Token::l_paren, 3025 "expected '(' at start of integer set constraint list")) 3026 return IntegerSet(); 3027 3028 SmallVector<AffineExpr, 4> constraints; 3029 SmallVector<bool, 4> isEqs; 3030 auto parseElt = [&]() -> ParseResult { 3031 bool isEq; 3032 auto elt = parseAffineConstraint(&isEq); 3033 ParseResult res = elt ? success() : failure(); 3034 if (elt) { 3035 constraints.push_back(elt); 3036 isEqs.push_back(isEq); 3037 } 3038 return res; 3039 }; 3040 3041 // Parse a list of affine constraints (comma-separated). 3042 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3043 return IntegerSet(); 3044 3045 // If no constraints were parsed, then treat this as a degenerate 'true' case. 3046 if (constraints.empty()) { 3047 /* 0 == 0 */ 3048 auto zero = getAffineConstantExpr(0, getContext()); 3049 return IntegerSet::get(numDims, numSymbols, zero, true); 3050 } 3051 3052 // Parsed a valid integer set. 3053 return IntegerSet::get(numDims, numSymbols, constraints, isEqs); 3054 } 3055 3056 /// Parse an ambiguous reference to either and affine map or an integer set. 3057 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 3058 IntegerSet &set) { 3059 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 3060 } 3061 ParseResult Parser::parseAffineMapReference(AffineMap &map) { 3062 llvm::SMLoc curLoc = getToken().getLoc(); 3063 IntegerSet set; 3064 if (parseAffineMapOrIntegerSetReference(map, set)) 3065 return failure(); 3066 if (set) 3067 return emitError(curLoc, "expected AffineMap, but got IntegerSet"); 3068 return success(); 3069 } 3070 ParseResult Parser::parseIntegerSetReference(IntegerSet &set) { 3071 llvm::SMLoc curLoc = getToken().getLoc(); 3072 AffineMap map; 3073 if (parseAffineMapOrIntegerSetReference(map, set)) 3074 return failure(); 3075 if (map) 3076 return emitError(curLoc, "expected IntegerSet, but got AffineMap"); 3077 return success(); 3078 } 3079 3080 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 3081 /// parse SSA value uses encountered while parsing affine expressions. 3082 ParseResult 3083 Parser::parseAffineMapOfSSAIds(AffineMap &map, 3084 function_ref<ParseResult(bool)> parseElement) { 3085 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 3086 .parseAffineMapOfSSAIds(map); 3087 } 3088 3089 //===----------------------------------------------------------------------===// 3090 // OperationParser 3091 //===----------------------------------------------------------------------===// 3092 3093 namespace { 3094 /// This class provides support for parsing operations and regions of 3095 /// operations. 3096 class OperationParser : public Parser { 3097 public: 3098 OperationParser(ParserState &state, ModuleOp moduleOp) 3099 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 3100 } 3101 3102 ~OperationParser(); 3103 3104 /// After parsing is finished, this function must be called to see if there 3105 /// are any remaining issues. 3106 ParseResult finalize(); 3107 3108 //===--------------------------------------------------------------------===// 3109 // SSA Value Handling 3110 //===--------------------------------------------------------------------===// 3111 3112 /// This represents a use of an SSA value in the program. The first two 3113 /// entries in the tuple are the name and result number of a reference. The 3114 /// third is the location of the reference, which is used in case this ends 3115 /// up being a use of an undefined value. 3116 struct SSAUseInfo { 3117 StringRef name; // Value name, e.g. %42 or %abc 3118 unsigned number; // Number, specified with #12 3119 SMLoc loc; // Location of first definition or use. 3120 }; 3121 3122 /// Push a new SSA name scope to the parser. 3123 void pushSSANameScope(bool isIsolated); 3124 3125 /// Pop the last SSA name scope from the parser. 3126 ParseResult popSSANameScope(); 3127 3128 /// Register a definition of a value with the symbol table. 3129 ParseResult addDefinition(SSAUseInfo useInfo, Value value); 3130 3131 /// Parse an optional list of SSA uses into 'results'. 3132 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 3133 3134 /// Parse a single SSA use into 'result'. 3135 ParseResult parseSSAUse(SSAUseInfo &result); 3136 3137 /// Given a reference to an SSA value and its type, return a reference. This 3138 /// returns null on failure. 3139 Value resolveSSAUse(SSAUseInfo useInfo, Type type); 3140 3141 ParseResult parseSSADefOrUseAndType( 3142 const std::function<ParseResult(SSAUseInfo, Type)> &action); 3143 3144 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value> &results); 3145 3146 /// Return the location of the value identified by its name and number if it 3147 /// has been already reference. 3148 Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 3149 auto &values = isolatedNameScopes.back().values; 3150 if (!values.count(name) || number >= values[name].size()) 3151 return {}; 3152 if (values[name][number].first) 3153 return values[name][number].second; 3154 return {}; 3155 } 3156 3157 //===--------------------------------------------------------------------===// 3158 // Operation Parsing 3159 //===--------------------------------------------------------------------===// 3160 3161 /// Parse an operation instance. 3162 ParseResult parseOperation(); 3163 3164 /// Parse a single operation successor and its operand list. 3165 ParseResult parseSuccessorAndUseList(Block *&dest, 3166 SmallVectorImpl<Value> &operands); 3167 3168 /// Parse a comma-separated list of operation successors in brackets. 3169 ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations, 3170 SmallVectorImpl<SmallVector<Value, 4>> &operands); 3171 3172 /// Parse an operation instance that is in the generic form. 3173 Operation *parseGenericOperation(); 3174 3175 /// Parse an operation instance that is in the generic form and insert it at 3176 /// the provided insertion point. 3177 Operation *parseGenericOperation(Block *insertBlock, 3178 Block::iterator insertPt); 3179 3180 /// Parse an operation instance that is in the op-defined custom form. 3181 Operation *parseCustomOperation(); 3182 3183 //===--------------------------------------------------------------------===// 3184 // Region Parsing 3185 //===--------------------------------------------------------------------===// 3186 3187 /// Parse a region into 'region' with the provided entry block arguments. 3188 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 3189 /// isolated from those above. 3190 ParseResult parseRegion(Region ®ion, 3191 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 3192 bool isIsolatedNameScope = false); 3193 3194 /// Parse a region body into 'region'. 3195 ParseResult parseRegionBody(Region ®ion); 3196 3197 //===--------------------------------------------------------------------===// 3198 // Block Parsing 3199 //===--------------------------------------------------------------------===// 3200 3201 /// Parse a new block into 'block'. 3202 ParseResult parseBlock(Block *&block); 3203 3204 /// Parse a list of operations into 'block'. 3205 ParseResult parseBlockBody(Block *block); 3206 3207 /// Parse a (possibly empty) list of block arguments. 3208 ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results, 3209 Block *owner); 3210 3211 /// Get the block with the specified name, creating it if it doesn't 3212 /// already exist. The location specified is the point of use, which allows 3213 /// us to diagnose references to blocks that are not defined precisely. 3214 Block *getBlockNamed(StringRef name, SMLoc loc); 3215 3216 /// Define the block with the specified name. Returns the Block* or nullptr in 3217 /// the case of redefinition. 3218 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 3219 3220 private: 3221 /// Returns the info for a block at the current scope for the given name. 3222 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 3223 return blocksByName.back()[name]; 3224 } 3225 3226 /// Insert a new forward reference to the given block. 3227 void insertForwardRef(Block *block, SMLoc loc) { 3228 forwardRef.back().try_emplace(block, loc); 3229 } 3230 3231 /// Erase any forward reference to the given block. 3232 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 3233 3234 /// Record that a definition was added at the current scope. 3235 void recordDefinition(StringRef def); 3236 3237 /// Get the value entry for the given SSA name. 3238 SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name); 3239 3240 /// Create a forward reference placeholder value with the given location and 3241 /// result type. 3242 Value createForwardRefPlaceholder(SMLoc loc, Type type); 3243 3244 /// Return true if this is a forward reference. 3245 bool isForwardRefPlaceholder(Value value) { 3246 return forwardRefPlaceholders.count(value); 3247 } 3248 3249 /// This struct represents an isolated SSA name scope. This scope may contain 3250 /// other nested non-isolated scopes. These scopes are used for operations 3251 /// that are known to be isolated to allow for reusing names within their 3252 /// regions, even if those names are used above. 3253 struct IsolatedSSANameScope { 3254 /// Record that a definition was added at the current scope. 3255 void recordDefinition(StringRef def) { 3256 definitionsPerScope.back().insert(def); 3257 } 3258 3259 /// Push a nested name scope. 3260 void pushSSANameScope() { definitionsPerScope.push_back({}); } 3261 3262 /// Pop a nested name scope. 3263 void popSSANameScope() { 3264 for (auto &def : definitionsPerScope.pop_back_val()) 3265 values.erase(def.getKey()); 3266 } 3267 3268 /// This keeps track of all of the SSA values we are tracking for each name 3269 /// scope, indexed by their name. This has one entry per result number. 3270 llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values; 3271 3272 /// This keeps track of all of the values defined by a specific name scope. 3273 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 3274 }; 3275 3276 /// A list of isolated name scopes. 3277 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 3278 3279 /// This keeps track of the block names as well as the location of the first 3280 /// reference for each nested name scope. This is used to diagnose invalid 3281 /// block references and memorize them. 3282 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 3283 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 3284 3285 /// These are all of the placeholders we've made along with the location of 3286 /// their first reference, to allow checking for use of undefined values. 3287 DenseMap<Value, SMLoc> forwardRefPlaceholders; 3288 3289 /// The builder used when creating parsed operation instances. 3290 OpBuilder opBuilder; 3291 3292 /// The top level module operation. 3293 ModuleOp moduleOp; 3294 }; 3295 } // end anonymous namespace 3296 3297 OperationParser::~OperationParser() { 3298 for (auto &fwd : forwardRefPlaceholders) { 3299 // Drop all uses of undefined forward declared reference and destroy 3300 // defining operation. 3301 fwd.first.dropAllUses(); 3302 fwd.first.getDefiningOp()->destroy(); 3303 } 3304 } 3305 3306 /// After parsing is finished, this function must be called to see if there are 3307 /// any remaining issues. 3308 ParseResult OperationParser::finalize() { 3309 // Check for any forward references that are left. If we find any, error 3310 // out. 3311 if (!forwardRefPlaceholders.empty()) { 3312 SmallVector<std::pair<const char *, Value>, 4> errors; 3313 // Iteration over the map isn't deterministic, so sort by source location. 3314 for (auto entry : forwardRefPlaceholders) 3315 errors.push_back({entry.second.getPointer(), entry.first}); 3316 llvm::array_pod_sort(errors.begin(), errors.end()); 3317 3318 for (auto entry : errors) { 3319 auto loc = SMLoc::getFromPointer(entry.first); 3320 emitError(loc, "use of undeclared SSA value name"); 3321 } 3322 return failure(); 3323 } 3324 3325 return success(); 3326 } 3327 3328 //===----------------------------------------------------------------------===// 3329 // SSA Value Handling 3330 //===----------------------------------------------------------------------===// 3331 3332 void OperationParser::pushSSANameScope(bool isIsolated) { 3333 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 3334 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 3335 3336 // Push back a new name definition scope. 3337 if (isIsolated) 3338 isolatedNameScopes.push_back({}); 3339 isolatedNameScopes.back().pushSSANameScope(); 3340 } 3341 3342 ParseResult OperationParser::popSSANameScope() { 3343 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 3344 3345 // Verify that all referenced blocks were defined. 3346 if (!forwardRefInCurrentScope.empty()) { 3347 SmallVector<std::pair<const char *, Block *>, 4> errors; 3348 // Iteration over the map isn't deterministic, so sort by source location. 3349 for (auto entry : forwardRefInCurrentScope) { 3350 errors.push_back({entry.second.getPointer(), entry.first}); 3351 // Add this block to the top-level region to allow for automatic cleanup. 3352 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 3353 } 3354 llvm::array_pod_sort(errors.begin(), errors.end()); 3355 3356 for (auto entry : errors) { 3357 auto loc = SMLoc::getFromPointer(entry.first); 3358 emitError(loc, "reference to an undefined block"); 3359 } 3360 return failure(); 3361 } 3362 3363 // Pop the next nested namescope. If there is only one internal namescope, 3364 // just pop the isolated scope. 3365 auto ¤tNameScope = isolatedNameScopes.back(); 3366 if (currentNameScope.definitionsPerScope.size() == 1) 3367 isolatedNameScopes.pop_back(); 3368 else 3369 currentNameScope.popSSANameScope(); 3370 3371 blocksByName.pop_back(); 3372 return success(); 3373 } 3374 3375 /// Register a definition of a value with the symbol table. 3376 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) { 3377 auto &entries = getSSAValueEntry(useInfo.name); 3378 3379 // Make sure there is a slot for this value. 3380 if (entries.size() <= useInfo.number) 3381 entries.resize(useInfo.number + 1); 3382 3383 // If we already have an entry for this, check to see if it was a definition 3384 // or a forward reference. 3385 if (auto existing = entries[useInfo.number].first) { 3386 if (!isForwardRefPlaceholder(existing)) { 3387 return emitError(useInfo.loc) 3388 .append("redefinition of SSA value '", useInfo.name, "'") 3389 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3390 .append("previously defined here"); 3391 } 3392 3393 // If it was a forward reference, update everything that used it to use 3394 // the actual definition instead, delete the forward ref, and remove it 3395 // from our set of forward references we track. 3396 existing.replaceAllUsesWith(value); 3397 existing.getDefiningOp()->destroy(); 3398 forwardRefPlaceholders.erase(existing); 3399 } 3400 3401 /// Record this definition for the current scope. 3402 entries[useInfo.number] = {value, useInfo.loc}; 3403 recordDefinition(useInfo.name); 3404 return success(); 3405 } 3406 3407 /// Parse a (possibly empty) list of SSA operands. 3408 /// 3409 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 3410 /// ssa-use-list-opt ::= ssa-use-list? 3411 /// 3412 ParseResult 3413 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 3414 if (getToken().isNot(Token::percent_identifier)) 3415 return success(); 3416 return parseCommaSeparatedList([&]() -> ParseResult { 3417 SSAUseInfo result; 3418 if (parseSSAUse(result)) 3419 return failure(); 3420 results.push_back(result); 3421 return success(); 3422 }); 3423 } 3424 3425 /// Parse a SSA operand for an operation. 3426 /// 3427 /// ssa-use ::= ssa-id 3428 /// 3429 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 3430 result.name = getTokenSpelling(); 3431 result.number = 0; 3432 result.loc = getToken().getLoc(); 3433 if (parseToken(Token::percent_identifier, "expected SSA operand")) 3434 return failure(); 3435 3436 // If we have an attribute ID, it is a result number. 3437 if (getToken().is(Token::hash_identifier)) { 3438 if (auto value = getToken().getHashIdentifierNumber()) 3439 result.number = value.getValue(); 3440 else 3441 return emitError("invalid SSA value result number"); 3442 consumeToken(Token::hash_identifier); 3443 } 3444 3445 return success(); 3446 } 3447 3448 /// Given an unbound reference to an SSA value and its type, return the value 3449 /// it specifies. This returns null on failure. 3450 Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 3451 auto &entries = getSSAValueEntry(useInfo.name); 3452 3453 // If we have already seen a value of this name, return it. 3454 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 3455 auto result = entries[useInfo.number].first; 3456 // Check that the type matches the other uses. 3457 if (result.getType() == type) 3458 return result; 3459 3460 emitError(useInfo.loc, "use of value '") 3461 .append(useInfo.name, 3462 "' expects different type than prior uses: ", type, " vs ", 3463 result.getType()) 3464 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3465 .append("prior use here"); 3466 return nullptr; 3467 } 3468 3469 // Make sure we have enough slots for this. 3470 if (entries.size() <= useInfo.number) 3471 entries.resize(useInfo.number + 1); 3472 3473 // If the value has already been defined and this is an overly large result 3474 // number, diagnose that. 3475 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 3476 return (emitError(useInfo.loc, "reference to invalid result number"), 3477 nullptr); 3478 3479 // Otherwise, this is a forward reference. Create a placeholder and remember 3480 // that we did so. 3481 auto result = createForwardRefPlaceholder(useInfo.loc, type); 3482 entries[useInfo.number].first = result; 3483 entries[useInfo.number].second = useInfo.loc; 3484 return result; 3485 } 3486 3487 /// Parse an SSA use with an associated type. 3488 /// 3489 /// ssa-use-and-type ::= ssa-use `:` type 3490 ParseResult OperationParser::parseSSADefOrUseAndType( 3491 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 3492 SSAUseInfo useInfo; 3493 if (parseSSAUse(useInfo) || 3494 parseToken(Token::colon, "expected ':' and type for SSA operand")) 3495 return failure(); 3496 3497 auto type = parseType(); 3498 if (!type) 3499 return failure(); 3500 3501 return action(useInfo, type); 3502 } 3503 3504 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 3505 /// followed by a type list. 3506 /// 3507 /// ssa-use-and-type-list 3508 /// ::= ssa-use-list ':' type-list-no-parens 3509 /// 3510 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 3511 SmallVectorImpl<Value> &results) { 3512 SmallVector<SSAUseInfo, 4> valueIDs; 3513 if (parseOptionalSSAUseList(valueIDs)) 3514 return failure(); 3515 3516 // If there were no operands, then there is no colon or type lists. 3517 if (valueIDs.empty()) 3518 return success(); 3519 3520 SmallVector<Type, 4> types; 3521 if (parseToken(Token::colon, "expected ':' in operand list") || 3522 parseTypeListNoParens(types)) 3523 return failure(); 3524 3525 if (valueIDs.size() != types.size()) 3526 return emitError("expected ") 3527 << valueIDs.size() << " types to match operand list"; 3528 3529 results.reserve(valueIDs.size()); 3530 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 3531 if (auto value = resolveSSAUse(valueIDs[i], types[i])) 3532 results.push_back(value); 3533 else 3534 return failure(); 3535 } 3536 3537 return success(); 3538 } 3539 3540 /// Record that a definition was added at the current scope. 3541 void OperationParser::recordDefinition(StringRef def) { 3542 isolatedNameScopes.back().recordDefinition(def); 3543 } 3544 3545 /// Get the value entry for the given SSA name. 3546 SmallVectorImpl<std::pair<Value, SMLoc>> & 3547 OperationParser::getSSAValueEntry(StringRef name) { 3548 return isolatedNameScopes.back().values[name]; 3549 } 3550 3551 /// Create and remember a new placeholder for a forward reference. 3552 Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 3553 // Forward references are always created as operations, because we just need 3554 // something with a def/use chain. 3555 // 3556 // We create these placeholders as having an empty name, which we know 3557 // cannot be created through normal user input, allowing us to distinguish 3558 // them. 3559 auto name = OperationName("placeholder", getContext()); 3560 auto *op = Operation::create( 3561 getEncodedSourceLocation(loc), name, type, /*operands=*/{}, 3562 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, 3563 /*resizableOperandList=*/false); 3564 forwardRefPlaceholders[op->getResult(0)] = loc; 3565 return op->getResult(0); 3566 } 3567 3568 //===----------------------------------------------------------------------===// 3569 // Operation Parsing 3570 //===----------------------------------------------------------------------===// 3571 3572 /// Parse an operation. 3573 /// 3574 /// operation ::= op-result-list? 3575 /// (generic-operation | custom-operation) 3576 /// trailing-location? 3577 /// generic-operation ::= string-literal `(` ssa-use-list? `)` 3578 /// successor-list? (`(` region-list `)`)? 3579 /// attribute-dict? `:` function-type 3580 /// custom-operation ::= bare-id custom-operation-format 3581 /// op-result-list ::= op-result (`,` op-result)* `=` 3582 /// op-result ::= ssa-id (`:` integer-literal) 3583 /// 3584 ParseResult OperationParser::parseOperation() { 3585 auto loc = getToken().getLoc(); 3586 SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs; 3587 size_t numExpectedResults = 0; 3588 if (getToken().is(Token::percent_identifier)) { 3589 // Parse the group of result ids. 3590 auto parseNextResult = [&]() -> ParseResult { 3591 // Parse the next result id. 3592 if (!getToken().is(Token::percent_identifier)) 3593 return emitError("expected valid ssa identifier"); 3594 3595 Token nameTok = getToken(); 3596 consumeToken(Token::percent_identifier); 3597 3598 // If the next token is a ':', we parse the expected result count. 3599 size_t expectedSubResults = 1; 3600 if (consumeIf(Token::colon)) { 3601 // Check that the next token is an integer. 3602 if (!getToken().is(Token::integer)) 3603 return emitError("expected integer number of results"); 3604 3605 // Check that number of results is > 0. 3606 auto val = getToken().getUInt64IntegerValue(); 3607 if (!val.hasValue() || val.getValue() < 1) 3608 return emitError("expected named operation to have atleast 1 result"); 3609 consumeToken(Token::integer); 3610 expectedSubResults = *val; 3611 } 3612 3613 resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults, 3614 nameTok.getLoc()); 3615 numExpectedResults += expectedSubResults; 3616 return success(); 3617 }; 3618 if (parseCommaSeparatedList(parseNextResult)) 3619 return failure(); 3620 3621 if (parseToken(Token::equal, "expected '=' after SSA name")) 3622 return failure(); 3623 } 3624 3625 Operation *op; 3626 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3627 op = parseCustomOperation(); 3628 else if (getToken().is(Token::string)) 3629 op = parseGenericOperation(); 3630 else 3631 return emitError("expected operation name in quotes"); 3632 3633 // If parsing of the basic operation failed, then this whole thing fails. 3634 if (!op) 3635 return failure(); 3636 3637 // If the operation had a name, register it. 3638 if (!resultIDs.empty()) { 3639 if (op->getNumResults() == 0) 3640 return emitError(loc, "cannot name an operation with no results"); 3641 if (numExpectedResults != op->getNumResults()) 3642 return emitError(loc, "operation defines ") 3643 << op->getNumResults() << " results but was provided " 3644 << numExpectedResults << " to bind"; 3645 3646 // Add definitions for each of the result groups. 3647 unsigned opResI = 0; 3648 for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) { 3649 for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) { 3650 if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, 3651 op->getResult(opResI++))) 3652 return failure(); 3653 } 3654 } 3655 } 3656 3657 return success(); 3658 } 3659 3660 /// Parse a single operation successor and its operand list. 3661 /// 3662 /// successor ::= block-id branch-use-list? 3663 /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` 3664 /// 3665 ParseResult 3666 OperationParser::parseSuccessorAndUseList(Block *&dest, 3667 SmallVectorImpl<Value> &operands) { 3668 // Verify branch is identifier and get the matching block. 3669 if (!getToken().is(Token::caret_identifier)) 3670 return emitError("expected block name"); 3671 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3672 consumeToken(); 3673 3674 // Handle optional arguments. 3675 if (consumeIf(Token::l_paren) && 3676 (parseOptionalSSAUseAndTypeList(operands) || 3677 parseToken(Token::r_paren, "expected ')' to close argument list"))) { 3678 return failure(); 3679 } 3680 3681 return success(); 3682 } 3683 3684 /// Parse a comma-separated list of operation successors in brackets. 3685 /// 3686 /// successor-list ::= `[` successor (`,` successor )* `]` 3687 /// 3688 ParseResult OperationParser::parseSuccessors( 3689 SmallVectorImpl<Block *> &destinations, 3690 SmallVectorImpl<SmallVector<Value, 4>> &operands) { 3691 if (parseToken(Token::l_square, "expected '['")) 3692 return failure(); 3693 3694 auto parseElt = [this, &destinations, &operands]() { 3695 Block *dest; 3696 SmallVector<Value, 4> destOperands; 3697 auto res = parseSuccessorAndUseList(dest, destOperands); 3698 destinations.push_back(dest); 3699 operands.push_back(destOperands); 3700 return res; 3701 }; 3702 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3703 /*allowEmptyList=*/false); 3704 } 3705 3706 namespace { 3707 // RAII-style guard for cleaning up the regions in the operation state before 3708 // deleting them. Within the parser, regions may get deleted if parsing failed, 3709 // and other errors may be present, in particular undominated uses. This makes 3710 // sure such uses are deleted. 3711 struct CleanupOpStateRegions { 3712 ~CleanupOpStateRegions() { 3713 SmallVector<Region *, 4> regionsToClean; 3714 regionsToClean.reserve(state.regions.size()); 3715 for (auto ®ion : state.regions) 3716 if (region) 3717 for (auto &block : *region) 3718 block.dropAllDefinedValueUses(); 3719 } 3720 OperationState &state; 3721 }; 3722 } // namespace 3723 3724 Operation *OperationParser::parseGenericOperation() { 3725 // Get location information for the operation. 3726 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3727 3728 auto name = getToken().getStringValue(); 3729 if (name.empty()) 3730 return (emitError("empty operation name is invalid"), nullptr); 3731 if (name.find('\0') != StringRef::npos) 3732 return (emitError("null character not allowed in operation name"), nullptr); 3733 3734 consumeToken(Token::string); 3735 3736 OperationState result(srcLocation, name); 3737 3738 // Generic operations have a resizable operation list. 3739 result.setOperandListToResizable(); 3740 3741 // Parse the operand list. 3742 SmallVector<SSAUseInfo, 8> operandInfos; 3743 3744 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3745 parseOptionalSSAUseList(operandInfos) || 3746 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3747 return nullptr; 3748 } 3749 3750 // Parse the successor list but don't add successors to the result yet to 3751 // avoid messing up with the argument order. 3752 SmallVector<Block *, 2> successors; 3753 SmallVector<SmallVector<Value, 4>, 2> successorOperands; 3754 if (getToken().is(Token::l_square)) { 3755 // Check if the operation is a known terminator. 3756 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3757 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3758 return emitError("successors in non-terminator"), nullptr; 3759 if (parseSuccessors(successors, successorOperands)) 3760 return nullptr; 3761 } 3762 3763 // Parse the region list. 3764 CleanupOpStateRegions guard{result}; 3765 if (consumeIf(Token::l_paren)) { 3766 do { 3767 // Create temporary regions with the top level region as parent. 3768 result.regions.emplace_back(new Region(moduleOp)); 3769 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3770 return nullptr; 3771 } while (consumeIf(Token::comma)); 3772 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3773 return nullptr; 3774 } 3775 3776 if (getToken().is(Token::l_brace)) { 3777 if (parseAttributeDict(result.attributes)) 3778 return nullptr; 3779 } 3780 3781 if (parseToken(Token::colon, "expected ':' followed by operation type")) 3782 return nullptr; 3783 3784 auto typeLoc = getToken().getLoc(); 3785 auto type = parseType(); 3786 if (!type) 3787 return nullptr; 3788 auto fnType = type.dyn_cast<FunctionType>(); 3789 if (!fnType) 3790 return (emitError(typeLoc, "expected function type"), nullptr); 3791 3792 result.addTypes(fnType.getResults()); 3793 3794 // Check that we have the right number of types for the operands. 3795 auto operandTypes = fnType.getInputs(); 3796 if (operandTypes.size() != operandInfos.size()) { 3797 auto plural = "s"[operandInfos.size() == 1]; 3798 return (emitError(typeLoc, "expected ") 3799 << operandInfos.size() << " operand type" << plural 3800 << " but had " << operandTypes.size(), 3801 nullptr); 3802 } 3803 3804 // Resolve all of the operands. 3805 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 3806 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 3807 if (!result.operands.back()) 3808 return nullptr; 3809 } 3810 3811 // Add the successors, and their operands after the proper operands. 3812 for (auto succ : llvm::zip(successors, successorOperands)) { 3813 Block *successor = std::get<0>(succ); 3814 const SmallVector<Value, 4> &operands = std::get<1>(succ); 3815 result.addSuccessor(successor, operands); 3816 } 3817 3818 // Parse a location if one is present. 3819 if (parseOptionalTrailingLocation(result.location)) 3820 return nullptr; 3821 3822 return opBuilder.createOperation(result); 3823 } 3824 3825 Operation *OperationParser::parseGenericOperation(Block *insertBlock, 3826 Block::iterator insertPt) { 3827 OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); 3828 opBuilder.setInsertionPoint(insertBlock, insertPt); 3829 return parseGenericOperation(); 3830 } 3831 3832 namespace { 3833 class CustomOpAsmParser : public OpAsmParser { 3834 public: 3835 CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, 3836 OperationParser &parser) 3837 : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} 3838 3839 /// Parse an instance of the operation described by 'opDefinition' into the 3840 /// provided operation state. 3841 ParseResult parseOperation(OperationState &opState) { 3842 if (opDefinition->parseAssembly(*this, opState)) 3843 return failure(); 3844 return success(); 3845 } 3846 3847 Operation *parseGenericOperation(Block *insertBlock, 3848 Block::iterator insertPt) final { 3849 return parser.parseGenericOperation(insertBlock, insertPt); 3850 } 3851 3852 //===--------------------------------------------------------------------===// 3853 // Utilities 3854 //===--------------------------------------------------------------------===// 3855 3856 /// Return if any errors were emitted during parsing. 3857 bool didEmitError() const { return emittedError; } 3858 3859 /// Emit a diagnostic at the specified location and return failure. 3860 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 3861 emittedError = true; 3862 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 3863 message); 3864 } 3865 3866 llvm::SMLoc getCurrentLocation() override { 3867 return parser.getToken().getLoc(); 3868 } 3869 3870 Builder &getBuilder() const override { return parser.builder; } 3871 3872 llvm::SMLoc getNameLoc() const override { return nameLoc; } 3873 3874 //===--------------------------------------------------------------------===// 3875 // Token Parsing 3876 //===--------------------------------------------------------------------===// 3877 3878 /// Parse a `->` token. 3879 ParseResult parseArrow() override { 3880 return parser.parseToken(Token::arrow, "expected '->'"); 3881 } 3882 3883 /// Parses a `->` if present. 3884 ParseResult parseOptionalArrow() override { 3885 return success(parser.consumeIf(Token::arrow)); 3886 } 3887 3888 /// Parse a `:` token. 3889 ParseResult parseColon() override { 3890 return parser.parseToken(Token::colon, "expected ':'"); 3891 } 3892 3893 /// Parse a `:` token if present. 3894 ParseResult parseOptionalColon() override { 3895 return success(parser.consumeIf(Token::colon)); 3896 } 3897 3898 /// Parse a `,` token. 3899 ParseResult parseComma() override { 3900 return parser.parseToken(Token::comma, "expected ','"); 3901 } 3902 3903 /// Parse a `,` token if present. 3904 ParseResult parseOptionalComma() override { 3905 return success(parser.consumeIf(Token::comma)); 3906 } 3907 3908 /// Parses a `...` if present. 3909 ParseResult parseOptionalEllipsis() override { 3910 return success(parser.consumeIf(Token::ellipsis)); 3911 } 3912 3913 /// Parse a `=` token. 3914 ParseResult parseEqual() override { 3915 return parser.parseToken(Token::equal, "expected '='"); 3916 } 3917 3918 /// Parse a '<' token. 3919 ParseResult parseLess() override { 3920 return parser.parseToken(Token::less, "expected '<'"); 3921 } 3922 3923 /// Parse a '>' token. 3924 ParseResult parseGreater() override { 3925 return parser.parseToken(Token::greater, "expected '>'"); 3926 } 3927 3928 /// Parse a `(` token. 3929 ParseResult parseLParen() override { 3930 return parser.parseToken(Token::l_paren, "expected '('"); 3931 } 3932 3933 /// Parses a '(' if present. 3934 ParseResult parseOptionalLParen() override { 3935 return success(parser.consumeIf(Token::l_paren)); 3936 } 3937 3938 /// Parse a `)` token. 3939 ParseResult parseRParen() override { 3940 return parser.parseToken(Token::r_paren, "expected ')'"); 3941 } 3942 3943 /// Parses a ')' if present. 3944 ParseResult parseOptionalRParen() override { 3945 return success(parser.consumeIf(Token::r_paren)); 3946 } 3947 3948 /// Parse a `[` token. 3949 ParseResult parseLSquare() override { 3950 return parser.parseToken(Token::l_square, "expected '['"); 3951 } 3952 3953 /// Parses a '[' if present. 3954 ParseResult parseOptionalLSquare() override { 3955 return success(parser.consumeIf(Token::l_square)); 3956 } 3957 3958 /// Parse a `]` token. 3959 ParseResult parseRSquare() override { 3960 return parser.parseToken(Token::r_square, "expected ']'"); 3961 } 3962 3963 /// Parses a ']' if present. 3964 ParseResult parseOptionalRSquare() override { 3965 return success(parser.consumeIf(Token::r_square)); 3966 } 3967 3968 //===--------------------------------------------------------------------===// 3969 // Attribute Parsing 3970 //===--------------------------------------------------------------------===// 3971 3972 /// Parse an arbitrary attribute of a given type and return it in result. This 3973 /// also adds the attribute to the specified attribute list with the specified 3974 /// name. 3975 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 3976 SmallVectorImpl<NamedAttribute> &attrs) override { 3977 result = parser.parseAttribute(type); 3978 if (!result) 3979 return failure(); 3980 3981 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 3982 return success(); 3983 } 3984 3985 /// Parse a named dictionary into 'result' if it is present. 3986 ParseResult 3987 parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override { 3988 if (parser.getToken().isNot(Token::l_brace)) 3989 return success(); 3990 return parser.parseAttributeDict(result); 3991 } 3992 3993 /// Parse a named dictionary into 'result' if the `attributes` keyword is 3994 /// present. 3995 ParseResult parseOptionalAttrDictWithKeyword( 3996 SmallVectorImpl<NamedAttribute> &result) override { 3997 if (failed(parseOptionalKeyword("attributes"))) 3998 return success(); 3999 return parser.parseAttributeDict(result); 4000 } 4001 4002 /// Parse an affine map instance into 'map'. 4003 ParseResult parseAffineMap(AffineMap &map) override { 4004 return parser.parseAffineMapReference(map); 4005 } 4006 4007 /// Parse an integer set instance into 'set'. 4008 ParseResult printIntegerSet(IntegerSet &set) override { 4009 return parser.parseIntegerSetReference(set); 4010 } 4011 4012 //===--------------------------------------------------------------------===// 4013 // Identifier Parsing 4014 //===--------------------------------------------------------------------===// 4015 4016 /// Returns if the current token corresponds to a keyword. 4017 bool isCurrentTokenAKeyword() const { 4018 return parser.getToken().is(Token::bare_identifier) || 4019 parser.getToken().isKeyword(); 4020 } 4021 4022 /// Parse the given keyword if present. 4023 ParseResult parseOptionalKeyword(StringRef keyword) override { 4024 // Check that the current token has the same spelling. 4025 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 4026 return failure(); 4027 parser.consumeToken(); 4028 return success(); 4029 } 4030 4031 /// Parse a keyword, if present, into 'keyword'. 4032 ParseResult parseOptionalKeyword(StringRef *keyword) override { 4033 // Check that the current token is a keyword. 4034 if (!isCurrentTokenAKeyword()) 4035 return failure(); 4036 4037 *keyword = parser.getTokenSpelling(); 4038 parser.consumeToken(); 4039 return success(); 4040 } 4041 4042 /// Parse an optional @-identifier and store it (without the '@' symbol) in a 4043 /// string attribute named 'attrName'. 4044 ParseResult 4045 parseOptionalSymbolName(StringAttr &result, StringRef attrName, 4046 SmallVectorImpl<NamedAttribute> &attrs) override { 4047 Token atToken = parser.getToken(); 4048 if (atToken.isNot(Token::at_identifier)) 4049 return failure(); 4050 4051 result = getBuilder().getStringAttr(extractSymbolReference(atToken)); 4052 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 4053 parser.consumeToken(); 4054 return success(); 4055 } 4056 4057 //===--------------------------------------------------------------------===// 4058 // Operand Parsing 4059 //===--------------------------------------------------------------------===// 4060 4061 /// Parse a single operand. 4062 ParseResult parseOperand(OperandType &result) override { 4063 OperationParser::SSAUseInfo useInfo; 4064 if (parser.parseSSAUse(useInfo)) 4065 return failure(); 4066 4067 result = {useInfo.loc, useInfo.name, useInfo.number}; 4068 return success(); 4069 } 4070 4071 /// Parse zero or more SSA comma-separated operand references with a specified 4072 /// surrounding delimiter, and an optional required operand count. 4073 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 4074 int requiredOperandCount = -1, 4075 Delimiter delimiter = Delimiter::None) override { 4076 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 4077 requiredOperandCount, delimiter); 4078 } 4079 4080 /// Parse zero or more SSA comma-separated operand or region arguments with 4081 /// optional surrounding delimiter and required operand count. 4082 ParseResult 4083 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 4084 bool isOperandList, int requiredOperandCount = -1, 4085 Delimiter delimiter = Delimiter::None) { 4086 auto startLoc = parser.getToken().getLoc(); 4087 4088 // Handle delimiters. 4089 switch (delimiter) { 4090 case Delimiter::None: 4091 // Don't check for the absence of a delimiter if the number of operands 4092 // is unknown (and hence the operand list could be empty). 4093 if (requiredOperandCount == -1) 4094 break; 4095 // Token already matches an identifier and so can't be a delimiter. 4096 if (parser.getToken().is(Token::percent_identifier)) 4097 break; 4098 // Test against known delimiters. 4099 if (parser.getToken().is(Token::l_paren) || 4100 parser.getToken().is(Token::l_square)) 4101 return emitError(startLoc, "unexpected delimiter"); 4102 return emitError(startLoc, "invalid operand"); 4103 case Delimiter::OptionalParen: 4104 if (parser.getToken().isNot(Token::l_paren)) 4105 return success(); 4106 LLVM_FALLTHROUGH; 4107 case Delimiter::Paren: 4108 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 4109 return failure(); 4110 break; 4111 case Delimiter::OptionalSquare: 4112 if (parser.getToken().isNot(Token::l_square)) 4113 return success(); 4114 LLVM_FALLTHROUGH; 4115 case Delimiter::Square: 4116 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 4117 return failure(); 4118 break; 4119 } 4120 4121 // Check for zero operands. 4122 if (parser.getToken().is(Token::percent_identifier)) { 4123 do { 4124 OperandType operandOrArg; 4125 if (isOperandList ? parseOperand(operandOrArg) 4126 : parseRegionArgument(operandOrArg)) 4127 return failure(); 4128 result.push_back(operandOrArg); 4129 } while (parser.consumeIf(Token::comma)); 4130 } 4131 4132 // Handle delimiters. If we reach here, the optional delimiters were 4133 // present, so we need to parse their closing one. 4134 switch (delimiter) { 4135 case Delimiter::None: 4136 break; 4137 case Delimiter::OptionalParen: 4138 case Delimiter::Paren: 4139 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 4140 return failure(); 4141 break; 4142 case Delimiter::OptionalSquare: 4143 case Delimiter::Square: 4144 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 4145 return failure(); 4146 break; 4147 } 4148 4149 if (requiredOperandCount != -1 && 4150 result.size() != static_cast<size_t>(requiredOperandCount)) 4151 return emitError(startLoc, "expected ") 4152 << requiredOperandCount << " operands"; 4153 return success(); 4154 } 4155 4156 /// Parse zero or more trailing SSA comma-separated trailing operand 4157 /// references with a specified surrounding delimiter, and an optional 4158 /// required operand count. A leading comma is expected before the operands. 4159 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 4160 int requiredOperandCount, 4161 Delimiter delimiter) override { 4162 if (parser.getToken().is(Token::comma)) { 4163 parseComma(); 4164 return parseOperandList(result, requiredOperandCount, delimiter); 4165 } 4166 if (requiredOperandCount != -1) 4167 return emitError(parser.getToken().getLoc(), "expected ") 4168 << requiredOperandCount << " operands"; 4169 return success(); 4170 } 4171 4172 /// Resolve an operand to an SSA value, emitting an error on failure. 4173 ParseResult resolveOperand(const OperandType &operand, Type type, 4174 SmallVectorImpl<Value> &result) override { 4175 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4176 operand.location}; 4177 if (auto value = parser.resolveSSAUse(operandInfo, type)) { 4178 result.push_back(value); 4179 return success(); 4180 } 4181 return failure(); 4182 } 4183 4184 /// Parse an AffineMap of SSA ids. 4185 ParseResult 4186 parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 4187 Attribute &mapAttr, StringRef attrName, 4188 SmallVectorImpl<NamedAttribute> &attrs) override { 4189 SmallVector<OperandType, 2> dimOperands; 4190 SmallVector<OperandType, 1> symOperands; 4191 4192 auto parseElement = [&](bool isSymbol) -> ParseResult { 4193 OperandType operand; 4194 if (parseOperand(operand)) 4195 return failure(); 4196 if (isSymbol) 4197 symOperands.push_back(operand); 4198 else 4199 dimOperands.push_back(operand); 4200 return success(); 4201 }; 4202 4203 AffineMap map; 4204 if (parser.parseAffineMapOfSSAIds(map, parseElement)) 4205 return failure(); 4206 // Add AffineMap attribute. 4207 if (map) { 4208 mapAttr = AffineMapAttr::get(map); 4209 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 4210 } 4211 4212 // Add dim operands before symbol operands in 'operands'. 4213 operands.assign(dimOperands.begin(), dimOperands.end()); 4214 operands.append(symOperands.begin(), symOperands.end()); 4215 return success(); 4216 } 4217 4218 //===--------------------------------------------------------------------===// 4219 // Region Parsing 4220 //===--------------------------------------------------------------------===// 4221 4222 /// Parse a region that takes `arguments` of `argTypes` types. This 4223 /// effectively defines the SSA values of `arguments` and assigns their type. 4224 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 4225 ArrayRef<Type> argTypes, 4226 bool enableNameShadowing) override { 4227 assert(arguments.size() == argTypes.size() && 4228 "mismatching number of arguments and types"); 4229 4230 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 4231 regionArguments; 4232 for (auto pair : llvm::zip(arguments, argTypes)) { 4233 const OperandType &operand = std::get<0>(pair); 4234 Type type = std::get<1>(pair); 4235 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4236 operand.location}; 4237 regionArguments.emplace_back(operandInfo, type); 4238 } 4239 4240 // Try to parse the region. 4241 assert((!enableNameShadowing || 4242 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 4243 "name shadowing is only allowed on isolated regions"); 4244 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 4245 return failure(); 4246 return success(); 4247 } 4248 4249 /// Parses a region if present. 4250 ParseResult parseOptionalRegion(Region ®ion, 4251 ArrayRef<OperandType> arguments, 4252 ArrayRef<Type> argTypes, 4253 bool enableNameShadowing) override { 4254 if (parser.getToken().isNot(Token::l_brace)) 4255 return success(); 4256 return parseRegion(region, arguments, argTypes, enableNameShadowing); 4257 } 4258 4259 /// Parse a region argument. The type of the argument will be resolved later 4260 /// by a call to `parseRegion`. 4261 ParseResult parseRegionArgument(OperandType &argument) override { 4262 return parseOperand(argument); 4263 } 4264 4265 /// Parse a region argument if present. 4266 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 4267 if (parser.getToken().isNot(Token::percent_identifier)) 4268 return success(); 4269 return parseRegionArgument(argument); 4270 } 4271 4272 ParseResult 4273 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 4274 int requiredOperandCount = -1, 4275 Delimiter delimiter = Delimiter::None) override { 4276 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 4277 requiredOperandCount, delimiter); 4278 } 4279 4280 //===--------------------------------------------------------------------===// 4281 // Successor Parsing 4282 //===--------------------------------------------------------------------===// 4283 4284 /// Parse a single operation successor and its operand list. 4285 ParseResult 4286 parseSuccessorAndUseList(Block *&dest, 4287 SmallVectorImpl<Value> &operands) override { 4288 return parser.parseSuccessorAndUseList(dest, operands); 4289 } 4290 4291 //===--------------------------------------------------------------------===// 4292 // Type Parsing 4293 //===--------------------------------------------------------------------===// 4294 4295 /// Parse a type. 4296 ParseResult parseType(Type &result) override { 4297 return failure(!(result = parser.parseType())); 4298 } 4299 4300 /// Parse an optional arrow followed by a type list. 4301 ParseResult 4302 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 4303 if (!parser.consumeIf(Token::arrow)) 4304 return success(); 4305 return parser.parseFunctionResultTypes(result); 4306 } 4307 4308 /// Parse a colon followed by a type. 4309 ParseResult parseColonType(Type &result) override { 4310 return failure(parser.parseToken(Token::colon, "expected ':'") || 4311 !(result = parser.parseType())); 4312 } 4313 4314 /// Parse a colon followed by a type list, which must have at least one type. 4315 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 4316 if (parser.parseToken(Token::colon, "expected ':'")) 4317 return failure(); 4318 return parser.parseTypeListNoParens(result); 4319 } 4320 4321 /// Parse an optional colon followed by a type list, which if present must 4322 /// have at least one type. 4323 ParseResult 4324 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 4325 if (!parser.consumeIf(Token::colon)) 4326 return success(); 4327 return parser.parseTypeListNoParens(result); 4328 } 4329 4330 private: 4331 /// The source location of the operation name. 4332 SMLoc nameLoc; 4333 4334 /// The abstract information of the operation. 4335 const AbstractOperation *opDefinition; 4336 4337 /// The main operation parser. 4338 OperationParser &parser; 4339 4340 /// A flag that indicates if any errors were emitted during parsing. 4341 bool emittedError = false; 4342 }; 4343 } // end anonymous namespace. 4344 4345 Operation *OperationParser::parseCustomOperation() { 4346 auto opLoc = getToken().getLoc(); 4347 auto opName = getTokenSpelling(); 4348 4349 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 4350 if (!opDefinition && !opName.contains('.')) { 4351 // If the operation name has no namespace prefix we treat it as a standard 4352 // operation and prefix it with "std". 4353 // TODO: Would it be better to just build a mapping of the registered 4354 // operations in the standard dialect? 4355 opDefinition = 4356 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 4357 } 4358 4359 if (!opDefinition) { 4360 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 4361 return nullptr; 4362 } 4363 4364 consumeToken(); 4365 4366 // If the custom op parser crashes, produce some indication to help 4367 // debugging. 4368 std::string opNameStr = opName.str(); 4369 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 4370 opNameStr.c_str()); 4371 4372 // Get location information for the operation. 4373 auto srcLocation = getEncodedSourceLocation(opLoc); 4374 4375 // Have the op implementation take a crack and parsing this. 4376 OperationState opState(srcLocation, opDefinition->name); 4377 CleanupOpStateRegions guard{opState}; 4378 CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); 4379 if (opAsmParser.parseOperation(opState)) 4380 return nullptr; 4381 4382 // If it emitted an error, we failed. 4383 if (opAsmParser.didEmitError()) 4384 return nullptr; 4385 4386 // Parse a location if one is present. 4387 if (parseOptionalTrailingLocation(opState.location)) 4388 return nullptr; 4389 4390 // Otherwise, we succeeded. Use the state it parsed as our op information. 4391 return opBuilder.createOperation(opState); 4392 } 4393 4394 //===----------------------------------------------------------------------===// 4395 // Region Parsing 4396 //===----------------------------------------------------------------------===// 4397 4398 /// Region. 4399 /// 4400 /// region ::= '{' region-body 4401 /// 4402 ParseResult OperationParser::parseRegion( 4403 Region ®ion, 4404 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 4405 bool isIsolatedNameScope) { 4406 // Parse the '{'. 4407 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 4408 return failure(); 4409 4410 // Check for an empty region. 4411 if (entryArguments.empty() && consumeIf(Token::r_brace)) 4412 return success(); 4413 auto currentPt = opBuilder.saveInsertionPoint(); 4414 4415 // Push a new named value scope. 4416 pushSSANameScope(isIsolatedNameScope); 4417 4418 // Parse the first block directly to allow for it to be unnamed. 4419 Block *block = new Block(); 4420 4421 // Add arguments to the entry block. 4422 if (!entryArguments.empty()) { 4423 for (auto &placeholderArgPair : entryArguments) { 4424 auto &argInfo = placeholderArgPair.first; 4425 // Ensure that the argument was not already defined. 4426 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 4427 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 4428 "' is already in use") 4429 .attachNote(getEncodedSourceLocation(*defLoc)) 4430 << "previously referenced here"; 4431 } 4432 if (addDefinition(placeholderArgPair.first, 4433 block->addArgument(placeholderArgPair.second))) { 4434 delete block; 4435 return failure(); 4436 } 4437 } 4438 4439 // If we had named arguments, then don't allow a block name. 4440 if (getToken().is(Token::caret_identifier)) 4441 return emitError("invalid block name in region with named arguments"); 4442 } 4443 4444 if (parseBlock(block)) { 4445 delete block; 4446 return failure(); 4447 } 4448 4449 // Verify that no other arguments were parsed. 4450 if (!entryArguments.empty() && 4451 block->getNumArguments() > entryArguments.size()) { 4452 delete block; 4453 return emitError("entry block arguments were already defined"); 4454 } 4455 4456 // Parse the rest of the region. 4457 region.push_back(block); 4458 if (parseRegionBody(region)) 4459 return failure(); 4460 4461 // Pop the SSA value scope for this region. 4462 if (popSSANameScope()) 4463 return failure(); 4464 4465 // Reset the original insertion point. 4466 opBuilder.restoreInsertionPoint(currentPt); 4467 return success(); 4468 } 4469 4470 /// Region. 4471 /// 4472 /// region-body ::= block* '}' 4473 /// 4474 ParseResult OperationParser::parseRegionBody(Region ®ion) { 4475 // Parse the list of blocks. 4476 while (!consumeIf(Token::r_brace)) { 4477 Block *newBlock = nullptr; 4478 if (parseBlock(newBlock)) 4479 return failure(); 4480 region.push_back(newBlock); 4481 } 4482 return success(); 4483 } 4484 4485 //===----------------------------------------------------------------------===// 4486 // Block Parsing 4487 //===----------------------------------------------------------------------===// 4488 4489 /// Block declaration. 4490 /// 4491 /// block ::= block-label? operation* 4492 /// block-label ::= block-id block-arg-list? `:` 4493 /// block-id ::= caret-id 4494 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 4495 /// 4496 ParseResult OperationParser::parseBlock(Block *&block) { 4497 // The first block of a region may already exist, if it does the caret 4498 // identifier is optional. 4499 if (block && getToken().isNot(Token::caret_identifier)) 4500 return parseBlockBody(block); 4501 4502 SMLoc nameLoc = getToken().getLoc(); 4503 auto name = getTokenSpelling(); 4504 if (parseToken(Token::caret_identifier, "expected block name")) 4505 return failure(); 4506 4507 block = defineBlockNamed(name, nameLoc, block); 4508 4509 // Fail if the block was already defined. 4510 if (!block) 4511 return emitError(nameLoc, "redefinition of block '") << name << "'"; 4512 4513 // If an argument list is present, parse it. 4514 if (consumeIf(Token::l_paren)) { 4515 SmallVector<BlockArgument, 8> bbArgs; 4516 if (parseOptionalBlockArgList(bbArgs, block) || 4517 parseToken(Token::r_paren, "expected ')' to end argument list")) 4518 return failure(); 4519 } 4520 4521 if (parseToken(Token::colon, "expected ':' after block name")) 4522 return failure(); 4523 4524 return parseBlockBody(block); 4525 } 4526 4527 ParseResult OperationParser::parseBlockBody(Block *block) { 4528 // Set the insertion point to the end of the block to parse. 4529 opBuilder.setInsertionPointToEnd(block); 4530 4531 // Parse the list of operations that make up the body of the block. 4532 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 4533 if (parseOperation()) 4534 return failure(); 4535 4536 return success(); 4537 } 4538 4539 /// Get the block with the specified name, creating it if it doesn't already 4540 /// exist. The location specified is the point of use, which allows 4541 /// us to diagnose references to blocks that are not defined precisely. 4542 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 4543 auto &blockAndLoc = getBlockInfoByName(name); 4544 if (!blockAndLoc.first) { 4545 blockAndLoc = {new Block(), loc}; 4546 insertForwardRef(blockAndLoc.first, loc); 4547 } 4548 4549 return blockAndLoc.first; 4550 } 4551 4552 /// Define the block with the specified name. Returns the Block* or nullptr in 4553 /// the case of redefinition. 4554 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 4555 Block *existing) { 4556 auto &blockAndLoc = getBlockInfoByName(name); 4557 if (!blockAndLoc.first) { 4558 // If the caller provided a block, use it. Otherwise create a new one. 4559 if (!existing) 4560 existing = new Block(); 4561 blockAndLoc.first = existing; 4562 blockAndLoc.second = loc; 4563 return blockAndLoc.first; 4564 } 4565 4566 // Forward declarations are removed once defined, so if we are defining a 4567 // existing block and it is not a forward declaration, then it is a 4568 // redeclaration. 4569 if (!eraseForwardRef(blockAndLoc.first)) 4570 return nullptr; 4571 return blockAndLoc.first; 4572 } 4573 4574 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 4575 /// 4576 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 4577 /// 4578 ParseResult OperationParser::parseOptionalBlockArgList( 4579 SmallVectorImpl<BlockArgument> &results, Block *owner) { 4580 if (getToken().is(Token::r_brace)) 4581 return success(); 4582 4583 // If the block already has arguments, then we're handling the entry block. 4584 // Parse and register the names for the arguments, but do not add them. 4585 bool definingExistingArgs = owner->getNumArguments() != 0; 4586 unsigned nextArgument = 0; 4587 4588 return parseCommaSeparatedList([&]() -> ParseResult { 4589 return parseSSADefOrUseAndType( 4590 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 4591 // If this block did not have existing arguments, define a new one. 4592 if (!definingExistingArgs) 4593 return addDefinition(useInfo, owner->addArgument(type)); 4594 4595 // Otherwise, ensure that this argument has already been created. 4596 if (nextArgument >= owner->getNumArguments()) 4597 return emitError("too many arguments specified in argument list"); 4598 4599 // Finally, make sure the existing argument has the correct type. 4600 auto arg = owner->getArgument(nextArgument++); 4601 if (arg.getType() != type) 4602 return emitError("argument and block argument type mismatch"); 4603 return addDefinition(useInfo, arg); 4604 }); 4605 }); 4606 } 4607 4608 //===----------------------------------------------------------------------===// 4609 // Top-level entity parsing. 4610 //===----------------------------------------------------------------------===// 4611 4612 namespace { 4613 /// This parser handles entities that are only valid at the top level of the 4614 /// file. 4615 class ModuleParser : public Parser { 4616 public: 4617 explicit ModuleParser(ParserState &state) : Parser(state) {} 4618 4619 ParseResult parseModule(ModuleOp module); 4620 4621 private: 4622 /// Parse an attribute alias declaration. 4623 ParseResult parseAttributeAliasDef(); 4624 4625 /// Parse an attribute alias declaration. 4626 ParseResult parseTypeAliasDef(); 4627 }; 4628 } // end anonymous namespace 4629 4630 /// Parses an attribute alias declaration. 4631 /// 4632 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 4633 /// 4634 ParseResult ModuleParser::parseAttributeAliasDef() { 4635 assert(getToken().is(Token::hash_identifier)); 4636 StringRef aliasName = getTokenSpelling().drop_front(); 4637 4638 // Check for redefinitions. 4639 if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) 4640 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 4641 4642 // Make sure this isn't invading the dialect attribute namespace. 4643 if (aliasName.contains('.')) 4644 return emitError("attribute names with a '.' are reserved for " 4645 "dialect-defined names"); 4646 4647 consumeToken(Token::hash_identifier); 4648 4649 // Parse the '='. 4650 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 4651 return failure(); 4652 4653 // Parse the attribute value. 4654 Attribute attr = parseAttribute(); 4655 if (!attr) 4656 return failure(); 4657 4658 getState().symbols.attributeAliasDefinitions[aliasName] = attr; 4659 return success(); 4660 } 4661 4662 /// Parse a type alias declaration. 4663 /// 4664 /// type-alias-def ::= '!' alias-name `=` 'type' type 4665 /// 4666 ParseResult ModuleParser::parseTypeAliasDef() { 4667 assert(getToken().is(Token::exclamation_identifier)); 4668 StringRef aliasName = getTokenSpelling().drop_front(); 4669 4670 // Check for redefinitions. 4671 if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) 4672 return emitError("redefinition of type alias id '" + aliasName + "'"); 4673 4674 // Make sure this isn't invading the dialect type namespace. 4675 if (aliasName.contains('.')) 4676 return emitError("type names with a '.' are reserved for " 4677 "dialect-defined names"); 4678 4679 consumeToken(Token::exclamation_identifier); 4680 4681 // Parse the '=' and 'type'. 4682 if (parseToken(Token::equal, "expected '=' in type alias definition") || 4683 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 4684 return failure(); 4685 4686 // Parse the type. 4687 Type aliasedType = parseType(); 4688 if (!aliasedType) 4689 return failure(); 4690 4691 // Register this alias with the parser state. 4692 getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); 4693 return success(); 4694 } 4695 4696 /// This is the top-level module parser. 4697 ParseResult ModuleParser::parseModule(ModuleOp module) { 4698 OperationParser opParser(getState(), module); 4699 4700 // Module itself is a name scope. 4701 opParser.pushSSANameScope(/*isIsolated=*/true); 4702 4703 while (true) { 4704 switch (getToken().getKind()) { 4705 default: 4706 // Parse a top-level operation. 4707 if (opParser.parseOperation()) 4708 return failure(); 4709 break; 4710 4711 // If we got to the end of the file, then we're done. 4712 case Token::eof: { 4713 if (opParser.finalize()) 4714 return failure(); 4715 4716 // Handle the case where the top level module was explicitly defined. 4717 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 4718 auto &operations = bodyBlocks.front().getOperations(); 4719 assert(!operations.empty() && "expected a valid module terminator"); 4720 4721 // Check that the first operation is a module, and it is the only 4722 // non-terminator operation. 4723 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 4724 if (nested && std::next(operations.begin(), 2) == operations.end()) { 4725 // Merge the data of the nested module operation into 'module'. 4726 module.setLoc(nested.getLoc()); 4727 module.setAttrs(nested.getOperation()->getAttrList()); 4728 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 4729 4730 // Erase the original module body. 4731 bodyBlocks.pop_front(); 4732 } 4733 4734 return opParser.popSSANameScope(); 4735 } 4736 4737 // If we got an error token, then the lexer already emitted an error, just 4738 // stop. Someday we could introduce error recovery if there was demand 4739 // for it. 4740 case Token::error: 4741 return failure(); 4742 4743 // Parse an attribute alias. 4744 case Token::hash_identifier: 4745 if (parseAttributeAliasDef()) 4746 return failure(); 4747 break; 4748 4749 // Parse a type alias. 4750 case Token::exclamation_identifier: 4751 if (parseTypeAliasDef()) 4752 return failure(); 4753 break; 4754 } 4755 } 4756 } 4757 4758 //===----------------------------------------------------------------------===// 4759 4760 /// This parses the file specified by the indicated SourceMgr and returns an 4761 /// MLIR module if it was valid. If not, it emits diagnostics and returns 4762 /// null. 4763 OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 4764 MLIRContext *context) { 4765 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 4766 4767 // This is the result module we are parsing into. 4768 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 4769 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 4770 4771 SymbolState aliasState; 4772 ParserState state(sourceMgr, context, aliasState); 4773 if (ModuleParser(state).parseModule(*module)) 4774 return nullptr; 4775 4776 // Make sure the parse module has no other structural problems detected by 4777 // the verifier. 4778 if (failed(verify(*module))) 4779 return nullptr; 4780 4781 return module; 4782 } 4783 4784 /// This parses the file specified by the indicated filename and returns an 4785 /// MLIR module if it was valid. If not, the error message is emitted through 4786 /// the error handler registered in the context, and a null pointer is returned. 4787 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4788 MLIRContext *context) { 4789 llvm::SourceMgr sourceMgr; 4790 return parseSourceFile(filename, sourceMgr, context); 4791 } 4792 4793 /// This parses the file specified by the indicated filename using the provided 4794 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 4795 /// message is emitted through the error handler registered in the context, and 4796 /// a null pointer is returned. 4797 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4798 llvm::SourceMgr &sourceMgr, 4799 MLIRContext *context) { 4800 if (sourceMgr.getNumBuffers() != 0) { 4801 // TODO(b/136086478): Extend to support multiple buffers. 4802 emitError(mlir::UnknownLoc::get(context), 4803 "only main buffer parsed at the moment"); 4804 return nullptr; 4805 } 4806 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 4807 if (std::error_code error = file_or_err.getError()) { 4808 emitError(mlir::UnknownLoc::get(context), 4809 "could not open input file " + filename); 4810 return nullptr; 4811 } 4812 4813 // Load the MLIR module. 4814 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 4815 return parseSourceFile(sourceMgr, context); 4816 } 4817 4818 /// This parses the program string to a MLIR module if it was valid. If not, 4819 /// it emits diagnostics and returns null. 4820 OwningModuleRef mlir::parseSourceString(StringRef moduleStr, 4821 MLIRContext *context) { 4822 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 4823 if (!memBuffer) 4824 return nullptr; 4825 4826 SourceMgr sourceMgr; 4827 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 4828 return parseSourceFile(sourceMgr, context); 4829 } 4830 4831 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 4832 /// parsing failed, nullptr is returned. The number of bytes read from the input 4833 /// string is returned in 'numRead'. 4834 template <typename T, typename ParserFn> 4835 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, 4836 ParserFn &&parserFn) { 4837 SymbolState aliasState; 4838 return parseSymbol<T>( 4839 inputStr, context, aliasState, 4840 [&](Parser &parser) { 4841 SourceMgrDiagnosticHandler handler( 4842 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), 4843 parser.getContext()); 4844 return parserFn(parser); 4845 }, 4846 &numRead); 4847 } 4848 4849 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { 4850 size_t numRead = 0; 4851 return parseAttribute(attrStr, context, numRead); 4852 } 4853 Attribute mlir::parseAttribute(StringRef attrStr, Type type) { 4854 size_t numRead = 0; 4855 return parseAttribute(attrStr, type, numRead); 4856 } 4857 4858 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, 4859 size_t &numRead) { 4860 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { 4861 return parser.parseAttribute(); 4862 }); 4863 } 4864 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { 4865 return parseSymbol<Attribute>( 4866 attrStr, type.getContext(), numRead, 4867 [type](Parser &parser) { return parser.parseAttribute(type); }); 4868 } 4869 4870 Type mlir::parseType(StringRef typeStr, MLIRContext *context) { 4871 size_t numRead = 0; 4872 return parseType(typeStr, context, numRead); 4873 } 4874 4875 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { 4876 return parseSymbol<Type>(typeStr, context, numRead, 4877 [](Parser &parser) { return parser.parseType(); }); 4878 } 4879