1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR textual form. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Parser.h" 14 #include "Lexer.h" 15 #include "mlir/Analysis/Verifier.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/IR/Location.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/Module.h" 26 #include "mlir/IR/OpImplementation.h" 27 #include "mlir/IR/StandardTypes.h" 28 #include "mlir/Support/STLExtras.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/ADT/bit.h" 33 #include "llvm/Support/PrettyStackTrace.h" 34 #include "llvm/Support/SMLoc.h" 35 #include "llvm/Support/SourceMgr.h" 36 #include <algorithm> 37 using namespace mlir; 38 using llvm::MemoryBuffer; 39 using llvm::SMLoc; 40 using llvm::SourceMgr; 41 42 namespace { 43 class Parser; 44 45 //===----------------------------------------------------------------------===// 46 // SymbolState 47 //===----------------------------------------------------------------------===// 48 49 /// This class contains record of any parsed top-level symbols. 50 struct SymbolState { 51 // A map from attribute alias identifier to Attribute. 52 llvm::StringMap<Attribute> attributeAliasDefinitions; 53 54 // A map from type alias identifier to Type. 55 llvm::StringMap<Type> typeAliasDefinitions; 56 57 /// A set of locations into the main parser memory buffer for each of the 58 /// active nested parsers. Given that some nested parsers, i.e. custom dialect 59 /// parsers, operate on a temporary memory buffer, this provides an anchor 60 /// point for emitting diagnostics. 61 SmallVector<llvm::SMLoc, 1> nestedParserLocs; 62 63 /// The top-level lexer that contains the original memory buffer provided by 64 /// the user. This is used by nested parsers to get a properly encoded source 65 /// location. 66 Lexer *topLevelLexer = nullptr; 67 }; 68 69 //===----------------------------------------------------------------------===// 70 // ParserState 71 //===----------------------------------------------------------------------===// 72 73 /// This class refers to all of the state maintained globally by the parser, 74 /// such as the current lexer position etc. 75 struct ParserState { 76 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, 77 SymbolState &symbols) 78 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), 79 symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { 80 // Set the top level lexer for the symbol state if one doesn't exist. 81 if (!symbols.topLevelLexer) 82 symbols.topLevelLexer = &lex; 83 } 84 ~ParserState() { 85 // Reset the top level lexer if it refers the lexer in our state. 86 if (symbols.topLevelLexer == &lex) 87 symbols.topLevelLexer = nullptr; 88 } 89 ParserState(const ParserState &) = delete; 90 void operator=(const ParserState &) = delete; 91 92 /// The context we're parsing into. 93 MLIRContext *const context; 94 95 /// The lexer for the source file we're parsing. 96 Lexer lex; 97 98 /// This is the next token that hasn't been consumed yet. 99 Token curToken; 100 101 /// The current state for symbol parsing. 102 SymbolState &symbols; 103 104 /// The depth of this parser in the nested parsing stack. 105 size_t parserDepth; 106 }; 107 108 //===----------------------------------------------------------------------===// 109 // Parser 110 //===----------------------------------------------------------------------===// 111 112 /// This class implement support for parsing global entities like types and 113 /// shared entities like SSA names. It is intended to be subclassed by 114 /// specialized subparsers that include state, e.g. when a local symbol table. 115 class Parser { 116 public: 117 Builder builder; 118 119 Parser(ParserState &state) : builder(state.context), state(state) {} 120 121 // Helper methods to get stuff from the parser-global state. 122 ParserState &getState() const { return state; } 123 MLIRContext *getContext() const { return state.context; } 124 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 125 126 /// Parse a comma-separated list of elements up until the specified end token. 127 ParseResult 128 parseCommaSeparatedListUntil(Token::Kind rightToken, 129 const std::function<ParseResult()> &parseElement, 130 bool allowEmptyList = true); 131 132 /// Parse a comma separated list of elements that must have at least one entry 133 /// in it. 134 ParseResult 135 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 136 137 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 138 139 // We have two forms of parsing methods - those that return a non-null 140 // pointer on success, and those that return a ParseResult to indicate whether 141 // they returned a failure. The second class fills in by-reference arguments 142 // as the results of their action. 143 144 //===--------------------------------------------------------------------===// 145 // Error Handling 146 //===--------------------------------------------------------------------===// 147 148 /// Emit an error and return failure. 149 InFlightDiagnostic emitError(const Twine &message = {}) { 150 return emitError(state.curToken.getLoc(), message); 151 } 152 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 153 154 /// Encode the specified source location information into an attribute for 155 /// attachment to the IR. 156 Location getEncodedSourceLocation(llvm::SMLoc loc) { 157 // If there are no active nested parsers, we can get the encoded source 158 // location directly. 159 if (state.parserDepth == 0) 160 return state.lex.getEncodedSourceLocation(loc); 161 // Otherwise, we need to re-encode it to point to the top level buffer. 162 return state.symbols.topLevelLexer->getEncodedSourceLocation( 163 remapLocationToTopLevelBuffer(loc)); 164 } 165 166 /// Remaps the given SMLoc to the top level lexer of the parser. This is used 167 /// to adjust locations of potentially nested parsers to ensure that they can 168 /// be emitted properly as diagnostics. 169 llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { 170 // If there are no active nested parsers, we can return location directly. 171 SymbolState &symbols = state.symbols; 172 if (state.parserDepth == 0) 173 return loc; 174 assert(symbols.topLevelLexer && "expected valid top-level lexer"); 175 176 // Otherwise, we need to remap the location to the main parser. This is 177 // simply offseting the location onto the location of the last nested 178 // parser. 179 size_t offset = loc.getPointer() - state.lex.getBufferBegin(); 180 auto *rawLoc = 181 symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; 182 return llvm::SMLoc::getFromPointer(rawLoc); 183 } 184 185 //===--------------------------------------------------------------------===// 186 // Token Parsing 187 //===--------------------------------------------------------------------===// 188 189 /// Return the current token the parser is inspecting. 190 const Token &getToken() const { return state.curToken; } 191 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 192 193 /// If the current token has the specified kind, consume it and return true. 194 /// If not, return false. 195 bool consumeIf(Token::Kind kind) { 196 if (state.curToken.isNot(kind)) 197 return false; 198 consumeToken(kind); 199 return true; 200 } 201 202 /// Advance the current lexer onto the next token. 203 void consumeToken() { 204 assert(state.curToken.isNot(Token::eof, Token::error) && 205 "shouldn't advance past EOF or errors"); 206 state.curToken = state.lex.lexToken(); 207 } 208 209 /// Advance the current lexer onto the next token, asserting what the expected 210 /// current token is. This is preferred to the above method because it leads 211 /// to more self-documenting code with better checking. 212 void consumeToken(Token::Kind kind) { 213 assert(state.curToken.is(kind) && "consumed an unexpected token"); 214 consumeToken(); 215 } 216 217 /// Consume the specified token if present and return success. On failure, 218 /// output a diagnostic and return failure. 219 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 220 221 //===--------------------------------------------------------------------===// 222 // Type Parsing 223 //===--------------------------------------------------------------------===// 224 225 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 226 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 227 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 228 229 /// Parse an arbitrary type. 230 Type parseType(); 231 232 /// Parse a complex type. 233 Type parseComplexType(); 234 235 /// Parse an extended type. 236 Type parseExtendedType(); 237 238 /// Parse a function type. 239 Type parseFunctionType(); 240 241 /// Parse a memref type. 242 Type parseMemRefType(); 243 244 /// Parse a non function type. 245 Type parseNonFunctionType(); 246 247 /// Parse a tensor type. 248 Type parseTensorType(); 249 250 /// Parse a tuple type. 251 Type parseTupleType(); 252 253 /// Parse a vector type. 254 VectorType parseVectorType(); 255 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 256 bool allowDynamic = true); 257 ParseResult parseXInDimensionList(); 258 259 /// Parse strided layout specification. 260 ParseResult parseStridedLayout(int64_t &offset, 261 SmallVectorImpl<int64_t> &strides); 262 263 // Parse a brace-delimiter list of comma-separated integers with `?` as an 264 // unknown marker. 265 ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions); 266 267 //===--------------------------------------------------------------------===// 268 // Attribute Parsing 269 //===--------------------------------------------------------------------===// 270 271 /// Parse an arbitrary attribute with an optional type. 272 Attribute parseAttribute(Type type = {}); 273 274 /// Parse an attribute dictionary. 275 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 276 277 /// Parse an extended attribute. 278 Attribute parseExtendedAttr(Type type); 279 280 /// Parse a float attribute. 281 Attribute parseFloatAttr(Type type, bool isNegative); 282 283 /// Parse a decimal or a hexadecimal literal, which can be either an integer 284 /// or a float attribute. 285 Attribute parseDecOrHexAttr(Type type, bool isNegative); 286 287 /// Parse an opaque elements attribute. 288 Attribute parseOpaqueElementsAttr(); 289 290 /// Parse a dense elements attribute. 291 Attribute parseDenseElementsAttr(); 292 ShapedType parseElementsLiteralType(); 293 294 /// Parse a sparse elements attribute. 295 Attribute parseSparseElementsAttr(); 296 297 //===--------------------------------------------------------------------===// 298 // Location Parsing 299 //===--------------------------------------------------------------------===// 300 301 /// Parse an inline location. 302 ParseResult parseLocation(LocationAttr &loc); 303 304 /// Parse a raw location instance. 305 ParseResult parseLocationInstance(LocationAttr &loc); 306 307 /// Parse a callsite location instance. 308 ParseResult parseCallSiteLocation(LocationAttr &loc); 309 310 /// Parse a fused location instance. 311 ParseResult parseFusedLocation(LocationAttr &loc); 312 313 /// Parse a name or FileLineCol location instance. 314 ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); 315 316 /// Parse an optional trailing location. 317 /// 318 /// trailing-location ::= (`loc` `(` location `)`)? 319 /// 320 ParseResult parseOptionalTrailingLocation(Location &loc) { 321 // If there is a 'loc' we parse a trailing location. 322 if (!getToken().is(Token::kw_loc)) 323 return success(); 324 325 // Parse the location. 326 LocationAttr directLoc; 327 if (parseLocation(directLoc)) 328 return failure(); 329 loc = directLoc; 330 return success(); 331 } 332 333 //===--------------------------------------------------------------------===// 334 // Affine Parsing 335 //===--------------------------------------------------------------------===// 336 337 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 338 IntegerSet &set); 339 340 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 341 ParseResult 342 parseAffineMapOfSSAIds(AffineMap &map, 343 function_ref<ParseResult(bool)> parseElement); 344 345 private: 346 /// The Parser is subclassed and reinstantiated. Do not add additional 347 /// non-trivial state here, add it to the ParserState class. 348 ParserState &state; 349 }; 350 } // end anonymous namespace 351 352 //===----------------------------------------------------------------------===// 353 // Helper methods. 354 //===----------------------------------------------------------------------===// 355 356 /// Parse a comma separated list of elements that must have at least one entry 357 /// in it. 358 ParseResult Parser::parseCommaSeparatedList( 359 const std::function<ParseResult()> &parseElement) { 360 // Non-empty case starts with an element. 361 if (parseElement()) 362 return failure(); 363 364 // Otherwise we have a list of comma separated elements. 365 while (consumeIf(Token::comma)) { 366 if (parseElement()) 367 return failure(); 368 } 369 return success(); 370 } 371 372 /// Parse a comma-separated list of elements, terminated with an arbitrary 373 /// token. This allows empty lists if allowEmptyList is true. 374 /// 375 /// abstract-list ::= rightToken // if allowEmptyList == true 376 /// abstract-list ::= element (',' element)* rightToken 377 /// 378 ParseResult Parser::parseCommaSeparatedListUntil( 379 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 380 bool allowEmptyList) { 381 // Handle the empty case. 382 if (getToken().is(rightToken)) { 383 if (!allowEmptyList) 384 return emitError("expected list element"); 385 consumeToken(rightToken); 386 return success(); 387 } 388 389 if (parseCommaSeparatedList(parseElement) || 390 parseToken(rightToken, "expected ',' or '" + 391 Token::getTokenSpelling(rightToken) + "'")) 392 return failure(); 393 394 return success(); 395 } 396 397 //===----------------------------------------------------------------------===// 398 // DialectAsmParser 399 //===----------------------------------------------------------------------===// 400 401 namespace { 402 /// This class provides the main implementation of the DialectAsmParser that 403 /// allows for dialects to parse attributes and types. This allows for dialect 404 /// hooking into the main MLIR parsing logic. 405 class CustomDialectAsmParser : public DialectAsmParser { 406 public: 407 CustomDialectAsmParser(StringRef fullSpec, Parser &parser) 408 : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), 409 parser(parser) {} 410 ~CustomDialectAsmParser() override {} 411 412 /// Emit a diagnostic at the specified location and return failure. 413 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 414 return parser.emitError(loc, message); 415 } 416 417 /// Return a builder which provides useful access to MLIRContext, global 418 /// objects like types and attributes. 419 Builder &getBuilder() const override { return parser.builder; } 420 421 /// Get the location of the next token and store it into the argument. This 422 /// always succeeds. 423 llvm::SMLoc getCurrentLocation() override { 424 return parser.getToken().getLoc(); 425 } 426 427 /// Return the location of the original name token. 428 llvm::SMLoc getNameLoc() const override { return nameLoc; } 429 430 /// Re-encode the given source location as an MLIR location and return it. 431 Location getEncodedSourceLoc(llvm::SMLoc loc) override { 432 return parser.getEncodedSourceLocation(loc); 433 } 434 435 /// Returns the full specification of the symbol being parsed. This allows 436 /// for using a separate parser if necessary. 437 StringRef getFullSymbolSpec() const override { return fullSpec; } 438 439 /// Parse a floating point value from the stream. 440 ParseResult parseFloat(double &result) override { 441 bool negative = parser.consumeIf(Token::minus); 442 Token curTok = parser.getToken(); 443 444 // Check for a floating point value. 445 if (curTok.is(Token::floatliteral)) { 446 auto val = curTok.getFloatingPointValue(); 447 if (!val.hasValue()) 448 return emitError(curTok.getLoc(), "floating point value too large"); 449 parser.consumeToken(Token::floatliteral); 450 result = negative ? -*val : *val; 451 return success(); 452 } 453 454 // TODO(riverriddle) support hex floating point values. 455 return emitError(getCurrentLocation(), "expected floating point literal"); 456 } 457 458 /// Parse an optional integer value from the stream. 459 OptionalParseResult parseOptionalInteger(uint64_t &result) override { 460 Token curToken = parser.getToken(); 461 if (curToken.isNot(Token::integer, Token::minus)) 462 return llvm::None; 463 464 bool negative = parser.consumeIf(Token::minus); 465 Token curTok = parser.getToken(); 466 if (parser.parseToken(Token::integer, "expected integer value")) 467 return failure(); 468 469 auto val = curTok.getUInt64IntegerValue(); 470 if (!val) 471 return emitError(curTok.getLoc(), "integer value too large"); 472 result = negative ? -*val : *val; 473 return success(); 474 } 475 476 //===--------------------------------------------------------------------===// 477 // Token Parsing 478 //===--------------------------------------------------------------------===// 479 480 /// Parse a `->` token. 481 ParseResult parseArrow() override { 482 return parser.parseToken(Token::arrow, "expected '->'"); 483 } 484 485 /// Parses a `->` if present. 486 ParseResult parseOptionalArrow() override { 487 return success(parser.consumeIf(Token::arrow)); 488 } 489 490 /// Parse a '{' token. 491 ParseResult parseLBrace() override { 492 return parser.parseToken(Token::l_brace, "expected '{'"); 493 } 494 495 /// Parse a '{' token if present 496 ParseResult parseOptionalLBrace() override { 497 return success(parser.consumeIf(Token::l_brace)); 498 } 499 500 /// Parse a `}` token. 501 ParseResult parseRBrace() override { 502 return parser.parseToken(Token::r_brace, "expected '}'"); 503 } 504 505 /// Parse a `}` token if present 506 ParseResult parseOptionalRBrace() override { 507 return success(parser.consumeIf(Token::r_brace)); 508 } 509 510 /// Parse a `:` token. 511 ParseResult parseColon() override { 512 return parser.parseToken(Token::colon, "expected ':'"); 513 } 514 515 /// Parse a `:` token if present. 516 ParseResult parseOptionalColon() override { 517 return success(parser.consumeIf(Token::colon)); 518 } 519 520 /// Parse a `,` token. 521 ParseResult parseComma() override { 522 return parser.parseToken(Token::comma, "expected ','"); 523 } 524 525 /// Parse a `,` token if present. 526 ParseResult parseOptionalComma() override { 527 return success(parser.consumeIf(Token::comma)); 528 } 529 530 /// Parses a `...` if present. 531 ParseResult parseOptionalEllipsis() override { 532 return success(parser.consumeIf(Token::ellipsis)); 533 } 534 535 /// Parse a `=` token. 536 ParseResult parseEqual() override { 537 return parser.parseToken(Token::equal, "expected '='"); 538 } 539 540 /// Parse a '<' token. 541 ParseResult parseLess() override { 542 return parser.parseToken(Token::less, "expected '<'"); 543 } 544 545 /// Parse a `<` token if present. 546 ParseResult parseOptionalLess() override { 547 return success(parser.consumeIf(Token::less)); 548 } 549 550 /// Parse a '>' token. 551 ParseResult parseGreater() override { 552 return parser.parseToken(Token::greater, "expected '>'"); 553 } 554 555 /// Parse a `>` token if present. 556 ParseResult parseOptionalGreater() override { 557 return success(parser.consumeIf(Token::greater)); 558 } 559 560 /// Parse a `(` token. 561 ParseResult parseLParen() override { 562 return parser.parseToken(Token::l_paren, "expected '('"); 563 } 564 565 /// Parses a '(' if present. 566 ParseResult parseOptionalLParen() override { 567 return success(parser.consumeIf(Token::l_paren)); 568 } 569 570 /// Parse a `)` token. 571 ParseResult parseRParen() override { 572 return parser.parseToken(Token::r_paren, "expected ')'"); 573 } 574 575 /// Parses a ')' if present. 576 ParseResult parseOptionalRParen() override { 577 return success(parser.consumeIf(Token::r_paren)); 578 } 579 580 /// Parse a `[` token. 581 ParseResult parseLSquare() override { 582 return parser.parseToken(Token::l_square, "expected '['"); 583 } 584 585 /// Parses a '[' if present. 586 ParseResult parseOptionalLSquare() override { 587 return success(parser.consumeIf(Token::l_square)); 588 } 589 590 /// Parse a `]` token. 591 ParseResult parseRSquare() override { 592 return parser.parseToken(Token::r_square, "expected ']'"); 593 } 594 595 /// Parses a ']' if present. 596 ParseResult parseOptionalRSquare() override { 597 return success(parser.consumeIf(Token::r_square)); 598 } 599 600 /// Parses a '?' if present. 601 ParseResult parseOptionalQuestion() override { 602 return success(parser.consumeIf(Token::question)); 603 } 604 605 /// Parses a '*' if present. 606 ParseResult parseOptionalStar() override { 607 return success(parser.consumeIf(Token::star)); 608 } 609 610 /// Returns if the current token corresponds to a keyword. 611 bool isCurrentTokenAKeyword() const { 612 return parser.getToken().is(Token::bare_identifier) || 613 parser.getToken().isKeyword(); 614 } 615 616 /// Parse the given keyword if present. 617 ParseResult parseOptionalKeyword(StringRef keyword) override { 618 // Check that the current token has the same spelling. 619 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 620 return failure(); 621 parser.consumeToken(); 622 return success(); 623 } 624 625 /// Parse a keyword, if present, into 'keyword'. 626 ParseResult parseOptionalKeyword(StringRef *keyword) override { 627 // Check that the current token is a keyword. 628 if (!isCurrentTokenAKeyword()) 629 return failure(); 630 631 *keyword = parser.getTokenSpelling(); 632 parser.consumeToken(); 633 return success(); 634 } 635 636 //===--------------------------------------------------------------------===// 637 // Attribute Parsing 638 //===--------------------------------------------------------------------===// 639 640 /// Parse an arbitrary attribute and return it in result. 641 ParseResult parseAttribute(Attribute &result, Type type) override { 642 result = parser.parseAttribute(type); 643 return success(static_cast<bool>(result)); 644 } 645 646 //===--------------------------------------------------------------------===// 647 // Type Parsing 648 //===--------------------------------------------------------------------===// 649 650 ParseResult parseType(Type &result) override { 651 result = parser.parseType(); 652 return success(static_cast<bool>(result)); 653 } 654 655 ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, 656 bool allowDynamic) override { 657 return parser.parseDimensionListRanked(dimensions, allowDynamic); 658 } 659 660 private: 661 /// The full symbol specification. 662 StringRef fullSpec; 663 664 /// The source location of the dialect symbol. 665 SMLoc nameLoc; 666 667 /// The main parser. 668 Parser &parser; 669 }; 670 } // namespace 671 672 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 673 /// and may be recursive. Return with the 'prettyName' StringRef encompassing 674 /// the entire pretty name. 675 /// 676 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 677 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 678 /// | '(' pretty-dialect-sym-contents+ ')' 679 /// | '[' pretty-dialect-sym-contents+ ']' 680 /// | '{' pretty-dialect-sym-contents+ '}' 681 /// | '[^[<({>\])}\0]+' 682 /// 683 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 684 // Pretty symbol names are a relatively unstructured format that contains a 685 // series of properly nested punctuation, with anything else in the middle. 686 // Scan ahead to find it and consume it if successful, otherwise emit an 687 // error. 688 auto *curPtr = getTokenSpelling().data(); 689 690 SmallVector<char, 8> nestedPunctuation; 691 692 // Scan over the nested punctuation, bailing out on error and consuming until 693 // we find the end. We know that we're currently looking at the '<', so we 694 // can go until we find the matching '>' character. 695 assert(*curPtr == '<'); 696 do { 697 char c = *curPtr++; 698 switch (c) { 699 case '\0': 700 // This also handles the EOF case. 701 return emitError("unexpected nul or EOF in pretty dialect name"); 702 case '<': 703 case '[': 704 case '(': 705 case '{': 706 nestedPunctuation.push_back(c); 707 continue; 708 709 case '-': 710 // The sequence `->` is treated as special token. 711 if (*curPtr == '>') 712 ++curPtr; 713 continue; 714 715 case '>': 716 if (nestedPunctuation.pop_back_val() != '<') 717 return emitError("unbalanced '>' character in pretty dialect name"); 718 break; 719 case ']': 720 if (nestedPunctuation.pop_back_val() != '[') 721 return emitError("unbalanced ']' character in pretty dialect name"); 722 break; 723 case ')': 724 if (nestedPunctuation.pop_back_val() != '(') 725 return emitError("unbalanced ')' character in pretty dialect name"); 726 break; 727 case '}': 728 if (nestedPunctuation.pop_back_val() != '{') 729 return emitError("unbalanced '}' character in pretty dialect name"); 730 break; 731 732 default: 733 continue; 734 } 735 } while (!nestedPunctuation.empty()); 736 737 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 738 // consuming all this stuff, and return. 739 state.lex.resetPointer(curPtr); 740 741 unsigned length = curPtr - prettyName.begin(); 742 prettyName = StringRef(prettyName.begin(), length); 743 consumeToken(); 744 return success(); 745 } 746 747 /// Parse an extended dialect symbol. 748 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 749 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 750 SymbolAliasMap &aliases, 751 CreateFn &&createSymbol) { 752 // Parse the dialect namespace. 753 StringRef identifier = p.getTokenSpelling().drop_front(); 754 auto loc = p.getToken().getLoc(); 755 p.consumeToken(identifierTok); 756 757 // If there is no '<' token following this, and if the typename contains no 758 // dot, then we are parsing a symbol alias. 759 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 760 // Check for an alias for this type. 761 auto aliasIt = aliases.find(identifier); 762 if (aliasIt == aliases.end()) 763 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 764 nullptr); 765 return aliasIt->second; 766 } 767 768 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 769 // a dot, then this is the "pretty" form. If not, it is the verbose form that 770 // looks like <"...">. 771 std::string symbolData; 772 auto dialectName = identifier; 773 774 // Handle the verbose form, where "identifier" is a simple dialect name. 775 if (!identifier.contains('.')) { 776 // Consume the '<'. 777 if (p.parseToken(Token::less, "expected '<' in dialect type")) 778 return nullptr; 779 780 // Parse the symbol specific data. 781 if (p.getToken().isNot(Token::string)) 782 return (p.emitError("expected string literal data in dialect symbol"), 783 nullptr); 784 symbolData = p.getToken().getStringValue(); 785 loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); 786 p.consumeToken(Token::string); 787 788 // Consume the '>'. 789 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 790 return nullptr; 791 } else { 792 // Ok, the dialect name is the part of the identifier before the dot, the 793 // part after the dot is the dialect's symbol, or the start thereof. 794 auto dotHalves = identifier.split('.'); 795 dialectName = dotHalves.first; 796 auto prettyName = dotHalves.second; 797 loc = llvm::SMLoc::getFromPointer(prettyName.data()); 798 799 // If the dialect's symbol is followed immediately by a <, then lex the body 800 // of it into prettyName. 801 if (p.getToken().is(Token::less) && 802 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 803 if (p.parsePrettyDialectSymbolName(prettyName)) 804 return nullptr; 805 } 806 807 symbolData = prettyName.str(); 808 } 809 810 // Record the name location of the type remapped to the top level buffer. 811 llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); 812 p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); 813 814 // Call into the provided symbol construction function. 815 Symbol sym = createSymbol(dialectName, symbolData, loc); 816 817 // Pop the last parser location. 818 p.getState().symbols.nestedParserLocs.pop_back(); 819 return sym; 820 } 821 822 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 823 /// parsing failed, nullptr is returned. The number of bytes read from the input 824 /// string is returned in 'numRead'. 825 template <typename T, typename ParserFn> 826 static T parseSymbol(StringRef inputStr, MLIRContext *context, 827 SymbolState &symbolState, ParserFn &&parserFn, 828 size_t *numRead = nullptr) { 829 SourceMgr sourceMgr; 830 auto memBuffer = MemoryBuffer::getMemBuffer( 831 inputStr, /*BufferName=*/"<mlir_parser_buffer>", 832 /*RequiresNullTerminator=*/false); 833 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 834 ParserState state(sourceMgr, context, symbolState); 835 Parser parser(state); 836 837 Token startTok = parser.getToken(); 838 T symbol = parserFn(parser); 839 if (!symbol) 840 return T(); 841 842 // If 'numRead' is valid, then provide the number of bytes that were read. 843 Token endTok = parser.getToken(); 844 if (numRead) { 845 *numRead = static_cast<size_t>(endTok.getLoc().getPointer() - 846 startTok.getLoc().getPointer()); 847 848 // Otherwise, ensure that all of the tokens were parsed. 849 } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { 850 parser.emitError(endTok.getLoc(), "encountered unexpected token"); 851 return T(); 852 } 853 return symbol; 854 } 855 856 //===----------------------------------------------------------------------===// 857 // Error Handling 858 //===----------------------------------------------------------------------===// 859 860 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 861 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 862 863 // If we hit a parse error in response to a lexer error, then the lexer 864 // already reported the error. 865 if (getToken().is(Token::error)) 866 diag.abandon(); 867 return diag; 868 } 869 870 //===----------------------------------------------------------------------===// 871 // Token Parsing 872 //===----------------------------------------------------------------------===// 873 874 /// Consume the specified token if present and return success. On failure, 875 /// output a diagnostic and return failure. 876 ParseResult Parser::parseToken(Token::Kind expectedToken, 877 const Twine &message) { 878 if (consumeIf(expectedToken)) 879 return success(); 880 return emitError(message); 881 } 882 883 //===----------------------------------------------------------------------===// 884 // Type Parsing 885 //===----------------------------------------------------------------------===// 886 887 /// Parse an arbitrary type. 888 /// 889 /// type ::= function-type 890 /// | non-function-type 891 /// 892 Type Parser::parseType() { 893 if (getToken().is(Token::l_paren)) 894 return parseFunctionType(); 895 return parseNonFunctionType(); 896 } 897 898 /// Parse a function result type. 899 /// 900 /// function-result-type ::= type-list-parens 901 /// | non-function-type 902 /// 903 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 904 if (getToken().is(Token::l_paren)) 905 return parseTypeListParens(elements); 906 907 Type t = parseNonFunctionType(); 908 if (!t) 909 return failure(); 910 elements.push_back(t); 911 return success(); 912 } 913 914 /// Parse a list of types without an enclosing parenthesis. The list must have 915 /// at least one member. 916 /// 917 /// type-list-no-parens ::= type (`,` type)* 918 /// 919 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 920 auto parseElt = [&]() -> ParseResult { 921 auto elt = parseType(); 922 elements.push_back(elt); 923 return elt ? success() : failure(); 924 }; 925 926 return parseCommaSeparatedList(parseElt); 927 } 928 929 /// Parse a parenthesized list of types. 930 /// 931 /// type-list-parens ::= `(` `)` 932 /// | `(` type-list-no-parens `)` 933 /// 934 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 935 if (parseToken(Token::l_paren, "expected '('")) 936 return failure(); 937 938 // Handle empty lists. 939 if (getToken().is(Token::r_paren)) 940 return consumeToken(), success(); 941 942 if (parseTypeListNoParens(elements) || 943 parseToken(Token::r_paren, "expected ')'")) 944 return failure(); 945 return success(); 946 } 947 948 /// Parse a complex type. 949 /// 950 /// complex-type ::= `complex` `<` type `>` 951 /// 952 Type Parser::parseComplexType() { 953 consumeToken(Token::kw_complex); 954 955 // Parse the '<'. 956 if (parseToken(Token::less, "expected '<' in complex type")) 957 return nullptr; 958 959 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 960 auto elementType = parseType(); 961 if (!elementType || 962 parseToken(Token::greater, "expected '>' in complex type")) 963 return nullptr; 964 965 return ComplexType::getChecked(elementType, typeLocation); 966 } 967 968 /// Parse an extended type. 969 /// 970 /// extended-type ::= (dialect-type | type-alias) 971 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 972 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 973 /// type-alias ::= `!` alias-name 974 /// 975 Type Parser::parseExtendedType() { 976 return parseExtendedSymbol<Type>( 977 *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, 978 [&](StringRef dialectName, StringRef symbolData, 979 llvm::SMLoc loc) -> Type { 980 // If we found a registered dialect, then ask it to parse the type. 981 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 982 return parseSymbol<Type>( 983 symbolData, state.context, state.symbols, [&](Parser &parser) { 984 CustomDialectAsmParser customParser(symbolData, parser); 985 return dialect->parseType(customParser); 986 }); 987 } 988 989 // Otherwise, form a new opaque type. 990 return OpaqueType::getChecked( 991 Identifier::get(dialectName, state.context), symbolData, 992 state.context, getEncodedSourceLocation(loc)); 993 }); 994 } 995 996 /// Parse a function type. 997 /// 998 /// function-type ::= type-list-parens `->` function-result-type 999 /// 1000 Type Parser::parseFunctionType() { 1001 assert(getToken().is(Token::l_paren)); 1002 1003 SmallVector<Type, 4> arguments, results; 1004 if (parseTypeListParens(arguments) || 1005 parseToken(Token::arrow, "expected '->' in function type") || 1006 parseFunctionResultTypes(results)) 1007 return nullptr; 1008 1009 return builder.getFunctionType(arguments, results); 1010 } 1011 1012 /// Parse the offset and strides from a strided layout specification. 1013 /// 1014 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 1015 /// 1016 ParseResult Parser::parseStridedLayout(int64_t &offset, 1017 SmallVectorImpl<int64_t> &strides) { 1018 // Parse offset. 1019 consumeToken(Token::kw_offset); 1020 if (!consumeIf(Token::colon)) 1021 return emitError("expected colon after `offset` keyword"); 1022 auto maybeOffset = getToken().getUnsignedIntegerValue(); 1023 bool question = getToken().is(Token::question); 1024 if (!maybeOffset && !question) 1025 return emitError("invalid offset"); 1026 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue()) 1027 : MemRefType::getDynamicStrideOrOffset(); 1028 consumeToken(); 1029 1030 if (!consumeIf(Token::comma)) 1031 return emitError("expected comma after offset value"); 1032 1033 // Parse stride list. 1034 if (!consumeIf(Token::kw_strides)) 1035 return emitError("expected `strides` keyword after offset specification"); 1036 if (!consumeIf(Token::colon)) 1037 return emitError("expected colon after `strides` keyword"); 1038 if (failed(parseStrideList(strides))) 1039 return emitError("invalid braces-enclosed stride list"); 1040 if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) 1041 return emitError("invalid memref stride"); 1042 1043 return success(); 1044 } 1045 1046 /// Parse a memref type. 1047 /// 1048 /// memref-type ::= ranked-memref-type | unranked-memref-type 1049 /// 1050 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type 1051 /// (`,` semi-affine-map-composition)? (`,` 1052 /// memory-space)? `>` 1053 /// 1054 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` 1055 /// 1056 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 1057 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 1058 /// 1059 Type Parser::parseMemRefType() { 1060 consumeToken(Token::kw_memref); 1061 1062 if (parseToken(Token::less, "expected '<' in memref type")) 1063 return nullptr; 1064 1065 bool isUnranked; 1066 SmallVector<int64_t, 4> dimensions; 1067 1068 if (consumeIf(Token::star)) { 1069 // This is an unranked memref type. 1070 isUnranked = true; 1071 if (parseXInDimensionList()) 1072 return nullptr; 1073 1074 } else { 1075 isUnranked = false; 1076 if (parseDimensionListRanked(dimensions)) 1077 return nullptr; 1078 } 1079 1080 // Parse the element type. 1081 auto typeLoc = getToken().getLoc(); 1082 auto elementType = parseType(); 1083 if (!elementType) 1084 return nullptr; 1085 1086 // Parse semi-affine-map-composition. 1087 SmallVector<AffineMap, 2> affineMapComposition; 1088 unsigned memorySpace = 0; 1089 bool parsedMemorySpace = false; 1090 1091 auto parseElt = [&]() -> ParseResult { 1092 if (getToken().is(Token::integer)) { 1093 // Parse memory space. 1094 if (parsedMemorySpace) 1095 return emitError("multiple memory spaces specified in memref type"); 1096 auto v = getToken().getUnsignedIntegerValue(); 1097 if (!v.hasValue()) 1098 return emitError("invalid memory space in memref type"); 1099 memorySpace = v.getValue(); 1100 consumeToken(Token::integer); 1101 parsedMemorySpace = true; 1102 } else { 1103 if (isUnranked) 1104 return emitError("cannot have affine map for unranked memref type"); 1105 if (parsedMemorySpace) 1106 return emitError("expected memory space to be last in memref type"); 1107 if (getToken().is(Token::kw_offset)) { 1108 int64_t offset; 1109 SmallVector<int64_t, 4> strides; 1110 if (failed(parseStridedLayout(offset, strides))) 1111 return failure(); 1112 // Construct strided affine map. 1113 auto map = makeStridedLinearLayoutMap(strides, offset, 1114 elementType.getContext()); 1115 affineMapComposition.push_back(map); 1116 } else { 1117 // Parse affine map. 1118 auto affineMap = parseAttribute(); 1119 if (!affineMap) 1120 return failure(); 1121 // Verify that the parsed attribute is an affine map. 1122 if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>()) 1123 affineMapComposition.push_back(affineMapAttr.getValue()); 1124 else 1125 return emitError("expected affine map in memref type"); 1126 } 1127 } 1128 return success(); 1129 }; 1130 1131 // Parse a list of mappings and address space if present. 1132 if (consumeIf(Token::comma)) { 1133 // Parse comma separated list of affine maps, followed by memory space. 1134 if (parseCommaSeparatedListUntil(Token::greater, parseElt, 1135 /*allowEmptyList=*/false)) { 1136 return nullptr; 1137 } 1138 } else { 1139 if (parseToken(Token::greater, "expected ',' or '>' in memref type")) 1140 return nullptr; 1141 } 1142 1143 if (isUnranked) 1144 return UnrankedMemRefType::getChecked(elementType, memorySpace, 1145 getEncodedSourceLocation(typeLoc)); 1146 1147 return MemRefType::getChecked(dimensions, elementType, affineMapComposition, 1148 memorySpace, getEncodedSourceLocation(typeLoc)); 1149 } 1150 1151 /// Parse any type except the function type. 1152 /// 1153 /// non-function-type ::= integer-type 1154 /// | index-type 1155 /// | float-type 1156 /// | extended-type 1157 /// | vector-type 1158 /// | tensor-type 1159 /// | memref-type 1160 /// | complex-type 1161 /// | tuple-type 1162 /// | none-type 1163 /// 1164 /// index-type ::= `index` 1165 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 1166 /// none-type ::= `none` 1167 /// 1168 Type Parser::parseNonFunctionType() { 1169 switch (getToken().getKind()) { 1170 default: 1171 return (emitError("expected non-function type"), nullptr); 1172 case Token::kw_memref: 1173 return parseMemRefType(); 1174 case Token::kw_tensor: 1175 return parseTensorType(); 1176 case Token::kw_complex: 1177 return parseComplexType(); 1178 case Token::kw_tuple: 1179 return parseTupleType(); 1180 case Token::kw_vector: 1181 return parseVectorType(); 1182 // integer-type 1183 case Token::inttype: { 1184 auto width = getToken().getIntTypeBitwidth(); 1185 if (!width.hasValue()) 1186 return (emitError("invalid integer width"), nullptr); 1187 auto loc = getEncodedSourceLocation(getToken().getLoc()); 1188 consumeToken(Token::inttype); 1189 return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); 1190 } 1191 1192 // float-type 1193 case Token::kw_bf16: 1194 consumeToken(Token::kw_bf16); 1195 return builder.getBF16Type(); 1196 case Token::kw_f16: 1197 consumeToken(Token::kw_f16); 1198 return builder.getF16Type(); 1199 case Token::kw_f32: 1200 consumeToken(Token::kw_f32); 1201 return builder.getF32Type(); 1202 case Token::kw_f64: 1203 consumeToken(Token::kw_f64); 1204 return builder.getF64Type(); 1205 1206 // index-type 1207 case Token::kw_index: 1208 consumeToken(Token::kw_index); 1209 return builder.getIndexType(); 1210 1211 // none-type 1212 case Token::kw_none: 1213 consumeToken(Token::kw_none); 1214 return builder.getNoneType(); 1215 1216 // extended type 1217 case Token::exclamation_identifier: 1218 return parseExtendedType(); 1219 } 1220 } 1221 1222 /// Parse a tensor type. 1223 /// 1224 /// tensor-type ::= `tensor` `<` dimension-list type `>` 1225 /// dimension-list ::= dimension-list-ranked | `*x` 1226 /// 1227 Type Parser::parseTensorType() { 1228 consumeToken(Token::kw_tensor); 1229 1230 if (parseToken(Token::less, "expected '<' in tensor type")) 1231 return nullptr; 1232 1233 bool isUnranked; 1234 SmallVector<int64_t, 4> dimensions; 1235 1236 if (consumeIf(Token::star)) { 1237 // This is an unranked tensor type. 1238 isUnranked = true; 1239 1240 if (parseXInDimensionList()) 1241 return nullptr; 1242 1243 } else { 1244 isUnranked = false; 1245 if (parseDimensionListRanked(dimensions)) 1246 return nullptr; 1247 } 1248 1249 // Parse the element type. 1250 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 1251 auto elementType = parseType(); 1252 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 1253 return nullptr; 1254 1255 if (isUnranked) 1256 return UnrankedTensorType::getChecked(elementType, typeLocation); 1257 return RankedTensorType::getChecked(dimensions, elementType, typeLocation); 1258 } 1259 1260 /// Parse a tuple type. 1261 /// 1262 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 1263 /// 1264 Type Parser::parseTupleType() { 1265 consumeToken(Token::kw_tuple); 1266 1267 // Parse the '<'. 1268 if (parseToken(Token::less, "expected '<' in tuple type")) 1269 return nullptr; 1270 1271 // Check for an empty tuple by directly parsing '>'. 1272 if (consumeIf(Token::greater)) 1273 return TupleType::get(getContext()); 1274 1275 // Parse the element types and the '>'. 1276 SmallVector<Type, 4> types; 1277 if (parseTypeListNoParens(types) || 1278 parseToken(Token::greater, "expected '>' in tuple type")) 1279 return nullptr; 1280 1281 return TupleType::get(types, getContext()); 1282 } 1283 1284 /// Parse a vector type. 1285 /// 1286 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` 1287 /// non-empty-static-dimension-list ::= decimal-literal `x` 1288 /// static-dimension-list 1289 /// static-dimension-list ::= (decimal-literal `x`)* 1290 /// 1291 VectorType Parser::parseVectorType() { 1292 consumeToken(Token::kw_vector); 1293 1294 if (parseToken(Token::less, "expected '<' in vector type")) 1295 return nullptr; 1296 1297 SmallVector<int64_t, 4> dimensions; 1298 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 1299 return nullptr; 1300 if (dimensions.empty()) 1301 return (emitError("expected dimension size in vector type"), nullptr); 1302 1303 // Parse the element type. 1304 auto typeLoc = getToken().getLoc(); 1305 auto elementType = parseType(); 1306 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 1307 return nullptr; 1308 1309 return VectorType::getChecked(dimensions, elementType, 1310 getEncodedSourceLocation(typeLoc)); 1311 } 1312 1313 /// Parse a dimension list of a tensor or memref type. This populates the 1314 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 1315 /// errors out on `?` otherwise. 1316 /// 1317 /// dimension-list-ranked ::= (dimension `x`)* 1318 /// dimension ::= `?` | decimal-literal 1319 /// 1320 /// When `allowDynamic` is not set, this is used to parse: 1321 /// 1322 /// static-dimension-list ::= (decimal-literal `x`)* 1323 ParseResult 1324 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 1325 bool allowDynamic) { 1326 while (getToken().isAny(Token::integer, Token::question)) { 1327 if (consumeIf(Token::question)) { 1328 if (!allowDynamic) 1329 return emitError("expected static shape"); 1330 dimensions.push_back(-1); 1331 } else { 1332 // Hexadecimal integer literals (starting with `0x`) are not allowed in 1333 // aggregate type declarations. Therefore, `0xf32` should be processed as 1334 // a sequence of separate elements `0`, `x`, `f32`. 1335 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 1336 // We can get here only if the token is an integer literal. Hexadecimal 1337 // integer literals can only start with `0x` (`1x` wouldn't lex as a 1338 // literal, just `1` would, at which point we don't get into this 1339 // branch). 1340 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 1341 dimensions.push_back(0); 1342 state.lex.resetPointer(getTokenSpelling().data() + 1); 1343 consumeToken(); 1344 } else { 1345 // Make sure this integer value is in bound and valid. 1346 auto dimension = getToken().getUnsignedIntegerValue(); 1347 if (!dimension.hasValue()) 1348 return emitError("invalid dimension"); 1349 dimensions.push_back((int64_t)dimension.getValue()); 1350 consumeToken(Token::integer); 1351 } 1352 } 1353 1354 // Make sure we have an 'x' or something like 'xbf32'. 1355 if (parseXInDimensionList()) 1356 return failure(); 1357 } 1358 1359 return success(); 1360 } 1361 1362 /// Parse an 'x' token in a dimension list, handling the case where the x is 1363 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 1364 /// token. 1365 ParseResult Parser::parseXInDimensionList() { 1366 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 1367 return emitError("expected 'x' in dimension list"); 1368 1369 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 1370 if (getTokenSpelling().size() != 1) 1371 state.lex.resetPointer(getTokenSpelling().data() + 1); 1372 1373 // Consume the 'x'. 1374 consumeToken(Token::bare_identifier); 1375 1376 return success(); 1377 } 1378 1379 // Parse a comma-separated list of dimensions, possibly empty: 1380 // stride-list ::= `[` (dimension (`,` dimension)*)? `]` 1381 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { 1382 if (!consumeIf(Token::l_square)) 1383 return failure(); 1384 // Empty list early exit. 1385 if (consumeIf(Token::r_square)) 1386 return success(); 1387 while (true) { 1388 if (consumeIf(Token::question)) { 1389 dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); 1390 } else { 1391 // This must be an integer value. 1392 int64_t val; 1393 if (getToken().getSpelling().getAsInteger(10, val)) 1394 return emitError("invalid integer value: ") << getToken().getSpelling(); 1395 // Make sure it is not the one value for `?`. 1396 if (ShapedType::isDynamic(val)) 1397 return emitError("invalid integer value: ") 1398 << getToken().getSpelling() 1399 << ", use `?` to specify a dynamic dimension"; 1400 dimensions.push_back(val); 1401 consumeToken(Token::integer); 1402 } 1403 if (!consumeIf(Token::comma)) 1404 break; 1405 } 1406 if (!consumeIf(Token::r_square)) 1407 return failure(); 1408 return success(); 1409 } 1410 1411 //===----------------------------------------------------------------------===// 1412 // Attribute parsing. 1413 //===----------------------------------------------------------------------===// 1414 1415 /// Return the symbol reference referred to by the given token, that is known to 1416 /// be an @-identifier. 1417 static std::string extractSymbolReference(Token tok) { 1418 assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); 1419 StringRef nameStr = tok.getSpelling().drop_front(); 1420 1421 // Check to see if the reference is a string literal, or a bare identifier. 1422 if (nameStr.front() == '"') 1423 return tok.getStringValue(); 1424 return nameStr; 1425 } 1426 1427 /// Parse an arbitrary attribute. 1428 /// 1429 /// attribute-value ::= `unit` 1430 /// | bool-literal 1431 /// | integer-literal (`:` (index-type | integer-type))? 1432 /// | float-literal (`:` float-type)? 1433 /// | string-literal (`:` type)? 1434 /// | type 1435 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 1436 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 1437 /// | symbol-ref-id (`::` symbol-ref-id)* 1438 /// | `dense` `<` attribute-value `>` `:` 1439 /// (tensor-type | vector-type) 1440 /// | `sparse` `<` attribute-value `,` attribute-value `>` 1441 /// `:` (tensor-type | vector-type) 1442 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 1443 /// `>` `:` (tensor-type | vector-type) 1444 /// | extended-attribute 1445 /// 1446 Attribute Parser::parseAttribute(Type type) { 1447 switch (getToken().getKind()) { 1448 // Parse an AffineMap or IntegerSet attribute. 1449 case Token::l_paren: { 1450 // Try to parse an affine map or an integer set reference. 1451 AffineMap map; 1452 IntegerSet set; 1453 if (parseAffineMapOrIntegerSetReference(map, set)) 1454 return nullptr; 1455 if (map) 1456 return AffineMapAttr::get(map); 1457 assert(set); 1458 return IntegerSetAttr::get(set); 1459 } 1460 1461 // Parse an array attribute. 1462 case Token::l_square: { 1463 consumeToken(Token::l_square); 1464 1465 SmallVector<Attribute, 4> elements; 1466 auto parseElt = [&]() -> ParseResult { 1467 elements.push_back(parseAttribute()); 1468 return elements.back() ? success() : failure(); 1469 }; 1470 1471 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 1472 return nullptr; 1473 return builder.getArrayAttr(elements); 1474 } 1475 1476 // Parse a boolean attribute. 1477 case Token::kw_false: 1478 consumeToken(Token::kw_false); 1479 return builder.getBoolAttr(false); 1480 case Token::kw_true: 1481 consumeToken(Token::kw_true); 1482 return builder.getBoolAttr(true); 1483 1484 // Parse a dense elements attribute. 1485 case Token::kw_dense: 1486 return parseDenseElementsAttr(); 1487 1488 // Parse a dictionary attribute. 1489 case Token::l_brace: { 1490 SmallVector<NamedAttribute, 4> elements; 1491 if (parseAttributeDict(elements)) 1492 return nullptr; 1493 return builder.getDictionaryAttr(elements); 1494 } 1495 1496 // Parse an extended attribute, i.e. alias or dialect attribute. 1497 case Token::hash_identifier: 1498 return parseExtendedAttr(type); 1499 1500 // Parse floating point and integer attributes. 1501 case Token::floatliteral: 1502 return parseFloatAttr(type, /*isNegative=*/false); 1503 case Token::integer: 1504 return parseDecOrHexAttr(type, /*isNegative=*/false); 1505 case Token::minus: { 1506 consumeToken(Token::minus); 1507 if (getToken().is(Token::integer)) 1508 return parseDecOrHexAttr(type, /*isNegative=*/true); 1509 if (getToken().is(Token::floatliteral)) 1510 return parseFloatAttr(type, /*isNegative=*/true); 1511 1512 return (emitError("expected constant integer or floating point value"), 1513 nullptr); 1514 } 1515 1516 // Parse a location attribute. 1517 case Token::kw_loc: { 1518 LocationAttr attr; 1519 return failed(parseLocation(attr)) ? Attribute() : attr; 1520 } 1521 1522 // Parse an opaque elements attribute. 1523 case Token::kw_opaque: 1524 return parseOpaqueElementsAttr(); 1525 1526 // Parse a sparse elements attribute. 1527 case Token::kw_sparse: 1528 return parseSparseElementsAttr(); 1529 1530 // Parse a string attribute. 1531 case Token::string: { 1532 auto val = getToken().getStringValue(); 1533 consumeToken(Token::string); 1534 // Parse the optional trailing colon type if one wasn't explicitly provided. 1535 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1536 return Attribute(); 1537 1538 return type ? StringAttr::get(val, type) 1539 : StringAttr::get(val, getContext()); 1540 } 1541 1542 // Parse a symbol reference attribute. 1543 case Token::at_identifier: { 1544 std::string nameStr = extractSymbolReference(getToken()); 1545 consumeToken(Token::at_identifier); 1546 1547 // Parse any nested references. 1548 std::vector<FlatSymbolRefAttr> nestedRefs; 1549 while (getToken().is(Token::colon)) { 1550 // Check for the '::' prefix. 1551 const char *curPointer = getToken().getLoc().getPointer(); 1552 consumeToken(Token::colon); 1553 if (!consumeIf(Token::colon)) { 1554 state.lex.resetPointer(curPointer); 1555 consumeToken(); 1556 break; 1557 } 1558 // Parse the reference itself. 1559 auto curLoc = getToken().getLoc(); 1560 if (getToken().isNot(Token::at_identifier)) { 1561 emitError(curLoc, "expected nested symbol reference identifier"); 1562 return Attribute(); 1563 } 1564 1565 std::string nameStr = extractSymbolReference(getToken()); 1566 consumeToken(Token::at_identifier); 1567 nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); 1568 } 1569 1570 return builder.getSymbolRefAttr(nameStr, nestedRefs); 1571 } 1572 1573 // Parse a 'unit' attribute. 1574 case Token::kw_unit: 1575 consumeToken(Token::kw_unit); 1576 return builder.getUnitAttr(); 1577 1578 default: 1579 // Parse a type attribute. 1580 if (Type type = parseType()) 1581 return TypeAttr::get(type); 1582 return nullptr; 1583 } 1584 } 1585 1586 /// Attribute dictionary. 1587 /// 1588 /// attribute-dict ::= `{` `}` 1589 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1590 /// attribute-entry ::= bare-id `=` attribute-value 1591 /// 1592 ParseResult 1593 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1594 if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) 1595 return failure(); 1596 1597 auto parseElt = [&]() -> ParseResult { 1598 // We allow keywords as attribute names. 1599 if (getToken().isNot(Token::bare_identifier, Token::inttype) && 1600 !getToken().isKeyword()) 1601 return emitError("expected attribute name"); 1602 Identifier nameId = builder.getIdentifier(getTokenSpelling()); 1603 consumeToken(); 1604 1605 // Try to parse the '=' for the attribute value. 1606 if (!consumeIf(Token::equal)) { 1607 // If there is no '=', we treat this as a unit attribute. 1608 attributes.push_back({nameId, builder.getUnitAttr()}); 1609 return success(); 1610 } 1611 1612 auto attr = parseAttribute(); 1613 if (!attr) 1614 return failure(); 1615 1616 attributes.push_back({nameId, attr}); 1617 return success(); 1618 }; 1619 1620 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1621 return failure(); 1622 1623 return success(); 1624 } 1625 1626 /// Parse an extended attribute. 1627 /// 1628 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1629 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1630 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1631 /// attribute-alias ::= `#` alias-name 1632 /// 1633 Attribute Parser::parseExtendedAttr(Type type) { 1634 Attribute attr = parseExtendedSymbol<Attribute>( 1635 *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, 1636 [&](StringRef dialectName, StringRef symbolData, 1637 llvm::SMLoc loc) -> Attribute { 1638 // Parse an optional trailing colon type. 1639 Type attrType = type; 1640 if (consumeIf(Token::colon) && !(attrType = parseType())) 1641 return Attribute(); 1642 1643 // If we found a registered dialect, then ask it to parse the attribute. 1644 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1645 return parseSymbol<Attribute>( 1646 symbolData, state.context, state.symbols, [&](Parser &parser) { 1647 CustomDialectAsmParser customParser(symbolData, parser); 1648 return dialect->parseAttribute(customParser, attrType); 1649 }); 1650 } 1651 1652 // Otherwise, form a new opaque attribute. 1653 return OpaqueAttr::getChecked( 1654 Identifier::get(dialectName, state.context), symbolData, 1655 attrType ? attrType : NoneType::get(state.context), 1656 getEncodedSourceLocation(loc)); 1657 }); 1658 1659 // Ensure that the attribute has the same type as requested. 1660 if (attr && type && attr.getType() != type) { 1661 emitError("attribute type different than expected: expected ") 1662 << type << ", but got " << attr.getType(); 1663 return nullptr; 1664 } 1665 return attr; 1666 } 1667 1668 /// Parse a float attribute. 1669 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1670 auto val = getToken().getFloatingPointValue(); 1671 if (!val.hasValue()) 1672 return (emitError("floating point value too large for attribute"), nullptr); 1673 consumeToken(Token::floatliteral); 1674 if (!type) { 1675 // Default to F64 when no type is specified. 1676 if (!consumeIf(Token::colon)) 1677 type = builder.getF64Type(); 1678 else if (!(type = parseType())) 1679 return nullptr; 1680 } 1681 if (!type.isa<FloatType>()) 1682 return (emitError("floating point value not valid for specified type"), 1683 nullptr); 1684 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1685 } 1686 1687 /// Construct a float attribute bitwise equivalent to the integer literal. 1688 static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1689 uint64_t value) { 1690 int width = type.getIntOrFloatBitWidth(); 1691 APInt apInt(width, value); 1692 if (apInt != value) { 1693 p->emitError("hexadecimal float constant out of range for type"); 1694 return nullptr; 1695 } 1696 APFloat apFloat(type.getFloatSemantics(), apInt); 1697 return p->builder.getFloatAttr(type, apFloat); 1698 } 1699 1700 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1701 /// or a float attribute. 1702 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1703 auto val = getToken().getUInt64IntegerValue(); 1704 if (!val.hasValue()) 1705 return (emitError("integer constant out of range for attribute"), nullptr); 1706 1707 // Remember if the literal is hexadecimal. 1708 StringRef spelling = getToken().getSpelling(); 1709 auto loc = state.curToken.getLoc(); 1710 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1711 1712 consumeToken(Token::integer); 1713 if (!type) { 1714 // Default to i64 if not type is specified. 1715 if (!consumeIf(Token::colon)) 1716 type = builder.getIntegerType(64); 1717 else if (!(type = parseType())) 1718 return nullptr; 1719 } 1720 1721 if (auto floatType = type.dyn_cast<FloatType>()) { 1722 // TODO(zinenko): Update once hex format for bfloat16 is supported. 1723 if (type.isBF16()) 1724 return emitError(loc, 1725 "hexadecimal float literal not supported for bfloat16"), 1726 nullptr; 1727 if (isNegative) 1728 return emitError( 1729 loc, 1730 "hexadecimal float literal should not have a leading minus"), 1731 nullptr; 1732 if (!isHex) { 1733 emitError(loc, "unexpected decimal integer literal for a float attribute") 1734 .attachNote() 1735 << "add a trailing dot to make the literal a float"; 1736 return nullptr; 1737 } 1738 1739 // Construct a float attribute bitwise equivalent to the integer literal. 1740 return buildHexadecimalFloatLiteral(this, floatType, *val); 1741 } 1742 1743 if (!type.isIntOrIndex()) 1744 return emitError(loc, "integer literal not valid for specified type"), 1745 nullptr; 1746 1747 // Parse the integer literal. 1748 int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); 1749 APInt apInt(width, *val, isNegative); 1750 if (apInt != *val) 1751 return emitError(loc, "integer constant out of range for attribute"), 1752 nullptr; 1753 1754 // Otherwise construct an integer attribute. 1755 if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0) 1756 return emitError(loc, "integer constant out of range for attribute"), 1757 nullptr; 1758 1759 return builder.getIntegerAttr(type, isNegative ? -apInt : apInt); 1760 } 1761 1762 /// Parse an opaque elements attribute. 1763 Attribute Parser::parseOpaqueElementsAttr() { 1764 consumeToken(Token::kw_opaque); 1765 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1766 return nullptr; 1767 1768 if (getToken().isNot(Token::string)) 1769 return (emitError("expected dialect namespace"), nullptr); 1770 1771 auto name = getToken().getStringValue(); 1772 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1773 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1774 // attribute. Otherwise, it can't be roundtripped without having the dialect 1775 // registered. 1776 if (!dialect) 1777 return (emitError("no registered dialect with namespace '" + name + "'"), 1778 nullptr); 1779 1780 consumeToken(Token::string); 1781 if (parseToken(Token::comma, "expected ','")) 1782 return nullptr; 1783 1784 if (getToken().getKind() != Token::string) 1785 return (emitError("opaque string should start with '0x'"), nullptr); 1786 1787 auto val = getToken().getStringValue(); 1788 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1789 return (emitError("opaque string should start with '0x'"), nullptr); 1790 1791 val = val.substr(2); 1792 if (!llvm::all_of(val, llvm::isHexDigit)) 1793 return (emitError("opaque string only contains hex digits"), nullptr); 1794 1795 consumeToken(Token::string); 1796 if (parseToken(Token::greater, "expected '>'") || 1797 parseToken(Token::colon, "expected ':'")) 1798 return nullptr; 1799 1800 auto type = parseElementsLiteralType(); 1801 if (!type) 1802 return nullptr; 1803 1804 return OpaqueElementsAttr::get(dialect, type, llvm::fromHex(val)); 1805 } 1806 1807 namespace { 1808 class TensorLiteralParser { 1809 public: 1810 TensorLiteralParser(Parser &p) : p(p) {} 1811 1812 ParseResult parse() { 1813 if (p.getToken().is(Token::l_square)) 1814 return parseList(shape); 1815 return parseElement(); 1816 } 1817 1818 /// Build a dense attribute instance with the parsed elements and the given 1819 /// shaped type. 1820 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1821 1822 ArrayRef<int64_t> getShape() const { return shape; } 1823 1824 private: 1825 enum class ElementKind { Boolean, Integer, Float }; 1826 1827 /// Return a string to represent the given element kind. 1828 const char *getElementKindStr(ElementKind kind) { 1829 switch (kind) { 1830 case ElementKind::Boolean: 1831 return "'boolean'"; 1832 case ElementKind::Integer: 1833 return "'integer'"; 1834 case ElementKind::Float: 1835 return "'float'"; 1836 } 1837 llvm_unreachable("unknown element kind"); 1838 } 1839 1840 /// Build a Dense Integer attribute for the given type. 1841 DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, 1842 IntegerType eltTy); 1843 1844 /// Build a Dense Float attribute for the given type. 1845 DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, 1846 FloatType eltTy); 1847 1848 /// Parse a single element, returning failure if it isn't a valid element 1849 /// literal. For example: 1850 /// parseElement(1) -> Success, 1 1851 /// parseElement([1]) -> Failure 1852 ParseResult parseElement(); 1853 1854 /// Parse a list of either lists or elements, returning the dimensions of the 1855 /// parsed sub-tensors in dims. For example: 1856 /// parseList([1, 2, 3]) -> Success, [3] 1857 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1858 /// parseList([[1, 2], 3]) -> Failure 1859 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1860 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 1861 1862 Parser &p; 1863 1864 /// The shape inferred from the parsed elements. 1865 SmallVector<int64_t, 4> shape; 1866 1867 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 1868 std::vector<std::pair<bool, Token>> storage; 1869 1870 /// A flag that indicates the type of elements that have been parsed. 1871 Optional<ElementKind> knownEltKind; 1872 }; 1873 } // namespace 1874 1875 /// Build a dense attribute instance with the parsed elements and the given 1876 /// shaped type. 1877 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 1878 ShapedType type) { 1879 // Check that the parsed storage size has the same number of elements to the 1880 // type, or is a known splat. 1881 if (!shape.empty() && getShape() != type.getShape()) { 1882 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 1883 << "]) does not match type ([" << type.getShape() << "])"; 1884 return nullptr; 1885 } 1886 1887 // If the type is an integer, build a set of APInt values from the storage 1888 // with the correct bitwidth. 1889 if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) 1890 return getIntAttr(loc, type, intTy); 1891 1892 // Otherwise, this must be a floating point type. 1893 auto floatTy = type.getElementType().dyn_cast<FloatType>(); 1894 if (!floatTy) { 1895 p.emitError(loc) << "expected floating-point or integer element type, got " 1896 << type.getElementType(); 1897 return nullptr; 1898 } 1899 return getFloatAttr(loc, type, floatTy); 1900 } 1901 1902 /// Build a Dense Integer attribute for the given type. 1903 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, 1904 ShapedType type, 1905 IntegerType eltTy) { 1906 std::vector<APInt> intElements; 1907 intElements.reserve(storage.size()); 1908 for (const auto &signAndToken : storage) { 1909 bool isNegative = signAndToken.first; 1910 const Token &token = signAndToken.second; 1911 1912 // Check to see if floating point values were parsed. 1913 if (token.is(Token::floatliteral)) { 1914 p.emitError() << "expected integer elements, but parsed floating-point"; 1915 return nullptr; 1916 } 1917 1918 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 1919 "unexpected token type"); 1920 if (token.isAny(Token::kw_true, Token::kw_false)) { 1921 if (!eltTy.isInteger(1)) 1922 p.emitError() << "expected i1 type for 'true' or 'false' values"; 1923 APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), 1924 /*isSigned=*/false); 1925 intElements.push_back(apInt); 1926 continue; 1927 } 1928 1929 // Create APInt values for each element with the correct bitwidth. 1930 auto val = token.getUInt64IntegerValue(); 1931 if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 1932 : (int64_t)val.getValue() < 0)) { 1933 p.emitError(token.getLoc(), 1934 "integer constant out of range for attribute"); 1935 return nullptr; 1936 } 1937 APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); 1938 if (apInt != val.getValue()) 1939 return (p.emitError("integer constant out of range for type"), nullptr); 1940 intElements.push_back(isNegative ? -apInt : apInt); 1941 } 1942 1943 return DenseElementsAttr::get(type, intElements); 1944 } 1945 1946 /// Build a Dense Float attribute for the given type. 1947 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, 1948 ShapedType type, 1949 FloatType eltTy) { 1950 std::vector<Attribute> floatValues; 1951 floatValues.reserve(storage.size()); 1952 for (const auto &signAndToken : storage) { 1953 bool isNegative = signAndToken.first; 1954 const Token &token = signAndToken.second; 1955 1956 // Handle hexadecimal float literals. 1957 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 1958 if (isNegative) { 1959 p.emitError(token.getLoc()) 1960 << "hexadecimal float literal should not have a leading minus"; 1961 return nullptr; 1962 } 1963 auto val = token.getUInt64IntegerValue(); 1964 if (!val.hasValue()) { 1965 p.emitError("hexadecimal float constant out of range for attribute"); 1966 return nullptr; 1967 } 1968 FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); 1969 if (!attr) 1970 return nullptr; 1971 floatValues.push_back(attr); 1972 continue; 1973 } 1974 1975 // Check to see if any decimal integers or booleans were parsed. 1976 if (!token.is(Token::floatliteral)) { 1977 p.emitError() << "expected floating-point elements, but parsed integer"; 1978 return nullptr; 1979 } 1980 1981 // Build the float values from tokens. 1982 auto val = token.getFloatingPointValue(); 1983 if (!val.hasValue()) { 1984 p.emitError("floating point value too large for attribute"); 1985 return nullptr; 1986 } 1987 floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); 1988 } 1989 1990 return DenseElementsAttr::get(type, floatValues); 1991 } 1992 1993 ParseResult TensorLiteralParser::parseElement() { 1994 switch (p.getToken().getKind()) { 1995 // Parse a boolean element. 1996 case Token::kw_true: 1997 case Token::kw_false: 1998 case Token::floatliteral: 1999 case Token::integer: 2000 storage.emplace_back(/*isNegative=*/false, p.getToken()); 2001 p.consumeToken(); 2002 break; 2003 2004 // Parse a signed integer or a negative floating-point element. 2005 case Token::minus: 2006 p.consumeToken(Token::minus); 2007 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 2008 return p.emitError("expected integer or floating point literal"); 2009 storage.emplace_back(/*isNegative=*/true, p.getToken()); 2010 p.consumeToken(); 2011 break; 2012 2013 default: 2014 return p.emitError("expected element literal of primitive type"); 2015 } 2016 2017 return success(); 2018 } 2019 2020 /// Parse a list of either lists or elements, returning the dimensions of the 2021 /// parsed sub-tensors in dims. For example: 2022 /// parseList([1, 2, 3]) -> Success, [3] 2023 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 2024 /// parseList([[1, 2], 3]) -> Failure 2025 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 2026 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 2027 p.consumeToken(Token::l_square); 2028 2029 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 2030 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 2031 if (prevDims == newDims) 2032 return success(); 2033 return p.emitError("tensor literal is invalid; ranks are not consistent " 2034 "between elements"); 2035 }; 2036 2037 bool first = true; 2038 SmallVector<int64_t, 4> newDims; 2039 unsigned size = 0; 2040 auto parseCommaSeparatedList = [&]() -> ParseResult { 2041 SmallVector<int64_t, 4> thisDims; 2042 if (p.getToken().getKind() == Token::l_square) { 2043 if (parseList(thisDims)) 2044 return failure(); 2045 } else if (parseElement()) { 2046 return failure(); 2047 } 2048 ++size; 2049 if (!first) 2050 return checkDims(newDims, thisDims); 2051 newDims = thisDims; 2052 first = false; 2053 return success(); 2054 }; 2055 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 2056 return failure(); 2057 2058 // Return the sublists' dimensions with 'size' prepended. 2059 dims.clear(); 2060 dims.push_back(size); 2061 dims.append(newDims.begin(), newDims.end()); 2062 return success(); 2063 } 2064 2065 /// Parse a dense elements attribute. 2066 Attribute Parser::parseDenseElementsAttr() { 2067 consumeToken(Token::kw_dense); 2068 if (parseToken(Token::less, "expected '<' after 'dense'")) 2069 return nullptr; 2070 2071 // Parse the literal data. 2072 TensorLiteralParser literalParser(*this); 2073 if (literalParser.parse()) 2074 return nullptr; 2075 2076 if (parseToken(Token::greater, "expected '>'") || 2077 parseToken(Token::colon, "expected ':'")) 2078 return nullptr; 2079 2080 auto typeLoc = getToken().getLoc(); 2081 auto type = parseElementsLiteralType(); 2082 if (!type) 2083 return nullptr; 2084 return literalParser.getAttr(typeLoc, type); 2085 } 2086 2087 /// Shaped type for elements attribute. 2088 /// 2089 /// elements-literal-type ::= vector-type | ranked-tensor-type 2090 /// 2091 /// This method also checks the type has static shape. 2092 ShapedType Parser::parseElementsLiteralType() { 2093 auto type = parseType(); 2094 if (!type) 2095 return nullptr; 2096 2097 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 2098 emitError("elements literal must be a ranked tensor or vector type"); 2099 return nullptr; 2100 } 2101 2102 auto sType = type.cast<ShapedType>(); 2103 if (!sType.hasStaticShape()) 2104 return (emitError("elements literal type must have static shape"), nullptr); 2105 2106 return sType; 2107 } 2108 2109 /// Parse a sparse elements attribute. 2110 Attribute Parser::parseSparseElementsAttr() { 2111 consumeToken(Token::kw_sparse); 2112 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 2113 return nullptr; 2114 2115 /// Parse indices 2116 auto indicesLoc = getToken().getLoc(); 2117 TensorLiteralParser indiceParser(*this); 2118 if (indiceParser.parse()) 2119 return nullptr; 2120 2121 if (parseToken(Token::comma, "expected ','")) 2122 return nullptr; 2123 2124 /// Parse values. 2125 auto valuesLoc = getToken().getLoc(); 2126 TensorLiteralParser valuesParser(*this); 2127 if (valuesParser.parse()) 2128 return nullptr; 2129 2130 if (parseToken(Token::greater, "expected '>'") || 2131 parseToken(Token::colon, "expected ':'")) 2132 return nullptr; 2133 2134 auto type = parseElementsLiteralType(); 2135 if (!type) 2136 return nullptr; 2137 2138 // If the indices are a splat, i.e. the literal parser parsed an element and 2139 // not a list, we set the shape explicitly. The indices are represented by a 2140 // 2-dimensional shape where the second dimension is the rank of the type. 2141 // Given that the parsed indices is a splat, we know that we only have one 2142 // indice and thus one for the first dimension. 2143 auto indiceEltType = builder.getIntegerType(64); 2144 ShapedType indicesType; 2145 if (indiceParser.getShape().empty()) { 2146 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 2147 } else { 2148 // Otherwise, set the shape to the one parsed by the literal parser. 2149 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 2150 } 2151 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 2152 2153 // If the values are a splat, set the shape explicitly based on the number of 2154 // indices. The number of indices is encoded in the first dimension of the 2155 // indice shape type. 2156 auto valuesEltType = type.getElementType(); 2157 ShapedType valuesType = 2158 valuesParser.getShape().empty() 2159 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 2160 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 2161 auto values = valuesParser.getAttr(valuesLoc, valuesType); 2162 2163 /// Sanity check. 2164 if (valuesType.getRank() != 1) 2165 return (emitError("expected 1-d tensor for values"), nullptr); 2166 2167 auto sameShape = (indicesType.getRank() == 1) || 2168 (type.getRank() == indicesType.getDimSize(1)); 2169 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 2170 if (!sameShape || !sameElementNum) { 2171 emitError() << "expected shape ([" << type.getShape() 2172 << "]); inferred shape of indices literal ([" 2173 << indicesType.getShape() 2174 << "]); inferred shape of values literal ([" 2175 << valuesType.getShape() << "])"; 2176 return nullptr; 2177 } 2178 2179 // Build the sparse elements attribute by the indices and values. 2180 return SparseElementsAttr::get(type, indices, values); 2181 } 2182 2183 //===----------------------------------------------------------------------===// 2184 // Location parsing. 2185 //===----------------------------------------------------------------------===// 2186 2187 /// Parse a location. 2188 /// 2189 /// location ::= `loc` inline-location 2190 /// inline-location ::= '(' location-inst ')' 2191 /// 2192 ParseResult Parser::parseLocation(LocationAttr &loc) { 2193 // Check for 'loc' identifier. 2194 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 2195 return emitError(); 2196 2197 // Parse the inline-location. 2198 if (parseToken(Token::l_paren, "expected '(' in inline location") || 2199 parseLocationInstance(loc) || 2200 parseToken(Token::r_paren, "expected ')' in inline location")) 2201 return failure(); 2202 return success(); 2203 } 2204 2205 /// Specific location instances. 2206 /// 2207 /// location-inst ::= filelinecol-location | 2208 /// name-location | 2209 /// callsite-location | 2210 /// fused-location | 2211 /// unknown-location 2212 /// filelinecol-location ::= string-literal ':' integer-literal 2213 /// ':' integer-literal 2214 /// name-location ::= string-literal 2215 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 2216 /// fused-location ::= fused ('<' attribute-value '>')? 2217 /// '[' location-inst (location-inst ',')* ']' 2218 /// unknown-location ::= 'unknown' 2219 /// 2220 ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { 2221 consumeToken(Token::bare_identifier); 2222 2223 // Parse the '('. 2224 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 2225 return failure(); 2226 2227 // Parse the callee location. 2228 LocationAttr calleeLoc; 2229 if (parseLocationInstance(calleeLoc)) 2230 return failure(); 2231 2232 // Parse the 'at'. 2233 if (getToken().isNot(Token::bare_identifier) || 2234 getToken().getSpelling() != "at") 2235 return emitError("expected 'at' in callsite location"); 2236 consumeToken(Token::bare_identifier); 2237 2238 // Parse the caller location. 2239 LocationAttr callerLoc; 2240 if (parseLocationInstance(callerLoc)) 2241 return failure(); 2242 2243 // Parse the ')'. 2244 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 2245 return failure(); 2246 2247 // Return the callsite location. 2248 loc = CallSiteLoc::get(calleeLoc, callerLoc); 2249 return success(); 2250 } 2251 2252 ParseResult Parser::parseFusedLocation(LocationAttr &loc) { 2253 consumeToken(Token::bare_identifier); 2254 2255 // Try to parse the optional metadata. 2256 Attribute metadata; 2257 if (consumeIf(Token::less)) { 2258 metadata = parseAttribute(); 2259 if (!metadata) 2260 return emitError("expected valid attribute metadata"); 2261 // Parse the '>' token. 2262 if (parseToken(Token::greater, 2263 "expected '>' after fused location metadata")) 2264 return failure(); 2265 } 2266 2267 SmallVector<Location, 4> locations; 2268 auto parseElt = [&] { 2269 LocationAttr newLoc; 2270 if (parseLocationInstance(newLoc)) 2271 return failure(); 2272 locations.push_back(newLoc); 2273 return success(); 2274 }; 2275 2276 if (parseToken(Token::l_square, "expected '[' in fused location") || 2277 parseCommaSeparatedList(parseElt) || 2278 parseToken(Token::r_square, "expected ']' in fused location")) 2279 return failure(); 2280 2281 // Return the fused location. 2282 loc = FusedLoc::get(locations, metadata, getContext()); 2283 return success(); 2284 } 2285 2286 ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { 2287 auto *ctx = getContext(); 2288 auto str = getToken().getStringValue(); 2289 consumeToken(Token::string); 2290 2291 // If the next token is ':' this is a filelinecol location. 2292 if (consumeIf(Token::colon)) { 2293 // Parse the line number. 2294 if (getToken().isNot(Token::integer)) 2295 return emitError("expected integer line number in FileLineColLoc"); 2296 auto line = getToken().getUnsignedIntegerValue(); 2297 if (!line.hasValue()) 2298 return emitError("expected integer line number in FileLineColLoc"); 2299 consumeToken(Token::integer); 2300 2301 // Parse the ':'. 2302 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 2303 return failure(); 2304 2305 // Parse the column number. 2306 if (getToken().isNot(Token::integer)) 2307 return emitError("expected integer column number in FileLineColLoc"); 2308 auto column = getToken().getUnsignedIntegerValue(); 2309 if (!column.hasValue()) 2310 return emitError("expected integer column number in FileLineColLoc"); 2311 consumeToken(Token::integer); 2312 2313 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 2314 return success(); 2315 } 2316 2317 // Otherwise, this is a NameLoc. 2318 2319 // Check for a child location. 2320 if (consumeIf(Token::l_paren)) { 2321 auto childSourceLoc = getToken().getLoc(); 2322 2323 // Parse the child location. 2324 LocationAttr childLoc; 2325 if (parseLocationInstance(childLoc)) 2326 return failure(); 2327 2328 // The child must not be another NameLoc. 2329 if (childLoc.isa<NameLoc>()) 2330 return emitError(childSourceLoc, 2331 "child of NameLoc cannot be another NameLoc"); 2332 loc = NameLoc::get(Identifier::get(str, ctx), childLoc); 2333 2334 // Parse the closing ')'. 2335 if (parseToken(Token::r_paren, 2336 "expected ')' after child location of NameLoc")) 2337 return failure(); 2338 } else { 2339 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 2340 } 2341 2342 return success(); 2343 } 2344 2345 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 2346 // Handle either name or filelinecol locations. 2347 if (getToken().is(Token::string)) 2348 return parseNameOrFileLineColLocation(loc); 2349 2350 // Bare tokens required for other cases. 2351 if (!getToken().is(Token::bare_identifier)) 2352 return emitError("expected location instance"); 2353 2354 // Check for the 'callsite' signifying a callsite location. 2355 if (getToken().getSpelling() == "callsite") 2356 return parseCallSiteLocation(loc); 2357 2358 // If the token is 'fused', then this is a fused location. 2359 if (getToken().getSpelling() == "fused") 2360 return parseFusedLocation(loc); 2361 2362 // Check for a 'unknown' for an unknown location. 2363 if (getToken().getSpelling() == "unknown") { 2364 consumeToken(Token::bare_identifier); 2365 loc = UnknownLoc::get(getContext()); 2366 return success(); 2367 } 2368 2369 return emitError("expected location instance"); 2370 } 2371 2372 //===----------------------------------------------------------------------===// 2373 // Affine parsing. 2374 //===----------------------------------------------------------------------===// 2375 2376 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 2377 /// the boolean sense. 2378 enum AffineLowPrecOp { 2379 /// Null value. 2380 LNoOp, 2381 Add, 2382 Sub 2383 }; 2384 2385 /// Higher precedence ops - all at the same precedence level. HNoOp is false 2386 /// in the boolean sense. 2387 enum AffineHighPrecOp { 2388 /// Null value. 2389 HNoOp, 2390 Mul, 2391 FloorDiv, 2392 CeilDiv, 2393 Mod 2394 }; 2395 2396 namespace { 2397 /// This is a specialized parser for affine structures (affine maps, affine 2398 /// expressions, and integer sets), maintaining the state transient to their 2399 /// bodies. 2400 class AffineParser : public Parser { 2401 public: 2402 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 2403 function_ref<ParseResult(bool)> parseElement = nullptr) 2404 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 2405 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 2406 2407 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 2408 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 2409 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 2410 ParseResult parseAffineMapOfSSAIds(AffineMap &map); 2411 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 2412 unsigned &numDims); 2413 2414 private: 2415 // Binary affine op parsing. 2416 AffineLowPrecOp consumeIfLowPrecOp(); 2417 AffineHighPrecOp consumeIfHighPrecOp(); 2418 2419 // Identifier lists for polyhedral structures. 2420 ParseResult parseDimIdList(unsigned &numDims); 2421 ParseResult parseSymbolIdList(unsigned &numSymbols); 2422 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 2423 unsigned &numSymbols); 2424 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 2425 2426 AffineExpr parseAffineExpr(); 2427 AffineExpr parseParentheticalExpr(); 2428 AffineExpr parseNegateExpression(AffineExpr lhs); 2429 AffineExpr parseIntegerExpr(); 2430 AffineExpr parseBareIdExpr(); 2431 AffineExpr parseSSAIdExpr(bool isSymbol); 2432 AffineExpr parseSymbolSSAIdExpr(); 2433 2434 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 2435 AffineExpr rhs, SMLoc opLoc); 2436 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 2437 AffineExpr rhs); 2438 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 2439 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 2440 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 2441 SMLoc llhsOpLoc); 2442 AffineExpr parseAffineConstraint(bool *isEq); 2443 2444 private: 2445 bool allowParsingSSAIds; 2446 function_ref<ParseResult(bool)> parseElement; 2447 unsigned numDimOperands; 2448 unsigned numSymbolOperands; 2449 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 2450 }; 2451 } // end anonymous namespace 2452 2453 /// Create an affine binary high precedence op expression (mul's, div's, mod). 2454 /// opLoc is the location of the op token to be used to report errors 2455 /// for non-conforming expressions. 2456 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 2457 AffineExpr lhs, AffineExpr rhs, 2458 SMLoc opLoc) { 2459 // TODO: make the error location info accurate. 2460 switch (op) { 2461 case Mul: 2462 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 2463 emitError(opLoc, "non-affine expression: at least one of the multiply " 2464 "operands has to be either a constant or symbolic"); 2465 return nullptr; 2466 } 2467 return lhs * rhs; 2468 case FloorDiv: 2469 if (!rhs.isSymbolicOrConstant()) { 2470 emitError(opLoc, "non-affine expression: right operand of floordiv " 2471 "has to be either a constant or symbolic"); 2472 return nullptr; 2473 } 2474 return lhs.floorDiv(rhs); 2475 case CeilDiv: 2476 if (!rhs.isSymbolicOrConstant()) { 2477 emitError(opLoc, "non-affine expression: right operand of ceildiv " 2478 "has to be either a constant or symbolic"); 2479 return nullptr; 2480 } 2481 return lhs.ceilDiv(rhs); 2482 case Mod: 2483 if (!rhs.isSymbolicOrConstant()) { 2484 emitError(opLoc, "non-affine expression: right operand of mod " 2485 "has to be either a constant or symbolic"); 2486 return nullptr; 2487 } 2488 return lhs % rhs; 2489 case HNoOp: 2490 llvm_unreachable("can't create affine expression for null high prec op"); 2491 return nullptr; 2492 } 2493 llvm_unreachable("Unknown AffineHighPrecOp"); 2494 } 2495 2496 /// Create an affine binary low precedence op expression (add, sub). 2497 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 2498 AffineExpr lhs, AffineExpr rhs) { 2499 switch (op) { 2500 case AffineLowPrecOp::Add: 2501 return lhs + rhs; 2502 case AffineLowPrecOp::Sub: 2503 return lhs - rhs; 2504 case AffineLowPrecOp::LNoOp: 2505 llvm_unreachable("can't create affine expression for null low prec op"); 2506 return nullptr; 2507 } 2508 llvm_unreachable("Unknown AffineLowPrecOp"); 2509 } 2510 2511 /// Consume this token if it is a lower precedence affine op (there are only 2512 /// two precedence levels). 2513 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 2514 switch (getToken().getKind()) { 2515 case Token::plus: 2516 consumeToken(Token::plus); 2517 return AffineLowPrecOp::Add; 2518 case Token::minus: 2519 consumeToken(Token::minus); 2520 return AffineLowPrecOp::Sub; 2521 default: 2522 return AffineLowPrecOp::LNoOp; 2523 } 2524 } 2525 2526 /// Consume this token if it is a higher precedence affine op (there are only 2527 /// two precedence levels) 2528 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 2529 switch (getToken().getKind()) { 2530 case Token::star: 2531 consumeToken(Token::star); 2532 return Mul; 2533 case Token::kw_floordiv: 2534 consumeToken(Token::kw_floordiv); 2535 return FloorDiv; 2536 case Token::kw_ceildiv: 2537 consumeToken(Token::kw_ceildiv); 2538 return CeilDiv; 2539 case Token::kw_mod: 2540 consumeToken(Token::kw_mod); 2541 return Mod; 2542 default: 2543 return HNoOp; 2544 } 2545 } 2546 2547 /// Parse a high precedence op expression list: mul, div, and mod are high 2548 /// precedence binary ops, i.e., parse a 2549 /// expr_1 op_1 expr_2 op_2 ... expr_n 2550 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2551 /// All affine binary ops are left associative. 2552 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2553 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2554 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2555 /// report an error for non-conforming expressions. 2556 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2557 AffineHighPrecOp llhsOp, 2558 SMLoc llhsOpLoc) { 2559 AffineExpr lhs = parseAffineOperandExpr(llhs); 2560 if (!lhs) 2561 return nullptr; 2562 2563 // Found an LHS. Parse the remaining expression. 2564 auto opLoc = getToken().getLoc(); 2565 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2566 if (llhs) { 2567 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2568 if (!expr) 2569 return nullptr; 2570 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2571 } 2572 // No LLHS, get RHS 2573 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2574 } 2575 2576 // This is the last operand in this expression. 2577 if (llhs) 2578 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2579 2580 // No llhs, 'lhs' itself is the expression. 2581 return lhs; 2582 } 2583 2584 /// Parse an affine expression inside parentheses. 2585 /// 2586 /// affine-expr ::= `(` affine-expr `)` 2587 AffineExpr AffineParser::parseParentheticalExpr() { 2588 if (parseToken(Token::l_paren, "expected '('")) 2589 return nullptr; 2590 if (getToken().is(Token::r_paren)) 2591 return (emitError("no expression inside parentheses"), nullptr); 2592 2593 auto expr = parseAffineExpr(); 2594 if (!expr) 2595 return nullptr; 2596 if (parseToken(Token::r_paren, "expected ')'")) 2597 return nullptr; 2598 2599 return expr; 2600 } 2601 2602 /// Parse the negation expression. 2603 /// 2604 /// affine-expr ::= `-` affine-expr 2605 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2606 if (parseToken(Token::minus, "expected '-'")) 2607 return nullptr; 2608 2609 AffineExpr operand = parseAffineOperandExpr(lhs); 2610 // Since negation has the highest precedence of all ops (including high 2611 // precedence ops) but lower than parentheses, we are only going to use 2612 // parseAffineOperandExpr instead of parseAffineExpr here. 2613 if (!operand) 2614 // Extra error message although parseAffineOperandExpr would have 2615 // complained. Leads to a better diagnostic. 2616 return (emitError("missing operand of negation"), nullptr); 2617 return (-1) * operand; 2618 } 2619 2620 /// Parse a bare id that may appear in an affine expression. 2621 /// 2622 /// affine-expr ::= bare-id 2623 AffineExpr AffineParser::parseBareIdExpr() { 2624 if (getToken().isNot(Token::bare_identifier)) 2625 return (emitError("expected bare identifier"), nullptr); 2626 2627 StringRef sRef = getTokenSpelling(); 2628 for (auto entry : dimsAndSymbols) { 2629 if (entry.first == sRef) { 2630 consumeToken(Token::bare_identifier); 2631 return entry.second; 2632 } 2633 } 2634 2635 return (emitError("use of undeclared identifier"), nullptr); 2636 } 2637 2638 /// Parse an SSA id which may appear in an affine expression. 2639 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2640 if (!allowParsingSSAIds) 2641 return (emitError("unexpected ssa identifier"), nullptr); 2642 if (getToken().isNot(Token::percent_identifier)) 2643 return (emitError("expected ssa identifier"), nullptr); 2644 auto name = getTokenSpelling(); 2645 // Check if we already parsed this SSA id. 2646 for (auto entry : dimsAndSymbols) { 2647 if (entry.first == name) { 2648 consumeToken(Token::percent_identifier); 2649 return entry.second; 2650 } 2651 } 2652 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2653 if (parseElement(isSymbol)) 2654 return (emitError("failed to parse ssa identifier"), nullptr); 2655 auto idExpr = isSymbol 2656 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2657 : getAffineDimExpr(numDimOperands++, getContext()); 2658 dimsAndSymbols.push_back({name, idExpr}); 2659 return idExpr; 2660 } 2661 2662 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2663 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2664 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2665 return nullptr; 2666 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2667 if (!symbolExpr) 2668 return nullptr; 2669 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2670 return nullptr; 2671 return symbolExpr; 2672 } 2673 2674 /// Parse a positive integral constant appearing in an affine expression. 2675 /// 2676 /// affine-expr ::= integer-literal 2677 AffineExpr AffineParser::parseIntegerExpr() { 2678 auto val = getToken().getUInt64IntegerValue(); 2679 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2680 return (emitError("constant too large for index"), nullptr); 2681 2682 consumeToken(Token::integer); 2683 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2684 } 2685 2686 /// Parses an expression that can be a valid operand of an affine expression. 2687 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2688 /// operator, the rhs of which is being parsed. This is used to determine 2689 /// whether an error should be emitted for a missing right operand. 2690 // Eg: for an expression without parentheses (like i + j + k + l), each 2691 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2692 // operand expression, it's an op expression and will be parsed via 2693 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2694 // -l are valid operands that will be parsed by this function. 2695 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2696 switch (getToken().getKind()) { 2697 case Token::bare_identifier: 2698 return parseBareIdExpr(); 2699 case Token::kw_symbol: 2700 return parseSymbolSSAIdExpr(); 2701 case Token::percent_identifier: 2702 return parseSSAIdExpr(/*isSymbol=*/false); 2703 case Token::integer: 2704 return parseIntegerExpr(); 2705 case Token::l_paren: 2706 return parseParentheticalExpr(); 2707 case Token::minus: 2708 return parseNegateExpression(lhs); 2709 case Token::kw_ceildiv: 2710 case Token::kw_floordiv: 2711 case Token::kw_mod: 2712 case Token::plus: 2713 case Token::star: 2714 if (lhs) 2715 emitError("missing right operand of binary operator"); 2716 else 2717 emitError("missing left operand of binary operator"); 2718 return nullptr; 2719 default: 2720 if (lhs) 2721 emitError("missing right operand of binary operator"); 2722 else 2723 emitError("expected affine expression"); 2724 return nullptr; 2725 } 2726 } 2727 2728 /// Parse affine expressions that are bare-id's, integer constants, 2729 /// parenthetical affine expressions, and affine op expressions that are a 2730 /// composition of those. 2731 /// 2732 /// All binary op's associate from left to right. 2733 /// 2734 /// {add, sub} have lower precedence than {mul, div, and mod}. 2735 /// 2736 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2737 /// ceildiv, and mod are at the same higher precedence level. Negation has 2738 /// higher precedence than any binary op. 2739 /// 2740 /// llhs: the affine expression appearing on the left of the one being parsed. 2741 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2742 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2743 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2744 /// associativity. 2745 /// 2746 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2747 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2748 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2749 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2750 AffineLowPrecOp llhsOp) { 2751 AffineExpr lhs; 2752 if (!(lhs = parseAffineOperandExpr(llhs))) 2753 return nullptr; 2754 2755 // Found an LHS. Deal with the ops. 2756 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 2757 if (llhs) { 2758 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2759 return parseAffineLowPrecOpExpr(sum, lOp); 2760 } 2761 // No LLHS, get RHS and form the expression. 2762 return parseAffineLowPrecOpExpr(lhs, lOp); 2763 } 2764 auto opLoc = getToken().getLoc(); 2765 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 2766 // We have a higher precedence op here. Get the rhs operand for the llhs 2767 // through parseAffineHighPrecOpExpr. 2768 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 2769 if (!highRes) 2770 return nullptr; 2771 2772 // If llhs is null, the product forms the first operand of the yet to be 2773 // found expression. If non-null, the op to associate with llhs is llhsOp. 2774 AffineExpr expr = 2775 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 2776 2777 // Recurse for subsequent low prec op's after the affine high prec op 2778 // expression. 2779 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 2780 return parseAffineLowPrecOpExpr(expr, nextOp); 2781 return expr; 2782 } 2783 // Last operand in the expression list. 2784 if (llhs) 2785 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2786 // No llhs, 'lhs' itself is the expression. 2787 return lhs; 2788 } 2789 2790 /// Parse an affine expression. 2791 /// affine-expr ::= `(` affine-expr `)` 2792 /// | `-` affine-expr 2793 /// | affine-expr `+` affine-expr 2794 /// | affine-expr `-` affine-expr 2795 /// | affine-expr `*` affine-expr 2796 /// | affine-expr `floordiv` affine-expr 2797 /// | affine-expr `ceildiv` affine-expr 2798 /// | affine-expr `mod` affine-expr 2799 /// | bare-id 2800 /// | integer-literal 2801 /// 2802 /// Additional conditions are checked depending on the production. For eg., 2803 /// one of the operands for `*` has to be either constant/symbolic; the second 2804 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 2805 AffineExpr AffineParser::parseAffineExpr() { 2806 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 2807 } 2808 2809 /// Parse a dim or symbol from the lists appearing before the actual 2810 /// expressions of the affine map. Update our state to store the 2811 /// dimensional/symbolic identifier. 2812 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 2813 if (getToken().isNot(Token::bare_identifier)) 2814 return emitError("expected bare identifier"); 2815 2816 auto name = getTokenSpelling(); 2817 for (auto entry : dimsAndSymbols) { 2818 if (entry.first == name) 2819 return emitError("redefinition of identifier '" + name + "'"); 2820 } 2821 consumeToken(Token::bare_identifier); 2822 2823 dimsAndSymbols.push_back({name, idExpr}); 2824 return success(); 2825 } 2826 2827 /// Parse the list of dimensional identifiers to an affine map. 2828 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 2829 if (parseToken(Token::l_paren, 2830 "expected '(' at start of dimensional identifiers list")) { 2831 return failure(); 2832 } 2833 2834 auto parseElt = [&]() -> ParseResult { 2835 auto dimension = getAffineDimExpr(numDims++, getContext()); 2836 return parseIdentifierDefinition(dimension); 2837 }; 2838 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 2839 } 2840 2841 /// Parse the list of symbolic identifiers to an affine map. 2842 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 2843 consumeToken(Token::l_square); 2844 auto parseElt = [&]() -> ParseResult { 2845 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 2846 return parseIdentifierDefinition(symbol); 2847 }; 2848 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 2849 } 2850 2851 /// Parse the list of symbolic identifiers to an affine map. 2852 ParseResult 2853 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 2854 unsigned &numSymbols) { 2855 if (parseDimIdList(numDims)) { 2856 return failure(); 2857 } 2858 if (!getToken().is(Token::l_square)) { 2859 numSymbols = 0; 2860 return success(); 2861 } 2862 return parseSymbolIdList(numSymbols); 2863 } 2864 2865 /// Parses an ambiguous affine map or integer set definition inline. 2866 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 2867 IntegerSet &set) { 2868 unsigned numDims = 0, numSymbols = 0; 2869 2870 // List of dimensional and optional symbol identifiers. 2871 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 2872 return failure(); 2873 } 2874 2875 // This is needed for parsing attributes as we wouldn't know whether we would 2876 // be parsing an integer set attribute or an affine map attribute. 2877 bool isArrow = getToken().is(Token::arrow); 2878 bool isColon = getToken().is(Token::colon); 2879 if (!isArrow && !isColon) { 2880 return emitError("expected '->' or ':'"); 2881 } else if (isArrow) { 2882 parseToken(Token::arrow, "expected '->' or '['"); 2883 map = parseAffineMapRange(numDims, numSymbols); 2884 return map ? success() : failure(); 2885 } else if (parseToken(Token::colon, "expected ':' or '['")) { 2886 return failure(); 2887 } 2888 2889 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 2890 return success(); 2891 2892 return failure(); 2893 } 2894 2895 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 2896 ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) { 2897 if (parseToken(Token::l_square, "expected '['")) 2898 return failure(); 2899 2900 SmallVector<AffineExpr, 4> exprs; 2901 auto parseElt = [&]() -> ParseResult { 2902 auto elt = parseAffineExpr(); 2903 exprs.push_back(elt); 2904 return elt ? success() : failure(); 2905 }; 2906 2907 // Parse a multi-dimensional affine expression (a comma-separated list of 2908 // 1-d affine expressions); the list cannot be empty. Grammar: 2909 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2910 if (parseCommaSeparatedListUntil(Token::r_square, parseElt, 2911 /*allowEmptyList=*/true)) 2912 return failure(); 2913 // Parsed a valid affine map. 2914 if (exprs.empty()) 2915 map = AffineMap::get(getContext()); 2916 else 2917 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 2918 exprs); 2919 return success(); 2920 } 2921 2922 /// Parse the range and sizes affine map definition inline. 2923 /// 2924 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 2925 /// 2926 /// multi-dim-affine-expr ::= `(` `)` 2927 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` 2928 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 2929 unsigned numSymbols) { 2930 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 2931 2932 SmallVector<AffineExpr, 4> exprs; 2933 auto parseElt = [&]() -> ParseResult { 2934 auto elt = parseAffineExpr(); 2935 ParseResult res = elt ? success() : failure(); 2936 exprs.push_back(elt); 2937 return res; 2938 }; 2939 2940 // Parse a multi-dimensional affine expression (a comma-separated list of 2941 // 1-d affine expressions); the list cannot be empty. Grammar: 2942 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2943 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 2944 return AffineMap(); 2945 2946 if (exprs.empty()) 2947 return AffineMap::get(getContext()); 2948 2949 // Parsed a valid affine map. 2950 return AffineMap::get(numDims, numSymbols, exprs); 2951 } 2952 2953 /// Parse an affine constraint. 2954 /// affine-constraint ::= affine-expr `>=` `0` 2955 /// | affine-expr `==` `0` 2956 /// 2957 /// isEq is set to true if the parsed constraint is an equality, false if it 2958 /// is an inequality (greater than or equal). 2959 /// 2960 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 2961 AffineExpr expr = parseAffineExpr(); 2962 if (!expr) 2963 return nullptr; 2964 2965 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 2966 getToken().is(Token::integer)) { 2967 auto dim = getToken().getUnsignedIntegerValue(); 2968 if (dim.hasValue() && dim.getValue() == 0) { 2969 consumeToken(Token::integer); 2970 *isEq = false; 2971 return expr; 2972 } 2973 return (emitError("expected '0' after '>='"), nullptr); 2974 } 2975 2976 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 2977 getToken().is(Token::integer)) { 2978 auto dim = getToken().getUnsignedIntegerValue(); 2979 if (dim.hasValue() && dim.getValue() == 0) { 2980 consumeToken(Token::integer); 2981 *isEq = true; 2982 return expr; 2983 } 2984 return (emitError("expected '0' after '=='"), nullptr); 2985 } 2986 2987 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 2988 nullptr); 2989 } 2990 2991 /// Parse the constraints that are part of an integer set definition. 2992 /// integer-set-inline 2993 /// ::= dim-and-symbol-id-lists `:` 2994 /// '(' affine-constraint-conjunction? ')' 2995 /// affine-constraint-conjunction ::= affine-constraint (`,` 2996 /// affine-constraint)* 2997 /// 2998 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 2999 unsigned numSymbols) { 3000 if (parseToken(Token::l_paren, 3001 "expected '(' at start of integer set constraint list")) 3002 return IntegerSet(); 3003 3004 SmallVector<AffineExpr, 4> constraints; 3005 SmallVector<bool, 4> isEqs; 3006 auto parseElt = [&]() -> ParseResult { 3007 bool isEq; 3008 auto elt = parseAffineConstraint(&isEq); 3009 ParseResult res = elt ? success() : failure(); 3010 if (elt) { 3011 constraints.push_back(elt); 3012 isEqs.push_back(isEq); 3013 } 3014 return res; 3015 }; 3016 3017 // Parse a list of affine constraints (comma-separated). 3018 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3019 return IntegerSet(); 3020 3021 // If no constraints were parsed, then treat this as a degenerate 'true' case. 3022 if (constraints.empty()) { 3023 /* 0 == 0 */ 3024 auto zero = getAffineConstantExpr(0, getContext()); 3025 return IntegerSet::get(numDims, numSymbols, zero, true); 3026 } 3027 3028 // Parsed a valid integer set. 3029 return IntegerSet::get(numDims, numSymbols, constraints, isEqs); 3030 } 3031 3032 /// Parse an ambiguous reference to either and affine map or an integer set. 3033 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 3034 IntegerSet &set) { 3035 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 3036 } 3037 3038 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 3039 /// parse SSA value uses encountered while parsing affine expressions. 3040 ParseResult 3041 Parser::parseAffineMapOfSSAIds(AffineMap &map, 3042 function_ref<ParseResult(bool)> parseElement) { 3043 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 3044 .parseAffineMapOfSSAIds(map); 3045 } 3046 3047 //===----------------------------------------------------------------------===// 3048 // OperationParser 3049 //===----------------------------------------------------------------------===// 3050 3051 namespace { 3052 /// This class provides support for parsing operations and regions of 3053 /// operations. 3054 class OperationParser : public Parser { 3055 public: 3056 OperationParser(ParserState &state, ModuleOp moduleOp) 3057 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 3058 } 3059 3060 ~OperationParser(); 3061 3062 /// After parsing is finished, this function must be called to see if there 3063 /// are any remaining issues. 3064 ParseResult finalize(); 3065 3066 //===--------------------------------------------------------------------===// 3067 // SSA Value Handling 3068 //===--------------------------------------------------------------------===// 3069 3070 /// This represents a use of an SSA value in the program. The first two 3071 /// entries in the tuple are the name and result number of a reference. The 3072 /// third is the location of the reference, which is used in case this ends 3073 /// up being a use of an undefined value. 3074 struct SSAUseInfo { 3075 StringRef name; // Value name, e.g. %42 or %abc 3076 unsigned number; // Number, specified with #12 3077 SMLoc loc; // Location of first definition or use. 3078 }; 3079 3080 /// Push a new SSA name scope to the parser. 3081 void pushSSANameScope(bool isIsolated); 3082 3083 /// Pop the last SSA name scope from the parser. 3084 ParseResult popSSANameScope(); 3085 3086 /// Register a definition of a value with the symbol table. 3087 ParseResult addDefinition(SSAUseInfo useInfo, Value value); 3088 3089 /// Parse an optional list of SSA uses into 'results'. 3090 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 3091 3092 /// Parse a single SSA use into 'result'. 3093 ParseResult parseSSAUse(SSAUseInfo &result); 3094 3095 /// Given a reference to an SSA value and its type, return a reference. This 3096 /// returns null on failure. 3097 Value resolveSSAUse(SSAUseInfo useInfo, Type type); 3098 3099 ParseResult parseSSADefOrUseAndType( 3100 const std::function<ParseResult(SSAUseInfo, Type)> &action); 3101 3102 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value> &results); 3103 3104 /// Return the location of the value identified by its name and number if it 3105 /// has been already reference. 3106 Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 3107 auto &values = isolatedNameScopes.back().values; 3108 if (!values.count(name) || number >= values[name].size()) 3109 return {}; 3110 if (values[name][number].first) 3111 return values[name][number].second; 3112 return {}; 3113 } 3114 3115 //===--------------------------------------------------------------------===// 3116 // Operation Parsing 3117 //===--------------------------------------------------------------------===// 3118 3119 /// Parse an operation instance. 3120 ParseResult parseOperation(); 3121 3122 /// Parse a single operation successor and its operand list. 3123 ParseResult parseSuccessorAndUseList(Block *&dest, 3124 SmallVectorImpl<Value> &operands); 3125 3126 /// Parse a comma-separated list of operation successors in brackets. 3127 ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations, 3128 SmallVectorImpl<SmallVector<Value, 4>> &operands); 3129 3130 /// Parse an operation instance that is in the generic form. 3131 Operation *parseGenericOperation(); 3132 3133 /// Parse an operation instance that is in the generic form and insert it at 3134 /// the provided insertion point. 3135 Operation *parseGenericOperation(Block *insertBlock, 3136 Block::iterator insertPt); 3137 3138 /// Parse an operation instance that is in the op-defined custom form. 3139 Operation *parseCustomOperation(); 3140 3141 //===--------------------------------------------------------------------===// 3142 // Region Parsing 3143 //===--------------------------------------------------------------------===// 3144 3145 /// Parse a region into 'region' with the provided entry block arguments. 3146 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 3147 /// isolated from those above. 3148 ParseResult parseRegion(Region ®ion, 3149 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 3150 bool isIsolatedNameScope = false); 3151 3152 /// Parse a region body into 'region'. 3153 ParseResult parseRegionBody(Region ®ion); 3154 3155 //===--------------------------------------------------------------------===// 3156 // Block Parsing 3157 //===--------------------------------------------------------------------===// 3158 3159 /// Parse a new block into 'block'. 3160 ParseResult parseBlock(Block *&block); 3161 3162 /// Parse a list of operations into 'block'. 3163 ParseResult parseBlockBody(Block *block); 3164 3165 /// Parse a (possibly empty) list of block arguments. 3166 ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results, 3167 Block *owner); 3168 3169 /// Get the block with the specified name, creating it if it doesn't 3170 /// already exist. The location specified is the point of use, which allows 3171 /// us to diagnose references to blocks that are not defined precisely. 3172 Block *getBlockNamed(StringRef name, SMLoc loc); 3173 3174 /// Define the block with the specified name. Returns the Block* or nullptr in 3175 /// the case of redefinition. 3176 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 3177 3178 private: 3179 /// Returns the info for a block at the current scope for the given name. 3180 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 3181 return blocksByName.back()[name]; 3182 } 3183 3184 /// Insert a new forward reference to the given block. 3185 void insertForwardRef(Block *block, SMLoc loc) { 3186 forwardRef.back().try_emplace(block, loc); 3187 } 3188 3189 /// Erase any forward reference to the given block. 3190 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 3191 3192 /// Record that a definition was added at the current scope. 3193 void recordDefinition(StringRef def); 3194 3195 /// Get the value entry for the given SSA name. 3196 SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name); 3197 3198 /// Create a forward reference placeholder value with the given location and 3199 /// result type. 3200 Value createForwardRefPlaceholder(SMLoc loc, Type type); 3201 3202 /// Return true if this is a forward reference. 3203 bool isForwardRefPlaceholder(Value value) { 3204 return forwardRefPlaceholders.count(value); 3205 } 3206 3207 /// This struct represents an isolated SSA name scope. This scope may contain 3208 /// other nested non-isolated scopes. These scopes are used for operations 3209 /// that are known to be isolated to allow for reusing names within their 3210 /// regions, even if those names are used above. 3211 struct IsolatedSSANameScope { 3212 /// Record that a definition was added at the current scope. 3213 void recordDefinition(StringRef def) { 3214 definitionsPerScope.back().insert(def); 3215 } 3216 3217 /// Push a nested name scope. 3218 void pushSSANameScope() { definitionsPerScope.push_back({}); } 3219 3220 /// Pop a nested name scope. 3221 void popSSANameScope() { 3222 for (auto &def : definitionsPerScope.pop_back_val()) 3223 values.erase(def.getKey()); 3224 } 3225 3226 /// This keeps track of all of the SSA values we are tracking for each name 3227 /// scope, indexed by their name. This has one entry per result number. 3228 llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values; 3229 3230 /// This keeps track of all of the values defined by a specific name scope. 3231 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 3232 }; 3233 3234 /// A list of isolated name scopes. 3235 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 3236 3237 /// This keeps track of the block names as well as the location of the first 3238 /// reference for each nested name scope. This is used to diagnose invalid 3239 /// block references and memorize them. 3240 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 3241 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 3242 3243 /// These are all of the placeholders we've made along with the location of 3244 /// their first reference, to allow checking for use of undefined values. 3245 DenseMap<Value, SMLoc> forwardRefPlaceholders; 3246 3247 /// The builder used when creating parsed operation instances. 3248 OpBuilder opBuilder; 3249 3250 /// The top level module operation. 3251 ModuleOp moduleOp; 3252 }; 3253 } // end anonymous namespace 3254 3255 OperationParser::~OperationParser() { 3256 for (auto &fwd : forwardRefPlaceholders) { 3257 // Drop all uses of undefined forward declared reference and destroy 3258 // defining operation. 3259 fwd.first->dropAllUses(); 3260 fwd.first->getDefiningOp()->destroy(); 3261 } 3262 } 3263 3264 /// After parsing is finished, this function must be called to see if there are 3265 /// any remaining issues. 3266 ParseResult OperationParser::finalize() { 3267 // Check for any forward references that are left. If we find any, error 3268 // out. 3269 if (!forwardRefPlaceholders.empty()) { 3270 SmallVector<std::pair<const char *, Value>, 4> errors; 3271 // Iteration over the map isn't deterministic, so sort by source location. 3272 for (auto entry : forwardRefPlaceholders) 3273 errors.push_back({entry.second.getPointer(), entry.first}); 3274 llvm::array_pod_sort(errors.begin(), errors.end()); 3275 3276 for (auto entry : errors) { 3277 auto loc = SMLoc::getFromPointer(entry.first); 3278 emitError(loc, "use of undeclared SSA value name"); 3279 } 3280 return failure(); 3281 } 3282 3283 return success(); 3284 } 3285 3286 //===----------------------------------------------------------------------===// 3287 // SSA Value Handling 3288 //===----------------------------------------------------------------------===// 3289 3290 void OperationParser::pushSSANameScope(bool isIsolated) { 3291 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 3292 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 3293 3294 // Push back a new name definition scope. 3295 if (isIsolated) 3296 isolatedNameScopes.push_back({}); 3297 isolatedNameScopes.back().pushSSANameScope(); 3298 } 3299 3300 ParseResult OperationParser::popSSANameScope() { 3301 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 3302 3303 // Verify that all referenced blocks were defined. 3304 if (!forwardRefInCurrentScope.empty()) { 3305 SmallVector<std::pair<const char *, Block *>, 4> errors; 3306 // Iteration over the map isn't deterministic, so sort by source location. 3307 for (auto entry : forwardRefInCurrentScope) { 3308 errors.push_back({entry.second.getPointer(), entry.first}); 3309 // Add this block to the top-level region to allow for automatic cleanup. 3310 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 3311 } 3312 llvm::array_pod_sort(errors.begin(), errors.end()); 3313 3314 for (auto entry : errors) { 3315 auto loc = SMLoc::getFromPointer(entry.first); 3316 emitError(loc, "reference to an undefined block"); 3317 } 3318 return failure(); 3319 } 3320 3321 // Pop the next nested namescope. If there is only one internal namescope, 3322 // just pop the isolated scope. 3323 auto ¤tNameScope = isolatedNameScopes.back(); 3324 if (currentNameScope.definitionsPerScope.size() == 1) 3325 isolatedNameScopes.pop_back(); 3326 else 3327 currentNameScope.popSSANameScope(); 3328 3329 blocksByName.pop_back(); 3330 return success(); 3331 } 3332 3333 /// Register a definition of a value with the symbol table. 3334 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) { 3335 auto &entries = getSSAValueEntry(useInfo.name); 3336 3337 // Make sure there is a slot for this value. 3338 if (entries.size() <= useInfo.number) 3339 entries.resize(useInfo.number + 1); 3340 3341 // If we already have an entry for this, check to see if it was a definition 3342 // or a forward reference. 3343 if (auto existing = entries[useInfo.number].first) { 3344 if (!isForwardRefPlaceholder(existing)) { 3345 return emitError(useInfo.loc) 3346 .append("redefinition of SSA value '", useInfo.name, "'") 3347 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3348 .append("previously defined here"); 3349 } 3350 3351 // If it was a forward reference, update everything that used it to use 3352 // the actual definition instead, delete the forward ref, and remove it 3353 // from our set of forward references we track. 3354 existing->replaceAllUsesWith(value); 3355 existing->getDefiningOp()->destroy(); 3356 forwardRefPlaceholders.erase(existing); 3357 } 3358 3359 /// Record this definition for the current scope. 3360 entries[useInfo.number] = {value, useInfo.loc}; 3361 recordDefinition(useInfo.name); 3362 return success(); 3363 } 3364 3365 /// Parse a (possibly empty) list of SSA operands. 3366 /// 3367 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 3368 /// ssa-use-list-opt ::= ssa-use-list? 3369 /// 3370 ParseResult 3371 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 3372 if (getToken().isNot(Token::percent_identifier)) 3373 return success(); 3374 return parseCommaSeparatedList([&]() -> ParseResult { 3375 SSAUseInfo result; 3376 if (parseSSAUse(result)) 3377 return failure(); 3378 results.push_back(result); 3379 return success(); 3380 }); 3381 } 3382 3383 /// Parse a SSA operand for an operation. 3384 /// 3385 /// ssa-use ::= ssa-id 3386 /// 3387 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 3388 result.name = getTokenSpelling(); 3389 result.number = 0; 3390 result.loc = getToken().getLoc(); 3391 if (parseToken(Token::percent_identifier, "expected SSA operand")) 3392 return failure(); 3393 3394 // If we have an attribute ID, it is a result number. 3395 if (getToken().is(Token::hash_identifier)) { 3396 if (auto value = getToken().getHashIdentifierNumber()) 3397 result.number = value.getValue(); 3398 else 3399 return emitError("invalid SSA value result number"); 3400 consumeToken(Token::hash_identifier); 3401 } 3402 3403 return success(); 3404 } 3405 3406 /// Given an unbound reference to an SSA value and its type, return the value 3407 /// it specifies. This returns null on failure. 3408 Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 3409 auto &entries = getSSAValueEntry(useInfo.name); 3410 3411 // If we have already seen a value of this name, return it. 3412 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 3413 auto result = entries[useInfo.number].first; 3414 // Check that the type matches the other uses. 3415 if (result->getType() == type) 3416 return result; 3417 3418 emitError(useInfo.loc, "use of value '") 3419 .append(useInfo.name, 3420 "' expects different type than prior uses: ", type, " vs ", 3421 result->getType()) 3422 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3423 .append("prior use here"); 3424 return nullptr; 3425 } 3426 3427 // Make sure we have enough slots for this. 3428 if (entries.size() <= useInfo.number) 3429 entries.resize(useInfo.number + 1); 3430 3431 // If the value has already been defined and this is an overly large result 3432 // number, diagnose that. 3433 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 3434 return (emitError(useInfo.loc, "reference to invalid result number"), 3435 nullptr); 3436 3437 // Otherwise, this is a forward reference. Create a placeholder and remember 3438 // that we did so. 3439 auto result = createForwardRefPlaceholder(useInfo.loc, type); 3440 entries[useInfo.number].first = result; 3441 entries[useInfo.number].second = useInfo.loc; 3442 return result; 3443 } 3444 3445 /// Parse an SSA use with an associated type. 3446 /// 3447 /// ssa-use-and-type ::= ssa-use `:` type 3448 ParseResult OperationParser::parseSSADefOrUseAndType( 3449 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 3450 SSAUseInfo useInfo; 3451 if (parseSSAUse(useInfo) || 3452 parseToken(Token::colon, "expected ':' and type for SSA operand")) 3453 return failure(); 3454 3455 auto type = parseType(); 3456 if (!type) 3457 return failure(); 3458 3459 return action(useInfo, type); 3460 } 3461 3462 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 3463 /// followed by a type list. 3464 /// 3465 /// ssa-use-and-type-list 3466 /// ::= ssa-use-list ':' type-list-no-parens 3467 /// 3468 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 3469 SmallVectorImpl<Value> &results) { 3470 SmallVector<SSAUseInfo, 4> valueIDs; 3471 if (parseOptionalSSAUseList(valueIDs)) 3472 return failure(); 3473 3474 // If there were no operands, then there is no colon or type lists. 3475 if (valueIDs.empty()) 3476 return success(); 3477 3478 SmallVector<Type, 4> types; 3479 if (parseToken(Token::colon, "expected ':' in operand list") || 3480 parseTypeListNoParens(types)) 3481 return failure(); 3482 3483 if (valueIDs.size() != types.size()) 3484 return emitError("expected ") 3485 << valueIDs.size() << " types to match operand list"; 3486 3487 results.reserve(valueIDs.size()); 3488 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 3489 if (auto value = resolveSSAUse(valueIDs[i], types[i])) 3490 results.push_back(value); 3491 else 3492 return failure(); 3493 } 3494 3495 return success(); 3496 } 3497 3498 /// Record that a definition was added at the current scope. 3499 void OperationParser::recordDefinition(StringRef def) { 3500 isolatedNameScopes.back().recordDefinition(def); 3501 } 3502 3503 /// Get the value entry for the given SSA name. 3504 SmallVectorImpl<std::pair<Value, SMLoc>> & 3505 OperationParser::getSSAValueEntry(StringRef name) { 3506 return isolatedNameScopes.back().values[name]; 3507 } 3508 3509 /// Create and remember a new placeholder for a forward reference. 3510 Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 3511 // Forward references are always created as operations, because we just need 3512 // something with a def/use chain. 3513 // 3514 // We create these placeholders as having an empty name, which we know 3515 // cannot be created through normal user input, allowing us to distinguish 3516 // them. 3517 auto name = OperationName("placeholder", getContext()); 3518 auto *op = Operation::create( 3519 getEncodedSourceLocation(loc), name, type, /*operands=*/{}, 3520 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, 3521 /*resizableOperandList=*/false); 3522 forwardRefPlaceholders[op->getResult(0)] = loc; 3523 return op->getResult(0); 3524 } 3525 3526 //===----------------------------------------------------------------------===// 3527 // Operation Parsing 3528 //===----------------------------------------------------------------------===// 3529 3530 /// Parse an operation. 3531 /// 3532 /// operation ::= op-result-list? 3533 /// (generic-operation | custom-operation) 3534 /// trailing-location? 3535 /// generic-operation ::= string-literal '(' ssa-use-list? ')' attribute-dict? 3536 /// `:` function-type 3537 /// custom-operation ::= bare-id custom-operation-format 3538 /// op-result-list ::= op-result (`,` op-result)* `=` 3539 /// op-result ::= ssa-id (`:` integer-literal) 3540 /// 3541 ParseResult OperationParser::parseOperation() { 3542 auto loc = getToken().getLoc(); 3543 SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs; 3544 size_t numExpectedResults = 0; 3545 if (getToken().is(Token::percent_identifier)) { 3546 // Parse the group of result ids. 3547 auto parseNextResult = [&]() -> ParseResult { 3548 // Parse the next result id. 3549 if (!getToken().is(Token::percent_identifier)) 3550 return emitError("expected valid ssa identifier"); 3551 3552 Token nameTok = getToken(); 3553 consumeToken(Token::percent_identifier); 3554 3555 // If the next token is a ':', we parse the expected result count. 3556 size_t expectedSubResults = 1; 3557 if (consumeIf(Token::colon)) { 3558 // Check that the next token is an integer. 3559 if (!getToken().is(Token::integer)) 3560 return emitError("expected integer number of results"); 3561 3562 // Check that number of results is > 0. 3563 auto val = getToken().getUInt64IntegerValue(); 3564 if (!val.hasValue() || val.getValue() < 1) 3565 return emitError("expected named operation to have atleast 1 result"); 3566 consumeToken(Token::integer); 3567 expectedSubResults = *val; 3568 } 3569 3570 resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults, 3571 nameTok.getLoc()); 3572 numExpectedResults += expectedSubResults; 3573 return success(); 3574 }; 3575 if (parseCommaSeparatedList(parseNextResult)) 3576 return failure(); 3577 3578 if (parseToken(Token::equal, "expected '=' after SSA name")) 3579 return failure(); 3580 } 3581 3582 Operation *op; 3583 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3584 op = parseCustomOperation(); 3585 else if (getToken().is(Token::string)) 3586 op = parseGenericOperation(); 3587 else 3588 return emitError("expected operation name in quotes"); 3589 3590 // If parsing of the basic operation failed, then this whole thing fails. 3591 if (!op) 3592 return failure(); 3593 3594 // If the operation had a name, register it. 3595 if (!resultIDs.empty()) { 3596 if (op->getNumResults() == 0) 3597 return emitError(loc, "cannot name an operation with no results"); 3598 if (numExpectedResults != op->getNumResults()) 3599 return emitError(loc, "operation defines ") 3600 << op->getNumResults() << " results but was provided " 3601 << numExpectedResults << " to bind"; 3602 3603 // Add definitions for each of the result groups. 3604 unsigned opResI = 0; 3605 for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) { 3606 for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) { 3607 if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, 3608 op->getResult(opResI++))) 3609 return failure(); 3610 } 3611 } 3612 } 3613 3614 return success(); 3615 } 3616 3617 /// Parse a single operation successor and its operand list. 3618 /// 3619 /// successor ::= block-id branch-use-list? 3620 /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` 3621 /// 3622 ParseResult 3623 OperationParser::parseSuccessorAndUseList(Block *&dest, 3624 SmallVectorImpl<Value> &operands) { 3625 // Verify branch is identifier and get the matching block. 3626 if (!getToken().is(Token::caret_identifier)) 3627 return emitError("expected block name"); 3628 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3629 consumeToken(); 3630 3631 // Handle optional arguments. 3632 if (consumeIf(Token::l_paren) && 3633 (parseOptionalSSAUseAndTypeList(operands) || 3634 parseToken(Token::r_paren, "expected ')' to close argument list"))) { 3635 return failure(); 3636 } 3637 3638 return success(); 3639 } 3640 3641 /// Parse a comma-separated list of operation successors in brackets. 3642 /// 3643 /// successor-list ::= `[` successor (`,` successor )* `]` 3644 /// 3645 ParseResult OperationParser::parseSuccessors( 3646 SmallVectorImpl<Block *> &destinations, 3647 SmallVectorImpl<SmallVector<Value, 4>> &operands) { 3648 if (parseToken(Token::l_square, "expected '['")) 3649 return failure(); 3650 3651 auto parseElt = [this, &destinations, &operands]() { 3652 Block *dest; 3653 SmallVector<Value, 4> destOperands; 3654 auto res = parseSuccessorAndUseList(dest, destOperands); 3655 destinations.push_back(dest); 3656 operands.push_back(destOperands); 3657 return res; 3658 }; 3659 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3660 /*allowEmptyList=*/false); 3661 } 3662 3663 namespace { 3664 // RAII-style guard for cleaning up the regions in the operation state before 3665 // deleting them. Within the parser, regions may get deleted if parsing failed, 3666 // and other errors may be present, in particular undominated uses. This makes 3667 // sure such uses are deleted. 3668 struct CleanupOpStateRegions { 3669 ~CleanupOpStateRegions() { 3670 SmallVector<Region *, 4> regionsToClean; 3671 regionsToClean.reserve(state.regions.size()); 3672 for (auto ®ion : state.regions) 3673 if (region) 3674 for (auto &block : *region) 3675 block.dropAllDefinedValueUses(); 3676 } 3677 OperationState &state; 3678 }; 3679 } // namespace 3680 3681 Operation *OperationParser::parseGenericOperation() { 3682 // Get location information for the operation. 3683 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3684 3685 auto name = getToken().getStringValue(); 3686 if (name.empty()) 3687 return (emitError("empty operation name is invalid"), nullptr); 3688 if (name.find('\0') != StringRef::npos) 3689 return (emitError("null character not allowed in operation name"), nullptr); 3690 3691 consumeToken(Token::string); 3692 3693 OperationState result(srcLocation, name); 3694 3695 // Generic operations have a resizable operation list. 3696 result.setOperandListToResizable(); 3697 3698 // Parse the operand list. 3699 SmallVector<SSAUseInfo, 8> operandInfos; 3700 3701 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3702 parseOptionalSSAUseList(operandInfos) || 3703 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3704 return nullptr; 3705 } 3706 3707 // Parse the successor list but don't add successors to the result yet to 3708 // avoid messing up with the argument order. 3709 SmallVector<Block *, 2> successors; 3710 SmallVector<SmallVector<Value, 4>, 2> successorOperands; 3711 if (getToken().is(Token::l_square)) { 3712 // Check if the operation is a known terminator. 3713 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3714 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3715 return emitError("successors in non-terminator"), nullptr; 3716 if (parseSuccessors(successors, successorOperands)) 3717 return nullptr; 3718 } 3719 3720 // Parse the region list. 3721 CleanupOpStateRegions guard{result}; 3722 if (consumeIf(Token::l_paren)) { 3723 do { 3724 // Create temporary regions with the top level region as parent. 3725 result.regions.emplace_back(new Region(moduleOp)); 3726 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3727 return nullptr; 3728 } while (consumeIf(Token::comma)); 3729 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3730 return nullptr; 3731 } 3732 3733 if (getToken().is(Token::l_brace)) { 3734 if (parseAttributeDict(result.attributes)) 3735 return nullptr; 3736 } 3737 3738 if (parseToken(Token::colon, "expected ':' followed by operation type")) 3739 return nullptr; 3740 3741 auto typeLoc = getToken().getLoc(); 3742 auto type = parseType(); 3743 if (!type) 3744 return nullptr; 3745 auto fnType = type.dyn_cast<FunctionType>(); 3746 if (!fnType) 3747 return (emitError(typeLoc, "expected function type"), nullptr); 3748 3749 result.addTypes(fnType.getResults()); 3750 3751 // Check that we have the right number of types for the operands. 3752 auto operandTypes = fnType.getInputs(); 3753 if (operandTypes.size() != operandInfos.size()) { 3754 auto plural = "s"[operandInfos.size() == 1]; 3755 return (emitError(typeLoc, "expected ") 3756 << operandInfos.size() << " operand type" << plural 3757 << " but had " << operandTypes.size(), 3758 nullptr); 3759 } 3760 3761 // Resolve all of the operands. 3762 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 3763 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 3764 if (!result.operands.back()) 3765 return nullptr; 3766 } 3767 3768 // Add the successors, and their operands after the proper operands. 3769 for (auto succ : llvm::zip(successors, successorOperands)) { 3770 Block *successor = std::get<0>(succ); 3771 const SmallVector<Value, 4> &operands = std::get<1>(succ); 3772 result.addSuccessor(successor, operands); 3773 } 3774 3775 // Parse a location if one is present. 3776 if (parseOptionalTrailingLocation(result.location)) 3777 return nullptr; 3778 3779 return opBuilder.createOperation(result); 3780 } 3781 3782 Operation *OperationParser::parseGenericOperation(Block *insertBlock, 3783 Block::iterator insertPt) { 3784 OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); 3785 opBuilder.setInsertionPoint(insertBlock, insertPt); 3786 return parseGenericOperation(); 3787 } 3788 3789 namespace { 3790 class CustomOpAsmParser : public OpAsmParser { 3791 public: 3792 CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, 3793 OperationParser &parser) 3794 : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} 3795 3796 /// Parse an instance of the operation described by 'opDefinition' into the 3797 /// provided operation state. 3798 ParseResult parseOperation(OperationState &opState) { 3799 if (opDefinition->parseAssembly(*this, opState)) 3800 return failure(); 3801 return success(); 3802 } 3803 3804 Operation *parseGenericOperation(Block *insertBlock, 3805 Block::iterator insertPt) final { 3806 return parser.parseGenericOperation(insertBlock, insertPt); 3807 } 3808 3809 //===--------------------------------------------------------------------===// 3810 // Utilities 3811 //===--------------------------------------------------------------------===// 3812 3813 /// Return if any errors were emitted during parsing. 3814 bool didEmitError() const { return emittedError; } 3815 3816 /// Emit a diagnostic at the specified location and return failure. 3817 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 3818 emittedError = true; 3819 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 3820 message); 3821 } 3822 3823 llvm::SMLoc getCurrentLocation() override { 3824 return parser.getToken().getLoc(); 3825 } 3826 3827 Builder &getBuilder() const override { return parser.builder; } 3828 3829 llvm::SMLoc getNameLoc() const override { return nameLoc; } 3830 3831 //===--------------------------------------------------------------------===// 3832 // Token Parsing 3833 //===--------------------------------------------------------------------===// 3834 3835 /// Parse a `->` token. 3836 ParseResult parseArrow() override { 3837 return parser.parseToken(Token::arrow, "expected '->'"); 3838 } 3839 3840 /// Parses a `->` if present. 3841 ParseResult parseOptionalArrow() override { 3842 return success(parser.consumeIf(Token::arrow)); 3843 } 3844 3845 /// Parse a `:` token. 3846 ParseResult parseColon() override { 3847 return parser.parseToken(Token::colon, "expected ':'"); 3848 } 3849 3850 /// Parse a `:` token if present. 3851 ParseResult parseOptionalColon() override { 3852 return success(parser.consumeIf(Token::colon)); 3853 } 3854 3855 /// Parse a `,` token. 3856 ParseResult parseComma() override { 3857 return parser.parseToken(Token::comma, "expected ','"); 3858 } 3859 3860 /// Parse a `,` token if present. 3861 ParseResult parseOptionalComma() override { 3862 return success(parser.consumeIf(Token::comma)); 3863 } 3864 3865 /// Parses a `...` if present. 3866 ParseResult parseOptionalEllipsis() override { 3867 return success(parser.consumeIf(Token::ellipsis)); 3868 } 3869 3870 /// Parse a `=` token. 3871 ParseResult parseEqual() override { 3872 return parser.parseToken(Token::equal, "expected '='"); 3873 } 3874 3875 /// Parse a '<' token. 3876 ParseResult parseLess() override { 3877 return parser.parseToken(Token::less, "expected '<'"); 3878 } 3879 3880 /// Parse a '>' token. 3881 ParseResult parseGreater() override { 3882 return parser.parseToken(Token::greater, "expected '>'"); 3883 } 3884 3885 /// Parse a `(` token. 3886 ParseResult parseLParen() override { 3887 return parser.parseToken(Token::l_paren, "expected '('"); 3888 } 3889 3890 /// Parses a '(' if present. 3891 ParseResult parseOptionalLParen() override { 3892 return success(parser.consumeIf(Token::l_paren)); 3893 } 3894 3895 /// Parse a `)` token. 3896 ParseResult parseRParen() override { 3897 return parser.parseToken(Token::r_paren, "expected ')'"); 3898 } 3899 3900 /// Parses a ')' if present. 3901 ParseResult parseOptionalRParen() override { 3902 return success(parser.consumeIf(Token::r_paren)); 3903 } 3904 3905 /// Parse a `[` token. 3906 ParseResult parseLSquare() override { 3907 return parser.parseToken(Token::l_square, "expected '['"); 3908 } 3909 3910 /// Parses a '[' if present. 3911 ParseResult parseOptionalLSquare() override { 3912 return success(parser.consumeIf(Token::l_square)); 3913 } 3914 3915 /// Parse a `]` token. 3916 ParseResult parseRSquare() override { 3917 return parser.parseToken(Token::r_square, "expected ']'"); 3918 } 3919 3920 /// Parses a ']' if present. 3921 ParseResult parseOptionalRSquare() override { 3922 return success(parser.consumeIf(Token::r_square)); 3923 } 3924 3925 //===--------------------------------------------------------------------===// 3926 // Attribute Parsing 3927 //===--------------------------------------------------------------------===// 3928 3929 /// Parse an arbitrary attribute of a given type and return it in result. This 3930 /// also adds the attribute to the specified attribute list with the specified 3931 /// name. 3932 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 3933 SmallVectorImpl<NamedAttribute> &attrs) override { 3934 result = parser.parseAttribute(type); 3935 if (!result) 3936 return failure(); 3937 3938 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 3939 return success(); 3940 } 3941 3942 /// Parse a named dictionary into 'result' if it is present. 3943 ParseResult 3944 parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override { 3945 if (parser.getToken().isNot(Token::l_brace)) 3946 return success(); 3947 return parser.parseAttributeDict(result); 3948 } 3949 3950 /// Parse a named dictionary into 'result' if the `attributes` keyword is 3951 /// present. 3952 ParseResult parseOptionalAttrDictWithKeyword( 3953 SmallVectorImpl<NamedAttribute> &result) override { 3954 if (failed(parseOptionalKeyword("attributes"))) 3955 return success(); 3956 return parser.parseAttributeDict(result); 3957 } 3958 3959 //===--------------------------------------------------------------------===// 3960 // Identifier Parsing 3961 //===--------------------------------------------------------------------===// 3962 3963 /// Returns if the current token corresponds to a keyword. 3964 bool isCurrentTokenAKeyword() const { 3965 return parser.getToken().is(Token::bare_identifier) || 3966 parser.getToken().isKeyword(); 3967 } 3968 3969 /// Parse the given keyword if present. 3970 ParseResult parseOptionalKeyword(StringRef keyword) override { 3971 // Check that the current token has the same spelling. 3972 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 3973 return failure(); 3974 parser.consumeToken(); 3975 return success(); 3976 } 3977 3978 /// Parse a keyword, if present, into 'keyword'. 3979 ParseResult parseOptionalKeyword(StringRef *keyword) override { 3980 // Check that the current token is a keyword. 3981 if (!isCurrentTokenAKeyword()) 3982 return failure(); 3983 3984 *keyword = parser.getTokenSpelling(); 3985 parser.consumeToken(); 3986 return success(); 3987 } 3988 3989 /// Parse an optional @-identifier and store it (without the '@' symbol) in a 3990 /// string attribute named 'attrName'. 3991 ParseResult 3992 parseOptionalSymbolName(StringAttr &result, StringRef attrName, 3993 SmallVectorImpl<NamedAttribute> &attrs) override { 3994 Token atToken = parser.getToken(); 3995 if (atToken.isNot(Token::at_identifier)) 3996 return failure(); 3997 3998 result = getBuilder().getStringAttr(extractSymbolReference(atToken)); 3999 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 4000 parser.consumeToken(); 4001 return success(); 4002 } 4003 4004 //===--------------------------------------------------------------------===// 4005 // Operand Parsing 4006 //===--------------------------------------------------------------------===// 4007 4008 /// Parse a single operand. 4009 ParseResult parseOperand(OperandType &result) override { 4010 OperationParser::SSAUseInfo useInfo; 4011 if (parser.parseSSAUse(useInfo)) 4012 return failure(); 4013 4014 result = {useInfo.loc, useInfo.name, useInfo.number}; 4015 return success(); 4016 } 4017 4018 /// Parse zero or more SSA comma-separated operand references with a specified 4019 /// surrounding delimiter, and an optional required operand count. 4020 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 4021 int requiredOperandCount = -1, 4022 Delimiter delimiter = Delimiter::None) override { 4023 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 4024 requiredOperandCount, delimiter); 4025 } 4026 4027 /// Parse zero or more SSA comma-separated operand or region arguments with 4028 /// optional surrounding delimiter and required operand count. 4029 ParseResult 4030 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 4031 bool isOperandList, int requiredOperandCount = -1, 4032 Delimiter delimiter = Delimiter::None) { 4033 auto startLoc = parser.getToken().getLoc(); 4034 4035 // Handle delimiters. 4036 switch (delimiter) { 4037 case Delimiter::None: 4038 // Don't check for the absence of a delimiter if the number of operands 4039 // is unknown (and hence the operand list could be empty). 4040 if (requiredOperandCount == -1) 4041 break; 4042 // Token already matches an identifier and so can't be a delimiter. 4043 if (parser.getToken().is(Token::percent_identifier)) 4044 break; 4045 // Test against known delimiters. 4046 if (parser.getToken().is(Token::l_paren) || 4047 parser.getToken().is(Token::l_square)) 4048 return emitError(startLoc, "unexpected delimiter"); 4049 return emitError(startLoc, "invalid operand"); 4050 case Delimiter::OptionalParen: 4051 if (parser.getToken().isNot(Token::l_paren)) 4052 return success(); 4053 LLVM_FALLTHROUGH; 4054 case Delimiter::Paren: 4055 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 4056 return failure(); 4057 break; 4058 case Delimiter::OptionalSquare: 4059 if (parser.getToken().isNot(Token::l_square)) 4060 return success(); 4061 LLVM_FALLTHROUGH; 4062 case Delimiter::Square: 4063 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 4064 return failure(); 4065 break; 4066 } 4067 4068 // Check for zero operands. 4069 if (parser.getToken().is(Token::percent_identifier)) { 4070 do { 4071 OperandType operandOrArg; 4072 if (isOperandList ? parseOperand(operandOrArg) 4073 : parseRegionArgument(operandOrArg)) 4074 return failure(); 4075 result.push_back(operandOrArg); 4076 } while (parser.consumeIf(Token::comma)); 4077 } 4078 4079 // Handle delimiters. If we reach here, the optional delimiters were 4080 // present, so we need to parse their closing one. 4081 switch (delimiter) { 4082 case Delimiter::None: 4083 break; 4084 case Delimiter::OptionalParen: 4085 case Delimiter::Paren: 4086 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 4087 return failure(); 4088 break; 4089 case Delimiter::OptionalSquare: 4090 case Delimiter::Square: 4091 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 4092 return failure(); 4093 break; 4094 } 4095 4096 if (requiredOperandCount != -1 && 4097 result.size() != static_cast<size_t>(requiredOperandCount)) 4098 return emitError(startLoc, "expected ") 4099 << requiredOperandCount << " operands"; 4100 return success(); 4101 } 4102 4103 /// Parse zero or more trailing SSA comma-separated trailing operand 4104 /// references with a specified surrounding delimiter, and an optional 4105 /// required operand count. A leading comma is expected before the operands. 4106 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 4107 int requiredOperandCount, 4108 Delimiter delimiter) override { 4109 if (parser.getToken().is(Token::comma)) { 4110 parseComma(); 4111 return parseOperandList(result, requiredOperandCount, delimiter); 4112 } 4113 if (requiredOperandCount != -1) 4114 return emitError(parser.getToken().getLoc(), "expected ") 4115 << requiredOperandCount << " operands"; 4116 return success(); 4117 } 4118 4119 /// Resolve an operand to an SSA value, emitting an error on failure. 4120 ParseResult resolveOperand(const OperandType &operand, Type type, 4121 SmallVectorImpl<Value> &result) override { 4122 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4123 operand.location}; 4124 if (auto value = parser.resolveSSAUse(operandInfo, type)) { 4125 result.push_back(value); 4126 return success(); 4127 } 4128 return failure(); 4129 } 4130 4131 /// Parse an AffineMap of SSA ids. 4132 ParseResult 4133 parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 4134 Attribute &mapAttr, StringRef attrName, 4135 SmallVectorImpl<NamedAttribute> &attrs) override { 4136 SmallVector<OperandType, 2> dimOperands; 4137 SmallVector<OperandType, 1> symOperands; 4138 4139 auto parseElement = [&](bool isSymbol) -> ParseResult { 4140 OperandType operand; 4141 if (parseOperand(operand)) 4142 return failure(); 4143 if (isSymbol) 4144 symOperands.push_back(operand); 4145 else 4146 dimOperands.push_back(operand); 4147 return success(); 4148 }; 4149 4150 AffineMap map; 4151 if (parser.parseAffineMapOfSSAIds(map, parseElement)) 4152 return failure(); 4153 // Add AffineMap attribute. 4154 if (map) { 4155 mapAttr = AffineMapAttr::get(map); 4156 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 4157 } 4158 4159 // Add dim operands before symbol operands in 'operands'. 4160 operands.assign(dimOperands.begin(), dimOperands.end()); 4161 operands.append(symOperands.begin(), symOperands.end()); 4162 return success(); 4163 } 4164 4165 //===--------------------------------------------------------------------===// 4166 // Region Parsing 4167 //===--------------------------------------------------------------------===// 4168 4169 /// Parse a region that takes `arguments` of `argTypes` types. This 4170 /// effectively defines the SSA values of `arguments` and assigns their type. 4171 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 4172 ArrayRef<Type> argTypes, 4173 bool enableNameShadowing) override { 4174 assert(arguments.size() == argTypes.size() && 4175 "mismatching number of arguments and types"); 4176 4177 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 4178 regionArguments; 4179 for (auto pair : llvm::zip(arguments, argTypes)) { 4180 const OperandType &operand = std::get<0>(pair); 4181 Type type = std::get<1>(pair); 4182 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4183 operand.location}; 4184 regionArguments.emplace_back(operandInfo, type); 4185 } 4186 4187 // Try to parse the region. 4188 assert((!enableNameShadowing || 4189 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 4190 "name shadowing is only allowed on isolated regions"); 4191 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 4192 return failure(); 4193 return success(); 4194 } 4195 4196 /// Parses a region if present. 4197 ParseResult parseOptionalRegion(Region ®ion, 4198 ArrayRef<OperandType> arguments, 4199 ArrayRef<Type> argTypes, 4200 bool enableNameShadowing) override { 4201 if (parser.getToken().isNot(Token::l_brace)) 4202 return success(); 4203 return parseRegion(region, arguments, argTypes, enableNameShadowing); 4204 } 4205 4206 /// Parse a region argument. The type of the argument will be resolved later 4207 /// by a call to `parseRegion`. 4208 ParseResult parseRegionArgument(OperandType &argument) override { 4209 return parseOperand(argument); 4210 } 4211 4212 /// Parse a region argument if present. 4213 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 4214 if (parser.getToken().isNot(Token::percent_identifier)) 4215 return success(); 4216 return parseRegionArgument(argument); 4217 } 4218 4219 ParseResult 4220 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 4221 int requiredOperandCount = -1, 4222 Delimiter delimiter = Delimiter::None) override { 4223 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 4224 requiredOperandCount, delimiter); 4225 } 4226 4227 //===--------------------------------------------------------------------===// 4228 // Successor Parsing 4229 //===--------------------------------------------------------------------===// 4230 4231 /// Parse a single operation successor and its operand list. 4232 ParseResult 4233 parseSuccessorAndUseList(Block *&dest, 4234 SmallVectorImpl<Value> &operands) override { 4235 return parser.parseSuccessorAndUseList(dest, operands); 4236 } 4237 4238 //===--------------------------------------------------------------------===// 4239 // Type Parsing 4240 //===--------------------------------------------------------------------===// 4241 4242 /// Parse a type. 4243 ParseResult parseType(Type &result) override { 4244 return failure(!(result = parser.parseType())); 4245 } 4246 4247 /// Parse an optional arrow followed by a type list. 4248 ParseResult 4249 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 4250 if (!parser.consumeIf(Token::arrow)) 4251 return success(); 4252 return parser.parseFunctionResultTypes(result); 4253 } 4254 4255 /// Parse a colon followed by a type. 4256 ParseResult parseColonType(Type &result) override { 4257 return failure(parser.parseToken(Token::colon, "expected ':'") || 4258 !(result = parser.parseType())); 4259 } 4260 4261 /// Parse a colon followed by a type list, which must have at least one type. 4262 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 4263 if (parser.parseToken(Token::colon, "expected ':'")) 4264 return failure(); 4265 return parser.parseTypeListNoParens(result); 4266 } 4267 4268 /// Parse an optional colon followed by a type list, which if present must 4269 /// have at least one type. 4270 ParseResult 4271 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 4272 if (!parser.consumeIf(Token::colon)) 4273 return success(); 4274 return parser.parseTypeListNoParens(result); 4275 } 4276 4277 private: 4278 /// The source location of the operation name. 4279 SMLoc nameLoc; 4280 4281 /// The abstract information of the operation. 4282 const AbstractOperation *opDefinition; 4283 4284 /// The main operation parser. 4285 OperationParser &parser; 4286 4287 /// A flag that indicates if any errors were emitted during parsing. 4288 bool emittedError = false; 4289 }; 4290 } // end anonymous namespace. 4291 4292 Operation *OperationParser::parseCustomOperation() { 4293 auto opLoc = getToken().getLoc(); 4294 auto opName = getTokenSpelling(); 4295 4296 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 4297 if (!opDefinition && !opName.contains('.')) { 4298 // If the operation name has no namespace prefix we treat it as a standard 4299 // operation and prefix it with "std". 4300 // TODO: Would it be better to just build a mapping of the registered 4301 // operations in the standard dialect? 4302 opDefinition = 4303 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 4304 } 4305 4306 if (!opDefinition) { 4307 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 4308 return nullptr; 4309 } 4310 4311 consumeToken(); 4312 4313 // If the custom op parser crashes, produce some indication to help 4314 // debugging. 4315 std::string opNameStr = opName.str(); 4316 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 4317 opNameStr.c_str()); 4318 4319 // Get location information for the operation. 4320 auto srcLocation = getEncodedSourceLocation(opLoc); 4321 4322 // Have the op implementation take a crack and parsing this. 4323 OperationState opState(srcLocation, opDefinition->name); 4324 CleanupOpStateRegions guard{opState}; 4325 CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); 4326 if (opAsmParser.parseOperation(opState)) 4327 return nullptr; 4328 4329 // If it emitted an error, we failed. 4330 if (opAsmParser.didEmitError()) 4331 return nullptr; 4332 4333 // Parse a location if one is present. 4334 if (parseOptionalTrailingLocation(opState.location)) 4335 return nullptr; 4336 4337 // Otherwise, we succeeded. Use the state it parsed as our op information. 4338 return opBuilder.createOperation(opState); 4339 } 4340 4341 //===----------------------------------------------------------------------===// 4342 // Region Parsing 4343 //===----------------------------------------------------------------------===// 4344 4345 /// Region. 4346 /// 4347 /// region ::= '{' region-body 4348 /// 4349 ParseResult OperationParser::parseRegion( 4350 Region ®ion, 4351 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 4352 bool isIsolatedNameScope) { 4353 // Parse the '{'. 4354 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 4355 return failure(); 4356 4357 // Check for an empty region. 4358 if (entryArguments.empty() && consumeIf(Token::r_brace)) 4359 return success(); 4360 auto currentPt = opBuilder.saveInsertionPoint(); 4361 4362 // Push a new named value scope. 4363 pushSSANameScope(isIsolatedNameScope); 4364 4365 // Parse the first block directly to allow for it to be unnamed. 4366 Block *block = new Block(); 4367 4368 // Add arguments to the entry block. 4369 if (!entryArguments.empty()) { 4370 for (auto &placeholderArgPair : entryArguments) { 4371 auto &argInfo = placeholderArgPair.first; 4372 // Ensure that the argument was not already defined. 4373 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 4374 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 4375 "' is already in use") 4376 .attachNote(getEncodedSourceLocation(*defLoc)) 4377 << "previously referenced here"; 4378 } 4379 if (addDefinition(placeholderArgPair.first, 4380 block->addArgument(placeholderArgPair.second))) { 4381 delete block; 4382 return failure(); 4383 } 4384 } 4385 4386 // If we had named arguments, then don't allow a block name. 4387 if (getToken().is(Token::caret_identifier)) 4388 return emitError("invalid block name in region with named arguments"); 4389 } 4390 4391 if (parseBlock(block)) { 4392 delete block; 4393 return failure(); 4394 } 4395 4396 // Verify that no other arguments were parsed. 4397 if (!entryArguments.empty() && 4398 block->getNumArguments() > entryArguments.size()) { 4399 delete block; 4400 return emitError("entry block arguments were already defined"); 4401 } 4402 4403 // Parse the rest of the region. 4404 region.push_back(block); 4405 if (parseRegionBody(region)) 4406 return failure(); 4407 4408 // Pop the SSA value scope for this region. 4409 if (popSSANameScope()) 4410 return failure(); 4411 4412 // Reset the original insertion point. 4413 opBuilder.restoreInsertionPoint(currentPt); 4414 return success(); 4415 } 4416 4417 /// Region. 4418 /// 4419 /// region-body ::= block* '}' 4420 /// 4421 ParseResult OperationParser::parseRegionBody(Region ®ion) { 4422 // Parse the list of blocks. 4423 while (!consumeIf(Token::r_brace)) { 4424 Block *newBlock = nullptr; 4425 if (parseBlock(newBlock)) 4426 return failure(); 4427 region.push_back(newBlock); 4428 } 4429 return success(); 4430 } 4431 4432 //===----------------------------------------------------------------------===// 4433 // Block Parsing 4434 //===----------------------------------------------------------------------===// 4435 4436 /// Block declaration. 4437 /// 4438 /// block ::= block-label? operation* 4439 /// block-label ::= block-id block-arg-list? `:` 4440 /// block-id ::= caret-id 4441 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 4442 /// 4443 ParseResult OperationParser::parseBlock(Block *&block) { 4444 // The first block of a region may already exist, if it does the caret 4445 // identifier is optional. 4446 if (block && getToken().isNot(Token::caret_identifier)) 4447 return parseBlockBody(block); 4448 4449 SMLoc nameLoc = getToken().getLoc(); 4450 auto name = getTokenSpelling(); 4451 if (parseToken(Token::caret_identifier, "expected block name")) 4452 return failure(); 4453 4454 block = defineBlockNamed(name, nameLoc, block); 4455 4456 // Fail if the block was already defined. 4457 if (!block) 4458 return emitError(nameLoc, "redefinition of block '") << name << "'"; 4459 4460 // If an argument list is present, parse it. 4461 if (consumeIf(Token::l_paren)) { 4462 SmallVector<BlockArgument, 8> bbArgs; 4463 if (parseOptionalBlockArgList(bbArgs, block) || 4464 parseToken(Token::r_paren, "expected ')' to end argument list")) 4465 return failure(); 4466 } 4467 4468 if (parseToken(Token::colon, "expected ':' after block name")) 4469 return failure(); 4470 4471 return parseBlockBody(block); 4472 } 4473 4474 ParseResult OperationParser::parseBlockBody(Block *block) { 4475 // Set the insertion point to the end of the block to parse. 4476 opBuilder.setInsertionPointToEnd(block); 4477 4478 // Parse the list of operations that make up the body of the block. 4479 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 4480 if (parseOperation()) 4481 return failure(); 4482 4483 return success(); 4484 } 4485 4486 /// Get the block with the specified name, creating it if it doesn't already 4487 /// exist. The location specified is the point of use, which allows 4488 /// us to diagnose references to blocks that are not defined precisely. 4489 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 4490 auto &blockAndLoc = getBlockInfoByName(name); 4491 if (!blockAndLoc.first) { 4492 blockAndLoc = {new Block(), loc}; 4493 insertForwardRef(blockAndLoc.first, loc); 4494 } 4495 4496 return blockAndLoc.first; 4497 } 4498 4499 /// Define the block with the specified name. Returns the Block* or nullptr in 4500 /// the case of redefinition. 4501 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 4502 Block *existing) { 4503 auto &blockAndLoc = getBlockInfoByName(name); 4504 if (!blockAndLoc.first) { 4505 // If the caller provided a block, use it. Otherwise create a new one. 4506 if (!existing) 4507 existing = new Block(); 4508 blockAndLoc.first = existing; 4509 blockAndLoc.second = loc; 4510 return blockAndLoc.first; 4511 } 4512 4513 // Forward declarations are removed once defined, so if we are defining a 4514 // existing block and it is not a forward declaration, then it is a 4515 // redeclaration. 4516 if (!eraseForwardRef(blockAndLoc.first)) 4517 return nullptr; 4518 return blockAndLoc.first; 4519 } 4520 4521 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 4522 /// 4523 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 4524 /// 4525 ParseResult OperationParser::parseOptionalBlockArgList( 4526 SmallVectorImpl<BlockArgument> &results, Block *owner) { 4527 if (getToken().is(Token::r_brace)) 4528 return success(); 4529 4530 // If the block already has arguments, then we're handling the entry block. 4531 // Parse and register the names for the arguments, but do not add them. 4532 bool definingExistingArgs = owner->getNumArguments() != 0; 4533 unsigned nextArgument = 0; 4534 4535 return parseCommaSeparatedList([&]() -> ParseResult { 4536 return parseSSADefOrUseAndType( 4537 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 4538 // If this block did not have existing arguments, define a new one. 4539 if (!definingExistingArgs) 4540 return addDefinition(useInfo, owner->addArgument(type)); 4541 4542 // Otherwise, ensure that this argument has already been created. 4543 if (nextArgument >= owner->getNumArguments()) 4544 return emitError("too many arguments specified in argument list"); 4545 4546 // Finally, make sure the existing argument has the correct type. 4547 auto arg = owner->getArgument(nextArgument++); 4548 if (arg->getType() != type) 4549 return emitError("argument and block argument type mismatch"); 4550 return addDefinition(useInfo, arg); 4551 }); 4552 }); 4553 } 4554 4555 //===----------------------------------------------------------------------===// 4556 // Top-level entity parsing. 4557 //===----------------------------------------------------------------------===// 4558 4559 namespace { 4560 /// This parser handles entities that are only valid at the top level of the 4561 /// file. 4562 class ModuleParser : public Parser { 4563 public: 4564 explicit ModuleParser(ParserState &state) : Parser(state) {} 4565 4566 ParseResult parseModule(ModuleOp module); 4567 4568 private: 4569 /// Parse an attribute alias declaration. 4570 ParseResult parseAttributeAliasDef(); 4571 4572 /// Parse an attribute alias declaration. 4573 ParseResult parseTypeAliasDef(); 4574 }; 4575 } // end anonymous namespace 4576 4577 /// Parses an attribute alias declaration. 4578 /// 4579 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 4580 /// 4581 ParseResult ModuleParser::parseAttributeAliasDef() { 4582 assert(getToken().is(Token::hash_identifier)); 4583 StringRef aliasName = getTokenSpelling().drop_front(); 4584 4585 // Check for redefinitions. 4586 if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) 4587 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 4588 4589 // Make sure this isn't invading the dialect attribute namespace. 4590 if (aliasName.contains('.')) 4591 return emitError("attribute names with a '.' are reserved for " 4592 "dialect-defined names"); 4593 4594 consumeToken(Token::hash_identifier); 4595 4596 // Parse the '='. 4597 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 4598 return failure(); 4599 4600 // Parse the attribute value. 4601 Attribute attr = parseAttribute(); 4602 if (!attr) 4603 return failure(); 4604 4605 getState().symbols.attributeAliasDefinitions[aliasName] = attr; 4606 return success(); 4607 } 4608 4609 /// Parse a type alias declaration. 4610 /// 4611 /// type-alias-def ::= '!' alias-name `=` 'type' type 4612 /// 4613 ParseResult ModuleParser::parseTypeAliasDef() { 4614 assert(getToken().is(Token::exclamation_identifier)); 4615 StringRef aliasName = getTokenSpelling().drop_front(); 4616 4617 // Check for redefinitions. 4618 if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) 4619 return emitError("redefinition of type alias id '" + aliasName + "'"); 4620 4621 // Make sure this isn't invading the dialect type namespace. 4622 if (aliasName.contains('.')) 4623 return emitError("type names with a '.' are reserved for " 4624 "dialect-defined names"); 4625 4626 consumeToken(Token::exclamation_identifier); 4627 4628 // Parse the '=' and 'type'. 4629 if (parseToken(Token::equal, "expected '=' in type alias definition") || 4630 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 4631 return failure(); 4632 4633 // Parse the type. 4634 Type aliasedType = parseType(); 4635 if (!aliasedType) 4636 return failure(); 4637 4638 // Register this alias with the parser state. 4639 getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); 4640 return success(); 4641 } 4642 4643 /// This is the top-level module parser. 4644 ParseResult ModuleParser::parseModule(ModuleOp module) { 4645 OperationParser opParser(getState(), module); 4646 4647 // Module itself is a name scope. 4648 opParser.pushSSANameScope(/*isIsolated=*/true); 4649 4650 while (true) { 4651 switch (getToken().getKind()) { 4652 default: 4653 // Parse a top-level operation. 4654 if (opParser.parseOperation()) 4655 return failure(); 4656 break; 4657 4658 // If we got to the end of the file, then we're done. 4659 case Token::eof: { 4660 if (opParser.finalize()) 4661 return failure(); 4662 4663 // Handle the case where the top level module was explicitly defined. 4664 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 4665 auto &operations = bodyBlocks.front().getOperations(); 4666 assert(!operations.empty() && "expected a valid module terminator"); 4667 4668 // Check that the first operation is a module, and it is the only 4669 // non-terminator operation. 4670 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 4671 if (nested && std::next(operations.begin(), 2) == operations.end()) { 4672 // Merge the data of the nested module operation into 'module'. 4673 module.setLoc(nested.getLoc()); 4674 module.setAttrs(nested.getOperation()->getAttrList()); 4675 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 4676 4677 // Erase the original module body. 4678 bodyBlocks.pop_front(); 4679 } 4680 4681 return opParser.popSSANameScope(); 4682 } 4683 4684 // If we got an error token, then the lexer already emitted an error, just 4685 // stop. Someday we could introduce error recovery if there was demand 4686 // for it. 4687 case Token::error: 4688 return failure(); 4689 4690 // Parse an attribute alias. 4691 case Token::hash_identifier: 4692 if (parseAttributeAliasDef()) 4693 return failure(); 4694 break; 4695 4696 // Parse a type alias. 4697 case Token::exclamation_identifier: 4698 if (parseTypeAliasDef()) 4699 return failure(); 4700 break; 4701 } 4702 } 4703 } 4704 4705 //===----------------------------------------------------------------------===// 4706 4707 /// This parses the file specified by the indicated SourceMgr and returns an 4708 /// MLIR module if it was valid. If not, it emits diagnostics and returns 4709 /// null. 4710 OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 4711 MLIRContext *context) { 4712 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 4713 4714 // This is the result module we are parsing into. 4715 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 4716 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 4717 4718 SymbolState aliasState; 4719 ParserState state(sourceMgr, context, aliasState); 4720 if (ModuleParser(state).parseModule(*module)) 4721 return nullptr; 4722 4723 // Make sure the parse module has no other structural problems detected by 4724 // the verifier. 4725 if (failed(verify(*module))) 4726 return nullptr; 4727 4728 return module; 4729 } 4730 4731 /// This parses the file specified by the indicated filename and returns an 4732 /// MLIR module if it was valid. If not, the error message is emitted through 4733 /// the error handler registered in the context, and a null pointer is returned. 4734 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4735 MLIRContext *context) { 4736 llvm::SourceMgr sourceMgr; 4737 return parseSourceFile(filename, sourceMgr, context); 4738 } 4739 4740 /// This parses the file specified by the indicated filename using the provided 4741 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 4742 /// message is emitted through the error handler registered in the context, and 4743 /// a null pointer is returned. 4744 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4745 llvm::SourceMgr &sourceMgr, 4746 MLIRContext *context) { 4747 if (sourceMgr.getNumBuffers() != 0) { 4748 // TODO(b/136086478): Extend to support multiple buffers. 4749 emitError(mlir::UnknownLoc::get(context), 4750 "only main buffer parsed at the moment"); 4751 return nullptr; 4752 } 4753 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 4754 if (std::error_code error = file_or_err.getError()) { 4755 emitError(mlir::UnknownLoc::get(context), 4756 "could not open input file " + filename); 4757 return nullptr; 4758 } 4759 4760 // Load the MLIR module. 4761 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 4762 return parseSourceFile(sourceMgr, context); 4763 } 4764 4765 /// This parses the program string to a MLIR module if it was valid. If not, 4766 /// it emits diagnostics and returns null. 4767 OwningModuleRef mlir::parseSourceString(StringRef moduleStr, 4768 MLIRContext *context) { 4769 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 4770 if (!memBuffer) 4771 return nullptr; 4772 4773 SourceMgr sourceMgr; 4774 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 4775 return parseSourceFile(sourceMgr, context); 4776 } 4777 4778 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 4779 /// parsing failed, nullptr is returned. The number of bytes read from the input 4780 /// string is returned in 'numRead'. 4781 template <typename T, typename ParserFn> 4782 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, 4783 ParserFn &&parserFn) { 4784 SymbolState aliasState; 4785 return parseSymbol<T>( 4786 inputStr, context, aliasState, 4787 [&](Parser &parser) { 4788 SourceMgrDiagnosticHandler handler( 4789 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), 4790 parser.getContext()); 4791 return parserFn(parser); 4792 }, 4793 &numRead); 4794 } 4795 4796 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { 4797 size_t numRead = 0; 4798 return parseAttribute(attrStr, context, numRead); 4799 } 4800 Attribute mlir::parseAttribute(StringRef attrStr, Type type) { 4801 size_t numRead = 0; 4802 return parseAttribute(attrStr, type, numRead); 4803 } 4804 4805 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, 4806 size_t &numRead) { 4807 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { 4808 return parser.parseAttribute(); 4809 }); 4810 } 4811 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { 4812 return parseSymbol<Attribute>( 4813 attrStr, type.getContext(), numRead, 4814 [type](Parser &parser) { return parser.parseAttribute(type); }); 4815 } 4816 4817 Type mlir::parseType(StringRef typeStr, MLIRContext *context) { 4818 size_t numRead = 0; 4819 return parseType(typeStr, context, numRead); 4820 } 4821 4822 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { 4823 return parseSymbol<Type>(typeStr, context, numRead, 4824 [](Parser &parser) { return parser.parseType(); }); 4825 } 4826