1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the parser for the MLIR textual form. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Parser.h" 23 #include "Lexer.h" 24 #include "mlir/Analysis/Verifier.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/IR/Attributes.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/IR/Dialect.h" 30 #include "mlir/IR/DialectImplementation.h" 31 #include "mlir/IR/IntegerSet.h" 32 #include "mlir/IR/Location.h" 33 #include "mlir/IR/MLIRContext.h" 34 #include "mlir/IR/Module.h" 35 #include "mlir/IR/OpImplementation.h" 36 #include "mlir/IR/StandardTypes.h" 37 #include "mlir/Support/STLExtras.h" 38 #include "llvm/ADT/APInt.h" 39 #include "llvm/ADT/DenseMap.h" 40 #include "llvm/ADT/StringSet.h" 41 #include "llvm/ADT/bit.h" 42 #include "llvm/Support/PrettyStackTrace.h" 43 #include "llvm/Support/SMLoc.h" 44 #include "llvm/Support/SourceMgr.h" 45 #include <algorithm> 46 using namespace mlir; 47 using llvm::MemoryBuffer; 48 using llvm::SMLoc; 49 using llvm::SourceMgr; 50 51 namespace { 52 class Parser; 53 54 //===----------------------------------------------------------------------===// 55 // SymbolState 56 //===----------------------------------------------------------------------===// 57 58 /// This class contains record of any parsed top-level symbols. 59 struct SymbolState { 60 // A map from attribute alias identifier to Attribute. 61 llvm::StringMap<Attribute> attributeAliasDefinitions; 62 63 // A map from type alias identifier to Type. 64 llvm::StringMap<Type> typeAliasDefinitions; 65 66 /// A set of locations into the main parser memory buffer for each of the 67 /// active nested parsers. Given that some nested parsers, i.e. custom dialect 68 /// parsers, operate on a temporary memory buffer, this provides an anchor 69 /// point for emitting diagnostics. 70 SmallVector<llvm::SMLoc, 1> nestedParserLocs; 71 72 /// The top-level lexer that contains the original memory buffer provided by 73 /// the user. This is used by nested parsers to get a properly encoded source 74 /// location. 75 Lexer *topLevelLexer = nullptr; 76 }; 77 78 //===----------------------------------------------------------------------===// 79 // ParserState 80 //===----------------------------------------------------------------------===// 81 82 /// This class refers to all of the state maintained globally by the parser, 83 /// such as the current lexer position etc. 84 struct ParserState { 85 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, 86 SymbolState &symbols) 87 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), 88 symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { 89 // Set the top level lexer for the symbol state if one doesn't exist. 90 if (!symbols.topLevelLexer) 91 symbols.topLevelLexer = &lex; 92 } 93 ~ParserState() { 94 // Reset the top level lexer if it refers the lexer in our state. 95 if (symbols.topLevelLexer == &lex) 96 symbols.topLevelLexer = nullptr; 97 } 98 ParserState(const ParserState &) = delete; 99 void operator=(const ParserState &) = delete; 100 101 /// The context we're parsing into. 102 MLIRContext *const context; 103 104 /// The lexer for the source file we're parsing. 105 Lexer lex; 106 107 /// This is the next token that hasn't been consumed yet. 108 Token curToken; 109 110 /// The current state for symbol parsing. 111 SymbolState &symbols; 112 113 /// The depth of this parser in the nested parsing stack. 114 size_t parserDepth; 115 }; 116 117 //===----------------------------------------------------------------------===// 118 // Parser 119 //===----------------------------------------------------------------------===// 120 121 /// This class implement support for parsing global entities like types and 122 /// shared entities like SSA names. It is intended to be subclassed by 123 /// specialized subparsers that include state, e.g. when a local symbol table. 124 class Parser { 125 public: 126 Builder builder; 127 128 Parser(ParserState &state) : builder(state.context), state(state) {} 129 130 // Helper methods to get stuff from the parser-global state. 131 ParserState &getState() const { return state; } 132 MLIRContext *getContext() const { return state.context; } 133 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 134 135 /// Parse a comma-separated list of elements up until the specified end token. 136 ParseResult 137 parseCommaSeparatedListUntil(Token::Kind rightToken, 138 const std::function<ParseResult()> &parseElement, 139 bool allowEmptyList = true); 140 141 /// Parse a comma separated list of elements that must have at least one entry 142 /// in it. 143 ParseResult 144 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 145 146 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 147 148 // We have two forms of parsing methods - those that return a non-null 149 // pointer on success, and those that return a ParseResult to indicate whether 150 // they returned a failure. The second class fills in by-reference arguments 151 // as the results of their action. 152 153 //===--------------------------------------------------------------------===// 154 // Error Handling 155 //===--------------------------------------------------------------------===// 156 157 /// Emit an error and return failure. 158 InFlightDiagnostic emitError(const Twine &message = {}) { 159 return emitError(state.curToken.getLoc(), message); 160 } 161 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 162 163 /// Encode the specified source location information into an attribute for 164 /// attachment to the IR. 165 Location getEncodedSourceLocation(llvm::SMLoc loc) { 166 // If there are no active nested parsers, we can get the encoded source 167 // location directly. 168 if (state.parserDepth == 0) 169 return state.lex.getEncodedSourceLocation(loc); 170 // Otherwise, we need to re-encode it to point to the top level buffer. 171 return state.symbols.topLevelLexer->getEncodedSourceLocation( 172 remapLocationToTopLevelBuffer(loc)); 173 } 174 175 /// Remaps the given SMLoc to the top level lexer of the parser. This is used 176 /// to adjust locations of potentially nested parsers to ensure that they can 177 /// be emitted properly as diagnostics. 178 llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { 179 // If there are no active nested parsers, we can return location directly. 180 SymbolState &symbols = state.symbols; 181 if (state.parserDepth == 0) 182 return loc; 183 assert(symbols.topLevelLexer && "expected valid top-level lexer"); 184 185 // Otherwise, we need to remap the location to the main parser. This is 186 // simply offseting the location onto the location of the last nested 187 // parser. 188 size_t offset = loc.getPointer() - state.lex.getBufferBegin(); 189 auto *rawLoc = 190 symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; 191 return llvm::SMLoc::getFromPointer(rawLoc); 192 } 193 194 //===--------------------------------------------------------------------===// 195 // Token Parsing 196 //===--------------------------------------------------------------------===// 197 198 /// Return the current token the parser is inspecting. 199 const Token &getToken() const { return state.curToken; } 200 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 201 202 /// If the current token has the specified kind, consume it and return true. 203 /// If not, return false. 204 bool consumeIf(Token::Kind kind) { 205 if (state.curToken.isNot(kind)) 206 return false; 207 consumeToken(kind); 208 return true; 209 } 210 211 /// Advance the current lexer onto the next token. 212 void consumeToken() { 213 assert(state.curToken.isNot(Token::eof, Token::error) && 214 "shouldn't advance past EOF or errors"); 215 state.curToken = state.lex.lexToken(); 216 } 217 218 /// Advance the current lexer onto the next token, asserting what the expected 219 /// current token is. This is preferred to the above method because it leads 220 /// to more self-documenting code with better checking. 221 void consumeToken(Token::Kind kind) { 222 assert(state.curToken.is(kind) && "consumed an unexpected token"); 223 consumeToken(); 224 } 225 226 /// Consume the specified token if present and return success. On failure, 227 /// output a diagnostic and return failure. 228 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 229 230 //===--------------------------------------------------------------------===// 231 // Type Parsing 232 //===--------------------------------------------------------------------===// 233 234 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 235 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 236 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 237 238 /// Parse an arbitrary type. 239 Type parseType(); 240 241 /// Parse a complex type. 242 Type parseComplexType(); 243 244 /// Parse an extended type. 245 Type parseExtendedType(); 246 247 /// Parse a function type. 248 Type parseFunctionType(); 249 250 /// Parse a memref type. 251 Type parseMemRefType(); 252 253 /// Parse a non function type. 254 Type parseNonFunctionType(); 255 256 /// Parse a tensor type. 257 Type parseTensorType(); 258 259 /// Parse a tuple type. 260 Type parseTupleType(); 261 262 /// Parse a vector type. 263 VectorType parseVectorType(); 264 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 265 bool allowDynamic = true); 266 ParseResult parseXInDimensionList(); 267 268 /// Parse strided layout specification. 269 ParseResult parseStridedLayout(int64_t &offset, 270 SmallVectorImpl<int64_t> &strides); 271 272 // Parse a brace-delimiter list of comma-separated integers with `?` as an 273 // unknown marker. 274 ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions); 275 276 //===--------------------------------------------------------------------===// 277 // Attribute Parsing 278 //===--------------------------------------------------------------------===// 279 280 /// Parse an arbitrary attribute with an optional type. 281 Attribute parseAttribute(Type type = {}); 282 283 /// Parse an attribute dictionary. 284 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 285 286 /// Parse an extended attribute. 287 Attribute parseExtendedAttr(Type type); 288 289 /// Parse a float attribute. 290 Attribute parseFloatAttr(Type type, bool isNegative); 291 292 /// Parse a decimal or a hexadecimal literal, which can be either an integer 293 /// or a float attribute. 294 Attribute parseDecOrHexAttr(Type type, bool isNegative); 295 296 /// Parse an opaque elements attribute. 297 Attribute parseOpaqueElementsAttr(); 298 299 /// Parse a dense elements attribute. 300 Attribute parseDenseElementsAttr(); 301 ShapedType parseElementsLiteralType(); 302 303 /// Parse a sparse elements attribute. 304 Attribute parseSparseElementsAttr(); 305 306 //===--------------------------------------------------------------------===// 307 // Location Parsing 308 //===--------------------------------------------------------------------===// 309 310 /// Parse an inline location. 311 ParseResult parseLocation(LocationAttr &loc); 312 313 /// Parse a raw location instance. 314 ParseResult parseLocationInstance(LocationAttr &loc); 315 316 /// Parse a callsite location instance. 317 ParseResult parseCallSiteLocation(LocationAttr &loc); 318 319 /// Parse a fused location instance. 320 ParseResult parseFusedLocation(LocationAttr &loc); 321 322 /// Parse a name or FileLineCol location instance. 323 ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); 324 325 /// Parse an optional trailing location. 326 /// 327 /// trailing-location ::= (`loc` `(` location `)`)? 328 /// 329 ParseResult parseOptionalTrailingLocation(Location &loc) { 330 // If there is a 'loc' we parse a trailing location. 331 if (!getToken().is(Token::kw_loc)) 332 return success(); 333 334 // Parse the location. 335 LocationAttr directLoc; 336 if (parseLocation(directLoc)) 337 return failure(); 338 loc = directLoc; 339 return success(); 340 } 341 342 //===--------------------------------------------------------------------===// 343 // Affine Parsing 344 //===--------------------------------------------------------------------===// 345 346 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 347 IntegerSet &set); 348 349 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 350 ParseResult 351 parseAffineMapOfSSAIds(AffineMap &map, 352 function_ref<ParseResult(bool)> parseElement); 353 354 private: 355 /// The Parser is subclassed and reinstantiated. Do not add additional 356 /// non-trivial state here, add it to the ParserState class. 357 ParserState &state; 358 }; 359 } // end anonymous namespace 360 361 //===----------------------------------------------------------------------===// 362 // Helper methods. 363 //===----------------------------------------------------------------------===// 364 365 /// Parse a comma separated list of elements that must have at least one entry 366 /// in it. 367 ParseResult Parser::parseCommaSeparatedList( 368 const std::function<ParseResult()> &parseElement) { 369 // Non-empty case starts with an element. 370 if (parseElement()) 371 return failure(); 372 373 // Otherwise we have a list of comma separated elements. 374 while (consumeIf(Token::comma)) { 375 if (parseElement()) 376 return failure(); 377 } 378 return success(); 379 } 380 381 /// Parse a comma-separated list of elements, terminated with an arbitrary 382 /// token. This allows empty lists if allowEmptyList is true. 383 /// 384 /// abstract-list ::= rightToken // if allowEmptyList == true 385 /// abstract-list ::= element (',' element)* rightToken 386 /// 387 ParseResult Parser::parseCommaSeparatedListUntil( 388 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 389 bool allowEmptyList) { 390 // Handle the empty case. 391 if (getToken().is(rightToken)) { 392 if (!allowEmptyList) 393 return emitError("expected list element"); 394 consumeToken(rightToken); 395 return success(); 396 } 397 398 if (parseCommaSeparatedList(parseElement) || 399 parseToken(rightToken, "expected ',' or '" + 400 Token::getTokenSpelling(rightToken) + "'")) 401 return failure(); 402 403 return success(); 404 } 405 406 //===----------------------------------------------------------------------===// 407 // DialectAsmParser 408 //===----------------------------------------------------------------------===// 409 410 namespace { 411 /// This class provides the main implementation of the DialectAsmParser that 412 /// allows for dialects to parse attributes and types. This allows for dialect 413 /// hooking into the main MLIR parsing logic. 414 class CustomDialectAsmParser : public DialectAsmParser { 415 public: 416 CustomDialectAsmParser(StringRef fullSpec, Parser &parser) 417 : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), 418 parser(parser) {} 419 ~CustomDialectAsmParser() override {} 420 421 /// Emit a diagnostic at the specified location and return failure. 422 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 423 return parser.emitError(loc, message); 424 } 425 426 /// Return a builder which provides useful access to MLIRContext, global 427 /// objects like types and attributes. 428 Builder &getBuilder() const override { return parser.builder; } 429 430 /// Get the location of the next token and store it into the argument. This 431 /// always succeeds. 432 llvm::SMLoc getCurrentLocation() override { 433 return parser.getToken().getLoc(); 434 } 435 436 /// Return the location of the original name token. 437 llvm::SMLoc getNameLoc() const override { return nameLoc; } 438 439 /// Re-encode the given source location as an MLIR location and return it. 440 Location getEncodedSourceLoc(llvm::SMLoc loc) override { 441 return parser.getEncodedSourceLocation(loc); 442 } 443 444 /// Returns the full specification of the symbol being parsed. This allows 445 /// for using a separate parser if necessary. 446 StringRef getFullSymbolSpec() const override { return fullSpec; } 447 448 /// Parse a floating point value from the stream. 449 ParseResult parseFloat(double &result) override { 450 bool negative = parser.consumeIf(Token::minus); 451 Token curTok = parser.getToken(); 452 453 // Check for a floating point value. 454 if (curTok.is(Token::floatliteral)) { 455 auto val = curTok.getFloatingPointValue(); 456 if (!val.hasValue()) 457 return emitError(curTok.getLoc(), "floating point value too large"); 458 parser.consumeToken(Token::floatliteral); 459 result = negative ? -*val : *val; 460 return success(); 461 } 462 463 // TODO(riverriddle) support hex floating point values. 464 return emitError(getCurrentLocation(), "expected floating point literal"); 465 } 466 467 /// Parse an optional integer value from the stream. 468 OptionalParseResult parseOptionalInteger(uint64_t &result) override { 469 Token curToken = parser.getToken(); 470 if (curToken.isNot(Token::integer, Token::minus)) 471 return llvm::None; 472 473 bool negative = parser.consumeIf(Token::minus); 474 Token curTok = parser.getToken(); 475 if (parser.parseToken(Token::integer, "expected integer value")) 476 return failure(); 477 478 auto val = curTok.getUInt64IntegerValue(); 479 if (!val) 480 return emitError(curTok.getLoc(), "integer value too large"); 481 result = negative ? -*val : *val; 482 return success(); 483 } 484 485 //===--------------------------------------------------------------------===// 486 // Token Parsing 487 //===--------------------------------------------------------------------===// 488 489 /// Parse a `->` token. 490 ParseResult parseArrow() override { 491 return parser.parseToken(Token::arrow, "expected '->'"); 492 } 493 494 /// Parses a `->` if present. 495 ParseResult parseOptionalArrow() override { 496 return success(parser.consumeIf(Token::arrow)); 497 } 498 499 /// Parse a '{' token. 500 ParseResult parseLBrace() override { 501 return parser.parseToken(Token::l_brace, "expected '{'"); 502 } 503 504 /// Parse a '{' token if present 505 ParseResult parseOptionalLBrace() override { 506 return success(parser.consumeIf(Token::l_brace)); 507 } 508 509 /// Parse a `}` token. 510 ParseResult parseRBrace() override { 511 return parser.parseToken(Token::r_brace, "expected '}'"); 512 } 513 514 /// Parse a `}` token if present 515 ParseResult parseOptionalRBrace() override { 516 return success(parser.consumeIf(Token::r_brace)); 517 } 518 519 /// Parse a `:` token. 520 ParseResult parseColon() override { 521 return parser.parseToken(Token::colon, "expected ':'"); 522 } 523 524 /// Parse a `:` token if present. 525 ParseResult parseOptionalColon() override { 526 return success(parser.consumeIf(Token::colon)); 527 } 528 529 /// Parse a `,` token. 530 ParseResult parseComma() override { 531 return parser.parseToken(Token::comma, "expected ','"); 532 } 533 534 /// Parse a `,` token if present. 535 ParseResult parseOptionalComma() override { 536 return success(parser.consumeIf(Token::comma)); 537 } 538 539 /// Parses a `...` if present. 540 ParseResult parseOptionalEllipsis() override { 541 return success(parser.consumeIf(Token::ellipsis)); 542 } 543 544 /// Parse a `=` token. 545 ParseResult parseEqual() override { 546 return parser.parseToken(Token::equal, "expected '='"); 547 } 548 549 /// Parse a '<' token. 550 ParseResult parseLess() override { 551 return parser.parseToken(Token::less, "expected '<'"); 552 } 553 554 /// Parse a `<` token if present. 555 ParseResult parseOptionalLess() override { 556 return success(parser.consumeIf(Token::less)); 557 } 558 559 /// Parse a '>' token. 560 ParseResult parseGreater() override { 561 return parser.parseToken(Token::greater, "expected '>'"); 562 } 563 564 /// Parse a `>` token if present. 565 ParseResult parseOptionalGreater() override { 566 return success(parser.consumeIf(Token::greater)); 567 } 568 569 /// Parse a `(` token. 570 ParseResult parseLParen() override { 571 return parser.parseToken(Token::l_paren, "expected '('"); 572 } 573 574 /// Parses a '(' if present. 575 ParseResult parseOptionalLParen() override { 576 return success(parser.consumeIf(Token::l_paren)); 577 } 578 579 /// Parse a `)` token. 580 ParseResult parseRParen() override { 581 return parser.parseToken(Token::r_paren, "expected ')'"); 582 } 583 584 /// Parses a ')' if present. 585 ParseResult parseOptionalRParen() override { 586 return success(parser.consumeIf(Token::r_paren)); 587 } 588 589 /// Parse a `[` token. 590 ParseResult parseLSquare() override { 591 return parser.parseToken(Token::l_square, "expected '['"); 592 } 593 594 /// Parses a '[' if present. 595 ParseResult parseOptionalLSquare() override { 596 return success(parser.consumeIf(Token::l_square)); 597 } 598 599 /// Parse a `]` token. 600 ParseResult parseRSquare() override { 601 return parser.parseToken(Token::r_square, "expected ']'"); 602 } 603 604 /// Parses a ']' if present. 605 ParseResult parseOptionalRSquare() override { 606 return success(parser.consumeIf(Token::r_square)); 607 } 608 609 /// Parses a '?' if present. 610 ParseResult parseOptionalQuestion() override { 611 return success(parser.consumeIf(Token::question)); 612 } 613 614 /// Parses a '*' if present. 615 ParseResult parseOptionalStar() override { 616 return success(parser.consumeIf(Token::star)); 617 } 618 619 /// Returns if the current token corresponds to a keyword. 620 bool isCurrentTokenAKeyword() const { 621 return parser.getToken().is(Token::bare_identifier) || 622 parser.getToken().isKeyword(); 623 } 624 625 /// Parse the given keyword if present. 626 ParseResult parseOptionalKeyword(StringRef keyword) override { 627 // Check that the current token has the same spelling. 628 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 629 return failure(); 630 parser.consumeToken(); 631 return success(); 632 } 633 634 /// Parse a keyword, if present, into 'keyword'. 635 ParseResult parseOptionalKeyword(StringRef *keyword) override { 636 // Check that the current token is a keyword. 637 if (!isCurrentTokenAKeyword()) 638 return failure(); 639 640 *keyword = parser.getTokenSpelling(); 641 parser.consumeToken(); 642 return success(); 643 } 644 645 //===--------------------------------------------------------------------===// 646 // Attribute Parsing 647 //===--------------------------------------------------------------------===// 648 649 /// Parse an arbitrary attribute and return it in result. 650 ParseResult parseAttribute(Attribute &result, Type type) override { 651 result = parser.parseAttribute(type); 652 return success(static_cast<bool>(result)); 653 } 654 655 //===--------------------------------------------------------------------===// 656 // Type Parsing 657 //===--------------------------------------------------------------------===// 658 659 ParseResult parseType(Type &result) override { 660 result = parser.parseType(); 661 return success(static_cast<bool>(result)); 662 } 663 664 ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, 665 bool allowDynamic) override { 666 return parser.parseDimensionListRanked(dimensions, allowDynamic); 667 } 668 669 private: 670 /// The full symbol specification. 671 StringRef fullSpec; 672 673 /// The source location of the dialect symbol. 674 SMLoc nameLoc; 675 676 /// The main parser. 677 Parser &parser; 678 }; 679 } // namespace 680 681 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 682 /// and may be recursive. Return with the 'prettyName' StringRef encompassing 683 /// the entire pretty name. 684 /// 685 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 686 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 687 /// | '(' pretty-dialect-sym-contents+ ')' 688 /// | '[' pretty-dialect-sym-contents+ ']' 689 /// | '{' pretty-dialect-sym-contents+ '}' 690 /// | '[^[<({>\])}\0]+' 691 /// 692 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 693 // Pretty symbol names are a relatively unstructured format that contains a 694 // series of properly nested punctuation, with anything else in the middle. 695 // Scan ahead to find it and consume it if successful, otherwise emit an 696 // error. 697 auto *curPtr = getTokenSpelling().data(); 698 699 SmallVector<char, 8> nestedPunctuation; 700 701 // Scan over the nested punctuation, bailing out on error and consuming until 702 // we find the end. We know that we're currently looking at the '<', so we 703 // can go until we find the matching '>' character. 704 assert(*curPtr == '<'); 705 do { 706 char c = *curPtr++; 707 switch (c) { 708 case '\0': 709 // This also handles the EOF case. 710 return emitError("unexpected nul or EOF in pretty dialect name"); 711 case '<': 712 case '[': 713 case '(': 714 case '{': 715 nestedPunctuation.push_back(c); 716 continue; 717 718 case '-': 719 // The sequence `->` is treated as special token. 720 if (*curPtr == '>') 721 ++curPtr; 722 continue; 723 724 case '>': 725 if (nestedPunctuation.pop_back_val() != '<') 726 return emitError("unbalanced '>' character in pretty dialect name"); 727 break; 728 case ']': 729 if (nestedPunctuation.pop_back_val() != '[') 730 return emitError("unbalanced ']' character in pretty dialect name"); 731 break; 732 case ')': 733 if (nestedPunctuation.pop_back_val() != '(') 734 return emitError("unbalanced ')' character in pretty dialect name"); 735 break; 736 case '}': 737 if (nestedPunctuation.pop_back_val() != '{') 738 return emitError("unbalanced '}' character in pretty dialect name"); 739 break; 740 741 default: 742 continue; 743 } 744 } while (!nestedPunctuation.empty()); 745 746 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 747 // consuming all this stuff, and return. 748 state.lex.resetPointer(curPtr); 749 750 unsigned length = curPtr - prettyName.begin(); 751 prettyName = StringRef(prettyName.begin(), length); 752 consumeToken(); 753 return success(); 754 } 755 756 /// Parse an extended dialect symbol. 757 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 758 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 759 SymbolAliasMap &aliases, 760 CreateFn &&createSymbol) { 761 // Parse the dialect namespace. 762 StringRef identifier = p.getTokenSpelling().drop_front(); 763 auto loc = p.getToken().getLoc(); 764 p.consumeToken(identifierTok); 765 766 // If there is no '<' token following this, and if the typename contains no 767 // dot, then we are parsing a symbol alias. 768 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 769 // Check for an alias for this type. 770 auto aliasIt = aliases.find(identifier); 771 if (aliasIt == aliases.end()) 772 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 773 nullptr); 774 return aliasIt->second; 775 } 776 777 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 778 // a dot, then this is the "pretty" form. If not, it is the verbose form that 779 // looks like <"...">. 780 std::string symbolData; 781 auto dialectName = identifier; 782 783 // Handle the verbose form, where "identifier" is a simple dialect name. 784 if (!identifier.contains('.')) { 785 // Consume the '<'. 786 if (p.parseToken(Token::less, "expected '<' in dialect type")) 787 return nullptr; 788 789 // Parse the symbol specific data. 790 if (p.getToken().isNot(Token::string)) 791 return (p.emitError("expected string literal data in dialect symbol"), 792 nullptr); 793 symbolData = p.getToken().getStringValue(); 794 loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); 795 p.consumeToken(Token::string); 796 797 // Consume the '>'. 798 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 799 return nullptr; 800 } else { 801 // Ok, the dialect name is the part of the identifier before the dot, the 802 // part after the dot is the dialect's symbol, or the start thereof. 803 auto dotHalves = identifier.split('.'); 804 dialectName = dotHalves.first; 805 auto prettyName = dotHalves.second; 806 loc = llvm::SMLoc::getFromPointer(prettyName.data()); 807 808 // If the dialect's symbol is followed immediately by a <, then lex the body 809 // of it into prettyName. 810 if (p.getToken().is(Token::less) && 811 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 812 if (p.parsePrettyDialectSymbolName(prettyName)) 813 return nullptr; 814 } 815 816 symbolData = prettyName.str(); 817 } 818 819 // Record the name location of the type remapped to the top level buffer. 820 llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); 821 p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); 822 823 // Call into the provided symbol construction function. 824 Symbol sym = createSymbol(dialectName, symbolData, loc); 825 826 // Pop the last parser location. 827 p.getState().symbols.nestedParserLocs.pop_back(); 828 return sym; 829 } 830 831 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 832 /// parsing failed, nullptr is returned. The number of bytes read from the input 833 /// string is returned in 'numRead'. 834 template <typename T, typename ParserFn> 835 static T parseSymbol(StringRef inputStr, MLIRContext *context, 836 SymbolState &symbolState, ParserFn &&parserFn, 837 size_t *numRead = nullptr) { 838 SourceMgr sourceMgr; 839 auto memBuffer = MemoryBuffer::getMemBuffer( 840 inputStr, /*BufferName=*/"<mlir_parser_buffer>", 841 /*RequiresNullTerminator=*/false); 842 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 843 ParserState state(sourceMgr, context, symbolState); 844 Parser parser(state); 845 846 Token startTok = parser.getToken(); 847 T symbol = parserFn(parser); 848 if (!symbol) 849 return T(); 850 851 // If 'numRead' is valid, then provide the number of bytes that were read. 852 Token endTok = parser.getToken(); 853 if (numRead) { 854 *numRead = static_cast<size_t>(endTok.getLoc().getPointer() - 855 startTok.getLoc().getPointer()); 856 857 // Otherwise, ensure that all of the tokens were parsed. 858 } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { 859 parser.emitError(endTok.getLoc(), "encountered unexpected token"); 860 return T(); 861 } 862 return symbol; 863 } 864 865 //===----------------------------------------------------------------------===// 866 // Error Handling 867 //===----------------------------------------------------------------------===// 868 869 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 870 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 871 872 // If we hit a parse error in response to a lexer error, then the lexer 873 // already reported the error. 874 if (getToken().is(Token::error)) 875 diag.abandon(); 876 return diag; 877 } 878 879 //===----------------------------------------------------------------------===// 880 // Token Parsing 881 //===----------------------------------------------------------------------===// 882 883 /// Consume the specified token if present and return success. On failure, 884 /// output a diagnostic and return failure. 885 ParseResult Parser::parseToken(Token::Kind expectedToken, 886 const Twine &message) { 887 if (consumeIf(expectedToken)) 888 return success(); 889 return emitError(message); 890 } 891 892 //===----------------------------------------------------------------------===// 893 // Type Parsing 894 //===----------------------------------------------------------------------===// 895 896 /// Parse an arbitrary type. 897 /// 898 /// type ::= function-type 899 /// | non-function-type 900 /// 901 Type Parser::parseType() { 902 if (getToken().is(Token::l_paren)) 903 return parseFunctionType(); 904 return parseNonFunctionType(); 905 } 906 907 /// Parse a function result type. 908 /// 909 /// function-result-type ::= type-list-parens 910 /// | non-function-type 911 /// 912 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 913 if (getToken().is(Token::l_paren)) 914 return parseTypeListParens(elements); 915 916 Type t = parseNonFunctionType(); 917 if (!t) 918 return failure(); 919 elements.push_back(t); 920 return success(); 921 } 922 923 /// Parse a list of types without an enclosing parenthesis. The list must have 924 /// at least one member. 925 /// 926 /// type-list-no-parens ::= type (`,` type)* 927 /// 928 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 929 auto parseElt = [&]() -> ParseResult { 930 auto elt = parseType(); 931 elements.push_back(elt); 932 return elt ? success() : failure(); 933 }; 934 935 return parseCommaSeparatedList(parseElt); 936 } 937 938 /// Parse a parenthesized list of types. 939 /// 940 /// type-list-parens ::= `(` `)` 941 /// | `(` type-list-no-parens `)` 942 /// 943 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 944 if (parseToken(Token::l_paren, "expected '('")) 945 return failure(); 946 947 // Handle empty lists. 948 if (getToken().is(Token::r_paren)) 949 return consumeToken(), success(); 950 951 if (parseTypeListNoParens(elements) || 952 parseToken(Token::r_paren, "expected ')'")) 953 return failure(); 954 return success(); 955 } 956 957 /// Parse a complex type. 958 /// 959 /// complex-type ::= `complex` `<` type `>` 960 /// 961 Type Parser::parseComplexType() { 962 consumeToken(Token::kw_complex); 963 964 // Parse the '<'. 965 if (parseToken(Token::less, "expected '<' in complex type")) 966 return nullptr; 967 968 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 969 auto elementType = parseType(); 970 if (!elementType || 971 parseToken(Token::greater, "expected '>' in complex type")) 972 return nullptr; 973 974 return ComplexType::getChecked(elementType, typeLocation); 975 } 976 977 /// Parse an extended type. 978 /// 979 /// extended-type ::= (dialect-type | type-alias) 980 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 981 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 982 /// type-alias ::= `!` alias-name 983 /// 984 Type Parser::parseExtendedType() { 985 return parseExtendedSymbol<Type>( 986 *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, 987 [&](StringRef dialectName, StringRef symbolData, 988 llvm::SMLoc loc) -> Type { 989 // If we found a registered dialect, then ask it to parse the type. 990 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 991 return parseSymbol<Type>( 992 symbolData, state.context, state.symbols, [&](Parser &parser) { 993 CustomDialectAsmParser customParser(symbolData, parser); 994 return dialect->parseType(customParser); 995 }); 996 } 997 998 // Otherwise, form a new opaque type. 999 return OpaqueType::getChecked( 1000 Identifier::get(dialectName, state.context), symbolData, 1001 state.context, getEncodedSourceLocation(loc)); 1002 }); 1003 } 1004 1005 /// Parse a function type. 1006 /// 1007 /// function-type ::= type-list-parens `->` function-result-type 1008 /// 1009 Type Parser::parseFunctionType() { 1010 assert(getToken().is(Token::l_paren)); 1011 1012 SmallVector<Type, 4> arguments, results; 1013 if (parseTypeListParens(arguments) || 1014 parseToken(Token::arrow, "expected '->' in function type") || 1015 parseFunctionResultTypes(results)) 1016 return nullptr; 1017 1018 return builder.getFunctionType(arguments, results); 1019 } 1020 1021 /// Parse the offset and strides from a strided layout specification. 1022 /// 1023 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 1024 /// 1025 ParseResult Parser::parseStridedLayout(int64_t &offset, 1026 SmallVectorImpl<int64_t> &strides) { 1027 // Parse offset. 1028 consumeToken(Token::kw_offset); 1029 if (!consumeIf(Token::colon)) 1030 return emitError("expected colon after `offset` keyword"); 1031 auto maybeOffset = getToken().getUnsignedIntegerValue(); 1032 bool question = getToken().is(Token::question); 1033 if (!maybeOffset && !question) 1034 return emitError("invalid offset"); 1035 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue()) 1036 : MemRefType::getDynamicStrideOrOffset(); 1037 consumeToken(); 1038 1039 if (!consumeIf(Token::comma)) 1040 return emitError("expected comma after offset value"); 1041 1042 // Parse stride list. 1043 if (!consumeIf(Token::kw_strides)) 1044 return emitError("expected `strides` keyword after offset specification"); 1045 if (!consumeIf(Token::colon)) 1046 return emitError("expected colon after `strides` keyword"); 1047 if (failed(parseStrideList(strides))) 1048 return emitError("invalid braces-enclosed stride list"); 1049 if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) 1050 return emitError("invalid memref stride"); 1051 1052 return success(); 1053 } 1054 1055 /// Parse a memref type. 1056 /// 1057 /// memref-type ::= ranked-memref-type | unranked-memref-type 1058 /// 1059 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type 1060 /// (`,` semi-affine-map-composition)? (`,` 1061 /// memory-space)? `>` 1062 /// 1063 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` 1064 /// 1065 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 1066 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 1067 /// 1068 Type Parser::parseMemRefType() { 1069 consumeToken(Token::kw_memref); 1070 1071 if (parseToken(Token::less, "expected '<' in memref type")) 1072 return nullptr; 1073 1074 bool isUnranked; 1075 SmallVector<int64_t, 4> dimensions; 1076 1077 if (consumeIf(Token::star)) { 1078 // This is an unranked memref type. 1079 isUnranked = true; 1080 if (parseXInDimensionList()) 1081 return nullptr; 1082 1083 } else { 1084 isUnranked = false; 1085 if (parseDimensionListRanked(dimensions)) 1086 return nullptr; 1087 } 1088 1089 // Parse the element type. 1090 auto typeLoc = getToken().getLoc(); 1091 auto elementType = parseType(); 1092 if (!elementType) 1093 return nullptr; 1094 1095 // Parse semi-affine-map-composition. 1096 SmallVector<AffineMap, 2> affineMapComposition; 1097 unsigned memorySpace = 0; 1098 bool parsedMemorySpace = false; 1099 1100 auto parseElt = [&]() -> ParseResult { 1101 if (getToken().is(Token::integer)) { 1102 // Parse memory space. 1103 if (parsedMemorySpace) 1104 return emitError("multiple memory spaces specified in memref type"); 1105 auto v = getToken().getUnsignedIntegerValue(); 1106 if (!v.hasValue()) 1107 return emitError("invalid memory space in memref type"); 1108 memorySpace = v.getValue(); 1109 consumeToken(Token::integer); 1110 parsedMemorySpace = true; 1111 } else { 1112 if (isUnranked) 1113 return emitError("cannot have affine map for unranked memref type"); 1114 if (parsedMemorySpace) 1115 return emitError("expected memory space to be last in memref type"); 1116 if (getToken().is(Token::kw_offset)) { 1117 int64_t offset; 1118 SmallVector<int64_t, 4> strides; 1119 if (failed(parseStridedLayout(offset, strides))) 1120 return failure(); 1121 // Construct strided affine map. 1122 auto map = makeStridedLinearLayoutMap(strides, offset, 1123 elementType.getContext()); 1124 affineMapComposition.push_back(map); 1125 } else { 1126 // Parse affine map. 1127 auto affineMap = parseAttribute(); 1128 if (!affineMap) 1129 return failure(); 1130 // Verify that the parsed attribute is an affine map. 1131 if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>()) 1132 affineMapComposition.push_back(affineMapAttr.getValue()); 1133 else 1134 return emitError("expected affine map in memref type"); 1135 } 1136 } 1137 return success(); 1138 }; 1139 1140 // Parse a list of mappings and address space if present. 1141 if (consumeIf(Token::comma)) { 1142 // Parse comma separated list of affine maps, followed by memory space. 1143 if (parseCommaSeparatedListUntil(Token::greater, parseElt, 1144 /*allowEmptyList=*/false)) { 1145 return nullptr; 1146 } 1147 } else { 1148 if (parseToken(Token::greater, "expected ',' or '>' in memref type")) 1149 return nullptr; 1150 } 1151 1152 if (isUnranked) 1153 return UnrankedMemRefType::getChecked(elementType, memorySpace, 1154 getEncodedSourceLocation(typeLoc)); 1155 1156 return MemRefType::getChecked(dimensions, elementType, affineMapComposition, 1157 memorySpace, getEncodedSourceLocation(typeLoc)); 1158 } 1159 1160 /// Parse any type except the function type. 1161 /// 1162 /// non-function-type ::= integer-type 1163 /// | index-type 1164 /// | float-type 1165 /// | extended-type 1166 /// | vector-type 1167 /// | tensor-type 1168 /// | memref-type 1169 /// | complex-type 1170 /// | tuple-type 1171 /// | none-type 1172 /// 1173 /// index-type ::= `index` 1174 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 1175 /// none-type ::= `none` 1176 /// 1177 Type Parser::parseNonFunctionType() { 1178 switch (getToken().getKind()) { 1179 default: 1180 return (emitError("expected non-function type"), nullptr); 1181 case Token::kw_memref: 1182 return parseMemRefType(); 1183 case Token::kw_tensor: 1184 return parseTensorType(); 1185 case Token::kw_complex: 1186 return parseComplexType(); 1187 case Token::kw_tuple: 1188 return parseTupleType(); 1189 case Token::kw_vector: 1190 return parseVectorType(); 1191 // integer-type 1192 case Token::inttype: { 1193 auto width = getToken().getIntTypeBitwidth(); 1194 if (!width.hasValue()) 1195 return (emitError("invalid integer width"), nullptr); 1196 auto loc = getEncodedSourceLocation(getToken().getLoc()); 1197 consumeToken(Token::inttype); 1198 return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); 1199 } 1200 1201 // float-type 1202 case Token::kw_bf16: 1203 consumeToken(Token::kw_bf16); 1204 return builder.getBF16Type(); 1205 case Token::kw_f16: 1206 consumeToken(Token::kw_f16); 1207 return builder.getF16Type(); 1208 case Token::kw_f32: 1209 consumeToken(Token::kw_f32); 1210 return builder.getF32Type(); 1211 case Token::kw_f64: 1212 consumeToken(Token::kw_f64); 1213 return builder.getF64Type(); 1214 1215 // index-type 1216 case Token::kw_index: 1217 consumeToken(Token::kw_index); 1218 return builder.getIndexType(); 1219 1220 // none-type 1221 case Token::kw_none: 1222 consumeToken(Token::kw_none); 1223 return builder.getNoneType(); 1224 1225 // extended type 1226 case Token::exclamation_identifier: 1227 return parseExtendedType(); 1228 } 1229 } 1230 1231 /// Parse a tensor type. 1232 /// 1233 /// tensor-type ::= `tensor` `<` dimension-list type `>` 1234 /// dimension-list ::= dimension-list-ranked | `*x` 1235 /// 1236 Type Parser::parseTensorType() { 1237 consumeToken(Token::kw_tensor); 1238 1239 if (parseToken(Token::less, "expected '<' in tensor type")) 1240 return nullptr; 1241 1242 bool isUnranked; 1243 SmallVector<int64_t, 4> dimensions; 1244 1245 if (consumeIf(Token::star)) { 1246 // This is an unranked tensor type. 1247 isUnranked = true; 1248 1249 if (parseXInDimensionList()) 1250 return nullptr; 1251 1252 } else { 1253 isUnranked = false; 1254 if (parseDimensionListRanked(dimensions)) 1255 return nullptr; 1256 } 1257 1258 // Parse the element type. 1259 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 1260 auto elementType = parseType(); 1261 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 1262 return nullptr; 1263 1264 if (isUnranked) 1265 return UnrankedTensorType::getChecked(elementType, typeLocation); 1266 return RankedTensorType::getChecked(dimensions, elementType, typeLocation); 1267 } 1268 1269 /// Parse a tuple type. 1270 /// 1271 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 1272 /// 1273 Type Parser::parseTupleType() { 1274 consumeToken(Token::kw_tuple); 1275 1276 // Parse the '<'. 1277 if (parseToken(Token::less, "expected '<' in tuple type")) 1278 return nullptr; 1279 1280 // Check for an empty tuple by directly parsing '>'. 1281 if (consumeIf(Token::greater)) 1282 return TupleType::get(getContext()); 1283 1284 // Parse the element types and the '>'. 1285 SmallVector<Type, 4> types; 1286 if (parseTypeListNoParens(types) || 1287 parseToken(Token::greater, "expected '>' in tuple type")) 1288 return nullptr; 1289 1290 return TupleType::get(types, getContext()); 1291 } 1292 1293 /// Parse a vector type. 1294 /// 1295 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` 1296 /// non-empty-static-dimension-list ::= decimal-literal `x` 1297 /// static-dimension-list 1298 /// static-dimension-list ::= (decimal-literal `x`)* 1299 /// 1300 VectorType Parser::parseVectorType() { 1301 consumeToken(Token::kw_vector); 1302 1303 if (parseToken(Token::less, "expected '<' in vector type")) 1304 return nullptr; 1305 1306 SmallVector<int64_t, 4> dimensions; 1307 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 1308 return nullptr; 1309 if (dimensions.empty()) 1310 return (emitError("expected dimension size in vector type"), nullptr); 1311 1312 // Parse the element type. 1313 auto typeLoc = getToken().getLoc(); 1314 auto elementType = parseType(); 1315 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 1316 return nullptr; 1317 1318 return VectorType::getChecked(dimensions, elementType, 1319 getEncodedSourceLocation(typeLoc)); 1320 } 1321 1322 /// Parse a dimension list of a tensor or memref type. This populates the 1323 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 1324 /// errors out on `?` otherwise. 1325 /// 1326 /// dimension-list-ranked ::= (dimension `x`)* 1327 /// dimension ::= `?` | decimal-literal 1328 /// 1329 /// When `allowDynamic` is not set, this is used to parse: 1330 /// 1331 /// static-dimension-list ::= (decimal-literal `x`)* 1332 ParseResult 1333 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 1334 bool allowDynamic) { 1335 while (getToken().isAny(Token::integer, Token::question)) { 1336 if (consumeIf(Token::question)) { 1337 if (!allowDynamic) 1338 return emitError("expected static shape"); 1339 dimensions.push_back(-1); 1340 } else { 1341 // Hexadecimal integer literals (starting with `0x`) are not allowed in 1342 // aggregate type declarations. Therefore, `0xf32` should be processed as 1343 // a sequence of separate elements `0`, `x`, `f32`. 1344 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 1345 // We can get here only if the token is an integer literal. Hexadecimal 1346 // integer literals can only start with `0x` (`1x` wouldn't lex as a 1347 // literal, just `1` would, at which point we don't get into this 1348 // branch). 1349 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 1350 dimensions.push_back(0); 1351 state.lex.resetPointer(getTokenSpelling().data() + 1); 1352 consumeToken(); 1353 } else { 1354 // Make sure this integer value is in bound and valid. 1355 auto dimension = getToken().getUnsignedIntegerValue(); 1356 if (!dimension.hasValue()) 1357 return emitError("invalid dimension"); 1358 dimensions.push_back((int64_t)dimension.getValue()); 1359 consumeToken(Token::integer); 1360 } 1361 } 1362 1363 // Make sure we have an 'x' or something like 'xbf32'. 1364 if (parseXInDimensionList()) 1365 return failure(); 1366 } 1367 1368 return success(); 1369 } 1370 1371 /// Parse an 'x' token in a dimension list, handling the case where the x is 1372 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 1373 /// token. 1374 ParseResult Parser::parseXInDimensionList() { 1375 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 1376 return emitError("expected 'x' in dimension list"); 1377 1378 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 1379 if (getTokenSpelling().size() != 1) 1380 state.lex.resetPointer(getTokenSpelling().data() + 1); 1381 1382 // Consume the 'x'. 1383 consumeToken(Token::bare_identifier); 1384 1385 return success(); 1386 } 1387 1388 // Parse a comma-separated list of dimensions, possibly empty: 1389 // stride-list ::= `[` (dimension (`,` dimension)*)? `]` 1390 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { 1391 if (!consumeIf(Token::l_square)) 1392 return failure(); 1393 // Empty list early exit. 1394 if (consumeIf(Token::r_square)) 1395 return success(); 1396 while (true) { 1397 if (consumeIf(Token::question)) { 1398 dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); 1399 } else { 1400 // This must be an integer value. 1401 int64_t val; 1402 if (getToken().getSpelling().getAsInteger(10, val)) 1403 return emitError("invalid integer value: ") << getToken().getSpelling(); 1404 // Make sure it is not the one value for `?`. 1405 if (ShapedType::isDynamic(val)) 1406 return emitError("invalid integer value: ") 1407 << getToken().getSpelling() 1408 << ", use `?` to specify a dynamic dimension"; 1409 dimensions.push_back(val); 1410 consumeToken(Token::integer); 1411 } 1412 if (!consumeIf(Token::comma)) 1413 break; 1414 } 1415 if (!consumeIf(Token::r_square)) 1416 return failure(); 1417 return success(); 1418 } 1419 1420 //===----------------------------------------------------------------------===// 1421 // Attribute parsing. 1422 //===----------------------------------------------------------------------===// 1423 1424 /// Return the symbol reference referred to by the given token, that is known to 1425 /// be an @-identifier. 1426 static std::string extractSymbolReference(Token tok) { 1427 assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); 1428 StringRef nameStr = tok.getSpelling().drop_front(); 1429 1430 // Check to see if the reference is a string literal, or a bare identifier. 1431 if (nameStr.front() == '"') 1432 return tok.getStringValue(); 1433 return nameStr; 1434 } 1435 1436 /// Parse an arbitrary attribute. 1437 /// 1438 /// attribute-value ::= `unit` 1439 /// | bool-literal 1440 /// | integer-literal (`:` (index-type | integer-type))? 1441 /// | float-literal (`:` float-type)? 1442 /// | string-literal (`:` type)? 1443 /// | type 1444 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 1445 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 1446 /// | symbol-ref-id (`::` symbol-ref-id)* 1447 /// | `dense` `<` attribute-value `>` `:` 1448 /// (tensor-type | vector-type) 1449 /// | `sparse` `<` attribute-value `,` attribute-value `>` 1450 /// `:` (tensor-type | vector-type) 1451 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 1452 /// `>` `:` (tensor-type | vector-type) 1453 /// | extended-attribute 1454 /// 1455 Attribute Parser::parseAttribute(Type type) { 1456 switch (getToken().getKind()) { 1457 // Parse an AffineMap or IntegerSet attribute. 1458 case Token::l_paren: { 1459 // Try to parse an affine map or an integer set reference. 1460 AffineMap map; 1461 IntegerSet set; 1462 if (parseAffineMapOrIntegerSetReference(map, set)) 1463 return nullptr; 1464 if (map) 1465 return AffineMapAttr::get(map); 1466 assert(set); 1467 return IntegerSetAttr::get(set); 1468 } 1469 1470 // Parse an array attribute. 1471 case Token::l_square: { 1472 consumeToken(Token::l_square); 1473 1474 SmallVector<Attribute, 4> elements; 1475 auto parseElt = [&]() -> ParseResult { 1476 elements.push_back(parseAttribute()); 1477 return elements.back() ? success() : failure(); 1478 }; 1479 1480 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 1481 return nullptr; 1482 return builder.getArrayAttr(elements); 1483 } 1484 1485 // Parse a boolean attribute. 1486 case Token::kw_false: 1487 consumeToken(Token::kw_false); 1488 return builder.getBoolAttr(false); 1489 case Token::kw_true: 1490 consumeToken(Token::kw_true); 1491 return builder.getBoolAttr(true); 1492 1493 // Parse a dense elements attribute. 1494 case Token::kw_dense: 1495 return parseDenseElementsAttr(); 1496 1497 // Parse a dictionary attribute. 1498 case Token::l_brace: { 1499 SmallVector<NamedAttribute, 4> elements; 1500 if (parseAttributeDict(elements)) 1501 return nullptr; 1502 return builder.getDictionaryAttr(elements); 1503 } 1504 1505 // Parse an extended attribute, i.e. alias or dialect attribute. 1506 case Token::hash_identifier: 1507 return parseExtendedAttr(type); 1508 1509 // Parse floating point and integer attributes. 1510 case Token::floatliteral: 1511 return parseFloatAttr(type, /*isNegative=*/false); 1512 case Token::integer: 1513 return parseDecOrHexAttr(type, /*isNegative=*/false); 1514 case Token::minus: { 1515 consumeToken(Token::minus); 1516 if (getToken().is(Token::integer)) 1517 return parseDecOrHexAttr(type, /*isNegative=*/true); 1518 if (getToken().is(Token::floatliteral)) 1519 return parseFloatAttr(type, /*isNegative=*/true); 1520 1521 return (emitError("expected constant integer or floating point value"), 1522 nullptr); 1523 } 1524 1525 // Parse a location attribute. 1526 case Token::kw_loc: { 1527 LocationAttr attr; 1528 return failed(parseLocation(attr)) ? Attribute() : attr; 1529 } 1530 1531 // Parse an opaque elements attribute. 1532 case Token::kw_opaque: 1533 return parseOpaqueElementsAttr(); 1534 1535 // Parse a sparse elements attribute. 1536 case Token::kw_sparse: 1537 return parseSparseElementsAttr(); 1538 1539 // Parse a string attribute. 1540 case Token::string: { 1541 auto val = getToken().getStringValue(); 1542 consumeToken(Token::string); 1543 // Parse the optional trailing colon type if one wasn't explicitly provided. 1544 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1545 return Attribute(); 1546 1547 return type ? StringAttr::get(val, type) 1548 : StringAttr::get(val, getContext()); 1549 } 1550 1551 // Parse a symbol reference attribute. 1552 case Token::at_identifier: { 1553 std::string nameStr = extractSymbolReference(getToken()); 1554 consumeToken(Token::at_identifier); 1555 1556 // Parse any nested references. 1557 std::vector<FlatSymbolRefAttr> nestedRefs; 1558 while (getToken().is(Token::colon)) { 1559 // Check for the '::' prefix. 1560 const char *curPointer = getToken().getLoc().getPointer(); 1561 consumeToken(Token::colon); 1562 if (!consumeIf(Token::colon)) { 1563 state.lex.resetPointer(curPointer); 1564 consumeToken(); 1565 break; 1566 } 1567 // Parse the reference itself. 1568 auto curLoc = getToken().getLoc(); 1569 if (getToken().isNot(Token::at_identifier)) { 1570 emitError(curLoc, "expected nested symbol reference identifier"); 1571 return Attribute(); 1572 } 1573 1574 std::string nameStr = extractSymbolReference(getToken()); 1575 consumeToken(Token::at_identifier); 1576 nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); 1577 } 1578 1579 return builder.getSymbolRefAttr(nameStr, nestedRefs); 1580 } 1581 1582 // Parse a 'unit' attribute. 1583 case Token::kw_unit: 1584 consumeToken(Token::kw_unit); 1585 return builder.getUnitAttr(); 1586 1587 default: 1588 // Parse a type attribute. 1589 if (Type type = parseType()) 1590 return TypeAttr::get(type); 1591 return nullptr; 1592 } 1593 } 1594 1595 /// Attribute dictionary. 1596 /// 1597 /// attribute-dict ::= `{` `}` 1598 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1599 /// attribute-entry ::= bare-id `=` attribute-value 1600 /// 1601 ParseResult 1602 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1603 if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) 1604 return failure(); 1605 1606 auto parseElt = [&]() -> ParseResult { 1607 // We allow keywords as attribute names. 1608 if (getToken().isNot(Token::bare_identifier, Token::inttype) && 1609 !getToken().isKeyword()) 1610 return emitError("expected attribute name"); 1611 Identifier nameId = builder.getIdentifier(getTokenSpelling()); 1612 consumeToken(); 1613 1614 // Try to parse the '=' for the attribute value. 1615 if (!consumeIf(Token::equal)) { 1616 // If there is no '=', we treat this as a unit attribute. 1617 attributes.push_back({nameId, builder.getUnitAttr()}); 1618 return success(); 1619 } 1620 1621 auto attr = parseAttribute(); 1622 if (!attr) 1623 return failure(); 1624 1625 attributes.push_back({nameId, attr}); 1626 return success(); 1627 }; 1628 1629 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1630 return failure(); 1631 1632 return success(); 1633 } 1634 1635 /// Parse an extended attribute. 1636 /// 1637 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1638 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1639 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1640 /// attribute-alias ::= `#` alias-name 1641 /// 1642 Attribute Parser::parseExtendedAttr(Type type) { 1643 Attribute attr = parseExtendedSymbol<Attribute>( 1644 *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, 1645 [&](StringRef dialectName, StringRef symbolData, 1646 llvm::SMLoc loc) -> Attribute { 1647 // Parse an optional trailing colon type. 1648 Type attrType = type; 1649 if (consumeIf(Token::colon) && !(attrType = parseType())) 1650 return Attribute(); 1651 1652 // If we found a registered dialect, then ask it to parse the attribute. 1653 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1654 return parseSymbol<Attribute>( 1655 symbolData, state.context, state.symbols, [&](Parser &parser) { 1656 CustomDialectAsmParser customParser(symbolData, parser); 1657 return dialect->parseAttribute(customParser, attrType); 1658 }); 1659 } 1660 1661 // Otherwise, form a new opaque attribute. 1662 return OpaqueAttr::getChecked( 1663 Identifier::get(dialectName, state.context), symbolData, 1664 attrType ? attrType : NoneType::get(state.context), 1665 getEncodedSourceLocation(loc)); 1666 }); 1667 1668 // Ensure that the attribute has the same type as requested. 1669 if (attr && type && attr.getType() != type) { 1670 emitError("attribute type different than expected: expected ") 1671 << type << ", but got " << attr.getType(); 1672 return nullptr; 1673 } 1674 return attr; 1675 } 1676 1677 /// Parse a float attribute. 1678 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1679 auto val = getToken().getFloatingPointValue(); 1680 if (!val.hasValue()) 1681 return (emitError("floating point value too large for attribute"), nullptr); 1682 consumeToken(Token::floatliteral); 1683 if (!type) { 1684 // Default to F64 when no type is specified. 1685 if (!consumeIf(Token::colon)) 1686 type = builder.getF64Type(); 1687 else if (!(type = parseType())) 1688 return nullptr; 1689 } 1690 if (!type.isa<FloatType>()) 1691 return (emitError("floating point value not valid for specified type"), 1692 nullptr); 1693 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1694 } 1695 1696 /// Construct a float attribute bitwise equivalent to the integer literal. 1697 static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1698 uint64_t value) { 1699 int width = type.getIntOrFloatBitWidth(); 1700 APInt apInt(width, value); 1701 if (apInt != value) { 1702 p->emitError("hexadecimal float constant out of range for type"); 1703 return nullptr; 1704 } 1705 APFloat apFloat(type.getFloatSemantics(), apInt); 1706 return p->builder.getFloatAttr(type, apFloat); 1707 } 1708 1709 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1710 /// or a float attribute. 1711 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1712 auto val = getToken().getUInt64IntegerValue(); 1713 if (!val.hasValue()) 1714 return (emitError("integer constant out of range for attribute"), nullptr); 1715 1716 // Remember if the literal is hexadecimal. 1717 StringRef spelling = getToken().getSpelling(); 1718 auto loc = state.curToken.getLoc(); 1719 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1720 1721 consumeToken(Token::integer); 1722 if (!type) { 1723 // Default to i64 if not type is specified. 1724 if (!consumeIf(Token::colon)) 1725 type = builder.getIntegerType(64); 1726 else if (!(type = parseType())) 1727 return nullptr; 1728 } 1729 1730 if (auto floatType = type.dyn_cast<FloatType>()) { 1731 // TODO(zinenko): Update once hex format for bfloat16 is supported. 1732 if (type.isBF16()) 1733 return emitError(loc, 1734 "hexadecimal float literal not supported for bfloat16"), 1735 nullptr; 1736 if (isNegative) 1737 return emitError( 1738 loc, 1739 "hexadecimal float literal should not have a leading minus"), 1740 nullptr; 1741 if (!isHex) { 1742 emitError(loc, "unexpected decimal integer literal for a float attribute") 1743 .attachNote() 1744 << "add a trailing dot to make the literal a float"; 1745 return nullptr; 1746 } 1747 1748 // Construct a float attribute bitwise equivalent to the integer literal. 1749 return buildHexadecimalFloatLiteral(this, floatType, *val); 1750 } 1751 1752 if (!type.isIntOrIndex()) 1753 return emitError(loc, "integer literal not valid for specified type"), 1754 nullptr; 1755 1756 // Parse the integer literal. 1757 int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); 1758 APInt apInt(width, *val, isNegative); 1759 if (apInt != *val) 1760 return emitError(loc, "integer constant out of range for attribute"), 1761 nullptr; 1762 1763 // Otherwise construct an integer attribute. 1764 if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0) 1765 return emitError(loc, "integer constant out of range for attribute"), 1766 nullptr; 1767 1768 return builder.getIntegerAttr(type, isNegative ? -apInt : apInt); 1769 } 1770 1771 /// Parse an opaque elements attribute. 1772 Attribute Parser::parseOpaqueElementsAttr() { 1773 consumeToken(Token::kw_opaque); 1774 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1775 return nullptr; 1776 1777 if (getToken().isNot(Token::string)) 1778 return (emitError("expected dialect namespace"), nullptr); 1779 1780 auto name = getToken().getStringValue(); 1781 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1782 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1783 // attribute. Otherwise, it can't be roundtripped without having the dialect 1784 // registered. 1785 if (!dialect) 1786 return (emitError("no registered dialect with namespace '" + name + "'"), 1787 nullptr); 1788 1789 consumeToken(Token::string); 1790 if (parseToken(Token::comma, "expected ','")) 1791 return nullptr; 1792 1793 if (getToken().getKind() != Token::string) 1794 return (emitError("opaque string should start with '0x'"), nullptr); 1795 1796 auto val = getToken().getStringValue(); 1797 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1798 return (emitError("opaque string should start with '0x'"), nullptr); 1799 1800 val = val.substr(2); 1801 if (!llvm::all_of(val, llvm::isHexDigit)) 1802 return (emitError("opaque string only contains hex digits"), nullptr); 1803 1804 consumeToken(Token::string); 1805 if (parseToken(Token::greater, "expected '>'") || 1806 parseToken(Token::colon, "expected ':'")) 1807 return nullptr; 1808 1809 auto type = parseElementsLiteralType(); 1810 if (!type) 1811 return nullptr; 1812 1813 return OpaqueElementsAttr::get(dialect, type, llvm::fromHex(val)); 1814 } 1815 1816 namespace { 1817 class TensorLiteralParser { 1818 public: 1819 TensorLiteralParser(Parser &p) : p(p) {} 1820 1821 ParseResult parse() { 1822 if (p.getToken().is(Token::l_square)) 1823 return parseList(shape); 1824 return parseElement(); 1825 } 1826 1827 /// Build a dense attribute instance with the parsed elements and the given 1828 /// shaped type. 1829 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1830 1831 ArrayRef<int64_t> getShape() const { return shape; } 1832 1833 private: 1834 enum class ElementKind { Boolean, Integer, Float }; 1835 1836 /// Return a string to represent the given element kind. 1837 const char *getElementKindStr(ElementKind kind) { 1838 switch (kind) { 1839 case ElementKind::Boolean: 1840 return "'boolean'"; 1841 case ElementKind::Integer: 1842 return "'integer'"; 1843 case ElementKind::Float: 1844 return "'float'"; 1845 } 1846 llvm_unreachable("unknown element kind"); 1847 } 1848 1849 /// Build a Dense Integer attribute for the given type. 1850 DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, 1851 IntegerType eltTy); 1852 1853 /// Build a Dense Float attribute for the given type. 1854 DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, 1855 FloatType eltTy); 1856 1857 /// Parse a single element, returning failure if it isn't a valid element 1858 /// literal. For example: 1859 /// parseElement(1) -> Success, 1 1860 /// parseElement([1]) -> Failure 1861 ParseResult parseElement(); 1862 1863 /// Parse a list of either lists or elements, returning the dimensions of the 1864 /// parsed sub-tensors in dims. For example: 1865 /// parseList([1, 2, 3]) -> Success, [3] 1866 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1867 /// parseList([[1, 2], 3]) -> Failure 1868 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1869 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 1870 1871 Parser &p; 1872 1873 /// The shape inferred from the parsed elements. 1874 SmallVector<int64_t, 4> shape; 1875 1876 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 1877 std::vector<std::pair<bool, Token>> storage; 1878 1879 /// A flag that indicates the type of elements that have been parsed. 1880 Optional<ElementKind> knownEltKind; 1881 }; 1882 } // namespace 1883 1884 /// Build a dense attribute instance with the parsed elements and the given 1885 /// shaped type. 1886 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 1887 ShapedType type) { 1888 // Check that the parsed storage size has the same number of elements to the 1889 // type, or is a known splat. 1890 if (!shape.empty() && getShape() != type.getShape()) { 1891 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 1892 << "]) does not match type ([" << type.getShape() << "])"; 1893 return nullptr; 1894 } 1895 1896 // If the type is an integer, build a set of APInt values from the storage 1897 // with the correct bitwidth. 1898 if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) 1899 return getIntAttr(loc, type, intTy); 1900 1901 // Otherwise, this must be a floating point type. 1902 auto floatTy = type.getElementType().dyn_cast<FloatType>(); 1903 if (!floatTy) { 1904 p.emitError(loc) << "expected floating-point or integer element type, got " 1905 << type.getElementType(); 1906 return nullptr; 1907 } 1908 return getFloatAttr(loc, type, floatTy); 1909 } 1910 1911 /// Build a Dense Integer attribute for the given type. 1912 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, 1913 ShapedType type, 1914 IntegerType eltTy) { 1915 std::vector<APInt> intElements; 1916 intElements.reserve(storage.size()); 1917 for (const auto &signAndToken : storage) { 1918 bool isNegative = signAndToken.first; 1919 const Token &token = signAndToken.second; 1920 1921 // Check to see if floating point values were parsed. 1922 if (token.is(Token::floatliteral)) { 1923 p.emitError() << "expected integer elements, but parsed floating-point"; 1924 return nullptr; 1925 } 1926 1927 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 1928 "unexpected token type"); 1929 if (token.isAny(Token::kw_true, Token::kw_false)) { 1930 if (!eltTy.isInteger(1)) 1931 p.emitError() << "expected i1 type for 'true' or 'false' values"; 1932 APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), 1933 /*isSigned=*/false); 1934 intElements.push_back(apInt); 1935 continue; 1936 } 1937 1938 // Create APInt values for each element with the correct bitwidth. 1939 auto val = token.getUInt64IntegerValue(); 1940 if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 1941 : (int64_t)val.getValue() < 0)) { 1942 p.emitError(token.getLoc(), 1943 "integer constant out of range for attribute"); 1944 return nullptr; 1945 } 1946 APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); 1947 if (apInt != val.getValue()) 1948 return (p.emitError("integer constant out of range for type"), nullptr); 1949 intElements.push_back(isNegative ? -apInt : apInt); 1950 } 1951 1952 return DenseElementsAttr::get(type, intElements); 1953 } 1954 1955 /// Build a Dense Float attribute for the given type. 1956 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, 1957 ShapedType type, 1958 FloatType eltTy) { 1959 std::vector<Attribute> floatValues; 1960 floatValues.reserve(storage.size()); 1961 for (const auto &signAndToken : storage) { 1962 bool isNegative = signAndToken.first; 1963 const Token &token = signAndToken.second; 1964 1965 // Handle hexadecimal float literals. 1966 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 1967 if (isNegative) { 1968 p.emitError(token.getLoc()) 1969 << "hexadecimal float literal should not have a leading minus"; 1970 return nullptr; 1971 } 1972 auto val = token.getUInt64IntegerValue(); 1973 if (!val.hasValue()) { 1974 p.emitError("hexadecimal float constant out of range for attribute"); 1975 return nullptr; 1976 } 1977 FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); 1978 if (!attr) 1979 return nullptr; 1980 floatValues.push_back(attr); 1981 continue; 1982 } 1983 1984 // Check to see if any decimal integers or booleans were parsed. 1985 if (!token.is(Token::floatliteral)) { 1986 p.emitError() << "expected floating-point elements, but parsed integer"; 1987 return nullptr; 1988 } 1989 1990 // Build the float values from tokens. 1991 auto val = token.getFloatingPointValue(); 1992 if (!val.hasValue()) { 1993 p.emitError("floating point value too large for attribute"); 1994 return nullptr; 1995 } 1996 floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); 1997 } 1998 1999 return DenseElementsAttr::get(type, floatValues); 2000 } 2001 2002 ParseResult TensorLiteralParser::parseElement() { 2003 switch (p.getToken().getKind()) { 2004 // Parse a boolean element. 2005 case Token::kw_true: 2006 case Token::kw_false: 2007 case Token::floatliteral: 2008 case Token::integer: 2009 storage.emplace_back(/*isNegative=*/false, p.getToken()); 2010 p.consumeToken(); 2011 break; 2012 2013 // Parse a signed integer or a negative floating-point element. 2014 case Token::minus: 2015 p.consumeToken(Token::minus); 2016 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 2017 return p.emitError("expected integer or floating point literal"); 2018 storage.emplace_back(/*isNegative=*/true, p.getToken()); 2019 p.consumeToken(); 2020 break; 2021 2022 default: 2023 return p.emitError("expected element literal of primitive type"); 2024 } 2025 2026 return success(); 2027 } 2028 2029 /// Parse a list of either lists or elements, returning the dimensions of the 2030 /// parsed sub-tensors in dims. For example: 2031 /// parseList([1, 2, 3]) -> Success, [3] 2032 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 2033 /// parseList([[1, 2], 3]) -> Failure 2034 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 2035 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 2036 p.consumeToken(Token::l_square); 2037 2038 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 2039 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 2040 if (prevDims == newDims) 2041 return success(); 2042 return p.emitError("tensor literal is invalid; ranks are not consistent " 2043 "between elements"); 2044 }; 2045 2046 bool first = true; 2047 SmallVector<int64_t, 4> newDims; 2048 unsigned size = 0; 2049 auto parseCommaSeparatedList = [&]() -> ParseResult { 2050 SmallVector<int64_t, 4> thisDims; 2051 if (p.getToken().getKind() == Token::l_square) { 2052 if (parseList(thisDims)) 2053 return failure(); 2054 } else if (parseElement()) { 2055 return failure(); 2056 } 2057 ++size; 2058 if (!first) 2059 return checkDims(newDims, thisDims); 2060 newDims = thisDims; 2061 first = false; 2062 return success(); 2063 }; 2064 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 2065 return failure(); 2066 2067 // Return the sublists' dimensions with 'size' prepended. 2068 dims.clear(); 2069 dims.push_back(size); 2070 dims.append(newDims.begin(), newDims.end()); 2071 return success(); 2072 } 2073 2074 /// Parse a dense elements attribute. 2075 Attribute Parser::parseDenseElementsAttr() { 2076 consumeToken(Token::kw_dense); 2077 if (parseToken(Token::less, "expected '<' after 'dense'")) 2078 return nullptr; 2079 2080 // Parse the literal data. 2081 TensorLiteralParser literalParser(*this); 2082 if (literalParser.parse()) 2083 return nullptr; 2084 2085 if (parseToken(Token::greater, "expected '>'") || 2086 parseToken(Token::colon, "expected ':'")) 2087 return nullptr; 2088 2089 auto typeLoc = getToken().getLoc(); 2090 auto type = parseElementsLiteralType(); 2091 if (!type) 2092 return nullptr; 2093 return literalParser.getAttr(typeLoc, type); 2094 } 2095 2096 /// Shaped type for elements attribute. 2097 /// 2098 /// elements-literal-type ::= vector-type | ranked-tensor-type 2099 /// 2100 /// This method also checks the type has static shape. 2101 ShapedType Parser::parseElementsLiteralType() { 2102 auto type = parseType(); 2103 if (!type) 2104 return nullptr; 2105 2106 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 2107 emitError("elements literal must be a ranked tensor or vector type"); 2108 return nullptr; 2109 } 2110 2111 auto sType = type.cast<ShapedType>(); 2112 if (!sType.hasStaticShape()) 2113 return (emitError("elements literal type must have static shape"), nullptr); 2114 2115 return sType; 2116 } 2117 2118 /// Parse a sparse elements attribute. 2119 Attribute Parser::parseSparseElementsAttr() { 2120 consumeToken(Token::kw_sparse); 2121 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 2122 return nullptr; 2123 2124 /// Parse indices 2125 auto indicesLoc = getToken().getLoc(); 2126 TensorLiteralParser indiceParser(*this); 2127 if (indiceParser.parse()) 2128 return nullptr; 2129 2130 if (parseToken(Token::comma, "expected ','")) 2131 return nullptr; 2132 2133 /// Parse values. 2134 auto valuesLoc = getToken().getLoc(); 2135 TensorLiteralParser valuesParser(*this); 2136 if (valuesParser.parse()) 2137 return nullptr; 2138 2139 if (parseToken(Token::greater, "expected '>'") || 2140 parseToken(Token::colon, "expected ':'")) 2141 return nullptr; 2142 2143 auto type = parseElementsLiteralType(); 2144 if (!type) 2145 return nullptr; 2146 2147 // If the indices are a splat, i.e. the literal parser parsed an element and 2148 // not a list, we set the shape explicitly. The indices are represented by a 2149 // 2-dimensional shape where the second dimension is the rank of the type. 2150 // Given that the parsed indices is a splat, we know that we only have one 2151 // indice and thus one for the first dimension. 2152 auto indiceEltType = builder.getIntegerType(64); 2153 ShapedType indicesType; 2154 if (indiceParser.getShape().empty()) { 2155 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 2156 } else { 2157 // Otherwise, set the shape to the one parsed by the literal parser. 2158 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 2159 } 2160 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 2161 2162 // If the values are a splat, set the shape explicitly based on the number of 2163 // indices. The number of indices is encoded in the first dimension of the 2164 // indice shape type. 2165 auto valuesEltType = type.getElementType(); 2166 ShapedType valuesType = 2167 valuesParser.getShape().empty() 2168 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 2169 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 2170 auto values = valuesParser.getAttr(valuesLoc, valuesType); 2171 2172 /// Sanity check. 2173 if (valuesType.getRank() != 1) 2174 return (emitError("expected 1-d tensor for values"), nullptr); 2175 2176 auto sameShape = (indicesType.getRank() == 1) || 2177 (type.getRank() == indicesType.getDimSize(1)); 2178 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 2179 if (!sameShape || !sameElementNum) { 2180 emitError() << "expected shape ([" << type.getShape() 2181 << "]); inferred shape of indices literal ([" 2182 << indicesType.getShape() 2183 << "]); inferred shape of values literal ([" 2184 << valuesType.getShape() << "])"; 2185 return nullptr; 2186 } 2187 2188 // Build the sparse elements attribute by the indices and values. 2189 return SparseElementsAttr::get(type, indices, values); 2190 } 2191 2192 //===----------------------------------------------------------------------===// 2193 // Location parsing. 2194 //===----------------------------------------------------------------------===// 2195 2196 /// Parse a location. 2197 /// 2198 /// location ::= `loc` inline-location 2199 /// inline-location ::= '(' location-inst ')' 2200 /// 2201 ParseResult Parser::parseLocation(LocationAttr &loc) { 2202 // Check for 'loc' identifier. 2203 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 2204 return emitError(); 2205 2206 // Parse the inline-location. 2207 if (parseToken(Token::l_paren, "expected '(' in inline location") || 2208 parseLocationInstance(loc) || 2209 parseToken(Token::r_paren, "expected ')' in inline location")) 2210 return failure(); 2211 return success(); 2212 } 2213 2214 /// Specific location instances. 2215 /// 2216 /// location-inst ::= filelinecol-location | 2217 /// name-location | 2218 /// callsite-location | 2219 /// fused-location | 2220 /// unknown-location 2221 /// filelinecol-location ::= string-literal ':' integer-literal 2222 /// ':' integer-literal 2223 /// name-location ::= string-literal 2224 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 2225 /// fused-location ::= fused ('<' attribute-value '>')? 2226 /// '[' location-inst (location-inst ',')* ']' 2227 /// unknown-location ::= 'unknown' 2228 /// 2229 ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { 2230 consumeToken(Token::bare_identifier); 2231 2232 // Parse the '('. 2233 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 2234 return failure(); 2235 2236 // Parse the callee location. 2237 LocationAttr calleeLoc; 2238 if (parseLocationInstance(calleeLoc)) 2239 return failure(); 2240 2241 // Parse the 'at'. 2242 if (getToken().isNot(Token::bare_identifier) || 2243 getToken().getSpelling() != "at") 2244 return emitError("expected 'at' in callsite location"); 2245 consumeToken(Token::bare_identifier); 2246 2247 // Parse the caller location. 2248 LocationAttr callerLoc; 2249 if (parseLocationInstance(callerLoc)) 2250 return failure(); 2251 2252 // Parse the ')'. 2253 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 2254 return failure(); 2255 2256 // Return the callsite location. 2257 loc = CallSiteLoc::get(calleeLoc, callerLoc); 2258 return success(); 2259 } 2260 2261 ParseResult Parser::parseFusedLocation(LocationAttr &loc) { 2262 consumeToken(Token::bare_identifier); 2263 2264 // Try to parse the optional metadata. 2265 Attribute metadata; 2266 if (consumeIf(Token::less)) { 2267 metadata = parseAttribute(); 2268 if (!metadata) 2269 return emitError("expected valid attribute metadata"); 2270 // Parse the '>' token. 2271 if (parseToken(Token::greater, 2272 "expected '>' after fused location metadata")) 2273 return failure(); 2274 } 2275 2276 SmallVector<Location, 4> locations; 2277 auto parseElt = [&] { 2278 LocationAttr newLoc; 2279 if (parseLocationInstance(newLoc)) 2280 return failure(); 2281 locations.push_back(newLoc); 2282 return success(); 2283 }; 2284 2285 if (parseToken(Token::l_square, "expected '[' in fused location") || 2286 parseCommaSeparatedList(parseElt) || 2287 parseToken(Token::r_square, "expected ']' in fused location")) 2288 return failure(); 2289 2290 // Return the fused location. 2291 loc = FusedLoc::get(locations, metadata, getContext()); 2292 return success(); 2293 } 2294 2295 ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { 2296 auto *ctx = getContext(); 2297 auto str = getToken().getStringValue(); 2298 consumeToken(Token::string); 2299 2300 // If the next token is ':' this is a filelinecol location. 2301 if (consumeIf(Token::colon)) { 2302 // Parse the line number. 2303 if (getToken().isNot(Token::integer)) 2304 return emitError("expected integer line number in FileLineColLoc"); 2305 auto line = getToken().getUnsignedIntegerValue(); 2306 if (!line.hasValue()) 2307 return emitError("expected integer line number in FileLineColLoc"); 2308 consumeToken(Token::integer); 2309 2310 // Parse the ':'. 2311 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 2312 return failure(); 2313 2314 // Parse the column number. 2315 if (getToken().isNot(Token::integer)) 2316 return emitError("expected integer column number in FileLineColLoc"); 2317 auto column = getToken().getUnsignedIntegerValue(); 2318 if (!column.hasValue()) 2319 return emitError("expected integer column number in FileLineColLoc"); 2320 consumeToken(Token::integer); 2321 2322 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 2323 return success(); 2324 } 2325 2326 // Otherwise, this is a NameLoc. 2327 2328 // Check for a child location. 2329 if (consumeIf(Token::l_paren)) { 2330 auto childSourceLoc = getToken().getLoc(); 2331 2332 // Parse the child location. 2333 LocationAttr childLoc; 2334 if (parseLocationInstance(childLoc)) 2335 return failure(); 2336 2337 // The child must not be another NameLoc. 2338 if (childLoc.isa<NameLoc>()) 2339 return emitError(childSourceLoc, 2340 "child of NameLoc cannot be another NameLoc"); 2341 loc = NameLoc::get(Identifier::get(str, ctx), childLoc); 2342 2343 // Parse the closing ')'. 2344 if (parseToken(Token::r_paren, 2345 "expected ')' after child location of NameLoc")) 2346 return failure(); 2347 } else { 2348 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 2349 } 2350 2351 return success(); 2352 } 2353 2354 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 2355 // Handle either name or filelinecol locations. 2356 if (getToken().is(Token::string)) 2357 return parseNameOrFileLineColLocation(loc); 2358 2359 // Bare tokens required for other cases. 2360 if (!getToken().is(Token::bare_identifier)) 2361 return emitError("expected location instance"); 2362 2363 // Check for the 'callsite' signifying a callsite location. 2364 if (getToken().getSpelling() == "callsite") 2365 return parseCallSiteLocation(loc); 2366 2367 // If the token is 'fused', then this is a fused location. 2368 if (getToken().getSpelling() == "fused") 2369 return parseFusedLocation(loc); 2370 2371 // Check for a 'unknown' for an unknown location. 2372 if (getToken().getSpelling() == "unknown") { 2373 consumeToken(Token::bare_identifier); 2374 loc = UnknownLoc::get(getContext()); 2375 return success(); 2376 } 2377 2378 return emitError("expected location instance"); 2379 } 2380 2381 //===----------------------------------------------------------------------===// 2382 // Affine parsing. 2383 //===----------------------------------------------------------------------===// 2384 2385 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 2386 /// the boolean sense. 2387 enum AffineLowPrecOp { 2388 /// Null value. 2389 LNoOp, 2390 Add, 2391 Sub 2392 }; 2393 2394 /// Higher precedence ops - all at the same precedence level. HNoOp is false 2395 /// in the boolean sense. 2396 enum AffineHighPrecOp { 2397 /// Null value. 2398 HNoOp, 2399 Mul, 2400 FloorDiv, 2401 CeilDiv, 2402 Mod 2403 }; 2404 2405 namespace { 2406 /// This is a specialized parser for affine structures (affine maps, affine 2407 /// expressions, and integer sets), maintaining the state transient to their 2408 /// bodies. 2409 class AffineParser : public Parser { 2410 public: 2411 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 2412 function_ref<ParseResult(bool)> parseElement = nullptr) 2413 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 2414 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 2415 2416 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 2417 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 2418 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 2419 ParseResult parseAffineMapOfSSAIds(AffineMap &map); 2420 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 2421 unsigned &numDims); 2422 2423 private: 2424 // Binary affine op parsing. 2425 AffineLowPrecOp consumeIfLowPrecOp(); 2426 AffineHighPrecOp consumeIfHighPrecOp(); 2427 2428 // Identifier lists for polyhedral structures. 2429 ParseResult parseDimIdList(unsigned &numDims); 2430 ParseResult parseSymbolIdList(unsigned &numSymbols); 2431 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 2432 unsigned &numSymbols); 2433 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 2434 2435 AffineExpr parseAffineExpr(); 2436 AffineExpr parseParentheticalExpr(); 2437 AffineExpr parseNegateExpression(AffineExpr lhs); 2438 AffineExpr parseIntegerExpr(); 2439 AffineExpr parseBareIdExpr(); 2440 AffineExpr parseSSAIdExpr(bool isSymbol); 2441 AffineExpr parseSymbolSSAIdExpr(); 2442 2443 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 2444 AffineExpr rhs, SMLoc opLoc); 2445 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 2446 AffineExpr rhs); 2447 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 2448 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 2449 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 2450 SMLoc llhsOpLoc); 2451 AffineExpr parseAffineConstraint(bool *isEq); 2452 2453 private: 2454 bool allowParsingSSAIds; 2455 function_ref<ParseResult(bool)> parseElement; 2456 unsigned numDimOperands; 2457 unsigned numSymbolOperands; 2458 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 2459 }; 2460 } // end anonymous namespace 2461 2462 /// Create an affine binary high precedence op expression (mul's, div's, mod). 2463 /// opLoc is the location of the op token to be used to report errors 2464 /// for non-conforming expressions. 2465 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 2466 AffineExpr lhs, AffineExpr rhs, 2467 SMLoc opLoc) { 2468 // TODO: make the error location info accurate. 2469 switch (op) { 2470 case Mul: 2471 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 2472 emitError(opLoc, "non-affine expression: at least one of the multiply " 2473 "operands has to be either a constant or symbolic"); 2474 return nullptr; 2475 } 2476 return lhs * rhs; 2477 case FloorDiv: 2478 if (!rhs.isSymbolicOrConstant()) { 2479 emitError(opLoc, "non-affine expression: right operand of floordiv " 2480 "has to be either a constant or symbolic"); 2481 return nullptr; 2482 } 2483 return lhs.floorDiv(rhs); 2484 case CeilDiv: 2485 if (!rhs.isSymbolicOrConstant()) { 2486 emitError(opLoc, "non-affine expression: right operand of ceildiv " 2487 "has to be either a constant or symbolic"); 2488 return nullptr; 2489 } 2490 return lhs.ceilDiv(rhs); 2491 case Mod: 2492 if (!rhs.isSymbolicOrConstant()) { 2493 emitError(opLoc, "non-affine expression: right operand of mod " 2494 "has to be either a constant or symbolic"); 2495 return nullptr; 2496 } 2497 return lhs % rhs; 2498 case HNoOp: 2499 llvm_unreachable("can't create affine expression for null high prec op"); 2500 return nullptr; 2501 } 2502 llvm_unreachable("Unknown AffineHighPrecOp"); 2503 } 2504 2505 /// Create an affine binary low precedence op expression (add, sub). 2506 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 2507 AffineExpr lhs, AffineExpr rhs) { 2508 switch (op) { 2509 case AffineLowPrecOp::Add: 2510 return lhs + rhs; 2511 case AffineLowPrecOp::Sub: 2512 return lhs - rhs; 2513 case AffineLowPrecOp::LNoOp: 2514 llvm_unreachable("can't create affine expression for null low prec op"); 2515 return nullptr; 2516 } 2517 llvm_unreachable("Unknown AffineLowPrecOp"); 2518 } 2519 2520 /// Consume this token if it is a lower precedence affine op (there are only 2521 /// two precedence levels). 2522 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 2523 switch (getToken().getKind()) { 2524 case Token::plus: 2525 consumeToken(Token::plus); 2526 return AffineLowPrecOp::Add; 2527 case Token::minus: 2528 consumeToken(Token::minus); 2529 return AffineLowPrecOp::Sub; 2530 default: 2531 return AffineLowPrecOp::LNoOp; 2532 } 2533 } 2534 2535 /// Consume this token if it is a higher precedence affine op (there are only 2536 /// two precedence levels) 2537 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 2538 switch (getToken().getKind()) { 2539 case Token::star: 2540 consumeToken(Token::star); 2541 return Mul; 2542 case Token::kw_floordiv: 2543 consumeToken(Token::kw_floordiv); 2544 return FloorDiv; 2545 case Token::kw_ceildiv: 2546 consumeToken(Token::kw_ceildiv); 2547 return CeilDiv; 2548 case Token::kw_mod: 2549 consumeToken(Token::kw_mod); 2550 return Mod; 2551 default: 2552 return HNoOp; 2553 } 2554 } 2555 2556 /// Parse a high precedence op expression list: mul, div, and mod are high 2557 /// precedence binary ops, i.e., parse a 2558 /// expr_1 op_1 expr_2 op_2 ... expr_n 2559 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2560 /// All affine binary ops are left associative. 2561 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2562 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2563 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2564 /// report an error for non-conforming expressions. 2565 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2566 AffineHighPrecOp llhsOp, 2567 SMLoc llhsOpLoc) { 2568 AffineExpr lhs = parseAffineOperandExpr(llhs); 2569 if (!lhs) 2570 return nullptr; 2571 2572 // Found an LHS. Parse the remaining expression. 2573 auto opLoc = getToken().getLoc(); 2574 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2575 if (llhs) { 2576 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2577 if (!expr) 2578 return nullptr; 2579 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2580 } 2581 // No LLHS, get RHS 2582 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2583 } 2584 2585 // This is the last operand in this expression. 2586 if (llhs) 2587 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2588 2589 // No llhs, 'lhs' itself is the expression. 2590 return lhs; 2591 } 2592 2593 /// Parse an affine expression inside parentheses. 2594 /// 2595 /// affine-expr ::= `(` affine-expr `)` 2596 AffineExpr AffineParser::parseParentheticalExpr() { 2597 if (parseToken(Token::l_paren, "expected '('")) 2598 return nullptr; 2599 if (getToken().is(Token::r_paren)) 2600 return (emitError("no expression inside parentheses"), nullptr); 2601 2602 auto expr = parseAffineExpr(); 2603 if (!expr) 2604 return nullptr; 2605 if (parseToken(Token::r_paren, "expected ')'")) 2606 return nullptr; 2607 2608 return expr; 2609 } 2610 2611 /// Parse the negation expression. 2612 /// 2613 /// affine-expr ::= `-` affine-expr 2614 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2615 if (parseToken(Token::minus, "expected '-'")) 2616 return nullptr; 2617 2618 AffineExpr operand = parseAffineOperandExpr(lhs); 2619 // Since negation has the highest precedence of all ops (including high 2620 // precedence ops) but lower than parentheses, we are only going to use 2621 // parseAffineOperandExpr instead of parseAffineExpr here. 2622 if (!operand) 2623 // Extra error message although parseAffineOperandExpr would have 2624 // complained. Leads to a better diagnostic. 2625 return (emitError("missing operand of negation"), nullptr); 2626 return (-1) * operand; 2627 } 2628 2629 /// Parse a bare id that may appear in an affine expression. 2630 /// 2631 /// affine-expr ::= bare-id 2632 AffineExpr AffineParser::parseBareIdExpr() { 2633 if (getToken().isNot(Token::bare_identifier)) 2634 return (emitError("expected bare identifier"), nullptr); 2635 2636 StringRef sRef = getTokenSpelling(); 2637 for (auto entry : dimsAndSymbols) { 2638 if (entry.first == sRef) { 2639 consumeToken(Token::bare_identifier); 2640 return entry.second; 2641 } 2642 } 2643 2644 return (emitError("use of undeclared identifier"), nullptr); 2645 } 2646 2647 /// Parse an SSA id which may appear in an affine expression. 2648 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2649 if (!allowParsingSSAIds) 2650 return (emitError("unexpected ssa identifier"), nullptr); 2651 if (getToken().isNot(Token::percent_identifier)) 2652 return (emitError("expected ssa identifier"), nullptr); 2653 auto name = getTokenSpelling(); 2654 // Check if we already parsed this SSA id. 2655 for (auto entry : dimsAndSymbols) { 2656 if (entry.first == name) { 2657 consumeToken(Token::percent_identifier); 2658 return entry.second; 2659 } 2660 } 2661 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2662 if (parseElement(isSymbol)) 2663 return (emitError("failed to parse ssa identifier"), nullptr); 2664 auto idExpr = isSymbol 2665 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2666 : getAffineDimExpr(numDimOperands++, getContext()); 2667 dimsAndSymbols.push_back({name, idExpr}); 2668 return idExpr; 2669 } 2670 2671 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2672 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2673 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2674 return nullptr; 2675 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2676 if (!symbolExpr) 2677 return nullptr; 2678 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2679 return nullptr; 2680 return symbolExpr; 2681 } 2682 2683 /// Parse a positive integral constant appearing in an affine expression. 2684 /// 2685 /// affine-expr ::= integer-literal 2686 AffineExpr AffineParser::parseIntegerExpr() { 2687 auto val = getToken().getUInt64IntegerValue(); 2688 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2689 return (emitError("constant too large for index"), nullptr); 2690 2691 consumeToken(Token::integer); 2692 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2693 } 2694 2695 /// Parses an expression that can be a valid operand of an affine expression. 2696 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2697 /// operator, the rhs of which is being parsed. This is used to determine 2698 /// whether an error should be emitted for a missing right operand. 2699 // Eg: for an expression without parentheses (like i + j + k + l), each 2700 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2701 // operand expression, it's an op expression and will be parsed via 2702 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2703 // -l are valid operands that will be parsed by this function. 2704 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2705 switch (getToken().getKind()) { 2706 case Token::bare_identifier: 2707 return parseBareIdExpr(); 2708 case Token::kw_symbol: 2709 return parseSymbolSSAIdExpr(); 2710 case Token::percent_identifier: 2711 return parseSSAIdExpr(/*isSymbol=*/false); 2712 case Token::integer: 2713 return parseIntegerExpr(); 2714 case Token::l_paren: 2715 return parseParentheticalExpr(); 2716 case Token::minus: 2717 return parseNegateExpression(lhs); 2718 case Token::kw_ceildiv: 2719 case Token::kw_floordiv: 2720 case Token::kw_mod: 2721 case Token::plus: 2722 case Token::star: 2723 if (lhs) 2724 emitError("missing right operand of binary operator"); 2725 else 2726 emitError("missing left operand of binary operator"); 2727 return nullptr; 2728 default: 2729 if (lhs) 2730 emitError("missing right operand of binary operator"); 2731 else 2732 emitError("expected affine expression"); 2733 return nullptr; 2734 } 2735 } 2736 2737 /// Parse affine expressions that are bare-id's, integer constants, 2738 /// parenthetical affine expressions, and affine op expressions that are a 2739 /// composition of those. 2740 /// 2741 /// All binary op's associate from left to right. 2742 /// 2743 /// {add, sub} have lower precedence than {mul, div, and mod}. 2744 /// 2745 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2746 /// ceildiv, and mod are at the same higher precedence level. Negation has 2747 /// higher precedence than any binary op. 2748 /// 2749 /// llhs: the affine expression appearing on the left of the one being parsed. 2750 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2751 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2752 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2753 /// associativity. 2754 /// 2755 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2756 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2757 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2758 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2759 AffineLowPrecOp llhsOp) { 2760 AffineExpr lhs; 2761 if (!(lhs = parseAffineOperandExpr(llhs))) 2762 return nullptr; 2763 2764 // Found an LHS. Deal with the ops. 2765 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 2766 if (llhs) { 2767 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2768 return parseAffineLowPrecOpExpr(sum, lOp); 2769 } 2770 // No LLHS, get RHS and form the expression. 2771 return parseAffineLowPrecOpExpr(lhs, lOp); 2772 } 2773 auto opLoc = getToken().getLoc(); 2774 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 2775 // We have a higher precedence op here. Get the rhs operand for the llhs 2776 // through parseAffineHighPrecOpExpr. 2777 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 2778 if (!highRes) 2779 return nullptr; 2780 2781 // If llhs is null, the product forms the first operand of the yet to be 2782 // found expression. If non-null, the op to associate with llhs is llhsOp. 2783 AffineExpr expr = 2784 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 2785 2786 // Recurse for subsequent low prec op's after the affine high prec op 2787 // expression. 2788 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 2789 return parseAffineLowPrecOpExpr(expr, nextOp); 2790 return expr; 2791 } 2792 // Last operand in the expression list. 2793 if (llhs) 2794 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2795 // No llhs, 'lhs' itself is the expression. 2796 return lhs; 2797 } 2798 2799 /// Parse an affine expression. 2800 /// affine-expr ::= `(` affine-expr `)` 2801 /// | `-` affine-expr 2802 /// | affine-expr `+` affine-expr 2803 /// | affine-expr `-` affine-expr 2804 /// | affine-expr `*` affine-expr 2805 /// | affine-expr `floordiv` affine-expr 2806 /// | affine-expr `ceildiv` affine-expr 2807 /// | affine-expr `mod` affine-expr 2808 /// | bare-id 2809 /// | integer-literal 2810 /// 2811 /// Additional conditions are checked depending on the production. For eg., 2812 /// one of the operands for `*` has to be either constant/symbolic; the second 2813 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 2814 AffineExpr AffineParser::parseAffineExpr() { 2815 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 2816 } 2817 2818 /// Parse a dim or symbol from the lists appearing before the actual 2819 /// expressions of the affine map. Update our state to store the 2820 /// dimensional/symbolic identifier. 2821 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 2822 if (getToken().isNot(Token::bare_identifier)) 2823 return emitError("expected bare identifier"); 2824 2825 auto name = getTokenSpelling(); 2826 for (auto entry : dimsAndSymbols) { 2827 if (entry.first == name) 2828 return emitError("redefinition of identifier '" + name + "'"); 2829 } 2830 consumeToken(Token::bare_identifier); 2831 2832 dimsAndSymbols.push_back({name, idExpr}); 2833 return success(); 2834 } 2835 2836 /// Parse the list of dimensional identifiers to an affine map. 2837 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 2838 if (parseToken(Token::l_paren, 2839 "expected '(' at start of dimensional identifiers list")) { 2840 return failure(); 2841 } 2842 2843 auto parseElt = [&]() -> ParseResult { 2844 auto dimension = getAffineDimExpr(numDims++, getContext()); 2845 return parseIdentifierDefinition(dimension); 2846 }; 2847 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 2848 } 2849 2850 /// Parse the list of symbolic identifiers to an affine map. 2851 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 2852 consumeToken(Token::l_square); 2853 auto parseElt = [&]() -> ParseResult { 2854 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 2855 return parseIdentifierDefinition(symbol); 2856 }; 2857 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 2858 } 2859 2860 /// Parse the list of symbolic identifiers to an affine map. 2861 ParseResult 2862 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 2863 unsigned &numSymbols) { 2864 if (parseDimIdList(numDims)) { 2865 return failure(); 2866 } 2867 if (!getToken().is(Token::l_square)) { 2868 numSymbols = 0; 2869 return success(); 2870 } 2871 return parseSymbolIdList(numSymbols); 2872 } 2873 2874 /// Parses an ambiguous affine map or integer set definition inline. 2875 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 2876 IntegerSet &set) { 2877 unsigned numDims = 0, numSymbols = 0; 2878 2879 // List of dimensional and optional symbol identifiers. 2880 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 2881 return failure(); 2882 } 2883 2884 // This is needed for parsing attributes as we wouldn't know whether we would 2885 // be parsing an integer set attribute or an affine map attribute. 2886 bool isArrow = getToken().is(Token::arrow); 2887 bool isColon = getToken().is(Token::colon); 2888 if (!isArrow && !isColon) { 2889 return emitError("expected '->' or ':'"); 2890 } else if (isArrow) { 2891 parseToken(Token::arrow, "expected '->' or '['"); 2892 map = parseAffineMapRange(numDims, numSymbols); 2893 return map ? success() : failure(); 2894 } else if (parseToken(Token::colon, "expected ':' or '['")) { 2895 return failure(); 2896 } 2897 2898 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 2899 return success(); 2900 2901 return failure(); 2902 } 2903 2904 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 2905 ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) { 2906 if (parseToken(Token::l_square, "expected '['")) 2907 return failure(); 2908 2909 SmallVector<AffineExpr, 4> exprs; 2910 auto parseElt = [&]() -> ParseResult { 2911 auto elt = parseAffineExpr(); 2912 exprs.push_back(elt); 2913 return elt ? success() : failure(); 2914 }; 2915 2916 // Parse a multi-dimensional affine expression (a comma-separated list of 2917 // 1-d affine expressions); the list cannot be empty. Grammar: 2918 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2919 if (parseCommaSeparatedListUntil(Token::r_square, parseElt, 2920 /*allowEmptyList=*/true)) 2921 return failure(); 2922 // Parsed a valid affine map. 2923 if (exprs.empty()) 2924 map = AffineMap::get(getContext()); 2925 else 2926 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 2927 exprs); 2928 return success(); 2929 } 2930 2931 /// Parse the range and sizes affine map definition inline. 2932 /// 2933 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 2934 /// 2935 /// multi-dim-affine-expr ::= `(` `)` 2936 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` 2937 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 2938 unsigned numSymbols) { 2939 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 2940 2941 SmallVector<AffineExpr, 4> exprs; 2942 auto parseElt = [&]() -> ParseResult { 2943 auto elt = parseAffineExpr(); 2944 ParseResult res = elt ? success() : failure(); 2945 exprs.push_back(elt); 2946 return res; 2947 }; 2948 2949 // Parse a multi-dimensional affine expression (a comma-separated list of 2950 // 1-d affine expressions); the list cannot be empty. Grammar: 2951 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2952 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 2953 return AffineMap(); 2954 2955 if (exprs.empty()) 2956 return AffineMap::get(getContext()); 2957 2958 // Parsed a valid affine map. 2959 return AffineMap::get(numDims, numSymbols, exprs); 2960 } 2961 2962 /// Parse an affine constraint. 2963 /// affine-constraint ::= affine-expr `>=` `0` 2964 /// | affine-expr `==` `0` 2965 /// 2966 /// isEq is set to true if the parsed constraint is an equality, false if it 2967 /// is an inequality (greater than or equal). 2968 /// 2969 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 2970 AffineExpr expr = parseAffineExpr(); 2971 if (!expr) 2972 return nullptr; 2973 2974 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 2975 getToken().is(Token::integer)) { 2976 auto dim = getToken().getUnsignedIntegerValue(); 2977 if (dim.hasValue() && dim.getValue() == 0) { 2978 consumeToken(Token::integer); 2979 *isEq = false; 2980 return expr; 2981 } 2982 return (emitError("expected '0' after '>='"), nullptr); 2983 } 2984 2985 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 2986 getToken().is(Token::integer)) { 2987 auto dim = getToken().getUnsignedIntegerValue(); 2988 if (dim.hasValue() && dim.getValue() == 0) { 2989 consumeToken(Token::integer); 2990 *isEq = true; 2991 return expr; 2992 } 2993 return (emitError("expected '0' after '=='"), nullptr); 2994 } 2995 2996 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 2997 nullptr); 2998 } 2999 3000 /// Parse the constraints that are part of an integer set definition. 3001 /// integer-set-inline 3002 /// ::= dim-and-symbol-id-lists `:` 3003 /// '(' affine-constraint-conjunction? ')' 3004 /// affine-constraint-conjunction ::= affine-constraint (`,` 3005 /// affine-constraint)* 3006 /// 3007 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 3008 unsigned numSymbols) { 3009 if (parseToken(Token::l_paren, 3010 "expected '(' at start of integer set constraint list")) 3011 return IntegerSet(); 3012 3013 SmallVector<AffineExpr, 4> constraints; 3014 SmallVector<bool, 4> isEqs; 3015 auto parseElt = [&]() -> ParseResult { 3016 bool isEq; 3017 auto elt = parseAffineConstraint(&isEq); 3018 ParseResult res = elt ? success() : failure(); 3019 if (elt) { 3020 constraints.push_back(elt); 3021 isEqs.push_back(isEq); 3022 } 3023 return res; 3024 }; 3025 3026 // Parse a list of affine constraints (comma-separated). 3027 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3028 return IntegerSet(); 3029 3030 // If no constraints were parsed, then treat this as a degenerate 'true' case. 3031 if (constraints.empty()) { 3032 /* 0 == 0 */ 3033 auto zero = getAffineConstantExpr(0, getContext()); 3034 return IntegerSet::get(numDims, numSymbols, zero, true); 3035 } 3036 3037 // Parsed a valid integer set. 3038 return IntegerSet::get(numDims, numSymbols, constraints, isEqs); 3039 } 3040 3041 /// Parse an ambiguous reference to either and affine map or an integer set. 3042 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 3043 IntegerSet &set) { 3044 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 3045 } 3046 3047 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 3048 /// parse SSA value uses encountered while parsing affine expressions. 3049 ParseResult 3050 Parser::parseAffineMapOfSSAIds(AffineMap &map, 3051 function_ref<ParseResult(bool)> parseElement) { 3052 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 3053 .parseAffineMapOfSSAIds(map); 3054 } 3055 3056 //===----------------------------------------------------------------------===// 3057 // OperationParser 3058 //===----------------------------------------------------------------------===// 3059 3060 namespace { 3061 /// This class provides support for parsing operations and regions of 3062 /// operations. 3063 class OperationParser : public Parser { 3064 public: 3065 OperationParser(ParserState &state, ModuleOp moduleOp) 3066 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 3067 } 3068 3069 ~OperationParser(); 3070 3071 /// After parsing is finished, this function must be called to see if there 3072 /// are any remaining issues. 3073 ParseResult finalize(); 3074 3075 //===--------------------------------------------------------------------===// 3076 // SSA Value Handling 3077 //===--------------------------------------------------------------------===// 3078 3079 /// This represents a use of an SSA value in the program. The first two 3080 /// entries in the tuple are the name and result number of a reference. The 3081 /// third is the location of the reference, which is used in case this ends 3082 /// up being a use of an undefined value. 3083 struct SSAUseInfo { 3084 StringRef name; // Value name, e.g. %42 or %abc 3085 unsigned number; // Number, specified with #12 3086 SMLoc loc; // Location of first definition or use. 3087 }; 3088 3089 /// Push a new SSA name scope to the parser. 3090 void pushSSANameScope(bool isIsolated); 3091 3092 /// Pop the last SSA name scope from the parser. 3093 ParseResult popSSANameScope(); 3094 3095 /// Register a definition of a value with the symbol table. 3096 ParseResult addDefinition(SSAUseInfo useInfo, Value *value); 3097 3098 /// Parse an optional list of SSA uses into 'results'. 3099 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 3100 3101 /// Parse a single SSA use into 'result'. 3102 ParseResult parseSSAUse(SSAUseInfo &result); 3103 3104 /// Given a reference to an SSA value and its type, return a reference. This 3105 /// returns null on failure. 3106 Value *resolveSSAUse(SSAUseInfo useInfo, Type type); 3107 3108 ParseResult parseSSADefOrUseAndType( 3109 const std::function<ParseResult(SSAUseInfo, Type)> &action); 3110 3111 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value *> &results); 3112 3113 /// Return the location of the value identified by its name and number if it 3114 /// has been already reference. 3115 Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 3116 auto &values = isolatedNameScopes.back().values; 3117 if (!values.count(name) || number >= values[name].size()) 3118 return {}; 3119 if (values[name][number].first) 3120 return values[name][number].second; 3121 return {}; 3122 } 3123 3124 //===--------------------------------------------------------------------===// 3125 // Operation Parsing 3126 //===--------------------------------------------------------------------===// 3127 3128 /// Parse an operation instance. 3129 ParseResult parseOperation(); 3130 3131 /// Parse a single operation successor and its operand list. 3132 ParseResult parseSuccessorAndUseList(Block *&dest, 3133 SmallVectorImpl<Value *> &operands); 3134 3135 /// Parse a comma-separated list of operation successors in brackets. 3136 ParseResult 3137 parseSuccessors(SmallVectorImpl<Block *> &destinations, 3138 SmallVectorImpl<SmallVector<Value *, 4>> &operands); 3139 3140 /// Parse an operation instance that is in the generic form. 3141 Operation *parseGenericOperation(); 3142 3143 /// Parse an operation instance that is in the generic form and insert it at 3144 /// the provided insertion point. 3145 Operation *parseGenericOperation(Block *insertBlock, 3146 Block::iterator insertPt); 3147 3148 /// Parse an operation instance that is in the op-defined custom form. 3149 Operation *parseCustomOperation(); 3150 3151 //===--------------------------------------------------------------------===// 3152 // Region Parsing 3153 //===--------------------------------------------------------------------===// 3154 3155 /// Parse a region into 'region' with the provided entry block arguments. 3156 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 3157 /// isolated from those above. 3158 ParseResult parseRegion(Region ®ion, 3159 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 3160 bool isIsolatedNameScope = false); 3161 3162 /// Parse a region body into 'region'. 3163 ParseResult parseRegionBody(Region ®ion); 3164 3165 //===--------------------------------------------------------------------===// 3166 // Block Parsing 3167 //===--------------------------------------------------------------------===// 3168 3169 /// Parse a new block into 'block'. 3170 ParseResult parseBlock(Block *&block); 3171 3172 /// Parse a list of operations into 'block'. 3173 ParseResult parseBlockBody(Block *block); 3174 3175 /// Parse a (possibly empty) list of block arguments. 3176 ParseResult 3177 parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results, 3178 Block *owner); 3179 3180 /// Get the block with the specified name, creating it if it doesn't 3181 /// already exist. The location specified is the point of use, which allows 3182 /// us to diagnose references to blocks that are not defined precisely. 3183 Block *getBlockNamed(StringRef name, SMLoc loc); 3184 3185 /// Define the block with the specified name. Returns the Block* or nullptr in 3186 /// the case of redefinition. 3187 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 3188 3189 private: 3190 /// Returns the info for a block at the current scope for the given name. 3191 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 3192 return blocksByName.back()[name]; 3193 } 3194 3195 /// Insert a new forward reference to the given block. 3196 void insertForwardRef(Block *block, SMLoc loc) { 3197 forwardRef.back().try_emplace(block, loc); 3198 } 3199 3200 /// Erase any forward reference to the given block. 3201 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 3202 3203 /// Record that a definition was added at the current scope. 3204 void recordDefinition(StringRef def); 3205 3206 /// Get the value entry for the given SSA name. 3207 SmallVectorImpl<std::pair<Value *, SMLoc>> &getSSAValueEntry(StringRef name); 3208 3209 /// Create a forward reference placeholder value with the given location and 3210 /// result type. 3211 Value *createForwardRefPlaceholder(SMLoc loc, Type type); 3212 3213 /// Return true if this is a forward reference. 3214 bool isForwardRefPlaceholder(Value *value) { 3215 return forwardRefPlaceholders.count(value); 3216 } 3217 3218 /// This struct represents an isolated SSA name scope. This scope may contain 3219 /// other nested non-isolated scopes. These scopes are used for operations 3220 /// that are known to be isolated to allow for reusing names within their 3221 /// regions, even if those names are used above. 3222 struct IsolatedSSANameScope { 3223 /// Record that a definition was added at the current scope. 3224 void recordDefinition(StringRef def) { 3225 definitionsPerScope.back().insert(def); 3226 } 3227 3228 /// Push a nested name scope. 3229 void pushSSANameScope() { definitionsPerScope.push_back({}); } 3230 3231 /// Pop a nested name scope. 3232 void popSSANameScope() { 3233 for (auto &def : definitionsPerScope.pop_back_val()) 3234 values.erase(def.getKey()); 3235 } 3236 3237 /// This keeps track of all of the SSA values we are tracking for each name 3238 /// scope, indexed by their name. This has one entry per result number. 3239 llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values; 3240 3241 /// This keeps track of all of the values defined by a specific name scope. 3242 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 3243 }; 3244 3245 /// A list of isolated name scopes. 3246 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 3247 3248 /// This keeps track of the block names as well as the location of the first 3249 /// reference for each nested name scope. This is used to diagnose invalid 3250 /// block references and memorize them. 3251 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 3252 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 3253 3254 /// These are all of the placeholders we've made along with the location of 3255 /// their first reference, to allow checking for use of undefined values. 3256 DenseMap<Value *, SMLoc> forwardRefPlaceholders; 3257 3258 /// The builder used when creating parsed operation instances. 3259 OpBuilder opBuilder; 3260 3261 /// The top level module operation. 3262 ModuleOp moduleOp; 3263 }; 3264 } // end anonymous namespace 3265 3266 OperationParser::~OperationParser() { 3267 for (auto &fwd : forwardRefPlaceholders) { 3268 // Drop all uses of undefined forward declared reference and destroy 3269 // defining operation. 3270 fwd.first->dropAllUses(); 3271 fwd.first->getDefiningOp()->destroy(); 3272 } 3273 } 3274 3275 /// After parsing is finished, this function must be called to see if there are 3276 /// any remaining issues. 3277 ParseResult OperationParser::finalize() { 3278 // Check for any forward references that are left. If we find any, error 3279 // out. 3280 if (!forwardRefPlaceholders.empty()) { 3281 SmallVector<std::pair<const char *, Value *>, 4> errors; 3282 // Iteration over the map isn't deterministic, so sort by source location. 3283 for (auto entry : forwardRefPlaceholders) 3284 errors.push_back({entry.second.getPointer(), entry.first}); 3285 llvm::array_pod_sort(errors.begin(), errors.end()); 3286 3287 for (auto entry : errors) { 3288 auto loc = SMLoc::getFromPointer(entry.first); 3289 emitError(loc, "use of undeclared SSA value name"); 3290 } 3291 return failure(); 3292 } 3293 3294 return success(); 3295 } 3296 3297 //===----------------------------------------------------------------------===// 3298 // SSA Value Handling 3299 //===----------------------------------------------------------------------===// 3300 3301 void OperationParser::pushSSANameScope(bool isIsolated) { 3302 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 3303 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 3304 3305 // Push back a new name definition scope. 3306 if (isIsolated) 3307 isolatedNameScopes.push_back({}); 3308 isolatedNameScopes.back().pushSSANameScope(); 3309 } 3310 3311 ParseResult OperationParser::popSSANameScope() { 3312 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 3313 3314 // Verify that all referenced blocks were defined. 3315 if (!forwardRefInCurrentScope.empty()) { 3316 SmallVector<std::pair<const char *, Block *>, 4> errors; 3317 // Iteration over the map isn't deterministic, so sort by source location. 3318 for (auto entry : forwardRefInCurrentScope) { 3319 errors.push_back({entry.second.getPointer(), entry.first}); 3320 // Add this block to the top-level region to allow for automatic cleanup. 3321 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 3322 } 3323 llvm::array_pod_sort(errors.begin(), errors.end()); 3324 3325 for (auto entry : errors) { 3326 auto loc = SMLoc::getFromPointer(entry.first); 3327 emitError(loc, "reference to an undefined block"); 3328 } 3329 return failure(); 3330 } 3331 3332 // Pop the next nested namescope. If there is only one internal namescope, 3333 // just pop the isolated scope. 3334 auto ¤tNameScope = isolatedNameScopes.back(); 3335 if (currentNameScope.definitionsPerScope.size() == 1) 3336 isolatedNameScopes.pop_back(); 3337 else 3338 currentNameScope.popSSANameScope(); 3339 3340 blocksByName.pop_back(); 3341 return success(); 3342 } 3343 3344 /// Register a definition of a value with the symbol table. 3345 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value *value) { 3346 auto &entries = getSSAValueEntry(useInfo.name); 3347 3348 // Make sure there is a slot for this value. 3349 if (entries.size() <= useInfo.number) 3350 entries.resize(useInfo.number + 1); 3351 3352 // If we already have an entry for this, check to see if it was a definition 3353 // or a forward reference. 3354 if (auto *existing = entries[useInfo.number].first) { 3355 if (!isForwardRefPlaceholder(existing)) { 3356 return emitError(useInfo.loc) 3357 .append("redefinition of SSA value '", useInfo.name, "'") 3358 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3359 .append("previously defined here"); 3360 } 3361 3362 // If it was a forward reference, update everything that used it to use 3363 // the actual definition instead, delete the forward ref, and remove it 3364 // from our set of forward references we track. 3365 existing->replaceAllUsesWith(value); 3366 existing->getDefiningOp()->destroy(); 3367 forwardRefPlaceholders.erase(existing); 3368 } 3369 3370 /// Record this definition for the current scope. 3371 entries[useInfo.number] = {value, useInfo.loc}; 3372 recordDefinition(useInfo.name); 3373 return success(); 3374 } 3375 3376 /// Parse a (possibly empty) list of SSA operands. 3377 /// 3378 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 3379 /// ssa-use-list-opt ::= ssa-use-list? 3380 /// 3381 ParseResult 3382 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 3383 if (getToken().isNot(Token::percent_identifier)) 3384 return success(); 3385 return parseCommaSeparatedList([&]() -> ParseResult { 3386 SSAUseInfo result; 3387 if (parseSSAUse(result)) 3388 return failure(); 3389 results.push_back(result); 3390 return success(); 3391 }); 3392 } 3393 3394 /// Parse a SSA operand for an operation. 3395 /// 3396 /// ssa-use ::= ssa-id 3397 /// 3398 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 3399 result.name = getTokenSpelling(); 3400 result.number = 0; 3401 result.loc = getToken().getLoc(); 3402 if (parseToken(Token::percent_identifier, "expected SSA operand")) 3403 return failure(); 3404 3405 // If we have an attribute ID, it is a result number. 3406 if (getToken().is(Token::hash_identifier)) { 3407 if (auto value = getToken().getHashIdentifierNumber()) 3408 result.number = value.getValue(); 3409 else 3410 return emitError("invalid SSA value result number"); 3411 consumeToken(Token::hash_identifier); 3412 } 3413 3414 return success(); 3415 } 3416 3417 /// Given an unbound reference to an SSA value and its type, return the value 3418 /// it specifies. This returns null on failure. 3419 Value *OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 3420 auto &entries = getSSAValueEntry(useInfo.name); 3421 3422 // If we have already seen a value of this name, return it. 3423 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 3424 auto *result = entries[useInfo.number].first; 3425 // Check that the type matches the other uses. 3426 if (result->getType() == type) 3427 return result; 3428 3429 emitError(useInfo.loc, "use of value '") 3430 .append(useInfo.name, 3431 "' expects different type than prior uses: ", type, " vs ", 3432 result->getType()) 3433 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3434 .append("prior use here"); 3435 return nullptr; 3436 } 3437 3438 // Make sure we have enough slots for this. 3439 if (entries.size() <= useInfo.number) 3440 entries.resize(useInfo.number + 1); 3441 3442 // If the value has already been defined and this is an overly large result 3443 // number, diagnose that. 3444 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 3445 return (emitError(useInfo.loc, "reference to invalid result number"), 3446 nullptr); 3447 3448 // Otherwise, this is a forward reference. Create a placeholder and remember 3449 // that we did so. 3450 auto *result = createForwardRefPlaceholder(useInfo.loc, type); 3451 entries[useInfo.number].first = result; 3452 entries[useInfo.number].second = useInfo.loc; 3453 return result; 3454 } 3455 3456 /// Parse an SSA use with an associated type. 3457 /// 3458 /// ssa-use-and-type ::= ssa-use `:` type 3459 ParseResult OperationParser::parseSSADefOrUseAndType( 3460 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 3461 SSAUseInfo useInfo; 3462 if (parseSSAUse(useInfo) || 3463 parseToken(Token::colon, "expected ':' and type for SSA operand")) 3464 return failure(); 3465 3466 auto type = parseType(); 3467 if (!type) 3468 return failure(); 3469 3470 return action(useInfo, type); 3471 } 3472 3473 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 3474 /// followed by a type list. 3475 /// 3476 /// ssa-use-and-type-list 3477 /// ::= ssa-use-list ':' type-list-no-parens 3478 /// 3479 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 3480 SmallVectorImpl<Value *> &results) { 3481 SmallVector<SSAUseInfo, 4> valueIDs; 3482 if (parseOptionalSSAUseList(valueIDs)) 3483 return failure(); 3484 3485 // If there were no operands, then there is no colon or type lists. 3486 if (valueIDs.empty()) 3487 return success(); 3488 3489 SmallVector<Type, 4> types; 3490 if (parseToken(Token::colon, "expected ':' in operand list") || 3491 parseTypeListNoParens(types)) 3492 return failure(); 3493 3494 if (valueIDs.size() != types.size()) 3495 return emitError("expected ") 3496 << valueIDs.size() << " types to match operand list"; 3497 3498 results.reserve(valueIDs.size()); 3499 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 3500 if (auto *value = resolveSSAUse(valueIDs[i], types[i])) 3501 results.push_back(value); 3502 else 3503 return failure(); 3504 } 3505 3506 return success(); 3507 } 3508 3509 /// Record that a definition was added at the current scope. 3510 void OperationParser::recordDefinition(StringRef def) { 3511 isolatedNameScopes.back().recordDefinition(def); 3512 } 3513 3514 /// Get the value entry for the given SSA name. 3515 SmallVectorImpl<std::pair<Value *, SMLoc>> & 3516 OperationParser::getSSAValueEntry(StringRef name) { 3517 return isolatedNameScopes.back().values[name]; 3518 } 3519 3520 /// Create and remember a new placeholder for a forward reference. 3521 Value *OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 3522 // Forward references are always created as operations, because we just need 3523 // something with a def/use chain. 3524 // 3525 // We create these placeholders as having an empty name, which we know 3526 // cannot be created through normal user input, allowing us to distinguish 3527 // them. 3528 auto name = OperationName("placeholder", getContext()); 3529 auto *op = Operation::create( 3530 getEncodedSourceLocation(loc), name, type, /*operands=*/{}, 3531 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, 3532 /*resizableOperandList=*/false); 3533 forwardRefPlaceholders[op->getResult(0)] = loc; 3534 return op->getResult(0); 3535 } 3536 3537 //===----------------------------------------------------------------------===// 3538 // Operation Parsing 3539 //===----------------------------------------------------------------------===// 3540 3541 /// Parse an operation. 3542 /// 3543 /// operation ::= op-result-list? 3544 /// (generic-operation | custom-operation) 3545 /// trailing-location? 3546 /// generic-operation ::= string-literal '(' ssa-use-list? ')' attribute-dict? 3547 /// `:` function-type 3548 /// custom-operation ::= bare-id custom-operation-format 3549 /// op-result-list ::= op-result (`,` op-result)* `=` 3550 /// op-result ::= ssa-id (`:` integer-literal) 3551 /// 3552 ParseResult OperationParser::parseOperation() { 3553 auto loc = getToken().getLoc(); 3554 SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs; 3555 size_t numExpectedResults = 0; 3556 if (getToken().is(Token::percent_identifier)) { 3557 // Parse the group of result ids. 3558 auto parseNextResult = [&]() -> ParseResult { 3559 // Parse the next result id. 3560 if (!getToken().is(Token::percent_identifier)) 3561 return emitError("expected valid ssa identifier"); 3562 3563 Token nameTok = getToken(); 3564 consumeToken(Token::percent_identifier); 3565 3566 // If the next token is a ':', we parse the expected result count. 3567 size_t expectedSubResults = 1; 3568 if (consumeIf(Token::colon)) { 3569 // Check that the next token is an integer. 3570 if (!getToken().is(Token::integer)) 3571 return emitError("expected integer number of results"); 3572 3573 // Check that number of results is > 0. 3574 auto val = getToken().getUInt64IntegerValue(); 3575 if (!val.hasValue() || val.getValue() < 1) 3576 return emitError("expected named operation to have atleast 1 result"); 3577 consumeToken(Token::integer); 3578 expectedSubResults = *val; 3579 } 3580 3581 resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults, 3582 nameTok.getLoc()); 3583 numExpectedResults += expectedSubResults; 3584 return success(); 3585 }; 3586 if (parseCommaSeparatedList(parseNextResult)) 3587 return failure(); 3588 3589 if (parseToken(Token::equal, "expected '=' after SSA name")) 3590 return failure(); 3591 } 3592 3593 Operation *op; 3594 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3595 op = parseCustomOperation(); 3596 else if (getToken().is(Token::string)) 3597 op = parseGenericOperation(); 3598 else 3599 return emitError("expected operation name in quotes"); 3600 3601 // If parsing of the basic operation failed, then this whole thing fails. 3602 if (!op) 3603 return failure(); 3604 3605 // If the operation had a name, register it. 3606 if (!resultIDs.empty()) { 3607 if (op->getNumResults() == 0) 3608 return emitError(loc, "cannot name an operation with no results"); 3609 if (numExpectedResults != op->getNumResults()) 3610 return emitError(loc, "operation defines ") 3611 << op->getNumResults() << " results but was provided " 3612 << numExpectedResults << " to bind"; 3613 3614 // Add definitions for each of the result groups. 3615 unsigned opResI = 0; 3616 for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) { 3617 for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) { 3618 if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, 3619 op->getResult(opResI++))) 3620 return failure(); 3621 } 3622 } 3623 } 3624 3625 return success(); 3626 } 3627 3628 /// Parse a single operation successor and its operand list. 3629 /// 3630 /// successor ::= block-id branch-use-list? 3631 /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` 3632 /// 3633 ParseResult 3634 OperationParser::parseSuccessorAndUseList(Block *&dest, 3635 SmallVectorImpl<Value *> &operands) { 3636 // Verify branch is identifier and get the matching block. 3637 if (!getToken().is(Token::caret_identifier)) 3638 return emitError("expected block name"); 3639 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3640 consumeToken(); 3641 3642 // Handle optional arguments. 3643 if (consumeIf(Token::l_paren) && 3644 (parseOptionalSSAUseAndTypeList(operands) || 3645 parseToken(Token::r_paren, "expected ')' to close argument list"))) { 3646 return failure(); 3647 } 3648 3649 return success(); 3650 } 3651 3652 /// Parse a comma-separated list of operation successors in brackets. 3653 /// 3654 /// successor-list ::= `[` successor (`,` successor )* `]` 3655 /// 3656 ParseResult OperationParser::parseSuccessors( 3657 SmallVectorImpl<Block *> &destinations, 3658 SmallVectorImpl<SmallVector<Value *, 4>> &operands) { 3659 if (parseToken(Token::l_square, "expected '['")) 3660 return failure(); 3661 3662 auto parseElt = [this, &destinations, &operands]() { 3663 Block *dest; 3664 SmallVector<Value *, 4> destOperands; 3665 auto res = parseSuccessorAndUseList(dest, destOperands); 3666 destinations.push_back(dest); 3667 operands.push_back(destOperands); 3668 return res; 3669 }; 3670 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3671 /*allowEmptyList=*/false); 3672 } 3673 3674 namespace { 3675 // RAII-style guard for cleaning up the regions in the operation state before 3676 // deleting them. Within the parser, regions may get deleted if parsing failed, 3677 // and other errors may be present, in particular undominated uses. This makes 3678 // sure such uses are deleted. 3679 struct CleanupOpStateRegions { 3680 ~CleanupOpStateRegions() { 3681 SmallVector<Region *, 4> regionsToClean; 3682 regionsToClean.reserve(state.regions.size()); 3683 for (auto ®ion : state.regions) 3684 if (region) 3685 for (auto &block : *region) 3686 block.dropAllDefinedValueUses(); 3687 } 3688 OperationState &state; 3689 }; 3690 } // namespace 3691 3692 Operation *OperationParser::parseGenericOperation() { 3693 // Get location information for the operation. 3694 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3695 3696 auto name = getToken().getStringValue(); 3697 if (name.empty()) 3698 return (emitError("empty operation name is invalid"), nullptr); 3699 if (name.find('\0') != StringRef::npos) 3700 return (emitError("null character not allowed in operation name"), nullptr); 3701 3702 consumeToken(Token::string); 3703 3704 OperationState result(srcLocation, name); 3705 3706 // Generic operations have a resizable operation list. 3707 result.setOperandListToResizable(); 3708 3709 // Parse the operand list. 3710 SmallVector<SSAUseInfo, 8> operandInfos; 3711 3712 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3713 parseOptionalSSAUseList(operandInfos) || 3714 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3715 return nullptr; 3716 } 3717 3718 // Parse the successor list but don't add successors to the result yet to 3719 // avoid messing up with the argument order. 3720 SmallVector<Block *, 2> successors; 3721 SmallVector<SmallVector<Value *, 4>, 2> successorOperands; 3722 if (getToken().is(Token::l_square)) { 3723 // Check if the operation is a known terminator. 3724 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3725 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3726 return emitError("successors in non-terminator"), nullptr; 3727 if (parseSuccessors(successors, successorOperands)) 3728 return nullptr; 3729 } 3730 3731 // Parse the region list. 3732 CleanupOpStateRegions guard{result}; 3733 if (consumeIf(Token::l_paren)) { 3734 do { 3735 // Create temporary regions with the top level region as parent. 3736 result.regions.emplace_back(new Region(moduleOp)); 3737 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3738 return nullptr; 3739 } while (consumeIf(Token::comma)); 3740 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3741 return nullptr; 3742 } 3743 3744 if (getToken().is(Token::l_brace)) { 3745 if (parseAttributeDict(result.attributes)) 3746 return nullptr; 3747 } 3748 3749 if (parseToken(Token::colon, "expected ':' followed by operation type")) 3750 return nullptr; 3751 3752 auto typeLoc = getToken().getLoc(); 3753 auto type = parseType(); 3754 if (!type) 3755 return nullptr; 3756 auto fnType = type.dyn_cast<FunctionType>(); 3757 if (!fnType) 3758 return (emitError(typeLoc, "expected function type"), nullptr); 3759 3760 result.addTypes(fnType.getResults()); 3761 3762 // Check that we have the right number of types for the operands. 3763 auto operandTypes = fnType.getInputs(); 3764 if (operandTypes.size() != operandInfos.size()) { 3765 auto plural = "s"[operandInfos.size() == 1]; 3766 return (emitError(typeLoc, "expected ") 3767 << operandInfos.size() << " operand type" << plural 3768 << " but had " << operandTypes.size(), 3769 nullptr); 3770 } 3771 3772 // Resolve all of the operands. 3773 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 3774 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 3775 if (!result.operands.back()) 3776 return nullptr; 3777 } 3778 3779 // Add the successors, and their operands after the proper operands. 3780 for (const auto &succ : llvm::zip(successors, successorOperands)) { 3781 Block *successor = std::get<0>(succ); 3782 const SmallVector<Value *, 4> &operands = std::get<1>(succ); 3783 result.addSuccessor(successor, operands); 3784 } 3785 3786 // Parse a location if one is present. 3787 if (parseOptionalTrailingLocation(result.location)) 3788 return nullptr; 3789 3790 return opBuilder.createOperation(result); 3791 } 3792 3793 Operation *OperationParser::parseGenericOperation(Block *insertBlock, 3794 Block::iterator insertPt) { 3795 OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); 3796 opBuilder.setInsertionPoint(insertBlock, insertPt); 3797 return parseGenericOperation(); 3798 } 3799 3800 namespace { 3801 class CustomOpAsmParser : public OpAsmParser { 3802 public: 3803 CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, 3804 OperationParser &parser) 3805 : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} 3806 3807 /// Parse an instance of the operation described by 'opDefinition' into the 3808 /// provided operation state. 3809 ParseResult parseOperation(OperationState &opState) { 3810 if (opDefinition->parseAssembly(*this, opState)) 3811 return failure(); 3812 return success(); 3813 } 3814 3815 Operation *parseGenericOperation(Block *insertBlock, 3816 Block::iterator insertPt) final { 3817 return parser.parseGenericOperation(insertBlock, insertPt); 3818 } 3819 3820 //===--------------------------------------------------------------------===// 3821 // Utilities 3822 //===--------------------------------------------------------------------===// 3823 3824 /// Return if any errors were emitted during parsing. 3825 bool didEmitError() const { return emittedError; } 3826 3827 /// Emit a diagnostic at the specified location and return failure. 3828 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 3829 emittedError = true; 3830 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 3831 message); 3832 } 3833 3834 llvm::SMLoc getCurrentLocation() override { 3835 return parser.getToken().getLoc(); 3836 } 3837 3838 Builder &getBuilder() const override { return parser.builder; } 3839 3840 llvm::SMLoc getNameLoc() const override { return nameLoc; } 3841 3842 //===--------------------------------------------------------------------===// 3843 // Token Parsing 3844 //===--------------------------------------------------------------------===// 3845 3846 /// Parse a `->` token. 3847 ParseResult parseArrow() override { 3848 return parser.parseToken(Token::arrow, "expected '->'"); 3849 } 3850 3851 /// Parses a `->` if present. 3852 ParseResult parseOptionalArrow() override { 3853 return success(parser.consumeIf(Token::arrow)); 3854 } 3855 3856 /// Parse a `:` token. 3857 ParseResult parseColon() override { 3858 return parser.parseToken(Token::colon, "expected ':'"); 3859 } 3860 3861 /// Parse a `:` token if present. 3862 ParseResult parseOptionalColon() override { 3863 return success(parser.consumeIf(Token::colon)); 3864 } 3865 3866 /// Parse a `,` token. 3867 ParseResult parseComma() override { 3868 return parser.parseToken(Token::comma, "expected ','"); 3869 } 3870 3871 /// Parse a `,` token if present. 3872 ParseResult parseOptionalComma() override { 3873 return success(parser.consumeIf(Token::comma)); 3874 } 3875 3876 /// Parses a `...` if present. 3877 ParseResult parseOptionalEllipsis() override { 3878 return success(parser.consumeIf(Token::ellipsis)); 3879 } 3880 3881 /// Parse a `=` token. 3882 ParseResult parseEqual() override { 3883 return parser.parseToken(Token::equal, "expected '='"); 3884 } 3885 3886 /// Parse a '<' token. 3887 ParseResult parseLess() override { 3888 return parser.parseToken(Token::less, "expected '<'"); 3889 } 3890 3891 /// Parse a '>' token. 3892 ParseResult parseGreater() override { 3893 return parser.parseToken(Token::greater, "expected '>'"); 3894 } 3895 3896 /// Parse a `(` token. 3897 ParseResult parseLParen() override { 3898 return parser.parseToken(Token::l_paren, "expected '('"); 3899 } 3900 3901 /// Parses a '(' if present. 3902 ParseResult parseOptionalLParen() override { 3903 return success(parser.consumeIf(Token::l_paren)); 3904 } 3905 3906 /// Parse a `)` token. 3907 ParseResult parseRParen() override { 3908 return parser.parseToken(Token::r_paren, "expected ')'"); 3909 } 3910 3911 /// Parses a ')' if present. 3912 ParseResult parseOptionalRParen() override { 3913 return success(parser.consumeIf(Token::r_paren)); 3914 } 3915 3916 /// Parse a `[` token. 3917 ParseResult parseLSquare() override { 3918 return parser.parseToken(Token::l_square, "expected '['"); 3919 } 3920 3921 /// Parses a '[' if present. 3922 ParseResult parseOptionalLSquare() override { 3923 return success(parser.consumeIf(Token::l_square)); 3924 } 3925 3926 /// Parse a `]` token. 3927 ParseResult parseRSquare() override { 3928 return parser.parseToken(Token::r_square, "expected ']'"); 3929 } 3930 3931 /// Parses a ']' if present. 3932 ParseResult parseOptionalRSquare() override { 3933 return success(parser.consumeIf(Token::r_square)); 3934 } 3935 3936 //===--------------------------------------------------------------------===// 3937 // Attribute Parsing 3938 //===--------------------------------------------------------------------===// 3939 3940 /// Parse an arbitrary attribute of a given type and return it in result. This 3941 /// also adds the attribute to the specified attribute list with the specified 3942 /// name. 3943 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 3944 SmallVectorImpl<NamedAttribute> &attrs) override { 3945 result = parser.parseAttribute(type); 3946 if (!result) 3947 return failure(); 3948 3949 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 3950 return success(); 3951 } 3952 3953 /// Parse a named dictionary into 'result' if it is present. 3954 ParseResult 3955 parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override { 3956 if (parser.getToken().isNot(Token::l_brace)) 3957 return success(); 3958 return parser.parseAttributeDict(result); 3959 } 3960 3961 /// Parse a named dictionary into 'result' if the `attributes` keyword is 3962 /// present. 3963 ParseResult parseOptionalAttrDictWithKeyword( 3964 SmallVectorImpl<NamedAttribute> &result) override { 3965 if (failed(parseOptionalKeyword("attributes"))) 3966 return success(); 3967 return parser.parseAttributeDict(result); 3968 } 3969 3970 //===--------------------------------------------------------------------===// 3971 // Identifier Parsing 3972 //===--------------------------------------------------------------------===// 3973 3974 /// Returns if the current token corresponds to a keyword. 3975 bool isCurrentTokenAKeyword() const { 3976 return parser.getToken().is(Token::bare_identifier) || 3977 parser.getToken().isKeyword(); 3978 } 3979 3980 /// Parse the given keyword if present. 3981 ParseResult parseOptionalKeyword(StringRef keyword) override { 3982 // Check that the current token has the same spelling. 3983 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 3984 return failure(); 3985 parser.consumeToken(); 3986 return success(); 3987 } 3988 3989 /// Parse a keyword, if present, into 'keyword'. 3990 ParseResult parseOptionalKeyword(StringRef *keyword) override { 3991 // Check that the current token is a keyword. 3992 if (!isCurrentTokenAKeyword()) 3993 return failure(); 3994 3995 *keyword = parser.getTokenSpelling(); 3996 parser.consumeToken(); 3997 return success(); 3998 } 3999 4000 /// Parse an optional @-identifier and store it (without the '@' symbol) in a 4001 /// string attribute named 'attrName'. 4002 ParseResult 4003 parseOptionalSymbolName(StringAttr &result, StringRef attrName, 4004 SmallVectorImpl<NamedAttribute> &attrs) override { 4005 Token atToken = parser.getToken(); 4006 if (atToken.isNot(Token::at_identifier)) 4007 return failure(); 4008 4009 result = getBuilder().getStringAttr(extractSymbolReference(atToken)); 4010 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 4011 parser.consumeToken(); 4012 return success(); 4013 } 4014 4015 //===--------------------------------------------------------------------===// 4016 // Operand Parsing 4017 //===--------------------------------------------------------------------===// 4018 4019 /// Parse a single operand. 4020 ParseResult parseOperand(OperandType &result) override { 4021 OperationParser::SSAUseInfo useInfo; 4022 if (parser.parseSSAUse(useInfo)) 4023 return failure(); 4024 4025 result = {useInfo.loc, useInfo.name, useInfo.number}; 4026 return success(); 4027 } 4028 4029 /// Parse zero or more SSA comma-separated operand references with a specified 4030 /// surrounding delimiter, and an optional required operand count. 4031 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 4032 int requiredOperandCount = -1, 4033 Delimiter delimiter = Delimiter::None) override { 4034 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 4035 requiredOperandCount, delimiter); 4036 } 4037 4038 /// Parse zero or more SSA comma-separated operand or region arguments with 4039 /// optional surrounding delimiter and required operand count. 4040 ParseResult 4041 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 4042 bool isOperandList, int requiredOperandCount = -1, 4043 Delimiter delimiter = Delimiter::None) { 4044 auto startLoc = parser.getToken().getLoc(); 4045 4046 // Handle delimiters. 4047 switch (delimiter) { 4048 case Delimiter::None: 4049 // Don't check for the absence of a delimiter if the number of operands 4050 // is unknown (and hence the operand list could be empty). 4051 if (requiredOperandCount == -1) 4052 break; 4053 // Token already matches an identifier and so can't be a delimiter. 4054 if (parser.getToken().is(Token::percent_identifier)) 4055 break; 4056 // Test against known delimiters. 4057 if (parser.getToken().is(Token::l_paren) || 4058 parser.getToken().is(Token::l_square)) 4059 return emitError(startLoc, "unexpected delimiter"); 4060 return emitError(startLoc, "invalid operand"); 4061 case Delimiter::OptionalParen: 4062 if (parser.getToken().isNot(Token::l_paren)) 4063 return success(); 4064 LLVM_FALLTHROUGH; 4065 case Delimiter::Paren: 4066 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 4067 return failure(); 4068 break; 4069 case Delimiter::OptionalSquare: 4070 if (parser.getToken().isNot(Token::l_square)) 4071 return success(); 4072 LLVM_FALLTHROUGH; 4073 case Delimiter::Square: 4074 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 4075 return failure(); 4076 break; 4077 } 4078 4079 // Check for zero operands. 4080 if (parser.getToken().is(Token::percent_identifier)) { 4081 do { 4082 OperandType operandOrArg; 4083 if (isOperandList ? parseOperand(operandOrArg) 4084 : parseRegionArgument(operandOrArg)) 4085 return failure(); 4086 result.push_back(operandOrArg); 4087 } while (parser.consumeIf(Token::comma)); 4088 } 4089 4090 // Handle delimiters. If we reach here, the optional delimiters were 4091 // present, so we need to parse their closing one. 4092 switch (delimiter) { 4093 case Delimiter::None: 4094 break; 4095 case Delimiter::OptionalParen: 4096 case Delimiter::Paren: 4097 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 4098 return failure(); 4099 break; 4100 case Delimiter::OptionalSquare: 4101 case Delimiter::Square: 4102 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 4103 return failure(); 4104 break; 4105 } 4106 4107 if (requiredOperandCount != -1 && 4108 result.size() != static_cast<size_t>(requiredOperandCount)) 4109 return emitError(startLoc, "expected ") 4110 << requiredOperandCount << " operands"; 4111 return success(); 4112 } 4113 4114 /// Parse zero or more trailing SSA comma-separated trailing operand 4115 /// references with a specified surrounding delimiter, and an optional 4116 /// required operand count. A leading comma is expected before the operands. 4117 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 4118 int requiredOperandCount, 4119 Delimiter delimiter) override { 4120 if (parser.getToken().is(Token::comma)) { 4121 parseComma(); 4122 return parseOperandList(result, requiredOperandCount, delimiter); 4123 } 4124 if (requiredOperandCount != -1) 4125 return emitError(parser.getToken().getLoc(), "expected ") 4126 << requiredOperandCount << " operands"; 4127 return success(); 4128 } 4129 4130 /// Resolve an operand to an SSA value, emitting an error on failure. 4131 ParseResult resolveOperand(const OperandType &operand, Type type, 4132 SmallVectorImpl<Value *> &result) override { 4133 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4134 operand.location}; 4135 if (auto *value = parser.resolveSSAUse(operandInfo, type)) { 4136 result.push_back(value); 4137 return success(); 4138 } 4139 return failure(); 4140 } 4141 4142 /// Parse an AffineMap of SSA ids. 4143 ParseResult 4144 parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 4145 Attribute &mapAttr, StringRef attrName, 4146 SmallVectorImpl<NamedAttribute> &attrs) override { 4147 SmallVector<OperandType, 2> dimOperands; 4148 SmallVector<OperandType, 1> symOperands; 4149 4150 auto parseElement = [&](bool isSymbol) -> ParseResult { 4151 OperandType operand; 4152 if (parseOperand(operand)) 4153 return failure(); 4154 if (isSymbol) 4155 symOperands.push_back(operand); 4156 else 4157 dimOperands.push_back(operand); 4158 return success(); 4159 }; 4160 4161 AffineMap map; 4162 if (parser.parseAffineMapOfSSAIds(map, parseElement)) 4163 return failure(); 4164 // Add AffineMap attribute. 4165 if (map) { 4166 mapAttr = AffineMapAttr::get(map); 4167 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 4168 } 4169 4170 // Add dim operands before symbol operands in 'operands'. 4171 operands.assign(dimOperands.begin(), dimOperands.end()); 4172 operands.append(symOperands.begin(), symOperands.end()); 4173 return success(); 4174 } 4175 4176 //===--------------------------------------------------------------------===// 4177 // Region Parsing 4178 //===--------------------------------------------------------------------===// 4179 4180 /// Parse a region that takes `arguments` of `argTypes` types. This 4181 /// effectively defines the SSA values of `arguments` and assigns their type. 4182 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 4183 ArrayRef<Type> argTypes, 4184 bool enableNameShadowing) override { 4185 assert(arguments.size() == argTypes.size() && 4186 "mismatching number of arguments and types"); 4187 4188 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 4189 regionArguments; 4190 for (const auto &pair : llvm::zip(arguments, argTypes)) { 4191 const OperandType &operand = std::get<0>(pair); 4192 Type type = std::get<1>(pair); 4193 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4194 operand.location}; 4195 regionArguments.emplace_back(operandInfo, type); 4196 } 4197 4198 // Try to parse the region. 4199 assert((!enableNameShadowing || 4200 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 4201 "name shadowing is only allowed on isolated regions"); 4202 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 4203 return failure(); 4204 return success(); 4205 } 4206 4207 /// Parses a region if present. 4208 ParseResult parseOptionalRegion(Region ®ion, 4209 ArrayRef<OperandType> arguments, 4210 ArrayRef<Type> argTypes, 4211 bool enableNameShadowing) override { 4212 if (parser.getToken().isNot(Token::l_brace)) 4213 return success(); 4214 return parseRegion(region, arguments, argTypes, enableNameShadowing); 4215 } 4216 4217 /// Parse a region argument. The type of the argument will be resolved later 4218 /// by a call to `parseRegion`. 4219 ParseResult parseRegionArgument(OperandType &argument) override { 4220 return parseOperand(argument); 4221 } 4222 4223 /// Parse a region argument if present. 4224 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 4225 if (parser.getToken().isNot(Token::percent_identifier)) 4226 return success(); 4227 return parseRegionArgument(argument); 4228 } 4229 4230 ParseResult 4231 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 4232 int requiredOperandCount = -1, 4233 Delimiter delimiter = Delimiter::None) override { 4234 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 4235 requiredOperandCount, delimiter); 4236 } 4237 4238 //===--------------------------------------------------------------------===// 4239 // Successor Parsing 4240 //===--------------------------------------------------------------------===// 4241 4242 /// Parse a single operation successor and its operand list. 4243 ParseResult 4244 parseSuccessorAndUseList(Block *&dest, 4245 SmallVectorImpl<Value *> &operands) override { 4246 return parser.parseSuccessorAndUseList(dest, operands); 4247 } 4248 4249 //===--------------------------------------------------------------------===// 4250 // Type Parsing 4251 //===--------------------------------------------------------------------===// 4252 4253 /// Parse a type. 4254 ParseResult parseType(Type &result) override { 4255 return failure(!(result = parser.parseType())); 4256 } 4257 4258 /// Parse an optional arrow followed by a type list. 4259 ParseResult 4260 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 4261 if (!parser.consumeIf(Token::arrow)) 4262 return success(); 4263 return parser.parseFunctionResultTypes(result); 4264 } 4265 4266 /// Parse a colon followed by a type. 4267 ParseResult parseColonType(Type &result) override { 4268 return failure(parser.parseToken(Token::colon, "expected ':'") || 4269 !(result = parser.parseType())); 4270 } 4271 4272 /// Parse a colon followed by a type list, which must have at least one type. 4273 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 4274 if (parser.parseToken(Token::colon, "expected ':'")) 4275 return failure(); 4276 return parser.parseTypeListNoParens(result); 4277 } 4278 4279 /// Parse an optional colon followed by a type list, which if present must 4280 /// have at least one type. 4281 ParseResult 4282 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 4283 if (!parser.consumeIf(Token::colon)) 4284 return success(); 4285 return parser.parseTypeListNoParens(result); 4286 } 4287 4288 private: 4289 /// The source location of the operation name. 4290 SMLoc nameLoc; 4291 4292 /// The abstract information of the operation. 4293 const AbstractOperation *opDefinition; 4294 4295 /// The main operation parser. 4296 OperationParser &parser; 4297 4298 /// A flag that indicates if any errors were emitted during parsing. 4299 bool emittedError = false; 4300 }; 4301 } // end anonymous namespace. 4302 4303 Operation *OperationParser::parseCustomOperation() { 4304 auto opLoc = getToken().getLoc(); 4305 auto opName = getTokenSpelling(); 4306 4307 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 4308 if (!opDefinition && !opName.contains('.')) { 4309 // If the operation name has no namespace prefix we treat it as a standard 4310 // operation and prefix it with "std". 4311 // TODO: Would it be better to just build a mapping of the registered 4312 // operations in the standard dialect? 4313 opDefinition = 4314 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 4315 } 4316 4317 if (!opDefinition) { 4318 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 4319 return nullptr; 4320 } 4321 4322 consumeToken(); 4323 4324 // If the custom op parser crashes, produce some indication to help 4325 // debugging. 4326 std::string opNameStr = opName.str(); 4327 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 4328 opNameStr.c_str()); 4329 4330 // Get location information for the operation. 4331 auto srcLocation = getEncodedSourceLocation(opLoc); 4332 4333 // Have the op implementation take a crack and parsing this. 4334 OperationState opState(srcLocation, opDefinition->name); 4335 CleanupOpStateRegions guard{opState}; 4336 CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); 4337 if (opAsmParser.parseOperation(opState)) 4338 return nullptr; 4339 4340 // If it emitted an error, we failed. 4341 if (opAsmParser.didEmitError()) 4342 return nullptr; 4343 4344 // Parse a location if one is present. 4345 if (parseOptionalTrailingLocation(opState.location)) 4346 return nullptr; 4347 4348 // Otherwise, we succeeded. Use the state it parsed as our op information. 4349 return opBuilder.createOperation(opState); 4350 } 4351 4352 //===----------------------------------------------------------------------===// 4353 // Region Parsing 4354 //===----------------------------------------------------------------------===// 4355 4356 /// Region. 4357 /// 4358 /// region ::= '{' region-body 4359 /// 4360 ParseResult OperationParser::parseRegion( 4361 Region ®ion, 4362 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 4363 bool isIsolatedNameScope) { 4364 // Parse the '{'. 4365 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 4366 return failure(); 4367 4368 // Check for an empty region. 4369 if (entryArguments.empty() && consumeIf(Token::r_brace)) 4370 return success(); 4371 auto currentPt = opBuilder.saveInsertionPoint(); 4372 4373 // Push a new named value scope. 4374 pushSSANameScope(isIsolatedNameScope); 4375 4376 // Parse the first block directly to allow for it to be unnamed. 4377 Block *block = new Block(); 4378 4379 // Add arguments to the entry block. 4380 if (!entryArguments.empty()) { 4381 for (auto &placeholderArgPair : entryArguments) { 4382 auto &argInfo = placeholderArgPair.first; 4383 // Ensure that the argument was not already defined. 4384 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 4385 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 4386 "' is already in use") 4387 .attachNote(getEncodedSourceLocation(*defLoc)) 4388 << "previously referenced here"; 4389 } 4390 if (addDefinition(placeholderArgPair.first, 4391 block->addArgument(placeholderArgPair.second))) { 4392 delete block; 4393 return failure(); 4394 } 4395 } 4396 4397 // If we had named arguments, then don't allow a block name. 4398 if (getToken().is(Token::caret_identifier)) 4399 return emitError("invalid block name in region with named arguments"); 4400 } 4401 4402 if (parseBlock(block)) { 4403 delete block; 4404 return failure(); 4405 } 4406 4407 // Verify that no other arguments were parsed. 4408 if (!entryArguments.empty() && 4409 block->getNumArguments() > entryArguments.size()) { 4410 delete block; 4411 return emitError("entry block arguments were already defined"); 4412 } 4413 4414 // Parse the rest of the region. 4415 region.push_back(block); 4416 if (parseRegionBody(region)) 4417 return failure(); 4418 4419 // Pop the SSA value scope for this region. 4420 if (popSSANameScope()) 4421 return failure(); 4422 4423 // Reset the original insertion point. 4424 opBuilder.restoreInsertionPoint(currentPt); 4425 return success(); 4426 } 4427 4428 /// Region. 4429 /// 4430 /// region-body ::= block* '}' 4431 /// 4432 ParseResult OperationParser::parseRegionBody(Region ®ion) { 4433 // Parse the list of blocks. 4434 while (!consumeIf(Token::r_brace)) { 4435 Block *newBlock = nullptr; 4436 if (parseBlock(newBlock)) 4437 return failure(); 4438 region.push_back(newBlock); 4439 } 4440 return success(); 4441 } 4442 4443 //===----------------------------------------------------------------------===// 4444 // Block Parsing 4445 //===----------------------------------------------------------------------===// 4446 4447 /// Block declaration. 4448 /// 4449 /// block ::= block-label? operation* 4450 /// block-label ::= block-id block-arg-list? `:` 4451 /// block-id ::= caret-id 4452 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 4453 /// 4454 ParseResult OperationParser::parseBlock(Block *&block) { 4455 // The first block of a region may already exist, if it does the caret 4456 // identifier is optional. 4457 if (block && getToken().isNot(Token::caret_identifier)) 4458 return parseBlockBody(block); 4459 4460 SMLoc nameLoc = getToken().getLoc(); 4461 auto name = getTokenSpelling(); 4462 if (parseToken(Token::caret_identifier, "expected block name")) 4463 return failure(); 4464 4465 block = defineBlockNamed(name, nameLoc, block); 4466 4467 // Fail if the block was already defined. 4468 if (!block) 4469 return emitError(nameLoc, "redefinition of block '") << name << "'"; 4470 4471 // If an argument list is present, parse it. 4472 if (consumeIf(Token::l_paren)) { 4473 SmallVector<BlockArgument *, 8> bbArgs; 4474 if (parseOptionalBlockArgList(bbArgs, block) || 4475 parseToken(Token::r_paren, "expected ')' to end argument list")) 4476 return failure(); 4477 } 4478 4479 if (parseToken(Token::colon, "expected ':' after block name")) 4480 return failure(); 4481 4482 return parseBlockBody(block); 4483 } 4484 4485 ParseResult OperationParser::parseBlockBody(Block *block) { 4486 // Set the insertion point to the end of the block to parse. 4487 opBuilder.setInsertionPointToEnd(block); 4488 4489 // Parse the list of operations that make up the body of the block. 4490 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 4491 if (parseOperation()) 4492 return failure(); 4493 4494 return success(); 4495 } 4496 4497 /// Get the block with the specified name, creating it if it doesn't already 4498 /// exist. The location specified is the point of use, which allows 4499 /// us to diagnose references to blocks that are not defined precisely. 4500 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 4501 auto &blockAndLoc = getBlockInfoByName(name); 4502 if (!blockAndLoc.first) { 4503 blockAndLoc = {new Block(), loc}; 4504 insertForwardRef(blockAndLoc.first, loc); 4505 } 4506 4507 return blockAndLoc.first; 4508 } 4509 4510 /// Define the block with the specified name. Returns the Block* or nullptr in 4511 /// the case of redefinition. 4512 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 4513 Block *existing) { 4514 auto &blockAndLoc = getBlockInfoByName(name); 4515 if (!blockAndLoc.first) { 4516 // If the caller provided a block, use it. Otherwise create a new one. 4517 if (!existing) 4518 existing = new Block(); 4519 blockAndLoc.first = existing; 4520 blockAndLoc.second = loc; 4521 return blockAndLoc.first; 4522 } 4523 4524 // Forward declarations are removed once defined, so if we are defining a 4525 // existing block and it is not a forward declaration, then it is a 4526 // redeclaration. 4527 if (!eraseForwardRef(blockAndLoc.first)) 4528 return nullptr; 4529 return blockAndLoc.first; 4530 } 4531 4532 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 4533 /// 4534 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 4535 /// 4536 ParseResult OperationParser::parseOptionalBlockArgList( 4537 SmallVectorImpl<BlockArgument *> &results, Block *owner) { 4538 if (getToken().is(Token::r_brace)) 4539 return success(); 4540 4541 // If the block already has arguments, then we're handling the entry block. 4542 // Parse and register the names for the arguments, but do not add them. 4543 bool definingExistingArgs = owner->getNumArguments() != 0; 4544 unsigned nextArgument = 0; 4545 4546 return parseCommaSeparatedList([&]() -> ParseResult { 4547 return parseSSADefOrUseAndType( 4548 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 4549 // If this block did not have existing arguments, define a new one. 4550 if (!definingExistingArgs) 4551 return addDefinition(useInfo, owner->addArgument(type)); 4552 4553 // Otherwise, ensure that this argument has already been created. 4554 if (nextArgument >= owner->getNumArguments()) 4555 return emitError("too many arguments specified in argument list"); 4556 4557 // Finally, make sure the existing argument has the correct type. 4558 auto *arg = owner->getArgument(nextArgument++); 4559 if (arg->getType() != type) 4560 return emitError("argument and block argument type mismatch"); 4561 return addDefinition(useInfo, arg); 4562 }); 4563 }); 4564 } 4565 4566 //===----------------------------------------------------------------------===// 4567 // Top-level entity parsing. 4568 //===----------------------------------------------------------------------===// 4569 4570 namespace { 4571 /// This parser handles entities that are only valid at the top level of the 4572 /// file. 4573 class ModuleParser : public Parser { 4574 public: 4575 explicit ModuleParser(ParserState &state) : Parser(state) {} 4576 4577 ParseResult parseModule(ModuleOp module); 4578 4579 private: 4580 /// Parse an attribute alias declaration. 4581 ParseResult parseAttributeAliasDef(); 4582 4583 /// Parse an attribute alias declaration. 4584 ParseResult parseTypeAliasDef(); 4585 }; 4586 } // end anonymous namespace 4587 4588 /// Parses an attribute alias declaration. 4589 /// 4590 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 4591 /// 4592 ParseResult ModuleParser::parseAttributeAliasDef() { 4593 assert(getToken().is(Token::hash_identifier)); 4594 StringRef aliasName = getTokenSpelling().drop_front(); 4595 4596 // Check for redefinitions. 4597 if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) 4598 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 4599 4600 // Make sure this isn't invading the dialect attribute namespace. 4601 if (aliasName.contains('.')) 4602 return emitError("attribute names with a '.' are reserved for " 4603 "dialect-defined names"); 4604 4605 consumeToken(Token::hash_identifier); 4606 4607 // Parse the '='. 4608 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 4609 return failure(); 4610 4611 // Parse the attribute value. 4612 Attribute attr = parseAttribute(); 4613 if (!attr) 4614 return failure(); 4615 4616 getState().symbols.attributeAliasDefinitions[aliasName] = attr; 4617 return success(); 4618 } 4619 4620 /// Parse a type alias declaration. 4621 /// 4622 /// type-alias-def ::= '!' alias-name `=` 'type' type 4623 /// 4624 ParseResult ModuleParser::parseTypeAliasDef() { 4625 assert(getToken().is(Token::exclamation_identifier)); 4626 StringRef aliasName = getTokenSpelling().drop_front(); 4627 4628 // Check for redefinitions. 4629 if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) 4630 return emitError("redefinition of type alias id '" + aliasName + "'"); 4631 4632 // Make sure this isn't invading the dialect type namespace. 4633 if (aliasName.contains('.')) 4634 return emitError("type names with a '.' are reserved for " 4635 "dialect-defined names"); 4636 4637 consumeToken(Token::exclamation_identifier); 4638 4639 // Parse the '=' and 'type'. 4640 if (parseToken(Token::equal, "expected '=' in type alias definition") || 4641 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 4642 return failure(); 4643 4644 // Parse the type. 4645 Type aliasedType = parseType(); 4646 if (!aliasedType) 4647 return failure(); 4648 4649 // Register this alias with the parser state. 4650 getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); 4651 return success(); 4652 } 4653 4654 /// This is the top-level module parser. 4655 ParseResult ModuleParser::parseModule(ModuleOp module) { 4656 OperationParser opParser(getState(), module); 4657 4658 // Module itself is a name scope. 4659 opParser.pushSSANameScope(/*isIsolated=*/true); 4660 4661 while (true) { 4662 switch (getToken().getKind()) { 4663 default: 4664 // Parse a top-level operation. 4665 if (opParser.parseOperation()) 4666 return failure(); 4667 break; 4668 4669 // If we got to the end of the file, then we're done. 4670 case Token::eof: { 4671 if (opParser.finalize()) 4672 return failure(); 4673 4674 // Handle the case where the top level module was explicitly defined. 4675 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 4676 auto &operations = bodyBlocks.front().getOperations(); 4677 assert(!operations.empty() && "expected a valid module terminator"); 4678 4679 // Check that the first operation is a module, and it is the only 4680 // non-terminator operation. 4681 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 4682 if (nested && std::next(operations.begin(), 2) == operations.end()) { 4683 // Merge the data of the nested module operation into 'module'. 4684 module.setLoc(nested.getLoc()); 4685 module.setAttrs(nested.getOperation()->getAttrList()); 4686 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 4687 4688 // Erase the original module body. 4689 bodyBlocks.pop_front(); 4690 } 4691 4692 return opParser.popSSANameScope(); 4693 } 4694 4695 // If we got an error token, then the lexer already emitted an error, just 4696 // stop. Someday we could introduce error recovery if there was demand 4697 // for it. 4698 case Token::error: 4699 return failure(); 4700 4701 // Parse an attribute alias. 4702 case Token::hash_identifier: 4703 if (parseAttributeAliasDef()) 4704 return failure(); 4705 break; 4706 4707 // Parse a type alias. 4708 case Token::exclamation_identifier: 4709 if (parseTypeAliasDef()) 4710 return failure(); 4711 break; 4712 } 4713 } 4714 } 4715 4716 //===----------------------------------------------------------------------===// 4717 4718 /// This parses the file specified by the indicated SourceMgr and returns an 4719 /// MLIR module if it was valid. If not, it emits diagnostics and returns 4720 /// null. 4721 OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 4722 MLIRContext *context) { 4723 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 4724 4725 // This is the result module we are parsing into. 4726 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 4727 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 4728 4729 SymbolState aliasState; 4730 ParserState state(sourceMgr, context, aliasState); 4731 if (ModuleParser(state).parseModule(*module)) 4732 return nullptr; 4733 4734 // Make sure the parse module has no other structural problems detected by 4735 // the verifier. 4736 if (failed(verify(*module))) 4737 return nullptr; 4738 4739 return module; 4740 } 4741 4742 /// This parses the file specified by the indicated filename and returns an 4743 /// MLIR module if it was valid. If not, the error message is emitted through 4744 /// the error handler registered in the context, and a null pointer is returned. 4745 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4746 MLIRContext *context) { 4747 llvm::SourceMgr sourceMgr; 4748 return parseSourceFile(filename, sourceMgr, context); 4749 } 4750 4751 /// This parses the file specified by the indicated filename using the provided 4752 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 4753 /// message is emitted through the error handler registered in the context, and 4754 /// a null pointer is returned. 4755 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4756 llvm::SourceMgr &sourceMgr, 4757 MLIRContext *context) { 4758 if (sourceMgr.getNumBuffers() != 0) { 4759 // TODO(b/136086478): Extend to support multiple buffers. 4760 emitError(mlir::UnknownLoc::get(context), 4761 "only main buffer parsed at the moment"); 4762 return nullptr; 4763 } 4764 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 4765 if (std::error_code error = file_or_err.getError()) { 4766 emitError(mlir::UnknownLoc::get(context), 4767 "could not open input file " + filename); 4768 return nullptr; 4769 } 4770 4771 // Load the MLIR module. 4772 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 4773 return parseSourceFile(sourceMgr, context); 4774 } 4775 4776 /// This parses the program string to a MLIR module if it was valid. If not, 4777 /// it emits diagnostics and returns null. 4778 OwningModuleRef mlir::parseSourceString(StringRef moduleStr, 4779 MLIRContext *context) { 4780 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 4781 if (!memBuffer) 4782 return nullptr; 4783 4784 SourceMgr sourceMgr; 4785 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 4786 return parseSourceFile(sourceMgr, context); 4787 } 4788 4789 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 4790 /// parsing failed, nullptr is returned. The number of bytes read from the input 4791 /// string is returned in 'numRead'. 4792 template <typename T, typename ParserFn> 4793 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, 4794 ParserFn &&parserFn) { 4795 SymbolState aliasState; 4796 return parseSymbol<T>( 4797 inputStr, context, aliasState, 4798 [&](Parser &parser) { 4799 SourceMgrDiagnosticHandler handler( 4800 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), 4801 parser.getContext()); 4802 return parserFn(parser); 4803 }, 4804 &numRead); 4805 } 4806 4807 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { 4808 size_t numRead = 0; 4809 return parseAttribute(attrStr, context, numRead); 4810 } 4811 Attribute mlir::parseAttribute(StringRef attrStr, Type type) { 4812 size_t numRead = 0; 4813 return parseAttribute(attrStr, type, numRead); 4814 } 4815 4816 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, 4817 size_t &numRead) { 4818 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { 4819 return parser.parseAttribute(); 4820 }); 4821 } 4822 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { 4823 return parseSymbol<Attribute>( 4824 attrStr, type.getContext(), numRead, 4825 [type](Parser &parser) { return parser.parseAttribute(type); }); 4826 } 4827 4828 Type mlir::parseType(StringRef typeStr, MLIRContext *context) { 4829 size_t numRead = 0; 4830 return parseType(typeStr, context, numRead); 4831 } 4832 4833 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { 4834 return parseSymbol<Type>(typeStr, context, numRead, 4835 [](Parser &parser) { return parser.parseType(); }); 4836 } 4837