1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR textual form. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Parser.h" 14 #include "Lexer.h" 15 #include "mlir/Analysis/Verifier.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/IR/Location.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/Module.h" 26 #include "mlir/IR/OpImplementation.h" 27 #include "mlir/IR/StandardTypes.h" 28 #include "mlir/Support/STLExtras.h" 29 #include "llvm/ADT/APInt.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/StringExtras.h" 32 #include "llvm/ADT/StringSet.h" 33 #include "llvm/ADT/bit.h" 34 #include "llvm/Support/PrettyStackTrace.h" 35 #include "llvm/Support/SMLoc.h" 36 #include "llvm/Support/SourceMgr.h" 37 #include <algorithm> 38 using namespace mlir; 39 using llvm::MemoryBuffer; 40 using llvm::SMLoc; 41 using llvm::SourceMgr; 42 43 namespace { 44 class Parser; 45 46 //===----------------------------------------------------------------------===// 47 // SymbolState 48 //===----------------------------------------------------------------------===// 49 50 /// This class contains record of any parsed top-level symbols. 51 struct SymbolState { 52 // A map from attribute alias identifier to Attribute. 53 llvm::StringMap<Attribute> attributeAliasDefinitions; 54 55 // A map from type alias identifier to Type. 56 llvm::StringMap<Type> typeAliasDefinitions; 57 58 /// A set of locations into the main parser memory buffer for each of the 59 /// active nested parsers. Given that some nested parsers, i.e. custom dialect 60 /// parsers, operate on a temporary memory buffer, this provides an anchor 61 /// point for emitting diagnostics. 62 SmallVector<llvm::SMLoc, 1> nestedParserLocs; 63 64 /// The top-level lexer that contains the original memory buffer provided by 65 /// the user. This is used by nested parsers to get a properly encoded source 66 /// location. 67 Lexer *topLevelLexer = nullptr; 68 }; 69 70 //===----------------------------------------------------------------------===// 71 // ParserState 72 //===----------------------------------------------------------------------===// 73 74 /// This class refers to all of the state maintained globally by the parser, 75 /// such as the current lexer position etc. 76 struct ParserState { 77 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, 78 SymbolState &symbols) 79 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), 80 symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { 81 // Set the top level lexer for the symbol state if one doesn't exist. 82 if (!symbols.topLevelLexer) 83 symbols.topLevelLexer = &lex; 84 } 85 ~ParserState() { 86 // Reset the top level lexer if it refers the lexer in our state. 87 if (symbols.topLevelLexer == &lex) 88 symbols.topLevelLexer = nullptr; 89 } 90 ParserState(const ParserState &) = delete; 91 void operator=(const ParserState &) = delete; 92 93 /// The context we're parsing into. 94 MLIRContext *const context; 95 96 /// The lexer for the source file we're parsing. 97 Lexer lex; 98 99 /// This is the next token that hasn't been consumed yet. 100 Token curToken; 101 102 /// The current state for symbol parsing. 103 SymbolState &symbols; 104 105 /// The depth of this parser in the nested parsing stack. 106 size_t parserDepth; 107 }; 108 109 //===----------------------------------------------------------------------===// 110 // Parser 111 //===----------------------------------------------------------------------===// 112 113 /// This class implement support for parsing global entities like types and 114 /// shared entities like SSA names. It is intended to be subclassed by 115 /// specialized subparsers that include state, e.g. when a local symbol table. 116 class Parser { 117 public: 118 Builder builder; 119 120 Parser(ParserState &state) : builder(state.context), state(state) {} 121 122 // Helper methods to get stuff from the parser-global state. 123 ParserState &getState() const { return state; } 124 MLIRContext *getContext() const { return state.context; } 125 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 126 127 /// Parse a comma-separated list of elements up until the specified end token. 128 ParseResult 129 parseCommaSeparatedListUntil(Token::Kind rightToken, 130 const std::function<ParseResult()> &parseElement, 131 bool allowEmptyList = true); 132 133 /// Parse a comma separated list of elements that must have at least one entry 134 /// in it. 135 ParseResult 136 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 137 138 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 139 140 // We have two forms of parsing methods - those that return a non-null 141 // pointer on success, and those that return a ParseResult to indicate whether 142 // they returned a failure. The second class fills in by-reference arguments 143 // as the results of their action. 144 145 //===--------------------------------------------------------------------===// 146 // Error Handling 147 //===--------------------------------------------------------------------===// 148 149 /// Emit an error and return failure. 150 InFlightDiagnostic emitError(const Twine &message = {}) { 151 return emitError(state.curToken.getLoc(), message); 152 } 153 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 154 155 /// Encode the specified source location information into an attribute for 156 /// attachment to the IR. 157 Location getEncodedSourceLocation(llvm::SMLoc loc) { 158 // If there are no active nested parsers, we can get the encoded source 159 // location directly. 160 if (state.parserDepth == 0) 161 return state.lex.getEncodedSourceLocation(loc); 162 // Otherwise, we need to re-encode it to point to the top level buffer. 163 return state.symbols.topLevelLexer->getEncodedSourceLocation( 164 remapLocationToTopLevelBuffer(loc)); 165 } 166 167 /// Remaps the given SMLoc to the top level lexer of the parser. This is used 168 /// to adjust locations of potentially nested parsers to ensure that they can 169 /// be emitted properly as diagnostics. 170 llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { 171 // If there are no active nested parsers, we can return location directly. 172 SymbolState &symbols = state.symbols; 173 if (state.parserDepth == 0) 174 return loc; 175 assert(symbols.topLevelLexer && "expected valid top-level lexer"); 176 177 // Otherwise, we need to remap the location to the main parser. This is 178 // simply offseting the location onto the location of the last nested 179 // parser. 180 size_t offset = loc.getPointer() - state.lex.getBufferBegin(); 181 auto *rawLoc = 182 symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; 183 return llvm::SMLoc::getFromPointer(rawLoc); 184 } 185 186 //===--------------------------------------------------------------------===// 187 // Token Parsing 188 //===--------------------------------------------------------------------===// 189 190 /// Return the current token the parser is inspecting. 191 const Token &getToken() const { return state.curToken; } 192 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 193 194 /// If the current token has the specified kind, consume it and return true. 195 /// If not, return false. 196 bool consumeIf(Token::Kind kind) { 197 if (state.curToken.isNot(kind)) 198 return false; 199 consumeToken(kind); 200 return true; 201 } 202 203 /// Advance the current lexer onto the next token. 204 void consumeToken() { 205 assert(state.curToken.isNot(Token::eof, Token::error) && 206 "shouldn't advance past EOF or errors"); 207 state.curToken = state.lex.lexToken(); 208 } 209 210 /// Advance the current lexer onto the next token, asserting what the expected 211 /// current token is. This is preferred to the above method because it leads 212 /// to more self-documenting code with better checking. 213 void consumeToken(Token::Kind kind) { 214 assert(state.curToken.is(kind) && "consumed an unexpected token"); 215 consumeToken(); 216 } 217 218 /// Consume the specified token if present and return success. On failure, 219 /// output a diagnostic and return failure. 220 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 221 222 //===--------------------------------------------------------------------===// 223 // Type Parsing 224 //===--------------------------------------------------------------------===// 225 226 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 227 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 228 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 229 230 /// Optionally parse a type. 231 OptionalParseResult parseOptionalType(Type &type); 232 233 /// Parse an arbitrary type. 234 Type parseType(); 235 236 /// Parse a complex type. 237 Type parseComplexType(); 238 239 /// Parse an extended type. 240 Type parseExtendedType(); 241 242 /// Parse a function type. 243 Type parseFunctionType(); 244 245 /// Parse a memref type. 246 Type parseMemRefType(); 247 248 /// Parse a non function type. 249 Type parseNonFunctionType(); 250 251 /// Parse a tensor type. 252 Type parseTensorType(); 253 254 /// Parse a tuple type. 255 Type parseTupleType(); 256 257 /// Parse a vector type. 258 VectorType parseVectorType(); 259 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 260 bool allowDynamic = true); 261 ParseResult parseXInDimensionList(); 262 263 /// Parse strided layout specification. 264 ParseResult parseStridedLayout(int64_t &offset, 265 SmallVectorImpl<int64_t> &strides); 266 267 // Parse a brace-delimiter list of comma-separated integers with `?` as an 268 // unknown marker. 269 ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions); 270 271 //===--------------------------------------------------------------------===// 272 // Attribute Parsing 273 //===--------------------------------------------------------------------===// 274 275 /// Parse an arbitrary attribute with an optional type. 276 Attribute parseAttribute(Type type = {}); 277 278 /// Parse an attribute dictionary. 279 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 280 281 /// Parse an extended attribute. 282 Attribute parseExtendedAttr(Type type); 283 284 /// Parse a float attribute. 285 Attribute parseFloatAttr(Type type, bool isNegative); 286 287 /// Parse a decimal or a hexadecimal literal, which can be either an integer 288 /// or a float attribute. 289 Attribute parseDecOrHexAttr(Type type, bool isNegative); 290 291 /// Parse an opaque elements attribute. 292 Attribute parseOpaqueElementsAttr(Type attrType); 293 294 /// Parse a dense elements attribute. 295 Attribute parseDenseElementsAttr(Type attrType); 296 ShapedType parseElementsLiteralType(Type type); 297 298 /// Parse a sparse elements attribute. 299 Attribute parseSparseElementsAttr(Type attrType); 300 301 //===--------------------------------------------------------------------===// 302 // Location Parsing 303 //===--------------------------------------------------------------------===// 304 305 /// Parse an inline location. 306 ParseResult parseLocation(LocationAttr &loc); 307 308 /// Parse a raw location instance. 309 ParseResult parseLocationInstance(LocationAttr &loc); 310 311 /// Parse a callsite location instance. 312 ParseResult parseCallSiteLocation(LocationAttr &loc); 313 314 /// Parse a fused location instance. 315 ParseResult parseFusedLocation(LocationAttr &loc); 316 317 /// Parse a name or FileLineCol location instance. 318 ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); 319 320 /// Parse an optional trailing location. 321 /// 322 /// trailing-location ::= (`loc` `(` location `)`)? 323 /// 324 ParseResult parseOptionalTrailingLocation(Location &loc) { 325 // If there is a 'loc' we parse a trailing location. 326 if (!getToken().is(Token::kw_loc)) 327 return success(); 328 329 // Parse the location. 330 LocationAttr directLoc; 331 if (parseLocation(directLoc)) 332 return failure(); 333 loc = directLoc; 334 return success(); 335 } 336 337 //===--------------------------------------------------------------------===// 338 // Affine Parsing 339 //===--------------------------------------------------------------------===// 340 341 /// Parse a reference to either an affine map, or an integer set. 342 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 343 IntegerSet &set); 344 ParseResult parseAffineMapReference(AffineMap &map); 345 ParseResult parseIntegerSetReference(IntegerSet &set); 346 347 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 348 ParseResult 349 parseAffineMapOfSSAIds(AffineMap &map, 350 function_ref<ParseResult(bool)> parseElement, 351 OpAsmParser::Delimiter delimiter); 352 353 private: 354 /// The Parser is subclassed and reinstantiated. Do not add additional 355 /// non-trivial state here, add it to the ParserState class. 356 ParserState &state; 357 }; 358 } // end anonymous namespace 359 360 //===----------------------------------------------------------------------===// 361 // Helper methods. 362 //===----------------------------------------------------------------------===// 363 364 /// Parse a comma separated list of elements that must have at least one entry 365 /// in it. 366 ParseResult Parser::parseCommaSeparatedList( 367 const std::function<ParseResult()> &parseElement) { 368 // Non-empty case starts with an element. 369 if (parseElement()) 370 return failure(); 371 372 // Otherwise we have a list of comma separated elements. 373 while (consumeIf(Token::comma)) { 374 if (parseElement()) 375 return failure(); 376 } 377 return success(); 378 } 379 380 /// Parse a comma-separated list of elements, terminated with an arbitrary 381 /// token. This allows empty lists if allowEmptyList is true. 382 /// 383 /// abstract-list ::= rightToken // if allowEmptyList == true 384 /// abstract-list ::= element (',' element)* rightToken 385 /// 386 ParseResult Parser::parseCommaSeparatedListUntil( 387 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 388 bool allowEmptyList) { 389 // Handle the empty case. 390 if (getToken().is(rightToken)) { 391 if (!allowEmptyList) 392 return emitError("expected list element"); 393 consumeToken(rightToken); 394 return success(); 395 } 396 397 if (parseCommaSeparatedList(parseElement) || 398 parseToken(rightToken, "expected ',' or '" + 399 Token::getTokenSpelling(rightToken) + "'")) 400 return failure(); 401 402 return success(); 403 } 404 405 //===----------------------------------------------------------------------===// 406 // DialectAsmParser 407 //===----------------------------------------------------------------------===// 408 409 namespace { 410 /// This class provides the main implementation of the DialectAsmParser that 411 /// allows for dialects to parse attributes and types. This allows for dialect 412 /// hooking into the main MLIR parsing logic. 413 class CustomDialectAsmParser : public DialectAsmParser { 414 public: 415 CustomDialectAsmParser(StringRef fullSpec, Parser &parser) 416 : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), 417 parser(parser) {} 418 ~CustomDialectAsmParser() override {} 419 420 /// Emit a diagnostic at the specified location and return failure. 421 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 422 return parser.emitError(loc, message); 423 } 424 425 /// Return a builder which provides useful access to MLIRContext, global 426 /// objects like types and attributes. 427 Builder &getBuilder() const override { return parser.builder; } 428 429 /// Get the location of the next token and store it into the argument. This 430 /// always succeeds. 431 llvm::SMLoc getCurrentLocation() override { 432 return parser.getToken().getLoc(); 433 } 434 435 /// Return the location of the original name token. 436 llvm::SMLoc getNameLoc() const override { return nameLoc; } 437 438 /// Re-encode the given source location as an MLIR location and return it. 439 Location getEncodedSourceLoc(llvm::SMLoc loc) override { 440 return parser.getEncodedSourceLocation(loc); 441 } 442 443 /// Returns the full specification of the symbol being parsed. This allows 444 /// for using a separate parser if necessary. 445 StringRef getFullSymbolSpec() const override { return fullSpec; } 446 447 /// Parse a floating point value from the stream. 448 ParseResult parseFloat(double &result) override { 449 bool negative = parser.consumeIf(Token::minus); 450 Token curTok = parser.getToken(); 451 452 // Check for a floating point value. 453 if (curTok.is(Token::floatliteral)) { 454 auto val = curTok.getFloatingPointValue(); 455 if (!val.hasValue()) 456 return emitError(curTok.getLoc(), "floating point value too large"); 457 parser.consumeToken(Token::floatliteral); 458 result = negative ? -*val : *val; 459 return success(); 460 } 461 462 // TODO(riverriddle) support hex floating point values. 463 return emitError(getCurrentLocation(), "expected floating point literal"); 464 } 465 466 /// Parse an optional integer value from the stream. 467 OptionalParseResult parseOptionalInteger(uint64_t &result) override { 468 Token curToken = parser.getToken(); 469 if (curToken.isNot(Token::integer, Token::minus)) 470 return llvm::None; 471 472 bool negative = parser.consumeIf(Token::minus); 473 Token curTok = parser.getToken(); 474 if (parser.parseToken(Token::integer, "expected integer value")) 475 return failure(); 476 477 auto val = curTok.getUInt64IntegerValue(); 478 if (!val) 479 return emitError(curTok.getLoc(), "integer value too large"); 480 result = negative ? -*val : *val; 481 return success(); 482 } 483 484 //===--------------------------------------------------------------------===// 485 // Token Parsing 486 //===--------------------------------------------------------------------===// 487 488 /// Parse a `->` token. 489 ParseResult parseArrow() override { 490 return parser.parseToken(Token::arrow, "expected '->'"); 491 } 492 493 /// Parses a `->` if present. 494 ParseResult parseOptionalArrow() override { 495 return success(parser.consumeIf(Token::arrow)); 496 } 497 498 /// Parse a '{' token. 499 ParseResult parseLBrace() override { 500 return parser.parseToken(Token::l_brace, "expected '{'"); 501 } 502 503 /// Parse a '{' token if present 504 ParseResult parseOptionalLBrace() override { 505 return success(parser.consumeIf(Token::l_brace)); 506 } 507 508 /// Parse a `}` token. 509 ParseResult parseRBrace() override { 510 return parser.parseToken(Token::r_brace, "expected '}'"); 511 } 512 513 /// Parse a `}` token if present 514 ParseResult parseOptionalRBrace() override { 515 return success(parser.consumeIf(Token::r_brace)); 516 } 517 518 /// Parse a `:` token. 519 ParseResult parseColon() override { 520 return parser.parseToken(Token::colon, "expected ':'"); 521 } 522 523 /// Parse a `:` token if present. 524 ParseResult parseOptionalColon() override { 525 return success(parser.consumeIf(Token::colon)); 526 } 527 528 /// Parse a `,` token. 529 ParseResult parseComma() override { 530 return parser.parseToken(Token::comma, "expected ','"); 531 } 532 533 /// Parse a `,` token if present. 534 ParseResult parseOptionalComma() override { 535 return success(parser.consumeIf(Token::comma)); 536 } 537 538 /// Parses a `...` if present. 539 ParseResult parseOptionalEllipsis() override { 540 return success(parser.consumeIf(Token::ellipsis)); 541 } 542 543 /// Parse a `=` token. 544 ParseResult parseEqual() override { 545 return parser.parseToken(Token::equal, "expected '='"); 546 } 547 548 /// Parse a '<' token. 549 ParseResult parseLess() override { 550 return parser.parseToken(Token::less, "expected '<'"); 551 } 552 553 /// Parse a `<` token if present. 554 ParseResult parseOptionalLess() override { 555 return success(parser.consumeIf(Token::less)); 556 } 557 558 /// Parse a '>' token. 559 ParseResult parseGreater() override { 560 return parser.parseToken(Token::greater, "expected '>'"); 561 } 562 563 /// Parse a `>` token if present. 564 ParseResult parseOptionalGreater() override { 565 return success(parser.consumeIf(Token::greater)); 566 } 567 568 /// Parse a `(` token. 569 ParseResult parseLParen() override { 570 return parser.parseToken(Token::l_paren, "expected '('"); 571 } 572 573 /// Parses a '(' if present. 574 ParseResult parseOptionalLParen() override { 575 return success(parser.consumeIf(Token::l_paren)); 576 } 577 578 /// Parse a `)` token. 579 ParseResult parseRParen() override { 580 return parser.parseToken(Token::r_paren, "expected ')'"); 581 } 582 583 /// Parses a ')' if present. 584 ParseResult parseOptionalRParen() override { 585 return success(parser.consumeIf(Token::r_paren)); 586 } 587 588 /// Parse a `[` token. 589 ParseResult parseLSquare() override { 590 return parser.parseToken(Token::l_square, "expected '['"); 591 } 592 593 /// Parses a '[' if present. 594 ParseResult parseOptionalLSquare() override { 595 return success(parser.consumeIf(Token::l_square)); 596 } 597 598 /// Parse a `]` token. 599 ParseResult parseRSquare() override { 600 return parser.parseToken(Token::r_square, "expected ']'"); 601 } 602 603 /// Parses a ']' if present. 604 ParseResult parseOptionalRSquare() override { 605 return success(parser.consumeIf(Token::r_square)); 606 } 607 608 /// Parses a '?' if present. 609 ParseResult parseOptionalQuestion() override { 610 return success(parser.consumeIf(Token::question)); 611 } 612 613 /// Parses a '*' if present. 614 ParseResult parseOptionalStar() override { 615 return success(parser.consumeIf(Token::star)); 616 } 617 618 /// Returns if the current token corresponds to a keyword. 619 bool isCurrentTokenAKeyword() const { 620 return parser.getToken().is(Token::bare_identifier) || 621 parser.getToken().isKeyword(); 622 } 623 624 /// Parse the given keyword if present. 625 ParseResult parseOptionalKeyword(StringRef keyword) override { 626 // Check that the current token has the same spelling. 627 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 628 return failure(); 629 parser.consumeToken(); 630 return success(); 631 } 632 633 /// Parse a keyword, if present, into 'keyword'. 634 ParseResult parseOptionalKeyword(StringRef *keyword) override { 635 // Check that the current token is a keyword. 636 if (!isCurrentTokenAKeyword()) 637 return failure(); 638 639 *keyword = parser.getTokenSpelling(); 640 parser.consumeToken(); 641 return success(); 642 } 643 644 //===--------------------------------------------------------------------===// 645 // Attribute Parsing 646 //===--------------------------------------------------------------------===// 647 648 /// Parse an arbitrary attribute and return it in result. 649 ParseResult parseAttribute(Attribute &result, Type type) override { 650 result = parser.parseAttribute(type); 651 return success(static_cast<bool>(result)); 652 } 653 654 /// Parse an affine map instance into 'map'. 655 ParseResult parseAffineMap(AffineMap &map) override { 656 return parser.parseAffineMapReference(map); 657 } 658 659 /// Parse an integer set instance into 'set'. 660 ParseResult printIntegerSet(IntegerSet &set) override { 661 return parser.parseIntegerSetReference(set); 662 } 663 664 //===--------------------------------------------------------------------===// 665 // Type Parsing 666 //===--------------------------------------------------------------------===// 667 668 ParseResult parseType(Type &result) override { 669 result = parser.parseType(); 670 return success(static_cast<bool>(result)); 671 } 672 673 ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, 674 bool allowDynamic) override { 675 return parser.parseDimensionListRanked(dimensions, allowDynamic); 676 } 677 678 private: 679 /// The full symbol specification. 680 StringRef fullSpec; 681 682 /// The source location of the dialect symbol. 683 SMLoc nameLoc; 684 685 /// The main parser. 686 Parser &parser; 687 }; 688 } // namespace 689 690 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 691 /// and may be recursive. Return with the 'prettyName' StringRef encompassing 692 /// the entire pretty name. 693 /// 694 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 695 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 696 /// | '(' pretty-dialect-sym-contents+ ')' 697 /// | '[' pretty-dialect-sym-contents+ ']' 698 /// | '{' pretty-dialect-sym-contents+ '}' 699 /// | '[^[<({>\])}\0]+' 700 /// 701 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 702 // Pretty symbol names are a relatively unstructured format that contains a 703 // series of properly nested punctuation, with anything else in the middle. 704 // Scan ahead to find it and consume it if successful, otherwise emit an 705 // error. 706 auto *curPtr = getTokenSpelling().data(); 707 708 SmallVector<char, 8> nestedPunctuation; 709 710 // Scan over the nested punctuation, bailing out on error and consuming until 711 // we find the end. We know that we're currently looking at the '<', so we 712 // can go until we find the matching '>' character. 713 assert(*curPtr == '<'); 714 do { 715 char c = *curPtr++; 716 switch (c) { 717 case '\0': 718 // This also handles the EOF case. 719 return emitError("unexpected nul or EOF in pretty dialect name"); 720 case '<': 721 case '[': 722 case '(': 723 case '{': 724 nestedPunctuation.push_back(c); 725 continue; 726 727 case '-': 728 // The sequence `->` is treated as special token. 729 if (*curPtr == '>') 730 ++curPtr; 731 continue; 732 733 case '>': 734 if (nestedPunctuation.pop_back_val() != '<') 735 return emitError("unbalanced '>' character in pretty dialect name"); 736 break; 737 case ']': 738 if (nestedPunctuation.pop_back_val() != '[') 739 return emitError("unbalanced ']' character in pretty dialect name"); 740 break; 741 case ')': 742 if (nestedPunctuation.pop_back_val() != '(') 743 return emitError("unbalanced ')' character in pretty dialect name"); 744 break; 745 case '}': 746 if (nestedPunctuation.pop_back_val() != '{') 747 return emitError("unbalanced '}' character in pretty dialect name"); 748 break; 749 750 default: 751 continue; 752 } 753 } while (!nestedPunctuation.empty()); 754 755 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 756 // consuming all this stuff, and return. 757 state.lex.resetPointer(curPtr); 758 759 unsigned length = curPtr - prettyName.begin(); 760 prettyName = StringRef(prettyName.begin(), length); 761 consumeToken(); 762 return success(); 763 } 764 765 /// Parse an extended dialect symbol. 766 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 767 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 768 SymbolAliasMap &aliases, 769 CreateFn &&createSymbol) { 770 // Parse the dialect namespace. 771 StringRef identifier = p.getTokenSpelling().drop_front(); 772 auto loc = p.getToken().getLoc(); 773 p.consumeToken(identifierTok); 774 775 // If there is no '<' token following this, and if the typename contains no 776 // dot, then we are parsing a symbol alias. 777 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 778 // Check for an alias for this type. 779 auto aliasIt = aliases.find(identifier); 780 if (aliasIt == aliases.end()) 781 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 782 nullptr); 783 return aliasIt->second; 784 } 785 786 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 787 // a dot, then this is the "pretty" form. If not, it is the verbose form that 788 // looks like <"...">. 789 std::string symbolData; 790 auto dialectName = identifier; 791 792 // Handle the verbose form, where "identifier" is a simple dialect name. 793 if (!identifier.contains('.')) { 794 // Consume the '<'. 795 if (p.parseToken(Token::less, "expected '<' in dialect type")) 796 return nullptr; 797 798 // Parse the symbol specific data. 799 if (p.getToken().isNot(Token::string)) 800 return (p.emitError("expected string literal data in dialect symbol"), 801 nullptr); 802 symbolData = p.getToken().getStringValue(); 803 loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); 804 p.consumeToken(Token::string); 805 806 // Consume the '>'. 807 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 808 return nullptr; 809 } else { 810 // Ok, the dialect name is the part of the identifier before the dot, the 811 // part after the dot is the dialect's symbol, or the start thereof. 812 auto dotHalves = identifier.split('.'); 813 dialectName = dotHalves.first; 814 auto prettyName = dotHalves.second; 815 loc = llvm::SMLoc::getFromPointer(prettyName.data()); 816 817 // If the dialect's symbol is followed immediately by a <, then lex the body 818 // of it into prettyName. 819 if (p.getToken().is(Token::less) && 820 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 821 if (p.parsePrettyDialectSymbolName(prettyName)) 822 return nullptr; 823 } 824 825 symbolData = prettyName.str(); 826 } 827 828 // Record the name location of the type remapped to the top level buffer. 829 llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); 830 p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); 831 832 // Call into the provided symbol construction function. 833 Symbol sym = createSymbol(dialectName, symbolData, loc); 834 835 // Pop the last parser location. 836 p.getState().symbols.nestedParserLocs.pop_back(); 837 return sym; 838 } 839 840 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 841 /// parsing failed, nullptr is returned. The number of bytes read from the input 842 /// string is returned in 'numRead'. 843 template <typename T, typename ParserFn> 844 static T parseSymbol(StringRef inputStr, MLIRContext *context, 845 SymbolState &symbolState, ParserFn &&parserFn, 846 size_t *numRead = nullptr) { 847 SourceMgr sourceMgr; 848 auto memBuffer = MemoryBuffer::getMemBuffer( 849 inputStr, /*BufferName=*/"<mlir_parser_buffer>", 850 /*RequiresNullTerminator=*/false); 851 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 852 ParserState state(sourceMgr, context, symbolState); 853 Parser parser(state); 854 855 Token startTok = parser.getToken(); 856 T symbol = parserFn(parser); 857 if (!symbol) 858 return T(); 859 860 // If 'numRead' is valid, then provide the number of bytes that were read. 861 Token endTok = parser.getToken(); 862 if (numRead) { 863 *numRead = static_cast<size_t>(endTok.getLoc().getPointer() - 864 startTok.getLoc().getPointer()); 865 866 // Otherwise, ensure that all of the tokens were parsed. 867 } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { 868 parser.emitError(endTok.getLoc(), "encountered unexpected token"); 869 return T(); 870 } 871 return symbol; 872 } 873 874 //===----------------------------------------------------------------------===// 875 // Error Handling 876 //===----------------------------------------------------------------------===// 877 878 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 879 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 880 881 // If we hit a parse error in response to a lexer error, then the lexer 882 // already reported the error. 883 if (getToken().is(Token::error)) 884 diag.abandon(); 885 return diag; 886 } 887 888 //===----------------------------------------------------------------------===// 889 // Token Parsing 890 //===----------------------------------------------------------------------===// 891 892 /// Consume the specified token if present and return success. On failure, 893 /// output a diagnostic and return failure. 894 ParseResult Parser::parseToken(Token::Kind expectedToken, 895 const Twine &message) { 896 if (consumeIf(expectedToken)) 897 return success(); 898 return emitError(message); 899 } 900 901 //===----------------------------------------------------------------------===// 902 // Type Parsing 903 //===----------------------------------------------------------------------===// 904 905 /// Optionally parse a type. 906 OptionalParseResult Parser::parseOptionalType(Type &type) { 907 // There are many different starting tokens for a type, check them here. 908 switch (getToken().getKind()) { 909 case Token::l_paren: 910 case Token::kw_memref: 911 case Token::kw_tensor: 912 case Token::kw_complex: 913 case Token::kw_tuple: 914 case Token::kw_vector: 915 case Token::inttype: 916 case Token::kw_bf16: 917 case Token::kw_f16: 918 case Token::kw_f32: 919 case Token::kw_f64: 920 case Token::kw_index: 921 case Token::kw_none: 922 case Token::exclamation_identifier: 923 return failure(!(type = parseType())); 924 925 default: 926 return llvm::None; 927 } 928 } 929 930 /// Parse an arbitrary type. 931 /// 932 /// type ::= function-type 933 /// | non-function-type 934 /// 935 Type Parser::parseType() { 936 if (getToken().is(Token::l_paren)) 937 return parseFunctionType(); 938 return parseNonFunctionType(); 939 } 940 941 /// Parse a function result type. 942 /// 943 /// function-result-type ::= type-list-parens 944 /// | non-function-type 945 /// 946 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 947 if (getToken().is(Token::l_paren)) 948 return parseTypeListParens(elements); 949 950 Type t = parseNonFunctionType(); 951 if (!t) 952 return failure(); 953 elements.push_back(t); 954 return success(); 955 } 956 957 /// Parse a list of types without an enclosing parenthesis. The list must have 958 /// at least one member. 959 /// 960 /// type-list-no-parens ::= type (`,` type)* 961 /// 962 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 963 auto parseElt = [&]() -> ParseResult { 964 auto elt = parseType(); 965 elements.push_back(elt); 966 return elt ? success() : failure(); 967 }; 968 969 return parseCommaSeparatedList(parseElt); 970 } 971 972 /// Parse a parenthesized list of types. 973 /// 974 /// type-list-parens ::= `(` `)` 975 /// | `(` type-list-no-parens `)` 976 /// 977 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 978 if (parseToken(Token::l_paren, "expected '('")) 979 return failure(); 980 981 // Handle empty lists. 982 if (getToken().is(Token::r_paren)) 983 return consumeToken(), success(); 984 985 if (parseTypeListNoParens(elements) || 986 parseToken(Token::r_paren, "expected ')'")) 987 return failure(); 988 return success(); 989 } 990 991 /// Parse a complex type. 992 /// 993 /// complex-type ::= `complex` `<` type `>` 994 /// 995 Type Parser::parseComplexType() { 996 consumeToken(Token::kw_complex); 997 998 // Parse the '<'. 999 if (parseToken(Token::less, "expected '<' in complex type")) 1000 return nullptr; 1001 1002 llvm::SMLoc elementTypeLoc = getToken().getLoc(); 1003 auto elementType = parseType(); 1004 if (!elementType || 1005 parseToken(Token::greater, "expected '>' in complex type")) 1006 return nullptr; 1007 if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) 1008 return emitError(elementTypeLoc, "invalid element type for complex"), 1009 nullptr; 1010 1011 return ComplexType::get(elementType); 1012 } 1013 1014 /// Parse an extended type. 1015 /// 1016 /// extended-type ::= (dialect-type | type-alias) 1017 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 1018 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 1019 /// type-alias ::= `!` alias-name 1020 /// 1021 Type Parser::parseExtendedType() { 1022 return parseExtendedSymbol<Type>( 1023 *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, 1024 [&](StringRef dialectName, StringRef symbolData, 1025 llvm::SMLoc loc) -> Type { 1026 // If we found a registered dialect, then ask it to parse the type. 1027 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1028 return parseSymbol<Type>( 1029 symbolData, state.context, state.symbols, [&](Parser &parser) { 1030 CustomDialectAsmParser customParser(symbolData, parser); 1031 return dialect->parseType(customParser); 1032 }); 1033 } 1034 1035 // Otherwise, form a new opaque type. 1036 return OpaqueType::getChecked( 1037 Identifier::get(dialectName, state.context), symbolData, 1038 state.context, getEncodedSourceLocation(loc)); 1039 }); 1040 } 1041 1042 /// Parse a function type. 1043 /// 1044 /// function-type ::= type-list-parens `->` function-result-type 1045 /// 1046 Type Parser::parseFunctionType() { 1047 assert(getToken().is(Token::l_paren)); 1048 1049 SmallVector<Type, 4> arguments, results; 1050 if (parseTypeListParens(arguments) || 1051 parseToken(Token::arrow, "expected '->' in function type") || 1052 parseFunctionResultTypes(results)) 1053 return nullptr; 1054 1055 return builder.getFunctionType(arguments, results); 1056 } 1057 1058 /// Parse the offset and strides from a strided layout specification. 1059 /// 1060 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 1061 /// 1062 ParseResult Parser::parseStridedLayout(int64_t &offset, 1063 SmallVectorImpl<int64_t> &strides) { 1064 // Parse offset. 1065 consumeToken(Token::kw_offset); 1066 if (!consumeIf(Token::colon)) 1067 return emitError("expected colon after `offset` keyword"); 1068 auto maybeOffset = getToken().getUnsignedIntegerValue(); 1069 bool question = getToken().is(Token::question); 1070 if (!maybeOffset && !question) 1071 return emitError("invalid offset"); 1072 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue()) 1073 : MemRefType::getDynamicStrideOrOffset(); 1074 consumeToken(); 1075 1076 if (!consumeIf(Token::comma)) 1077 return emitError("expected comma after offset value"); 1078 1079 // Parse stride list. 1080 if (!consumeIf(Token::kw_strides)) 1081 return emitError("expected `strides` keyword after offset specification"); 1082 if (!consumeIf(Token::colon)) 1083 return emitError("expected colon after `strides` keyword"); 1084 if (failed(parseStrideList(strides))) 1085 return emitError("invalid braces-enclosed stride list"); 1086 if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) 1087 return emitError("invalid memref stride"); 1088 1089 return success(); 1090 } 1091 1092 /// Parse a memref type. 1093 /// 1094 /// memref-type ::= ranked-memref-type | unranked-memref-type 1095 /// 1096 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type 1097 /// (`,` semi-affine-map-composition)? (`,` 1098 /// memory-space)? `>` 1099 /// 1100 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` 1101 /// 1102 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 1103 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 1104 /// 1105 Type Parser::parseMemRefType() { 1106 consumeToken(Token::kw_memref); 1107 1108 if (parseToken(Token::less, "expected '<' in memref type")) 1109 return nullptr; 1110 1111 bool isUnranked; 1112 SmallVector<int64_t, 4> dimensions; 1113 1114 if (consumeIf(Token::star)) { 1115 // This is an unranked memref type. 1116 isUnranked = true; 1117 if (parseXInDimensionList()) 1118 return nullptr; 1119 1120 } else { 1121 isUnranked = false; 1122 if (parseDimensionListRanked(dimensions)) 1123 return nullptr; 1124 } 1125 1126 // Parse the element type. 1127 auto typeLoc = getToken().getLoc(); 1128 auto elementType = parseType(); 1129 if (!elementType) 1130 return nullptr; 1131 1132 // Check that memref is formed from allowed types. 1133 if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() && 1134 !elementType.isa<ComplexType>()) 1135 return emitError(typeLoc, "invalid memref element type"), nullptr; 1136 1137 // Parse semi-affine-map-composition. 1138 SmallVector<AffineMap, 2> affineMapComposition; 1139 Optional<unsigned> memorySpace; 1140 unsigned numDims = dimensions.size(); 1141 1142 auto parseElt = [&]() -> ParseResult { 1143 // Check for the memory space. 1144 if (getToken().is(Token::integer)) { 1145 if (memorySpace) 1146 return emitError("multiple memory spaces specified in memref type"); 1147 memorySpace = getToken().getUnsignedIntegerValue(); 1148 if (!memorySpace.hasValue()) 1149 return emitError("invalid memory space in memref type"); 1150 consumeToken(Token::integer); 1151 return success(); 1152 } 1153 if (isUnranked) 1154 return emitError("cannot have affine map for unranked memref type"); 1155 if (memorySpace) 1156 return emitError("expected memory space to be last in memref type"); 1157 1158 AffineMap map; 1159 llvm::SMLoc mapLoc = getToken().getLoc(); 1160 if (getToken().is(Token::kw_offset)) { 1161 int64_t offset; 1162 SmallVector<int64_t, 4> strides; 1163 if (failed(parseStridedLayout(offset, strides))) 1164 return failure(); 1165 // Construct strided affine map. 1166 map = makeStridedLinearLayoutMap(strides, offset, state.context); 1167 } else { 1168 // Parse an affine map attribute. 1169 auto affineMap = parseAttribute(); 1170 if (!affineMap) 1171 return failure(); 1172 auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>(); 1173 if (!affineMapAttr) 1174 return emitError("expected affine map in memref type"); 1175 map = affineMapAttr.getValue(); 1176 } 1177 1178 if (map.getNumDims() != numDims) { 1179 size_t i = affineMapComposition.size(); 1180 return emitError(mapLoc, "memref affine map dimension mismatch between ") 1181 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 1182 << " and affine map" << i + 1 << ": " << numDims 1183 << " != " << map.getNumDims(); 1184 } 1185 numDims = map.getNumResults(); 1186 affineMapComposition.push_back(map); 1187 return success(); 1188 }; 1189 1190 // Parse a list of mappings and address space if present. 1191 if (!consumeIf(Token::greater)) { 1192 // Parse comma separated list of affine maps, followed by memory space. 1193 if (parseToken(Token::comma, "expected ',' or '>' in memref type") || 1194 parseCommaSeparatedListUntil(Token::greater, parseElt, 1195 /*allowEmptyList=*/false)) { 1196 return nullptr; 1197 } 1198 } 1199 1200 if (isUnranked) 1201 return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); 1202 1203 return MemRefType::get(dimensions, elementType, affineMapComposition, 1204 memorySpace.getValueOr(0)); 1205 } 1206 1207 /// Parse any type except the function type. 1208 /// 1209 /// non-function-type ::= integer-type 1210 /// | index-type 1211 /// | float-type 1212 /// | extended-type 1213 /// | vector-type 1214 /// | tensor-type 1215 /// | memref-type 1216 /// | complex-type 1217 /// | tuple-type 1218 /// | none-type 1219 /// 1220 /// index-type ::= `index` 1221 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 1222 /// none-type ::= `none` 1223 /// 1224 Type Parser::parseNonFunctionType() { 1225 switch (getToken().getKind()) { 1226 default: 1227 return (emitError("expected non-function type"), nullptr); 1228 case Token::kw_memref: 1229 return parseMemRefType(); 1230 case Token::kw_tensor: 1231 return parseTensorType(); 1232 case Token::kw_complex: 1233 return parseComplexType(); 1234 case Token::kw_tuple: 1235 return parseTupleType(); 1236 case Token::kw_vector: 1237 return parseVectorType(); 1238 // integer-type 1239 case Token::inttype: { 1240 auto width = getToken().getIntTypeBitwidth(); 1241 if (!width.hasValue()) 1242 return (emitError("invalid integer width"), nullptr); 1243 if (width.getValue() > IntegerType::kMaxWidth) { 1244 emitError(getToken().getLoc(), "integer bitwidth is limited to ") 1245 << IntegerType::kMaxWidth << " bits"; 1246 return nullptr; 1247 } 1248 1249 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; 1250 if (Optional<bool> signedness = getToken().getIntTypeSignedness()) 1251 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; 1252 1253 auto loc = getEncodedSourceLocation(getToken().getLoc()); 1254 consumeToken(Token::inttype); 1255 return IntegerType::getChecked(width.getValue(), signSemantics, loc); 1256 } 1257 1258 // float-type 1259 case Token::kw_bf16: 1260 consumeToken(Token::kw_bf16); 1261 return builder.getBF16Type(); 1262 case Token::kw_f16: 1263 consumeToken(Token::kw_f16); 1264 return builder.getF16Type(); 1265 case Token::kw_f32: 1266 consumeToken(Token::kw_f32); 1267 return builder.getF32Type(); 1268 case Token::kw_f64: 1269 consumeToken(Token::kw_f64); 1270 return builder.getF64Type(); 1271 1272 // index-type 1273 case Token::kw_index: 1274 consumeToken(Token::kw_index); 1275 return builder.getIndexType(); 1276 1277 // none-type 1278 case Token::kw_none: 1279 consumeToken(Token::kw_none); 1280 return builder.getNoneType(); 1281 1282 // extended type 1283 case Token::exclamation_identifier: 1284 return parseExtendedType(); 1285 } 1286 } 1287 1288 /// Parse a tensor type. 1289 /// 1290 /// tensor-type ::= `tensor` `<` dimension-list type `>` 1291 /// dimension-list ::= dimension-list-ranked | `*x` 1292 /// 1293 Type Parser::parseTensorType() { 1294 consumeToken(Token::kw_tensor); 1295 1296 if (parseToken(Token::less, "expected '<' in tensor type")) 1297 return nullptr; 1298 1299 bool isUnranked; 1300 SmallVector<int64_t, 4> dimensions; 1301 1302 if (consumeIf(Token::star)) { 1303 // This is an unranked tensor type. 1304 isUnranked = true; 1305 1306 if (parseXInDimensionList()) 1307 return nullptr; 1308 1309 } else { 1310 isUnranked = false; 1311 if (parseDimensionListRanked(dimensions)) 1312 return nullptr; 1313 } 1314 1315 // Parse the element type. 1316 auto elementTypeLoc = getToken().getLoc(); 1317 auto elementType = parseType(); 1318 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 1319 return nullptr; 1320 if (!TensorType::isValidElementType(elementType)) 1321 return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; 1322 1323 if (isUnranked) 1324 return UnrankedTensorType::get(elementType); 1325 return RankedTensorType::get(dimensions, elementType); 1326 } 1327 1328 /// Parse a tuple type. 1329 /// 1330 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 1331 /// 1332 Type Parser::parseTupleType() { 1333 consumeToken(Token::kw_tuple); 1334 1335 // Parse the '<'. 1336 if (parseToken(Token::less, "expected '<' in tuple type")) 1337 return nullptr; 1338 1339 // Check for an empty tuple by directly parsing '>'. 1340 if (consumeIf(Token::greater)) 1341 return TupleType::get(getContext()); 1342 1343 // Parse the element types and the '>'. 1344 SmallVector<Type, 4> types; 1345 if (parseTypeListNoParens(types) || 1346 parseToken(Token::greater, "expected '>' in tuple type")) 1347 return nullptr; 1348 1349 return TupleType::get(types, getContext()); 1350 } 1351 1352 /// Parse a vector type. 1353 /// 1354 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` 1355 /// non-empty-static-dimension-list ::= decimal-literal `x` 1356 /// static-dimension-list 1357 /// static-dimension-list ::= (decimal-literal `x`)* 1358 /// 1359 VectorType Parser::parseVectorType() { 1360 consumeToken(Token::kw_vector); 1361 1362 if (parseToken(Token::less, "expected '<' in vector type")) 1363 return nullptr; 1364 1365 SmallVector<int64_t, 4> dimensions; 1366 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 1367 return nullptr; 1368 if (dimensions.empty()) 1369 return (emitError("expected dimension size in vector type"), nullptr); 1370 if (any_of(dimensions, [](int64_t i) { return i <= 0; })) 1371 return emitError(getToken().getLoc(), 1372 "vector types must have positive constant sizes"), 1373 nullptr; 1374 1375 // Parse the element type. 1376 auto typeLoc = getToken().getLoc(); 1377 auto elementType = parseType(); 1378 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 1379 return nullptr; 1380 if (!VectorType::isValidElementType(elementType)) 1381 return emitError(typeLoc, "vector elements must be int or float type"), 1382 nullptr; 1383 1384 return VectorType::get(dimensions, elementType); 1385 } 1386 1387 /// Parse a dimension list of a tensor or memref type. This populates the 1388 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 1389 /// errors out on `?` otherwise. 1390 /// 1391 /// dimension-list-ranked ::= (dimension `x`)* 1392 /// dimension ::= `?` | decimal-literal 1393 /// 1394 /// When `allowDynamic` is not set, this is used to parse: 1395 /// 1396 /// static-dimension-list ::= (decimal-literal `x`)* 1397 ParseResult 1398 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 1399 bool allowDynamic) { 1400 while (getToken().isAny(Token::integer, Token::question)) { 1401 if (consumeIf(Token::question)) { 1402 if (!allowDynamic) 1403 return emitError("expected static shape"); 1404 dimensions.push_back(-1); 1405 } else { 1406 // Hexadecimal integer literals (starting with `0x`) are not allowed in 1407 // aggregate type declarations. Therefore, `0xf32` should be processed as 1408 // a sequence of separate elements `0`, `x`, `f32`. 1409 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 1410 // We can get here only if the token is an integer literal. Hexadecimal 1411 // integer literals can only start with `0x` (`1x` wouldn't lex as a 1412 // literal, just `1` would, at which point we don't get into this 1413 // branch). 1414 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 1415 dimensions.push_back(0); 1416 state.lex.resetPointer(getTokenSpelling().data() + 1); 1417 consumeToken(); 1418 } else { 1419 // Make sure this integer value is in bound and valid. 1420 auto dimension = getToken().getUnsignedIntegerValue(); 1421 if (!dimension.hasValue()) 1422 return emitError("invalid dimension"); 1423 dimensions.push_back((int64_t)dimension.getValue()); 1424 consumeToken(Token::integer); 1425 } 1426 } 1427 1428 // Make sure we have an 'x' or something like 'xbf32'. 1429 if (parseXInDimensionList()) 1430 return failure(); 1431 } 1432 1433 return success(); 1434 } 1435 1436 /// Parse an 'x' token in a dimension list, handling the case where the x is 1437 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 1438 /// token. 1439 ParseResult Parser::parseXInDimensionList() { 1440 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 1441 return emitError("expected 'x' in dimension list"); 1442 1443 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 1444 if (getTokenSpelling().size() != 1) 1445 state.lex.resetPointer(getTokenSpelling().data() + 1); 1446 1447 // Consume the 'x'. 1448 consumeToken(Token::bare_identifier); 1449 1450 return success(); 1451 } 1452 1453 // Parse a comma-separated list of dimensions, possibly empty: 1454 // stride-list ::= `[` (dimension (`,` dimension)*)? `]` 1455 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { 1456 if (!consumeIf(Token::l_square)) 1457 return failure(); 1458 // Empty list early exit. 1459 if (consumeIf(Token::r_square)) 1460 return success(); 1461 while (true) { 1462 if (consumeIf(Token::question)) { 1463 dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); 1464 } else { 1465 // This must be an integer value. 1466 int64_t val; 1467 if (getToken().getSpelling().getAsInteger(10, val)) 1468 return emitError("invalid integer value: ") << getToken().getSpelling(); 1469 // Make sure it is not the one value for `?`. 1470 if (ShapedType::isDynamic(val)) 1471 return emitError("invalid integer value: ") 1472 << getToken().getSpelling() 1473 << ", use `?` to specify a dynamic dimension"; 1474 dimensions.push_back(val); 1475 consumeToken(Token::integer); 1476 } 1477 if (!consumeIf(Token::comma)) 1478 break; 1479 } 1480 if (!consumeIf(Token::r_square)) 1481 return failure(); 1482 return success(); 1483 } 1484 1485 //===----------------------------------------------------------------------===// 1486 // Attribute parsing. 1487 //===----------------------------------------------------------------------===// 1488 1489 /// Return the symbol reference referred to by the given token, that is known to 1490 /// be an @-identifier. 1491 static std::string extractSymbolReference(Token tok) { 1492 assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); 1493 StringRef nameStr = tok.getSpelling().drop_front(); 1494 1495 // Check to see if the reference is a string literal, or a bare identifier. 1496 if (nameStr.front() == '"') 1497 return tok.getStringValue(); 1498 return std::string(nameStr); 1499 } 1500 1501 /// Parse an arbitrary attribute. 1502 /// 1503 /// attribute-value ::= `unit` 1504 /// | bool-literal 1505 /// | integer-literal (`:` (index-type | integer-type))? 1506 /// | float-literal (`:` float-type)? 1507 /// | string-literal (`:` type)? 1508 /// | type 1509 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 1510 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 1511 /// | symbol-ref-id (`::` symbol-ref-id)* 1512 /// | `dense` `<` attribute-value `>` `:` 1513 /// (tensor-type | vector-type) 1514 /// | `sparse` `<` attribute-value `,` attribute-value `>` 1515 /// `:` (tensor-type | vector-type) 1516 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 1517 /// `>` `:` (tensor-type | vector-type) 1518 /// | extended-attribute 1519 /// 1520 Attribute Parser::parseAttribute(Type type) { 1521 switch (getToken().getKind()) { 1522 // Parse an AffineMap or IntegerSet attribute. 1523 case Token::kw_affine_map: { 1524 consumeToken(Token::kw_affine_map); 1525 1526 AffineMap map; 1527 if (parseToken(Token::less, "expected '<' in affine map") || 1528 parseAffineMapReference(map) || 1529 parseToken(Token::greater, "expected '>' in affine map")) 1530 return Attribute(); 1531 return AffineMapAttr::get(map); 1532 } 1533 case Token::kw_affine_set: { 1534 consumeToken(Token::kw_affine_set); 1535 1536 IntegerSet set; 1537 if (parseToken(Token::less, "expected '<' in integer set") || 1538 parseIntegerSetReference(set) || 1539 parseToken(Token::greater, "expected '>' in integer set")) 1540 return Attribute(); 1541 return IntegerSetAttr::get(set); 1542 } 1543 1544 // Parse an array attribute. 1545 case Token::l_square: { 1546 consumeToken(Token::l_square); 1547 1548 SmallVector<Attribute, 4> elements; 1549 auto parseElt = [&]() -> ParseResult { 1550 elements.push_back(parseAttribute()); 1551 return elements.back() ? success() : failure(); 1552 }; 1553 1554 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 1555 return nullptr; 1556 return builder.getArrayAttr(elements); 1557 } 1558 1559 // Parse a boolean attribute. 1560 case Token::kw_false: 1561 consumeToken(Token::kw_false); 1562 return builder.getBoolAttr(false); 1563 case Token::kw_true: 1564 consumeToken(Token::kw_true); 1565 return builder.getBoolAttr(true); 1566 1567 // Parse a dense elements attribute. 1568 case Token::kw_dense: 1569 return parseDenseElementsAttr(type); 1570 1571 // Parse a dictionary attribute. 1572 case Token::l_brace: { 1573 SmallVector<NamedAttribute, 4> elements; 1574 if (parseAttributeDict(elements)) 1575 return nullptr; 1576 return builder.getDictionaryAttr(elements); 1577 } 1578 1579 // Parse an extended attribute, i.e. alias or dialect attribute. 1580 case Token::hash_identifier: 1581 return parseExtendedAttr(type); 1582 1583 // Parse floating point and integer attributes. 1584 case Token::floatliteral: 1585 return parseFloatAttr(type, /*isNegative=*/false); 1586 case Token::integer: 1587 return parseDecOrHexAttr(type, /*isNegative=*/false); 1588 case Token::minus: { 1589 consumeToken(Token::minus); 1590 if (getToken().is(Token::integer)) 1591 return parseDecOrHexAttr(type, /*isNegative=*/true); 1592 if (getToken().is(Token::floatliteral)) 1593 return parseFloatAttr(type, /*isNegative=*/true); 1594 1595 return (emitError("expected constant integer or floating point value"), 1596 nullptr); 1597 } 1598 1599 // Parse a location attribute. 1600 case Token::kw_loc: { 1601 LocationAttr attr; 1602 return failed(parseLocation(attr)) ? Attribute() : attr; 1603 } 1604 1605 // Parse an opaque elements attribute. 1606 case Token::kw_opaque: 1607 return parseOpaqueElementsAttr(type); 1608 1609 // Parse a sparse elements attribute. 1610 case Token::kw_sparse: 1611 return parseSparseElementsAttr(type); 1612 1613 // Parse a string attribute. 1614 case Token::string: { 1615 auto val = getToken().getStringValue(); 1616 consumeToken(Token::string); 1617 // Parse the optional trailing colon type if one wasn't explicitly provided. 1618 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1619 return Attribute(); 1620 1621 return type ? StringAttr::get(val, type) 1622 : StringAttr::get(val, getContext()); 1623 } 1624 1625 // Parse a symbol reference attribute. 1626 case Token::at_identifier: { 1627 std::string nameStr = extractSymbolReference(getToken()); 1628 consumeToken(Token::at_identifier); 1629 1630 // Parse any nested references. 1631 std::vector<FlatSymbolRefAttr> nestedRefs; 1632 while (getToken().is(Token::colon)) { 1633 // Check for the '::' prefix. 1634 const char *curPointer = getToken().getLoc().getPointer(); 1635 consumeToken(Token::colon); 1636 if (!consumeIf(Token::colon)) { 1637 state.lex.resetPointer(curPointer); 1638 consumeToken(); 1639 break; 1640 } 1641 // Parse the reference itself. 1642 auto curLoc = getToken().getLoc(); 1643 if (getToken().isNot(Token::at_identifier)) { 1644 emitError(curLoc, "expected nested symbol reference identifier"); 1645 return Attribute(); 1646 } 1647 1648 std::string nameStr = extractSymbolReference(getToken()); 1649 consumeToken(Token::at_identifier); 1650 nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); 1651 } 1652 1653 return builder.getSymbolRefAttr(nameStr, nestedRefs); 1654 } 1655 1656 // Parse a 'unit' attribute. 1657 case Token::kw_unit: 1658 consumeToken(Token::kw_unit); 1659 return builder.getUnitAttr(); 1660 1661 default: 1662 // Parse a type attribute. 1663 if (Type type = parseType()) 1664 return TypeAttr::get(type); 1665 return nullptr; 1666 } 1667 } 1668 1669 /// Attribute dictionary. 1670 /// 1671 /// attribute-dict ::= `{` `}` 1672 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1673 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value 1674 /// 1675 ParseResult 1676 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1677 if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) 1678 return failure(); 1679 1680 auto parseElt = [&]() -> ParseResult { 1681 // The name of an attribute can either be a bare identifier, or a string. 1682 Optional<Identifier> nameId; 1683 if (getToken().is(Token::string)) 1684 nameId = builder.getIdentifier(getToken().getStringValue()); 1685 else if (getToken().isAny(Token::bare_identifier, Token::inttype) || 1686 getToken().isKeyword()) 1687 nameId = builder.getIdentifier(getTokenSpelling()); 1688 else 1689 return emitError("expected attribute name"); 1690 consumeToken(); 1691 1692 // Try to parse the '=' for the attribute value. 1693 if (!consumeIf(Token::equal)) { 1694 // If there is no '=', we treat this as a unit attribute. 1695 attributes.push_back({*nameId, builder.getUnitAttr()}); 1696 return success(); 1697 } 1698 1699 auto attr = parseAttribute(); 1700 if (!attr) 1701 return failure(); 1702 1703 attributes.push_back({*nameId, attr}); 1704 return success(); 1705 }; 1706 1707 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1708 return failure(); 1709 1710 return success(); 1711 } 1712 1713 /// Parse an extended attribute. 1714 /// 1715 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1716 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1717 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1718 /// attribute-alias ::= `#` alias-name 1719 /// 1720 Attribute Parser::parseExtendedAttr(Type type) { 1721 Attribute attr = parseExtendedSymbol<Attribute>( 1722 *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, 1723 [&](StringRef dialectName, StringRef symbolData, 1724 llvm::SMLoc loc) -> Attribute { 1725 // Parse an optional trailing colon type. 1726 Type attrType = type; 1727 if (consumeIf(Token::colon) && !(attrType = parseType())) 1728 return Attribute(); 1729 1730 // If we found a registered dialect, then ask it to parse the attribute. 1731 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { 1732 return parseSymbol<Attribute>( 1733 symbolData, state.context, state.symbols, [&](Parser &parser) { 1734 CustomDialectAsmParser customParser(symbolData, parser); 1735 return dialect->parseAttribute(customParser, attrType); 1736 }); 1737 } 1738 1739 // Otherwise, form a new opaque attribute. 1740 return OpaqueAttr::getChecked( 1741 Identifier::get(dialectName, state.context), symbolData, 1742 attrType ? attrType : NoneType::get(state.context), 1743 getEncodedSourceLocation(loc)); 1744 }); 1745 1746 // Ensure that the attribute has the same type as requested. 1747 if (attr && type && attr.getType() != type) { 1748 emitError("attribute type different than expected: expected ") 1749 << type << ", but got " << attr.getType(); 1750 return nullptr; 1751 } 1752 return attr; 1753 } 1754 1755 /// Parse a float attribute. 1756 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1757 auto val = getToken().getFloatingPointValue(); 1758 if (!val.hasValue()) 1759 return (emitError("floating point value too large for attribute"), nullptr); 1760 consumeToken(Token::floatliteral); 1761 if (!type) { 1762 // Default to F64 when no type is specified. 1763 if (!consumeIf(Token::colon)) 1764 type = builder.getF64Type(); 1765 else if (!(type = parseType())) 1766 return nullptr; 1767 } 1768 if (!type.isa<FloatType>()) 1769 return (emitError("floating point value not valid for specified type"), 1770 nullptr); 1771 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1772 } 1773 1774 /// Construct a float attribute bitwise equivalent to the integer literal. 1775 static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1776 uint64_t value) { 1777 // FIXME: bfloat is currently stored as a double internally because it doesn't 1778 // have valid APFloat semantics. 1779 if (type.isF64() || type.isBF16()) 1780 return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value)); 1781 1782 APInt apInt(type.getWidth(), value); 1783 if (apInt != value) { 1784 p->emitError("hexadecimal float constant out of range for type"); 1785 return llvm::None; 1786 } 1787 return APFloat(type.getFloatSemantics(), apInt); 1788 } 1789 1790 /// Construct an APint from a parsed value, a known attribute type and 1791 /// sign. 1792 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative, 1793 StringRef spelling) { 1794 // Parse the integer value into an APInt that is big enough to hold the value. 1795 APInt result; 1796 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1797 if (spelling.getAsInteger(isHex ? 0 : 10, result)) 1798 return llvm::None; 1799 1800 // Extend or truncate the bitwidth to the right size. 1801 unsigned width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); 1802 if (width > result.getBitWidth()) { 1803 result = result.zext(width); 1804 } else if (width < result.getBitWidth()) { 1805 // The parser can return an unnecessarily wide result with leading zeros. 1806 // This isn't a problem, but truncating off bits is bad. 1807 if (result.countLeadingZeros() < result.getBitWidth() - width) 1808 return llvm::None; 1809 1810 result = result.trunc(width); 1811 } 1812 1813 if (isNegative) { 1814 // The value is negative, we have an overflow if the sign bit is not set 1815 // in the negated apInt. 1816 result.negate(); 1817 if (!result.isSignBitSet()) 1818 return llvm::None; 1819 } else if ((type.isSignedInteger() || type.isIndex()) && 1820 result.isSignBitSet()) { 1821 // The value is a positive signed integer or index, 1822 // we have an overflow if the sign bit is set. 1823 return llvm::None; 1824 } 1825 1826 return result; 1827 } 1828 1829 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1830 /// or a float attribute. 1831 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1832 // Remember if the literal is hexadecimal. 1833 StringRef spelling = getToken().getSpelling(); 1834 auto loc = state.curToken.getLoc(); 1835 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1836 1837 consumeToken(Token::integer); 1838 if (!type) { 1839 // Default to i64 if not type is specified. 1840 if (!consumeIf(Token::colon)) 1841 type = builder.getIntegerType(64); 1842 else if (!(type = parseType())) 1843 return nullptr; 1844 } 1845 1846 if (auto floatType = type.dyn_cast<FloatType>()) { 1847 if (isNegative) 1848 return emitError( 1849 loc, 1850 "hexadecimal float literal should not have a leading minus"), 1851 nullptr; 1852 if (!isHex) { 1853 emitError(loc, "unexpected decimal integer literal for a float attribute") 1854 .attachNote() 1855 << "add a trailing dot to make the literal a float"; 1856 return nullptr; 1857 } 1858 1859 auto val = Token::getUInt64IntegerValue(spelling); 1860 if (!val.hasValue()) 1861 return emitError("integer constant out of range for attribute"), nullptr; 1862 1863 // Construct a float attribute bitwise equivalent to the integer literal. 1864 Optional<APFloat> apVal = 1865 buildHexadecimalFloatLiteral(this, floatType, *val); 1866 return apVal ? FloatAttr::get(floatType, *apVal) : Attribute(); 1867 } 1868 1869 if (!type.isa<IntegerType>() && !type.isa<IndexType>()) 1870 return emitError(loc, "integer literal not valid for specified type"), 1871 nullptr; 1872 1873 if (isNegative && type.isUnsignedInteger()) { 1874 emitError(loc, 1875 "negative integer literal not valid for unsigned integer type"); 1876 return nullptr; 1877 } 1878 1879 Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling); 1880 if (!apInt) 1881 return emitError(loc, "integer constant out of range for attribute"), 1882 nullptr; 1883 return builder.getIntegerAttr(type, *apInt); 1884 } 1885 1886 /// Parse elements values stored within a hex etring. On success, the values are 1887 /// stored into 'result'. 1888 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, 1889 std::string &result) { 1890 std::string val = tok.getStringValue(); 1891 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1892 return parser.emitError(tok.getLoc(), 1893 "elements hex string should start with '0x'"); 1894 1895 StringRef hexValues = StringRef(val).drop_front(2); 1896 if (!llvm::all_of(hexValues, llvm::isHexDigit)) 1897 return parser.emitError(tok.getLoc(), 1898 "elements hex string only contains hex digits"); 1899 1900 result = llvm::fromHex(hexValues); 1901 return success(); 1902 } 1903 1904 /// Parse an opaque elements attribute. 1905 Attribute Parser::parseOpaqueElementsAttr(Type attrType) { 1906 consumeToken(Token::kw_opaque); 1907 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1908 return nullptr; 1909 1910 if (getToken().isNot(Token::string)) 1911 return (emitError("expected dialect namespace"), nullptr); 1912 1913 auto name = getToken().getStringValue(); 1914 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1915 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1916 // attribute. Otherwise, it can't be roundtripped without having the dialect 1917 // registered. 1918 if (!dialect) 1919 return (emitError("no registered dialect with namespace '" + name + "'"), 1920 nullptr); 1921 consumeToken(Token::string); 1922 1923 if (parseToken(Token::comma, "expected ','")) 1924 return nullptr; 1925 1926 Token hexTok = getToken(); 1927 if (parseToken(Token::string, "elements hex string should start with '0x'") || 1928 parseToken(Token::greater, "expected '>'")) 1929 return nullptr; 1930 auto type = parseElementsLiteralType(attrType); 1931 if (!type) 1932 return nullptr; 1933 1934 std::string data; 1935 if (parseElementAttrHexValues(*this, hexTok, data)) 1936 return nullptr; 1937 return OpaqueElementsAttr::get(dialect, type, data); 1938 } 1939 1940 namespace { 1941 class TensorLiteralParser { 1942 public: 1943 TensorLiteralParser(Parser &p) : p(p) {} 1944 1945 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 1946 /// may also parse a tensor literal that is store as a hex string. 1947 ParseResult parse(bool allowHex); 1948 1949 /// Build a dense attribute instance with the parsed elements and the given 1950 /// shaped type. 1951 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1952 1953 ArrayRef<int64_t> getShape() const { return shape; } 1954 1955 private: 1956 enum class ElementKind { Boolean, Integer, Float }; 1957 1958 /// Return a string to represent the given element kind. 1959 const char *getElementKindStr(ElementKind kind) { 1960 switch (kind) { 1961 case ElementKind::Boolean: 1962 return "'boolean'"; 1963 case ElementKind::Integer: 1964 return "'integer'"; 1965 case ElementKind::Float: 1966 return "'float'"; 1967 } 1968 llvm_unreachable("unknown element kind"); 1969 } 1970 1971 /// Build a Dense Integer attribute for the given type. 1972 DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, 1973 IntegerType eltTy); 1974 1975 /// Build a Dense Float attribute for the given type. 1976 DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, 1977 FloatType eltTy); 1978 1979 /// Build a Dense attribute with hex data for the given type. 1980 DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); 1981 1982 /// Parse a single element, returning failure if it isn't a valid element 1983 /// literal. For example: 1984 /// parseElement(1) -> Success, 1 1985 /// parseElement([1]) -> Failure 1986 ParseResult parseElement(); 1987 1988 /// Parse a list of either lists or elements, returning the dimensions of the 1989 /// parsed sub-tensors in dims. For example: 1990 /// parseList([1, 2, 3]) -> Success, [3] 1991 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1992 /// parseList([[1, 2], 3]) -> Failure 1993 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1994 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 1995 1996 /// Parse a literal that was printed as a hex string. 1997 ParseResult parseHexElements(); 1998 1999 Parser &p; 2000 2001 /// The shape inferred from the parsed elements. 2002 SmallVector<int64_t, 4> shape; 2003 2004 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 2005 std::vector<std::pair<bool, Token>> storage; 2006 2007 /// A flag that indicates the type of elements that have been parsed. 2008 Optional<ElementKind> knownEltKind; 2009 2010 /// Storage used when parsing elements that were stored as hex values. 2011 Optional<Token> hexStorage; 2012 }; 2013 } // namespace 2014 2015 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 2016 /// may also parse a tensor literal that is store as a hex string. 2017 ParseResult TensorLiteralParser::parse(bool allowHex) { 2018 // If hex is allowed, check for a string literal. 2019 if (allowHex && p.getToken().is(Token::string)) { 2020 hexStorage = p.getToken(); 2021 p.consumeToken(Token::string); 2022 return success(); 2023 } 2024 // Otherwise, parse a list or an individual element. 2025 if (p.getToken().is(Token::l_square)) 2026 return parseList(shape); 2027 return parseElement(); 2028 } 2029 2030 /// Build a dense attribute instance with the parsed elements and the given 2031 /// shaped type. 2032 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 2033 ShapedType type) { 2034 // Check to see if we parsed the literal from a hex string. 2035 if (hexStorage.hasValue()) 2036 return getHexAttr(loc, type); 2037 2038 // Check that the parsed storage size has the same number of elements to the 2039 // type, or is a known splat. 2040 if (!shape.empty() && getShape() != type.getShape()) { 2041 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 2042 << "]) does not match type ([" << type.getShape() << "])"; 2043 return nullptr; 2044 } 2045 2046 // If the type is an integer, build a set of APInt values from the storage 2047 // with the correct bitwidth. 2048 if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) 2049 return getIntAttr(loc, type, intTy); 2050 2051 // Otherwise, this must be a floating point type. 2052 auto floatTy = type.getElementType().dyn_cast<FloatType>(); 2053 if (!floatTy) { 2054 p.emitError(loc) << "expected floating-point or integer element type, got " 2055 << type.getElementType(); 2056 return nullptr; 2057 } 2058 return getFloatAttr(loc, type, floatTy); 2059 } 2060 2061 /// Build a Dense Integer attribute for the given type. 2062 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, 2063 ShapedType type, 2064 IntegerType eltTy) { 2065 std::vector<APInt> intElements; 2066 intElements.reserve(storage.size()); 2067 auto isUintType = type.getElementType().isUnsignedInteger(); 2068 for (const auto &signAndToken : storage) { 2069 bool isNegative = signAndToken.first; 2070 const Token &token = signAndToken.second; 2071 auto tokenLoc = token.getLoc(); 2072 2073 if (isNegative && isUintType) { 2074 p.emitError(tokenLoc) 2075 << "expected unsigned integer elements, but parsed negative value"; 2076 return nullptr; 2077 } 2078 2079 // Check to see if floating point values were parsed. 2080 if (token.is(Token::floatliteral)) { 2081 p.emitError(tokenLoc) 2082 << "expected integer elements, but parsed floating-point"; 2083 return nullptr; 2084 } 2085 2086 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 2087 "unexpected token type"); 2088 if (token.isAny(Token::kw_true, Token::kw_false)) { 2089 if (!eltTy.isInteger(1)) 2090 p.emitError(tokenLoc) 2091 << "expected i1 type for 'true' or 'false' values"; 2092 APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), 2093 /*isSigned=*/false); 2094 intElements.push_back(apInt); 2095 continue; 2096 } 2097 2098 // Create APInt values for each element with the correct bitwidth. 2099 Optional<APInt> apInt = 2100 buildAttributeAPInt(eltTy, isNegative, token.getSpelling()); 2101 if (!apInt) 2102 return (p.emitError(tokenLoc, "integer constant out of range for type"), 2103 nullptr); 2104 intElements.push_back(*apInt); 2105 } 2106 2107 return DenseElementsAttr::get(type, intElements); 2108 } 2109 2110 /// Build a Dense Float attribute for the given type. 2111 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, 2112 ShapedType type, 2113 FloatType eltTy) { 2114 std::vector<APFloat> floatValues; 2115 floatValues.reserve(storage.size()); 2116 for (const auto &signAndToken : storage) { 2117 bool isNegative = signAndToken.first; 2118 const Token &token = signAndToken.second; 2119 2120 // Handle hexadecimal float literals. 2121 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 2122 if (isNegative) { 2123 p.emitError(token.getLoc()) 2124 << "hexadecimal float literal should not have a leading minus"; 2125 return nullptr; 2126 } 2127 auto val = token.getUInt64IntegerValue(); 2128 if (!val.hasValue()) { 2129 p.emitError("hexadecimal float constant out of range for attribute"); 2130 return nullptr; 2131 } 2132 Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); 2133 if (!apVal) 2134 return nullptr; 2135 floatValues.push_back(*apVal); 2136 continue; 2137 } 2138 2139 // Check to see if any decimal integers or booleans were parsed. 2140 if (!token.is(Token::floatliteral)) { 2141 p.emitError() << "expected floating-point elements, but parsed integer"; 2142 return nullptr; 2143 } 2144 2145 // Build the float values from tokens. 2146 auto val = token.getFloatingPointValue(); 2147 if (!val.hasValue()) { 2148 p.emitError("floating point value too large for attribute"); 2149 return nullptr; 2150 } 2151 // Treat BF16 as double because it is not supported in LLVM's APFloat. 2152 APFloat apVal(isNegative ? -*val : *val); 2153 if (!eltTy.isBF16() && !eltTy.isF64()) { 2154 bool unused; 2155 apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, 2156 &unused); 2157 } 2158 floatValues.push_back(apVal); 2159 } 2160 2161 return DenseElementsAttr::get(type, floatValues); 2162 } 2163 2164 /// Build a Dense attribute with hex data for the given type. 2165 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, 2166 ShapedType type) { 2167 Type elementType = type.getElementType(); 2168 if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) { 2169 p.emitError(loc) << "expected floating-point or integer element type, got " 2170 << elementType; 2171 return nullptr; 2172 } 2173 2174 std::string data; 2175 if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) 2176 return nullptr; 2177 2178 // Check that the size of the hex data corresponds to the size of the type, or 2179 // a splat of the type. 2180 // TODO: bf16 is currently stored as a double, this should be removed when 2181 // APFloat properly supports it. 2182 int64_t elementWidth = 2183 elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth(); 2184 if (static_cast<int64_t>(data.size() * CHAR_BIT) != 2185 (type.getNumElements() * elementWidth)) { 2186 p.emitError(loc) << "elements hex data size is invalid for provided type: " 2187 << type; 2188 return nullptr; 2189 } 2190 2191 return DenseElementsAttr::getFromRawBuffer( 2192 type, ArrayRef<char>(data.data(), data.size()), /*isSplatBuffer=*/false); 2193 } 2194 2195 ParseResult TensorLiteralParser::parseElement() { 2196 switch (p.getToken().getKind()) { 2197 // Parse a boolean element. 2198 case Token::kw_true: 2199 case Token::kw_false: 2200 case Token::floatliteral: 2201 case Token::integer: 2202 storage.emplace_back(/*isNegative=*/false, p.getToken()); 2203 p.consumeToken(); 2204 break; 2205 2206 // Parse a signed integer or a negative floating-point element. 2207 case Token::minus: 2208 p.consumeToken(Token::minus); 2209 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 2210 return p.emitError("expected integer or floating point literal"); 2211 storage.emplace_back(/*isNegative=*/true, p.getToken()); 2212 p.consumeToken(); 2213 break; 2214 2215 default: 2216 return p.emitError("expected element literal of primitive type"); 2217 } 2218 2219 return success(); 2220 } 2221 2222 /// Parse a list of either lists or elements, returning the dimensions of the 2223 /// parsed sub-tensors in dims. For example: 2224 /// parseList([1, 2, 3]) -> Success, [3] 2225 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 2226 /// parseList([[1, 2], 3]) -> Failure 2227 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 2228 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 2229 p.consumeToken(Token::l_square); 2230 2231 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 2232 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 2233 if (prevDims == newDims) 2234 return success(); 2235 return p.emitError("tensor literal is invalid; ranks are not consistent " 2236 "between elements"); 2237 }; 2238 2239 bool first = true; 2240 SmallVector<int64_t, 4> newDims; 2241 unsigned size = 0; 2242 auto parseCommaSeparatedList = [&]() -> ParseResult { 2243 SmallVector<int64_t, 4> thisDims; 2244 if (p.getToken().getKind() == Token::l_square) { 2245 if (parseList(thisDims)) 2246 return failure(); 2247 } else if (parseElement()) { 2248 return failure(); 2249 } 2250 ++size; 2251 if (!first) 2252 return checkDims(newDims, thisDims); 2253 newDims = thisDims; 2254 first = false; 2255 return success(); 2256 }; 2257 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 2258 return failure(); 2259 2260 // Return the sublists' dimensions with 'size' prepended. 2261 dims.clear(); 2262 dims.push_back(size); 2263 dims.append(newDims.begin(), newDims.end()); 2264 return success(); 2265 } 2266 2267 /// Parse a dense elements attribute. 2268 Attribute Parser::parseDenseElementsAttr(Type attrType) { 2269 consumeToken(Token::kw_dense); 2270 if (parseToken(Token::less, "expected '<' after 'dense'")) 2271 return nullptr; 2272 2273 // Parse the literal data. 2274 TensorLiteralParser literalParser(*this); 2275 if (literalParser.parse(/*allowHex=*/true)) 2276 return nullptr; 2277 2278 if (parseToken(Token::greater, "expected '>'")) 2279 return nullptr; 2280 2281 auto typeLoc = getToken().getLoc(); 2282 auto type = parseElementsLiteralType(attrType); 2283 if (!type) 2284 return nullptr; 2285 return literalParser.getAttr(typeLoc, type); 2286 } 2287 2288 /// Shaped type for elements attribute. 2289 /// 2290 /// elements-literal-type ::= vector-type | ranked-tensor-type 2291 /// 2292 /// This method also checks the type has static shape. 2293 ShapedType Parser::parseElementsLiteralType(Type type) { 2294 // If the user didn't provide a type, parse the colon type for the literal. 2295 if (!type) { 2296 if (parseToken(Token::colon, "expected ':'")) 2297 return nullptr; 2298 if (!(type = parseType())) 2299 return nullptr; 2300 } 2301 2302 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 2303 emitError("elements literal must be a ranked tensor or vector type"); 2304 return nullptr; 2305 } 2306 2307 auto sType = type.cast<ShapedType>(); 2308 if (!sType.hasStaticShape()) 2309 return (emitError("elements literal type must have static shape"), nullptr); 2310 2311 return sType; 2312 } 2313 2314 /// Parse a sparse elements attribute. 2315 Attribute Parser::parseSparseElementsAttr(Type attrType) { 2316 consumeToken(Token::kw_sparse); 2317 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 2318 return nullptr; 2319 2320 /// Parse the indices. We don't allow hex values here as we may need to use 2321 /// the inferred shape. 2322 auto indicesLoc = getToken().getLoc(); 2323 TensorLiteralParser indiceParser(*this); 2324 if (indiceParser.parse(/*allowHex=*/false)) 2325 return nullptr; 2326 2327 if (parseToken(Token::comma, "expected ','")) 2328 return nullptr; 2329 2330 /// Parse the values. 2331 auto valuesLoc = getToken().getLoc(); 2332 TensorLiteralParser valuesParser(*this); 2333 if (valuesParser.parse(/*allowHex=*/true)) 2334 return nullptr; 2335 2336 if (parseToken(Token::greater, "expected '>'")) 2337 return nullptr; 2338 2339 auto type = parseElementsLiteralType(attrType); 2340 if (!type) 2341 return nullptr; 2342 2343 // If the indices are a splat, i.e. the literal parser parsed an element and 2344 // not a list, we set the shape explicitly. The indices are represented by a 2345 // 2-dimensional shape where the second dimension is the rank of the type. 2346 // Given that the parsed indices is a splat, we know that we only have one 2347 // indice and thus one for the first dimension. 2348 auto indiceEltType = builder.getIntegerType(64); 2349 ShapedType indicesType; 2350 if (indiceParser.getShape().empty()) { 2351 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 2352 } else { 2353 // Otherwise, set the shape to the one parsed by the literal parser. 2354 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 2355 } 2356 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 2357 2358 // If the values are a splat, set the shape explicitly based on the number of 2359 // indices. The number of indices is encoded in the first dimension of the 2360 // indice shape type. 2361 auto valuesEltType = type.getElementType(); 2362 ShapedType valuesType = 2363 valuesParser.getShape().empty() 2364 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 2365 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 2366 auto values = valuesParser.getAttr(valuesLoc, valuesType); 2367 2368 /// Sanity check. 2369 if (valuesType.getRank() != 1) 2370 return (emitError("expected 1-d tensor for values"), nullptr); 2371 2372 auto sameShape = (indicesType.getRank() == 1) || 2373 (type.getRank() == indicesType.getDimSize(1)); 2374 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 2375 if (!sameShape || !sameElementNum) { 2376 emitError() << "expected shape ([" << type.getShape() 2377 << "]); inferred shape of indices literal ([" 2378 << indicesType.getShape() 2379 << "]); inferred shape of values literal ([" 2380 << valuesType.getShape() << "])"; 2381 return nullptr; 2382 } 2383 2384 // Build the sparse elements attribute by the indices and values. 2385 return SparseElementsAttr::get(type, indices, values); 2386 } 2387 2388 //===----------------------------------------------------------------------===// 2389 // Location parsing. 2390 //===----------------------------------------------------------------------===// 2391 2392 /// Parse a location. 2393 /// 2394 /// location ::= `loc` inline-location 2395 /// inline-location ::= '(' location-inst ')' 2396 /// 2397 ParseResult Parser::parseLocation(LocationAttr &loc) { 2398 // Check for 'loc' identifier. 2399 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 2400 return emitError(); 2401 2402 // Parse the inline-location. 2403 if (parseToken(Token::l_paren, "expected '(' in inline location") || 2404 parseLocationInstance(loc) || 2405 parseToken(Token::r_paren, "expected ')' in inline location")) 2406 return failure(); 2407 return success(); 2408 } 2409 2410 /// Specific location instances. 2411 /// 2412 /// location-inst ::= filelinecol-location | 2413 /// name-location | 2414 /// callsite-location | 2415 /// fused-location | 2416 /// unknown-location 2417 /// filelinecol-location ::= string-literal ':' integer-literal 2418 /// ':' integer-literal 2419 /// name-location ::= string-literal 2420 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 2421 /// fused-location ::= fused ('<' attribute-value '>')? 2422 /// '[' location-inst (location-inst ',')* ']' 2423 /// unknown-location ::= 'unknown' 2424 /// 2425 ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { 2426 consumeToken(Token::bare_identifier); 2427 2428 // Parse the '('. 2429 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 2430 return failure(); 2431 2432 // Parse the callee location. 2433 LocationAttr calleeLoc; 2434 if (parseLocationInstance(calleeLoc)) 2435 return failure(); 2436 2437 // Parse the 'at'. 2438 if (getToken().isNot(Token::bare_identifier) || 2439 getToken().getSpelling() != "at") 2440 return emitError("expected 'at' in callsite location"); 2441 consumeToken(Token::bare_identifier); 2442 2443 // Parse the caller location. 2444 LocationAttr callerLoc; 2445 if (parseLocationInstance(callerLoc)) 2446 return failure(); 2447 2448 // Parse the ')'. 2449 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 2450 return failure(); 2451 2452 // Return the callsite location. 2453 loc = CallSiteLoc::get(calleeLoc, callerLoc); 2454 return success(); 2455 } 2456 2457 ParseResult Parser::parseFusedLocation(LocationAttr &loc) { 2458 consumeToken(Token::bare_identifier); 2459 2460 // Try to parse the optional metadata. 2461 Attribute metadata; 2462 if (consumeIf(Token::less)) { 2463 metadata = parseAttribute(); 2464 if (!metadata) 2465 return emitError("expected valid attribute metadata"); 2466 // Parse the '>' token. 2467 if (parseToken(Token::greater, 2468 "expected '>' after fused location metadata")) 2469 return failure(); 2470 } 2471 2472 SmallVector<Location, 4> locations; 2473 auto parseElt = [&] { 2474 LocationAttr newLoc; 2475 if (parseLocationInstance(newLoc)) 2476 return failure(); 2477 locations.push_back(newLoc); 2478 return success(); 2479 }; 2480 2481 if (parseToken(Token::l_square, "expected '[' in fused location") || 2482 parseCommaSeparatedList(parseElt) || 2483 parseToken(Token::r_square, "expected ']' in fused location")) 2484 return failure(); 2485 2486 // Return the fused location. 2487 loc = FusedLoc::get(locations, metadata, getContext()); 2488 return success(); 2489 } 2490 2491 ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { 2492 auto *ctx = getContext(); 2493 auto str = getToken().getStringValue(); 2494 consumeToken(Token::string); 2495 2496 // If the next token is ':' this is a filelinecol location. 2497 if (consumeIf(Token::colon)) { 2498 // Parse the line number. 2499 if (getToken().isNot(Token::integer)) 2500 return emitError("expected integer line number in FileLineColLoc"); 2501 auto line = getToken().getUnsignedIntegerValue(); 2502 if (!line.hasValue()) 2503 return emitError("expected integer line number in FileLineColLoc"); 2504 consumeToken(Token::integer); 2505 2506 // Parse the ':'. 2507 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 2508 return failure(); 2509 2510 // Parse the column number. 2511 if (getToken().isNot(Token::integer)) 2512 return emitError("expected integer column number in FileLineColLoc"); 2513 auto column = getToken().getUnsignedIntegerValue(); 2514 if (!column.hasValue()) 2515 return emitError("expected integer column number in FileLineColLoc"); 2516 consumeToken(Token::integer); 2517 2518 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 2519 return success(); 2520 } 2521 2522 // Otherwise, this is a NameLoc. 2523 2524 // Check for a child location. 2525 if (consumeIf(Token::l_paren)) { 2526 auto childSourceLoc = getToken().getLoc(); 2527 2528 // Parse the child location. 2529 LocationAttr childLoc; 2530 if (parseLocationInstance(childLoc)) 2531 return failure(); 2532 2533 // The child must not be another NameLoc. 2534 if (childLoc.isa<NameLoc>()) 2535 return emitError(childSourceLoc, 2536 "child of NameLoc cannot be another NameLoc"); 2537 loc = NameLoc::get(Identifier::get(str, ctx), childLoc); 2538 2539 // Parse the closing ')'. 2540 if (parseToken(Token::r_paren, 2541 "expected ')' after child location of NameLoc")) 2542 return failure(); 2543 } else { 2544 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 2545 } 2546 2547 return success(); 2548 } 2549 2550 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 2551 // Handle either name or filelinecol locations. 2552 if (getToken().is(Token::string)) 2553 return parseNameOrFileLineColLocation(loc); 2554 2555 // Bare tokens required for other cases. 2556 if (!getToken().is(Token::bare_identifier)) 2557 return emitError("expected location instance"); 2558 2559 // Check for the 'callsite' signifying a callsite location. 2560 if (getToken().getSpelling() == "callsite") 2561 return parseCallSiteLocation(loc); 2562 2563 // If the token is 'fused', then this is a fused location. 2564 if (getToken().getSpelling() == "fused") 2565 return parseFusedLocation(loc); 2566 2567 // Check for a 'unknown' for an unknown location. 2568 if (getToken().getSpelling() == "unknown") { 2569 consumeToken(Token::bare_identifier); 2570 loc = UnknownLoc::get(getContext()); 2571 return success(); 2572 } 2573 2574 return emitError("expected location instance"); 2575 } 2576 2577 //===----------------------------------------------------------------------===// 2578 // Affine parsing. 2579 //===----------------------------------------------------------------------===// 2580 2581 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 2582 /// the boolean sense. 2583 enum AffineLowPrecOp { 2584 /// Null value. 2585 LNoOp, 2586 Add, 2587 Sub 2588 }; 2589 2590 /// Higher precedence ops - all at the same precedence level. HNoOp is false 2591 /// in the boolean sense. 2592 enum AffineHighPrecOp { 2593 /// Null value. 2594 HNoOp, 2595 Mul, 2596 FloorDiv, 2597 CeilDiv, 2598 Mod 2599 }; 2600 2601 namespace { 2602 /// This is a specialized parser for affine structures (affine maps, affine 2603 /// expressions, and integer sets), maintaining the state transient to their 2604 /// bodies. 2605 class AffineParser : public Parser { 2606 public: 2607 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 2608 function_ref<ParseResult(bool)> parseElement = nullptr) 2609 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 2610 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 2611 2612 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 2613 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 2614 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 2615 ParseResult parseAffineMapOfSSAIds(AffineMap &map, 2616 OpAsmParser::Delimiter delimiter); 2617 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 2618 unsigned &numDims); 2619 2620 private: 2621 // Binary affine op parsing. 2622 AffineLowPrecOp consumeIfLowPrecOp(); 2623 AffineHighPrecOp consumeIfHighPrecOp(); 2624 2625 // Identifier lists for polyhedral structures. 2626 ParseResult parseDimIdList(unsigned &numDims); 2627 ParseResult parseSymbolIdList(unsigned &numSymbols); 2628 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 2629 unsigned &numSymbols); 2630 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 2631 2632 AffineExpr parseAffineExpr(); 2633 AffineExpr parseParentheticalExpr(); 2634 AffineExpr parseNegateExpression(AffineExpr lhs); 2635 AffineExpr parseIntegerExpr(); 2636 AffineExpr parseBareIdExpr(); 2637 AffineExpr parseSSAIdExpr(bool isSymbol); 2638 AffineExpr parseSymbolSSAIdExpr(); 2639 2640 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 2641 AffineExpr rhs, SMLoc opLoc); 2642 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 2643 AffineExpr rhs); 2644 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 2645 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 2646 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 2647 SMLoc llhsOpLoc); 2648 AffineExpr parseAffineConstraint(bool *isEq); 2649 2650 private: 2651 bool allowParsingSSAIds; 2652 function_ref<ParseResult(bool)> parseElement; 2653 unsigned numDimOperands; 2654 unsigned numSymbolOperands; 2655 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 2656 }; 2657 } // end anonymous namespace 2658 2659 /// Create an affine binary high precedence op expression (mul's, div's, mod). 2660 /// opLoc is the location of the op token to be used to report errors 2661 /// for non-conforming expressions. 2662 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 2663 AffineExpr lhs, AffineExpr rhs, 2664 SMLoc opLoc) { 2665 // TODO: make the error location info accurate. 2666 switch (op) { 2667 case Mul: 2668 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 2669 emitError(opLoc, "non-affine expression: at least one of the multiply " 2670 "operands has to be either a constant or symbolic"); 2671 return nullptr; 2672 } 2673 return lhs * rhs; 2674 case FloorDiv: 2675 if (!rhs.isSymbolicOrConstant()) { 2676 emitError(opLoc, "non-affine expression: right operand of floordiv " 2677 "has to be either a constant or symbolic"); 2678 return nullptr; 2679 } 2680 return lhs.floorDiv(rhs); 2681 case CeilDiv: 2682 if (!rhs.isSymbolicOrConstant()) { 2683 emitError(opLoc, "non-affine expression: right operand of ceildiv " 2684 "has to be either a constant or symbolic"); 2685 return nullptr; 2686 } 2687 return lhs.ceilDiv(rhs); 2688 case Mod: 2689 if (!rhs.isSymbolicOrConstant()) { 2690 emitError(opLoc, "non-affine expression: right operand of mod " 2691 "has to be either a constant or symbolic"); 2692 return nullptr; 2693 } 2694 return lhs % rhs; 2695 case HNoOp: 2696 llvm_unreachable("can't create affine expression for null high prec op"); 2697 return nullptr; 2698 } 2699 llvm_unreachable("Unknown AffineHighPrecOp"); 2700 } 2701 2702 /// Create an affine binary low precedence op expression (add, sub). 2703 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 2704 AffineExpr lhs, AffineExpr rhs) { 2705 switch (op) { 2706 case AffineLowPrecOp::Add: 2707 return lhs + rhs; 2708 case AffineLowPrecOp::Sub: 2709 return lhs - rhs; 2710 case AffineLowPrecOp::LNoOp: 2711 llvm_unreachable("can't create affine expression for null low prec op"); 2712 return nullptr; 2713 } 2714 llvm_unreachable("Unknown AffineLowPrecOp"); 2715 } 2716 2717 /// Consume this token if it is a lower precedence affine op (there are only 2718 /// two precedence levels). 2719 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 2720 switch (getToken().getKind()) { 2721 case Token::plus: 2722 consumeToken(Token::plus); 2723 return AffineLowPrecOp::Add; 2724 case Token::minus: 2725 consumeToken(Token::minus); 2726 return AffineLowPrecOp::Sub; 2727 default: 2728 return AffineLowPrecOp::LNoOp; 2729 } 2730 } 2731 2732 /// Consume this token if it is a higher precedence affine op (there are only 2733 /// two precedence levels) 2734 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 2735 switch (getToken().getKind()) { 2736 case Token::star: 2737 consumeToken(Token::star); 2738 return Mul; 2739 case Token::kw_floordiv: 2740 consumeToken(Token::kw_floordiv); 2741 return FloorDiv; 2742 case Token::kw_ceildiv: 2743 consumeToken(Token::kw_ceildiv); 2744 return CeilDiv; 2745 case Token::kw_mod: 2746 consumeToken(Token::kw_mod); 2747 return Mod; 2748 default: 2749 return HNoOp; 2750 } 2751 } 2752 2753 /// Parse a high precedence op expression list: mul, div, and mod are high 2754 /// precedence binary ops, i.e., parse a 2755 /// expr_1 op_1 expr_2 op_2 ... expr_n 2756 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2757 /// All affine binary ops are left associative. 2758 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2759 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2760 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2761 /// report an error for non-conforming expressions. 2762 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2763 AffineHighPrecOp llhsOp, 2764 SMLoc llhsOpLoc) { 2765 AffineExpr lhs = parseAffineOperandExpr(llhs); 2766 if (!lhs) 2767 return nullptr; 2768 2769 // Found an LHS. Parse the remaining expression. 2770 auto opLoc = getToken().getLoc(); 2771 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2772 if (llhs) { 2773 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2774 if (!expr) 2775 return nullptr; 2776 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2777 } 2778 // No LLHS, get RHS 2779 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2780 } 2781 2782 // This is the last operand in this expression. 2783 if (llhs) 2784 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2785 2786 // No llhs, 'lhs' itself is the expression. 2787 return lhs; 2788 } 2789 2790 /// Parse an affine expression inside parentheses. 2791 /// 2792 /// affine-expr ::= `(` affine-expr `)` 2793 AffineExpr AffineParser::parseParentheticalExpr() { 2794 if (parseToken(Token::l_paren, "expected '('")) 2795 return nullptr; 2796 if (getToken().is(Token::r_paren)) 2797 return (emitError("no expression inside parentheses"), nullptr); 2798 2799 auto expr = parseAffineExpr(); 2800 if (!expr) 2801 return nullptr; 2802 if (parseToken(Token::r_paren, "expected ')'")) 2803 return nullptr; 2804 2805 return expr; 2806 } 2807 2808 /// Parse the negation expression. 2809 /// 2810 /// affine-expr ::= `-` affine-expr 2811 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2812 if (parseToken(Token::minus, "expected '-'")) 2813 return nullptr; 2814 2815 AffineExpr operand = parseAffineOperandExpr(lhs); 2816 // Since negation has the highest precedence of all ops (including high 2817 // precedence ops) but lower than parentheses, we are only going to use 2818 // parseAffineOperandExpr instead of parseAffineExpr here. 2819 if (!operand) 2820 // Extra error message although parseAffineOperandExpr would have 2821 // complained. Leads to a better diagnostic. 2822 return (emitError("missing operand of negation"), nullptr); 2823 return (-1) * operand; 2824 } 2825 2826 /// Parse a bare id that may appear in an affine expression. 2827 /// 2828 /// affine-expr ::= bare-id 2829 AffineExpr AffineParser::parseBareIdExpr() { 2830 if (getToken().isNot(Token::bare_identifier)) 2831 return (emitError("expected bare identifier"), nullptr); 2832 2833 StringRef sRef = getTokenSpelling(); 2834 for (auto entry : dimsAndSymbols) { 2835 if (entry.first == sRef) { 2836 consumeToken(Token::bare_identifier); 2837 return entry.second; 2838 } 2839 } 2840 2841 return (emitError("use of undeclared identifier"), nullptr); 2842 } 2843 2844 /// Parse an SSA id which may appear in an affine expression. 2845 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2846 if (!allowParsingSSAIds) 2847 return (emitError("unexpected ssa identifier"), nullptr); 2848 if (getToken().isNot(Token::percent_identifier)) 2849 return (emitError("expected ssa identifier"), nullptr); 2850 auto name = getTokenSpelling(); 2851 // Check if we already parsed this SSA id. 2852 for (auto entry : dimsAndSymbols) { 2853 if (entry.first == name) { 2854 consumeToken(Token::percent_identifier); 2855 return entry.second; 2856 } 2857 } 2858 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2859 if (parseElement(isSymbol)) 2860 return (emitError("failed to parse ssa identifier"), nullptr); 2861 auto idExpr = isSymbol 2862 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2863 : getAffineDimExpr(numDimOperands++, getContext()); 2864 dimsAndSymbols.push_back({name, idExpr}); 2865 return idExpr; 2866 } 2867 2868 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2869 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2870 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2871 return nullptr; 2872 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2873 if (!symbolExpr) 2874 return nullptr; 2875 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2876 return nullptr; 2877 return symbolExpr; 2878 } 2879 2880 /// Parse a positive integral constant appearing in an affine expression. 2881 /// 2882 /// affine-expr ::= integer-literal 2883 AffineExpr AffineParser::parseIntegerExpr() { 2884 auto val = getToken().getUInt64IntegerValue(); 2885 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2886 return (emitError("constant too large for index"), nullptr); 2887 2888 consumeToken(Token::integer); 2889 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2890 } 2891 2892 /// Parses an expression that can be a valid operand of an affine expression. 2893 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2894 /// operator, the rhs of which is being parsed. This is used to determine 2895 /// whether an error should be emitted for a missing right operand. 2896 // Eg: for an expression without parentheses (like i + j + k + l), each 2897 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2898 // operand expression, it's an op expression and will be parsed via 2899 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2900 // -l are valid operands that will be parsed by this function. 2901 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2902 switch (getToken().getKind()) { 2903 case Token::bare_identifier: 2904 return parseBareIdExpr(); 2905 case Token::kw_symbol: 2906 return parseSymbolSSAIdExpr(); 2907 case Token::percent_identifier: 2908 return parseSSAIdExpr(/*isSymbol=*/false); 2909 case Token::integer: 2910 return parseIntegerExpr(); 2911 case Token::l_paren: 2912 return parseParentheticalExpr(); 2913 case Token::minus: 2914 return parseNegateExpression(lhs); 2915 case Token::kw_ceildiv: 2916 case Token::kw_floordiv: 2917 case Token::kw_mod: 2918 case Token::plus: 2919 case Token::star: 2920 if (lhs) 2921 emitError("missing right operand of binary operator"); 2922 else 2923 emitError("missing left operand of binary operator"); 2924 return nullptr; 2925 default: 2926 if (lhs) 2927 emitError("missing right operand of binary operator"); 2928 else 2929 emitError("expected affine expression"); 2930 return nullptr; 2931 } 2932 } 2933 2934 /// Parse affine expressions that are bare-id's, integer constants, 2935 /// parenthetical affine expressions, and affine op expressions that are a 2936 /// composition of those. 2937 /// 2938 /// All binary op's associate from left to right. 2939 /// 2940 /// {add, sub} have lower precedence than {mul, div, and mod}. 2941 /// 2942 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2943 /// ceildiv, and mod are at the same higher precedence level. Negation has 2944 /// higher precedence than any binary op. 2945 /// 2946 /// llhs: the affine expression appearing on the left of the one being parsed. 2947 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2948 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2949 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2950 /// associativity. 2951 /// 2952 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2953 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2954 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2955 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2956 AffineLowPrecOp llhsOp) { 2957 AffineExpr lhs; 2958 if (!(lhs = parseAffineOperandExpr(llhs))) 2959 return nullptr; 2960 2961 // Found an LHS. Deal with the ops. 2962 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 2963 if (llhs) { 2964 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2965 return parseAffineLowPrecOpExpr(sum, lOp); 2966 } 2967 // No LLHS, get RHS and form the expression. 2968 return parseAffineLowPrecOpExpr(lhs, lOp); 2969 } 2970 auto opLoc = getToken().getLoc(); 2971 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 2972 // We have a higher precedence op here. Get the rhs operand for the llhs 2973 // through parseAffineHighPrecOpExpr. 2974 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 2975 if (!highRes) 2976 return nullptr; 2977 2978 // If llhs is null, the product forms the first operand of the yet to be 2979 // found expression. If non-null, the op to associate with llhs is llhsOp. 2980 AffineExpr expr = 2981 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 2982 2983 // Recurse for subsequent low prec op's after the affine high prec op 2984 // expression. 2985 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 2986 return parseAffineLowPrecOpExpr(expr, nextOp); 2987 return expr; 2988 } 2989 // Last operand in the expression list. 2990 if (llhs) 2991 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2992 // No llhs, 'lhs' itself is the expression. 2993 return lhs; 2994 } 2995 2996 /// Parse an affine expression. 2997 /// affine-expr ::= `(` affine-expr `)` 2998 /// | `-` affine-expr 2999 /// | affine-expr `+` affine-expr 3000 /// | affine-expr `-` affine-expr 3001 /// | affine-expr `*` affine-expr 3002 /// | affine-expr `floordiv` affine-expr 3003 /// | affine-expr `ceildiv` affine-expr 3004 /// | affine-expr `mod` affine-expr 3005 /// | bare-id 3006 /// | integer-literal 3007 /// 3008 /// Additional conditions are checked depending on the production. For eg., 3009 /// one of the operands for `*` has to be either constant/symbolic; the second 3010 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 3011 AffineExpr AffineParser::parseAffineExpr() { 3012 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 3013 } 3014 3015 /// Parse a dim or symbol from the lists appearing before the actual 3016 /// expressions of the affine map. Update our state to store the 3017 /// dimensional/symbolic identifier. 3018 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 3019 if (getToken().isNot(Token::bare_identifier)) 3020 return emitError("expected bare identifier"); 3021 3022 auto name = getTokenSpelling(); 3023 for (auto entry : dimsAndSymbols) { 3024 if (entry.first == name) 3025 return emitError("redefinition of identifier '" + name + "'"); 3026 } 3027 consumeToken(Token::bare_identifier); 3028 3029 dimsAndSymbols.push_back({name, idExpr}); 3030 return success(); 3031 } 3032 3033 /// Parse the list of dimensional identifiers to an affine map. 3034 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 3035 if (parseToken(Token::l_paren, 3036 "expected '(' at start of dimensional identifiers list")) { 3037 return failure(); 3038 } 3039 3040 auto parseElt = [&]() -> ParseResult { 3041 auto dimension = getAffineDimExpr(numDims++, getContext()); 3042 return parseIdentifierDefinition(dimension); 3043 }; 3044 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 3045 } 3046 3047 /// Parse the list of symbolic identifiers to an affine map. 3048 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 3049 consumeToken(Token::l_square); 3050 auto parseElt = [&]() -> ParseResult { 3051 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 3052 return parseIdentifierDefinition(symbol); 3053 }; 3054 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 3055 } 3056 3057 /// Parse the list of symbolic identifiers to an affine map. 3058 ParseResult 3059 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 3060 unsigned &numSymbols) { 3061 if (parseDimIdList(numDims)) { 3062 return failure(); 3063 } 3064 if (!getToken().is(Token::l_square)) { 3065 numSymbols = 0; 3066 return success(); 3067 } 3068 return parseSymbolIdList(numSymbols); 3069 } 3070 3071 /// Parses an ambiguous affine map or integer set definition inline. 3072 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 3073 IntegerSet &set) { 3074 unsigned numDims = 0, numSymbols = 0; 3075 3076 // List of dimensional and optional symbol identifiers. 3077 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 3078 return failure(); 3079 } 3080 3081 // This is needed for parsing attributes as we wouldn't know whether we would 3082 // be parsing an integer set attribute or an affine map attribute. 3083 bool isArrow = getToken().is(Token::arrow); 3084 bool isColon = getToken().is(Token::colon); 3085 if (!isArrow && !isColon) { 3086 return emitError("expected '->' or ':'"); 3087 } else if (isArrow) { 3088 parseToken(Token::arrow, "expected '->' or '['"); 3089 map = parseAffineMapRange(numDims, numSymbols); 3090 return map ? success() : failure(); 3091 } else if (parseToken(Token::colon, "expected ':' or '['")) { 3092 return failure(); 3093 } 3094 3095 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 3096 return success(); 3097 3098 return failure(); 3099 } 3100 3101 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 3102 ParseResult 3103 AffineParser::parseAffineMapOfSSAIds(AffineMap &map, 3104 OpAsmParser::Delimiter delimiter) { 3105 Token::Kind rightToken; 3106 switch (delimiter) { 3107 case OpAsmParser::Delimiter::Square: 3108 if (parseToken(Token::l_square, "expected '['")) 3109 return failure(); 3110 rightToken = Token::r_square; 3111 break; 3112 case OpAsmParser::Delimiter::Paren: 3113 if (parseToken(Token::l_paren, "expected '('")) 3114 return failure(); 3115 rightToken = Token::r_paren; 3116 break; 3117 default: 3118 return emitError("unexpected delimiter"); 3119 } 3120 3121 SmallVector<AffineExpr, 4> exprs; 3122 auto parseElt = [&]() -> ParseResult { 3123 auto elt = parseAffineExpr(); 3124 exprs.push_back(elt); 3125 return elt ? success() : failure(); 3126 }; 3127 3128 // Parse a multi-dimensional affine expression (a comma-separated list of 3129 // 1-d affine expressions); the list can be empty. Grammar: 3130 // multi-dim-affine-expr ::= `(` `)` 3131 // | `(` affine-expr (`,` affine-expr)* `)` 3132 if (parseCommaSeparatedListUntil(rightToken, parseElt, 3133 /*allowEmptyList=*/true)) 3134 return failure(); 3135 // Parsed a valid affine map. 3136 if (exprs.empty()) 3137 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 3138 getContext()); 3139 else 3140 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 3141 exprs); 3142 return success(); 3143 } 3144 3145 /// Parse the range and sizes affine map definition inline. 3146 /// 3147 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 3148 /// 3149 /// multi-dim-affine-expr ::= `(` `)` 3150 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` 3151 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 3152 unsigned numSymbols) { 3153 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 3154 3155 SmallVector<AffineExpr, 4> exprs; 3156 auto parseElt = [&]() -> ParseResult { 3157 auto elt = parseAffineExpr(); 3158 ParseResult res = elt ? success() : failure(); 3159 exprs.push_back(elt); 3160 return res; 3161 }; 3162 3163 // Parse a multi-dimensional affine expression (a comma-separated list of 3164 // 1-d affine expressions). Grammar: 3165 // multi-dim-affine-expr ::= `(` `)` 3166 // | `(` affine-expr (`,` affine-expr)* `)` 3167 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3168 return AffineMap(); 3169 3170 if (exprs.empty()) 3171 return AffineMap::get(numDims, numSymbols, getContext()); 3172 3173 // Parsed a valid affine map. 3174 return AffineMap::get(numDims, numSymbols, exprs); 3175 } 3176 3177 /// Parse an affine constraint. 3178 /// affine-constraint ::= affine-expr `>=` `0` 3179 /// | affine-expr `==` `0` 3180 /// 3181 /// isEq is set to true if the parsed constraint is an equality, false if it 3182 /// is an inequality (greater than or equal). 3183 /// 3184 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 3185 AffineExpr expr = parseAffineExpr(); 3186 if (!expr) 3187 return nullptr; 3188 3189 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 3190 getToken().is(Token::integer)) { 3191 auto dim = getToken().getUnsignedIntegerValue(); 3192 if (dim.hasValue() && dim.getValue() == 0) { 3193 consumeToken(Token::integer); 3194 *isEq = false; 3195 return expr; 3196 } 3197 return (emitError("expected '0' after '>='"), nullptr); 3198 } 3199 3200 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 3201 getToken().is(Token::integer)) { 3202 auto dim = getToken().getUnsignedIntegerValue(); 3203 if (dim.hasValue() && dim.getValue() == 0) { 3204 consumeToken(Token::integer); 3205 *isEq = true; 3206 return expr; 3207 } 3208 return (emitError("expected '0' after '=='"), nullptr); 3209 } 3210 3211 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 3212 nullptr); 3213 } 3214 3215 /// Parse the constraints that are part of an integer set definition. 3216 /// integer-set-inline 3217 /// ::= dim-and-symbol-id-lists `:` 3218 /// '(' affine-constraint-conjunction? ')' 3219 /// affine-constraint-conjunction ::= affine-constraint (`,` 3220 /// affine-constraint)* 3221 /// 3222 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 3223 unsigned numSymbols) { 3224 if (parseToken(Token::l_paren, 3225 "expected '(' at start of integer set constraint list")) 3226 return IntegerSet(); 3227 3228 SmallVector<AffineExpr, 4> constraints; 3229 SmallVector<bool, 4> isEqs; 3230 auto parseElt = [&]() -> ParseResult { 3231 bool isEq; 3232 auto elt = parseAffineConstraint(&isEq); 3233 ParseResult res = elt ? success() : failure(); 3234 if (elt) { 3235 constraints.push_back(elt); 3236 isEqs.push_back(isEq); 3237 } 3238 return res; 3239 }; 3240 3241 // Parse a list of affine constraints (comma-separated). 3242 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 3243 return IntegerSet(); 3244 3245 // If no constraints were parsed, then treat this as a degenerate 'true' case. 3246 if (constraints.empty()) { 3247 /* 0 == 0 */ 3248 auto zero = getAffineConstantExpr(0, getContext()); 3249 return IntegerSet::get(numDims, numSymbols, zero, true); 3250 } 3251 3252 // Parsed a valid integer set. 3253 return IntegerSet::get(numDims, numSymbols, constraints, isEqs); 3254 } 3255 3256 /// Parse an ambiguous reference to either and affine map or an integer set. 3257 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 3258 IntegerSet &set) { 3259 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 3260 } 3261 ParseResult Parser::parseAffineMapReference(AffineMap &map) { 3262 llvm::SMLoc curLoc = getToken().getLoc(); 3263 IntegerSet set; 3264 if (parseAffineMapOrIntegerSetReference(map, set)) 3265 return failure(); 3266 if (set) 3267 return emitError(curLoc, "expected AffineMap, but got IntegerSet"); 3268 return success(); 3269 } 3270 ParseResult Parser::parseIntegerSetReference(IntegerSet &set) { 3271 llvm::SMLoc curLoc = getToken().getLoc(); 3272 AffineMap map; 3273 if (parseAffineMapOrIntegerSetReference(map, set)) 3274 return failure(); 3275 if (map) 3276 return emitError(curLoc, "expected IntegerSet, but got AffineMap"); 3277 return success(); 3278 } 3279 3280 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 3281 /// parse SSA value uses encountered while parsing affine expressions. 3282 ParseResult 3283 Parser::parseAffineMapOfSSAIds(AffineMap &map, 3284 function_ref<ParseResult(bool)> parseElement, 3285 OpAsmParser::Delimiter delimiter) { 3286 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 3287 .parseAffineMapOfSSAIds(map, delimiter); 3288 } 3289 3290 //===----------------------------------------------------------------------===// 3291 // OperationParser 3292 //===----------------------------------------------------------------------===// 3293 3294 namespace { 3295 /// This class provides support for parsing operations and regions of 3296 /// operations. 3297 class OperationParser : public Parser { 3298 public: 3299 OperationParser(ParserState &state, ModuleOp moduleOp) 3300 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 3301 } 3302 3303 ~OperationParser(); 3304 3305 /// After parsing is finished, this function must be called to see if there 3306 /// are any remaining issues. 3307 ParseResult finalize(); 3308 3309 //===--------------------------------------------------------------------===// 3310 // SSA Value Handling 3311 //===--------------------------------------------------------------------===// 3312 3313 /// This represents a use of an SSA value in the program. The first two 3314 /// entries in the tuple are the name and result number of a reference. The 3315 /// third is the location of the reference, which is used in case this ends 3316 /// up being a use of an undefined value. 3317 struct SSAUseInfo { 3318 StringRef name; // Value name, e.g. %42 or %abc 3319 unsigned number; // Number, specified with #12 3320 SMLoc loc; // Location of first definition or use. 3321 }; 3322 3323 /// Push a new SSA name scope to the parser. 3324 void pushSSANameScope(bool isIsolated); 3325 3326 /// Pop the last SSA name scope from the parser. 3327 ParseResult popSSANameScope(); 3328 3329 /// Register a definition of a value with the symbol table. 3330 ParseResult addDefinition(SSAUseInfo useInfo, Value value); 3331 3332 /// Parse an optional list of SSA uses into 'results'. 3333 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 3334 3335 /// Parse a single SSA use into 'result'. 3336 ParseResult parseSSAUse(SSAUseInfo &result); 3337 3338 /// Given a reference to an SSA value and its type, return a reference. This 3339 /// returns null on failure. 3340 Value resolveSSAUse(SSAUseInfo useInfo, Type type); 3341 3342 ParseResult parseSSADefOrUseAndType( 3343 const std::function<ParseResult(SSAUseInfo, Type)> &action); 3344 3345 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value> &results); 3346 3347 /// Return the location of the value identified by its name and number if it 3348 /// has been already reference. 3349 Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 3350 auto &values = isolatedNameScopes.back().values; 3351 if (!values.count(name) || number >= values[name].size()) 3352 return {}; 3353 if (values[name][number].first) 3354 return values[name][number].second; 3355 return {}; 3356 } 3357 3358 //===--------------------------------------------------------------------===// 3359 // Operation Parsing 3360 //===--------------------------------------------------------------------===// 3361 3362 /// Parse an operation instance. 3363 ParseResult parseOperation(); 3364 3365 /// Parse a single operation successor. 3366 ParseResult parseSuccessor(Block *&dest); 3367 3368 /// Parse a comma-separated list of operation successors in brackets. 3369 ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations); 3370 3371 /// Parse an operation instance that is in the generic form. 3372 Operation *parseGenericOperation(); 3373 3374 /// Parse an operation instance that is in the generic form and insert it at 3375 /// the provided insertion point. 3376 Operation *parseGenericOperation(Block *insertBlock, 3377 Block::iterator insertPt); 3378 3379 /// This is the structure of a result specifier in the assembly syntax, 3380 /// including the name, number of results, and location. 3381 typedef std::tuple<StringRef, unsigned, SMLoc> ResultRecord; 3382 3383 /// Parse an operation instance that is in the op-defined custom form. 3384 /// resultInfo specifies information about the "%name =" specifiers. 3385 Operation *parseCustomOperation(ArrayRef<ResultRecord> resultInfo); 3386 3387 //===--------------------------------------------------------------------===// 3388 // Region Parsing 3389 //===--------------------------------------------------------------------===// 3390 3391 /// Parse a region into 'region' with the provided entry block arguments. 3392 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 3393 /// isolated from those above. 3394 ParseResult parseRegion(Region ®ion, 3395 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 3396 bool isIsolatedNameScope = false); 3397 3398 /// Parse a region body into 'region'. 3399 ParseResult parseRegionBody(Region ®ion); 3400 3401 //===--------------------------------------------------------------------===// 3402 // Block Parsing 3403 //===--------------------------------------------------------------------===// 3404 3405 /// Parse a new block into 'block'. 3406 ParseResult parseBlock(Block *&block); 3407 3408 /// Parse a list of operations into 'block'. 3409 ParseResult parseBlockBody(Block *block); 3410 3411 /// Parse a (possibly empty) list of block arguments. 3412 ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results, 3413 Block *owner); 3414 3415 /// Get the block with the specified name, creating it if it doesn't 3416 /// already exist. The location specified is the point of use, which allows 3417 /// us to diagnose references to blocks that are not defined precisely. 3418 Block *getBlockNamed(StringRef name, SMLoc loc); 3419 3420 /// Define the block with the specified name. Returns the Block* or nullptr in 3421 /// the case of redefinition. 3422 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 3423 3424 private: 3425 /// Returns the info for a block at the current scope for the given name. 3426 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 3427 return blocksByName.back()[name]; 3428 } 3429 3430 /// Insert a new forward reference to the given block. 3431 void insertForwardRef(Block *block, SMLoc loc) { 3432 forwardRef.back().try_emplace(block, loc); 3433 } 3434 3435 /// Erase any forward reference to the given block. 3436 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 3437 3438 /// Record that a definition was added at the current scope. 3439 void recordDefinition(StringRef def); 3440 3441 /// Get the value entry for the given SSA name. 3442 SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name); 3443 3444 /// Create a forward reference placeholder value with the given location and 3445 /// result type. 3446 Value createForwardRefPlaceholder(SMLoc loc, Type type); 3447 3448 /// Return true if this is a forward reference. 3449 bool isForwardRefPlaceholder(Value value) { 3450 return forwardRefPlaceholders.count(value); 3451 } 3452 3453 /// This struct represents an isolated SSA name scope. This scope may contain 3454 /// other nested non-isolated scopes. These scopes are used for operations 3455 /// that are known to be isolated to allow for reusing names within their 3456 /// regions, even if those names are used above. 3457 struct IsolatedSSANameScope { 3458 /// Record that a definition was added at the current scope. 3459 void recordDefinition(StringRef def) { 3460 definitionsPerScope.back().insert(def); 3461 } 3462 3463 /// Push a nested name scope. 3464 void pushSSANameScope() { definitionsPerScope.push_back({}); } 3465 3466 /// Pop a nested name scope. 3467 void popSSANameScope() { 3468 for (auto &def : definitionsPerScope.pop_back_val()) 3469 values.erase(def.getKey()); 3470 } 3471 3472 /// This keeps track of all of the SSA values we are tracking for each name 3473 /// scope, indexed by their name. This has one entry per result number. 3474 llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values; 3475 3476 /// This keeps track of all of the values defined by a specific name scope. 3477 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 3478 }; 3479 3480 /// A list of isolated name scopes. 3481 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 3482 3483 /// This keeps track of the block names as well as the location of the first 3484 /// reference for each nested name scope. This is used to diagnose invalid 3485 /// block references and memorize them. 3486 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 3487 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 3488 3489 /// These are all of the placeholders we've made along with the location of 3490 /// their first reference, to allow checking for use of undefined values. 3491 DenseMap<Value, SMLoc> forwardRefPlaceholders; 3492 3493 /// The builder used when creating parsed operation instances. 3494 OpBuilder opBuilder; 3495 3496 /// The top level module operation. 3497 ModuleOp moduleOp; 3498 }; 3499 } // end anonymous namespace 3500 3501 OperationParser::~OperationParser() { 3502 for (auto &fwd : forwardRefPlaceholders) { 3503 // Drop all uses of undefined forward declared reference and destroy 3504 // defining operation. 3505 fwd.first.dropAllUses(); 3506 fwd.first.getDefiningOp()->destroy(); 3507 } 3508 } 3509 3510 /// After parsing is finished, this function must be called to see if there are 3511 /// any remaining issues. 3512 ParseResult OperationParser::finalize() { 3513 // Check for any forward references that are left. If we find any, error 3514 // out. 3515 if (!forwardRefPlaceholders.empty()) { 3516 SmallVector<std::pair<const char *, Value>, 4> errors; 3517 // Iteration over the map isn't deterministic, so sort by source location. 3518 for (auto entry : forwardRefPlaceholders) 3519 errors.push_back({entry.second.getPointer(), entry.first}); 3520 llvm::array_pod_sort(errors.begin(), errors.end()); 3521 3522 for (auto entry : errors) { 3523 auto loc = SMLoc::getFromPointer(entry.first); 3524 emitError(loc, "use of undeclared SSA value name"); 3525 } 3526 return failure(); 3527 } 3528 3529 return success(); 3530 } 3531 3532 //===----------------------------------------------------------------------===// 3533 // SSA Value Handling 3534 //===----------------------------------------------------------------------===// 3535 3536 void OperationParser::pushSSANameScope(bool isIsolated) { 3537 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 3538 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 3539 3540 // Push back a new name definition scope. 3541 if (isIsolated) 3542 isolatedNameScopes.push_back({}); 3543 isolatedNameScopes.back().pushSSANameScope(); 3544 } 3545 3546 ParseResult OperationParser::popSSANameScope() { 3547 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 3548 3549 // Verify that all referenced blocks were defined. 3550 if (!forwardRefInCurrentScope.empty()) { 3551 SmallVector<std::pair<const char *, Block *>, 4> errors; 3552 // Iteration over the map isn't deterministic, so sort by source location. 3553 for (auto entry : forwardRefInCurrentScope) { 3554 errors.push_back({entry.second.getPointer(), entry.first}); 3555 // Add this block to the top-level region to allow for automatic cleanup. 3556 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 3557 } 3558 llvm::array_pod_sort(errors.begin(), errors.end()); 3559 3560 for (auto entry : errors) { 3561 auto loc = SMLoc::getFromPointer(entry.first); 3562 emitError(loc, "reference to an undefined block"); 3563 } 3564 return failure(); 3565 } 3566 3567 // Pop the next nested namescope. If there is only one internal namescope, 3568 // just pop the isolated scope. 3569 auto ¤tNameScope = isolatedNameScopes.back(); 3570 if (currentNameScope.definitionsPerScope.size() == 1) 3571 isolatedNameScopes.pop_back(); 3572 else 3573 currentNameScope.popSSANameScope(); 3574 3575 blocksByName.pop_back(); 3576 return success(); 3577 } 3578 3579 /// Register a definition of a value with the symbol table. 3580 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) { 3581 auto &entries = getSSAValueEntry(useInfo.name); 3582 3583 // Make sure there is a slot for this value. 3584 if (entries.size() <= useInfo.number) 3585 entries.resize(useInfo.number + 1); 3586 3587 // If we already have an entry for this, check to see if it was a definition 3588 // or a forward reference. 3589 if (auto existing = entries[useInfo.number].first) { 3590 if (!isForwardRefPlaceholder(existing)) { 3591 return emitError(useInfo.loc) 3592 .append("redefinition of SSA value '", useInfo.name, "'") 3593 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3594 .append("previously defined here"); 3595 } 3596 3597 // If it was a forward reference, update everything that used it to use 3598 // the actual definition instead, delete the forward ref, and remove it 3599 // from our set of forward references we track. 3600 existing.replaceAllUsesWith(value); 3601 existing.getDefiningOp()->destroy(); 3602 forwardRefPlaceholders.erase(existing); 3603 } 3604 3605 /// Record this definition for the current scope. 3606 entries[useInfo.number] = {value, useInfo.loc}; 3607 recordDefinition(useInfo.name); 3608 return success(); 3609 } 3610 3611 /// Parse a (possibly empty) list of SSA operands. 3612 /// 3613 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 3614 /// ssa-use-list-opt ::= ssa-use-list? 3615 /// 3616 ParseResult 3617 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 3618 if (getToken().isNot(Token::percent_identifier)) 3619 return success(); 3620 return parseCommaSeparatedList([&]() -> ParseResult { 3621 SSAUseInfo result; 3622 if (parseSSAUse(result)) 3623 return failure(); 3624 results.push_back(result); 3625 return success(); 3626 }); 3627 } 3628 3629 /// Parse a SSA operand for an operation. 3630 /// 3631 /// ssa-use ::= ssa-id 3632 /// 3633 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 3634 result.name = getTokenSpelling(); 3635 result.number = 0; 3636 result.loc = getToken().getLoc(); 3637 if (parseToken(Token::percent_identifier, "expected SSA operand")) 3638 return failure(); 3639 3640 // If we have an attribute ID, it is a result number. 3641 if (getToken().is(Token::hash_identifier)) { 3642 if (auto value = getToken().getHashIdentifierNumber()) 3643 result.number = value.getValue(); 3644 else 3645 return emitError("invalid SSA value result number"); 3646 consumeToken(Token::hash_identifier); 3647 } 3648 3649 return success(); 3650 } 3651 3652 /// Given an unbound reference to an SSA value and its type, return the value 3653 /// it specifies. This returns null on failure. 3654 Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 3655 auto &entries = getSSAValueEntry(useInfo.name); 3656 3657 // If we have already seen a value of this name, return it. 3658 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 3659 auto result = entries[useInfo.number].first; 3660 // Check that the type matches the other uses. 3661 if (result.getType() == type) 3662 return result; 3663 3664 emitError(useInfo.loc, "use of value '") 3665 .append(useInfo.name, 3666 "' expects different type than prior uses: ", type, " vs ", 3667 result.getType()) 3668 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3669 .append("prior use here"); 3670 return nullptr; 3671 } 3672 3673 // Make sure we have enough slots for this. 3674 if (entries.size() <= useInfo.number) 3675 entries.resize(useInfo.number + 1); 3676 3677 // If the value has already been defined and this is an overly large result 3678 // number, diagnose that. 3679 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 3680 return (emitError(useInfo.loc, "reference to invalid result number"), 3681 nullptr); 3682 3683 // Otherwise, this is a forward reference. Create a placeholder and remember 3684 // that we did so. 3685 auto result = createForwardRefPlaceholder(useInfo.loc, type); 3686 entries[useInfo.number].first = result; 3687 entries[useInfo.number].second = useInfo.loc; 3688 return result; 3689 } 3690 3691 /// Parse an SSA use with an associated type. 3692 /// 3693 /// ssa-use-and-type ::= ssa-use `:` type 3694 ParseResult OperationParser::parseSSADefOrUseAndType( 3695 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 3696 SSAUseInfo useInfo; 3697 if (parseSSAUse(useInfo) || 3698 parseToken(Token::colon, "expected ':' and type for SSA operand")) 3699 return failure(); 3700 3701 auto type = parseType(); 3702 if (!type) 3703 return failure(); 3704 3705 return action(useInfo, type); 3706 } 3707 3708 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 3709 /// followed by a type list. 3710 /// 3711 /// ssa-use-and-type-list 3712 /// ::= ssa-use-list ':' type-list-no-parens 3713 /// 3714 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 3715 SmallVectorImpl<Value> &results) { 3716 SmallVector<SSAUseInfo, 4> valueIDs; 3717 if (parseOptionalSSAUseList(valueIDs)) 3718 return failure(); 3719 3720 // If there were no operands, then there is no colon or type lists. 3721 if (valueIDs.empty()) 3722 return success(); 3723 3724 SmallVector<Type, 4> types; 3725 if (parseToken(Token::colon, "expected ':' in operand list") || 3726 parseTypeListNoParens(types)) 3727 return failure(); 3728 3729 if (valueIDs.size() != types.size()) 3730 return emitError("expected ") 3731 << valueIDs.size() << " types to match operand list"; 3732 3733 results.reserve(valueIDs.size()); 3734 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 3735 if (auto value = resolveSSAUse(valueIDs[i], types[i])) 3736 results.push_back(value); 3737 else 3738 return failure(); 3739 } 3740 3741 return success(); 3742 } 3743 3744 /// Record that a definition was added at the current scope. 3745 void OperationParser::recordDefinition(StringRef def) { 3746 isolatedNameScopes.back().recordDefinition(def); 3747 } 3748 3749 /// Get the value entry for the given SSA name. 3750 SmallVectorImpl<std::pair<Value, SMLoc>> & 3751 OperationParser::getSSAValueEntry(StringRef name) { 3752 return isolatedNameScopes.back().values[name]; 3753 } 3754 3755 /// Create and remember a new placeholder for a forward reference. 3756 Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 3757 // Forward references are always created as operations, because we just need 3758 // something with a def/use chain. 3759 // 3760 // We create these placeholders as having an empty name, which we know 3761 // cannot be created through normal user input, allowing us to distinguish 3762 // them. 3763 auto name = OperationName("placeholder", getContext()); 3764 auto *op = Operation::create( 3765 getEncodedSourceLocation(loc), name, type, /*operands=*/{}, 3766 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, 3767 /*resizableOperandList=*/false); 3768 forwardRefPlaceholders[op->getResult(0)] = loc; 3769 return op->getResult(0); 3770 } 3771 3772 //===----------------------------------------------------------------------===// 3773 // Operation Parsing 3774 //===----------------------------------------------------------------------===// 3775 3776 /// Parse an operation. 3777 /// 3778 /// operation ::= op-result-list? 3779 /// (generic-operation | custom-operation) 3780 /// trailing-location? 3781 /// generic-operation ::= string-literal `(` ssa-use-list? `)` 3782 /// successor-list? (`(` region-list `)`)? 3783 /// attribute-dict? `:` function-type 3784 /// custom-operation ::= bare-id custom-operation-format 3785 /// op-result-list ::= op-result (`,` op-result)* `=` 3786 /// op-result ::= ssa-id (`:` integer-literal) 3787 /// 3788 ParseResult OperationParser::parseOperation() { 3789 auto loc = getToken().getLoc(); 3790 SmallVector<ResultRecord, 1> resultIDs; 3791 size_t numExpectedResults = 0; 3792 if (getToken().is(Token::percent_identifier)) { 3793 // Parse the group of result ids. 3794 auto parseNextResult = [&]() -> ParseResult { 3795 // Parse the next result id. 3796 if (!getToken().is(Token::percent_identifier)) 3797 return emitError("expected valid ssa identifier"); 3798 3799 Token nameTok = getToken(); 3800 consumeToken(Token::percent_identifier); 3801 3802 // If the next token is a ':', we parse the expected result count. 3803 size_t expectedSubResults = 1; 3804 if (consumeIf(Token::colon)) { 3805 // Check that the next token is an integer. 3806 if (!getToken().is(Token::integer)) 3807 return emitError("expected integer number of results"); 3808 3809 // Check that number of results is > 0. 3810 auto val = getToken().getUInt64IntegerValue(); 3811 if (!val.hasValue() || val.getValue() < 1) 3812 return emitError("expected named operation to have atleast 1 result"); 3813 consumeToken(Token::integer); 3814 expectedSubResults = *val; 3815 } 3816 3817 resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults, 3818 nameTok.getLoc()); 3819 numExpectedResults += expectedSubResults; 3820 return success(); 3821 }; 3822 if (parseCommaSeparatedList(parseNextResult)) 3823 return failure(); 3824 3825 if (parseToken(Token::equal, "expected '=' after SSA name")) 3826 return failure(); 3827 } 3828 3829 Operation *op; 3830 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3831 op = parseCustomOperation(resultIDs); 3832 else if (getToken().is(Token::string)) 3833 op = parseGenericOperation(); 3834 else 3835 return emitError("expected operation name in quotes"); 3836 3837 // If parsing of the basic operation failed, then this whole thing fails. 3838 if (!op) 3839 return failure(); 3840 3841 // If the operation had a name, register it. 3842 if (!resultIDs.empty()) { 3843 if (op->getNumResults() == 0) 3844 return emitError(loc, "cannot name an operation with no results"); 3845 if (numExpectedResults != op->getNumResults()) 3846 return emitError(loc, "operation defines ") 3847 << op->getNumResults() << " results but was provided " 3848 << numExpectedResults << " to bind"; 3849 3850 // Add definitions for each of the result groups. 3851 unsigned opResI = 0; 3852 for (ResultRecord &resIt : resultIDs) { 3853 for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) { 3854 if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, 3855 op->getResult(opResI++))) 3856 return failure(); 3857 } 3858 } 3859 } 3860 3861 return success(); 3862 } 3863 3864 /// Parse a single operation successor. 3865 /// 3866 /// successor ::= block-id 3867 /// 3868 ParseResult OperationParser::parseSuccessor(Block *&dest) { 3869 // Verify branch is identifier and get the matching block. 3870 if (!getToken().is(Token::caret_identifier)) 3871 return emitError("expected block name"); 3872 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3873 consumeToken(); 3874 return success(); 3875 } 3876 3877 /// Parse a comma-separated list of operation successors in brackets. 3878 /// 3879 /// successor-list ::= `[` successor (`,` successor )* `]` 3880 /// 3881 ParseResult 3882 OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) { 3883 if (parseToken(Token::l_square, "expected '['")) 3884 return failure(); 3885 3886 auto parseElt = [this, &destinations] { 3887 Block *dest; 3888 ParseResult res = parseSuccessor(dest); 3889 destinations.push_back(dest); 3890 return res; 3891 }; 3892 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3893 /*allowEmptyList=*/false); 3894 } 3895 3896 namespace { 3897 // RAII-style guard for cleaning up the regions in the operation state before 3898 // deleting them. Within the parser, regions may get deleted if parsing failed, 3899 // and other errors may be present, in particular undominated uses. This makes 3900 // sure such uses are deleted. 3901 struct CleanupOpStateRegions { 3902 ~CleanupOpStateRegions() { 3903 SmallVector<Region *, 4> regionsToClean; 3904 regionsToClean.reserve(state.regions.size()); 3905 for (auto ®ion : state.regions) 3906 if (region) 3907 for (auto &block : *region) 3908 block.dropAllDefinedValueUses(); 3909 } 3910 OperationState &state; 3911 }; 3912 } // namespace 3913 3914 Operation *OperationParser::parseGenericOperation() { 3915 // Get location information for the operation. 3916 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3917 3918 auto name = getToken().getStringValue(); 3919 if (name.empty()) 3920 return (emitError("empty operation name is invalid"), nullptr); 3921 if (name.find('\0') != StringRef::npos) 3922 return (emitError("null character not allowed in operation name"), nullptr); 3923 3924 consumeToken(Token::string); 3925 3926 OperationState result(srcLocation, name); 3927 3928 // Generic operations have a resizable operation list. 3929 result.setOperandListToResizable(); 3930 3931 // Parse the operand list. 3932 SmallVector<SSAUseInfo, 8> operandInfos; 3933 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3934 parseOptionalSSAUseList(operandInfos) || 3935 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3936 return nullptr; 3937 } 3938 3939 // Parse the successor list. 3940 if (getToken().is(Token::l_square)) { 3941 // Check if the operation is a known terminator. 3942 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3943 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3944 return emitError("successors in non-terminator"), nullptr; 3945 3946 SmallVector<Block *, 2> successors; 3947 if (parseSuccessors(successors)) 3948 return nullptr; 3949 result.addSuccessors(successors); 3950 } 3951 3952 // Parse the region list. 3953 CleanupOpStateRegions guard{result}; 3954 if (consumeIf(Token::l_paren)) { 3955 do { 3956 // Create temporary regions with the top level region as parent. 3957 result.regions.emplace_back(new Region(moduleOp)); 3958 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3959 return nullptr; 3960 } while (consumeIf(Token::comma)); 3961 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3962 return nullptr; 3963 } 3964 3965 if (getToken().is(Token::l_brace)) { 3966 if (parseAttributeDict(result.attributes)) 3967 return nullptr; 3968 } 3969 3970 if (parseToken(Token::colon, "expected ':' followed by operation type")) 3971 return nullptr; 3972 3973 auto typeLoc = getToken().getLoc(); 3974 auto type = parseType(); 3975 if (!type) 3976 return nullptr; 3977 auto fnType = type.dyn_cast<FunctionType>(); 3978 if (!fnType) 3979 return (emitError(typeLoc, "expected function type"), nullptr); 3980 3981 result.addTypes(fnType.getResults()); 3982 3983 // Check that we have the right number of types for the operands. 3984 auto operandTypes = fnType.getInputs(); 3985 if (operandTypes.size() != operandInfos.size()) { 3986 auto plural = "s"[operandInfos.size() == 1]; 3987 return (emitError(typeLoc, "expected ") 3988 << operandInfos.size() << " operand type" << plural 3989 << " but had " << operandTypes.size(), 3990 nullptr); 3991 } 3992 3993 // Resolve all of the operands. 3994 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 3995 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 3996 if (!result.operands.back()) 3997 return nullptr; 3998 } 3999 4000 // Parse a location if one is present. 4001 if (parseOptionalTrailingLocation(result.location)) 4002 return nullptr; 4003 4004 return opBuilder.createOperation(result); 4005 } 4006 4007 Operation *OperationParser::parseGenericOperation(Block *insertBlock, 4008 Block::iterator insertPt) { 4009 OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); 4010 opBuilder.setInsertionPoint(insertBlock, insertPt); 4011 return parseGenericOperation(); 4012 } 4013 4014 namespace { 4015 class CustomOpAsmParser : public OpAsmParser { 4016 public: 4017 CustomOpAsmParser(SMLoc nameLoc, 4018 ArrayRef<OperationParser::ResultRecord> resultIDs, 4019 const AbstractOperation *opDefinition, 4020 OperationParser &parser) 4021 : nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition), 4022 parser(parser) {} 4023 4024 /// Parse an instance of the operation described by 'opDefinition' into the 4025 /// provided operation state. 4026 ParseResult parseOperation(OperationState &opState) { 4027 if (opDefinition->parseAssembly(*this, opState)) 4028 return failure(); 4029 return success(); 4030 } 4031 4032 Operation *parseGenericOperation(Block *insertBlock, 4033 Block::iterator insertPt) final { 4034 return parser.parseGenericOperation(insertBlock, insertPt); 4035 } 4036 4037 //===--------------------------------------------------------------------===// 4038 // Utilities 4039 //===--------------------------------------------------------------------===// 4040 4041 /// Return if any errors were emitted during parsing. 4042 bool didEmitError() const { return emittedError; } 4043 4044 /// Emit a diagnostic at the specified location and return failure. 4045 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 4046 emittedError = true; 4047 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 4048 message); 4049 } 4050 4051 llvm::SMLoc getCurrentLocation() override { 4052 return parser.getToken().getLoc(); 4053 } 4054 4055 Builder &getBuilder() const override { return parser.builder; } 4056 4057 /// Return the name of the specified result in the specified syntax, as well 4058 /// as the subelement in the name. For example, in this operation: 4059 /// 4060 /// %x, %y:2, %z = foo.op 4061 /// 4062 /// getResultName(0) == {"x", 0 } 4063 /// getResultName(1) == {"y", 0 } 4064 /// getResultName(2) == {"y", 1 } 4065 /// getResultName(3) == {"z", 0 } 4066 std::pair<StringRef, unsigned> 4067 getResultName(unsigned resultNo) const override { 4068 // Scan for the resultID that contains this result number. 4069 for (unsigned nameID = 0, e = resultIDs.size(); nameID != e; ++nameID) { 4070 const auto &entry = resultIDs[nameID]; 4071 if (resultNo < std::get<1>(entry)) { 4072 // Don't pass on the leading %. 4073 StringRef name = std::get<0>(entry).drop_front(); 4074 return {name, resultNo}; 4075 } 4076 resultNo -= std::get<1>(entry); 4077 } 4078 4079 // Invalid result number. 4080 return {"", ~0U}; 4081 } 4082 4083 /// Return the number of declared SSA results. This returns 4 for the foo.op 4084 /// example in the comment for getResultName. 4085 size_t getNumResults() const override { 4086 size_t count = 0; 4087 for (auto &entry : resultIDs) 4088 count += std::get<1>(entry); 4089 return count; 4090 } 4091 4092 llvm::SMLoc getNameLoc() const override { return nameLoc; } 4093 4094 //===--------------------------------------------------------------------===// 4095 // Token Parsing 4096 //===--------------------------------------------------------------------===// 4097 4098 /// Parse a `->` token. 4099 ParseResult parseArrow() override { 4100 return parser.parseToken(Token::arrow, "expected '->'"); 4101 } 4102 4103 /// Parses a `->` if present. 4104 ParseResult parseOptionalArrow() override { 4105 return success(parser.consumeIf(Token::arrow)); 4106 } 4107 4108 /// Parse a `:` token. 4109 ParseResult parseColon() override { 4110 return parser.parseToken(Token::colon, "expected ':'"); 4111 } 4112 4113 /// Parse a `:` token if present. 4114 ParseResult parseOptionalColon() override { 4115 return success(parser.consumeIf(Token::colon)); 4116 } 4117 4118 /// Parse a `,` token. 4119 ParseResult parseComma() override { 4120 return parser.parseToken(Token::comma, "expected ','"); 4121 } 4122 4123 /// Parse a `,` token if present. 4124 ParseResult parseOptionalComma() override { 4125 return success(parser.consumeIf(Token::comma)); 4126 } 4127 4128 /// Parses a `...` if present. 4129 ParseResult parseOptionalEllipsis() override { 4130 return success(parser.consumeIf(Token::ellipsis)); 4131 } 4132 4133 /// Parse a `=` token. 4134 ParseResult parseEqual() override { 4135 return parser.parseToken(Token::equal, "expected '='"); 4136 } 4137 4138 /// Parse a '<' token. 4139 ParseResult parseLess() override { 4140 return parser.parseToken(Token::less, "expected '<'"); 4141 } 4142 4143 /// Parse a '>' token. 4144 ParseResult parseGreater() override { 4145 return parser.parseToken(Token::greater, "expected '>'"); 4146 } 4147 4148 /// Parse a `(` token. 4149 ParseResult parseLParen() override { 4150 return parser.parseToken(Token::l_paren, "expected '('"); 4151 } 4152 4153 /// Parses a '(' if present. 4154 ParseResult parseOptionalLParen() override { 4155 return success(parser.consumeIf(Token::l_paren)); 4156 } 4157 4158 /// Parse a `)` token. 4159 ParseResult parseRParen() override { 4160 return parser.parseToken(Token::r_paren, "expected ')'"); 4161 } 4162 4163 /// Parses a ')' if present. 4164 ParseResult parseOptionalRParen() override { 4165 return success(parser.consumeIf(Token::r_paren)); 4166 } 4167 4168 /// Parse a `[` token. 4169 ParseResult parseLSquare() override { 4170 return parser.parseToken(Token::l_square, "expected '['"); 4171 } 4172 4173 /// Parses a '[' if present. 4174 ParseResult parseOptionalLSquare() override { 4175 return success(parser.consumeIf(Token::l_square)); 4176 } 4177 4178 /// Parse a `]` token. 4179 ParseResult parseRSquare() override { 4180 return parser.parseToken(Token::r_square, "expected ']'"); 4181 } 4182 4183 /// Parses a ']' if present. 4184 ParseResult parseOptionalRSquare() override { 4185 return success(parser.consumeIf(Token::r_square)); 4186 } 4187 4188 //===--------------------------------------------------------------------===// 4189 // Attribute Parsing 4190 //===--------------------------------------------------------------------===// 4191 4192 /// Parse an arbitrary attribute of a given type and return it in result. This 4193 /// also adds the attribute to the specified attribute list with the specified 4194 /// name. 4195 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 4196 SmallVectorImpl<NamedAttribute> &attrs) override { 4197 result = parser.parseAttribute(type); 4198 if (!result) 4199 return failure(); 4200 4201 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 4202 return success(); 4203 } 4204 4205 /// Parse a named dictionary into 'result' if it is present. 4206 ParseResult 4207 parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) override { 4208 if (parser.getToken().isNot(Token::l_brace)) 4209 return success(); 4210 return parser.parseAttributeDict(result); 4211 } 4212 4213 /// Parse a named dictionary into 'result' if the `attributes` keyword is 4214 /// present. 4215 ParseResult parseOptionalAttrDictWithKeyword( 4216 SmallVectorImpl<NamedAttribute> &result) override { 4217 if (failed(parseOptionalKeyword("attributes"))) 4218 return success(); 4219 return parser.parseAttributeDict(result); 4220 } 4221 4222 /// Parse an affine map instance into 'map'. 4223 ParseResult parseAffineMap(AffineMap &map) override { 4224 return parser.parseAffineMapReference(map); 4225 } 4226 4227 /// Parse an integer set instance into 'set'. 4228 ParseResult printIntegerSet(IntegerSet &set) override { 4229 return parser.parseIntegerSetReference(set); 4230 } 4231 4232 //===--------------------------------------------------------------------===// 4233 // Identifier Parsing 4234 //===--------------------------------------------------------------------===// 4235 4236 /// Returns if the current token corresponds to a keyword. 4237 bool isCurrentTokenAKeyword() const { 4238 return parser.getToken().is(Token::bare_identifier) || 4239 parser.getToken().isKeyword(); 4240 } 4241 4242 /// Parse the given keyword if present. 4243 ParseResult parseOptionalKeyword(StringRef keyword) override { 4244 // Check that the current token has the same spelling. 4245 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 4246 return failure(); 4247 parser.consumeToken(); 4248 return success(); 4249 } 4250 4251 /// Parse a keyword, if present, into 'keyword'. 4252 ParseResult parseOptionalKeyword(StringRef *keyword) override { 4253 // Check that the current token is a keyword. 4254 if (!isCurrentTokenAKeyword()) 4255 return failure(); 4256 4257 *keyword = parser.getTokenSpelling(); 4258 parser.consumeToken(); 4259 return success(); 4260 } 4261 4262 /// Parse an optional @-identifier and store it (without the '@' symbol) in a 4263 /// string attribute named 'attrName'. 4264 ParseResult 4265 parseOptionalSymbolName(StringAttr &result, StringRef attrName, 4266 SmallVectorImpl<NamedAttribute> &attrs) override { 4267 Token atToken = parser.getToken(); 4268 if (atToken.isNot(Token::at_identifier)) 4269 return failure(); 4270 4271 result = getBuilder().getStringAttr(extractSymbolReference(atToken)); 4272 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 4273 parser.consumeToken(); 4274 return success(); 4275 } 4276 4277 //===--------------------------------------------------------------------===// 4278 // Operand Parsing 4279 //===--------------------------------------------------------------------===// 4280 4281 /// Parse a single operand. 4282 ParseResult parseOperand(OperandType &result) override { 4283 OperationParser::SSAUseInfo useInfo; 4284 if (parser.parseSSAUse(useInfo)) 4285 return failure(); 4286 4287 result = {useInfo.loc, useInfo.name, useInfo.number}; 4288 return success(); 4289 } 4290 4291 /// Parse a single operand if present. 4292 OptionalParseResult parseOptionalOperand(OperandType &result) override { 4293 if (parser.getToken().is(Token::percent_identifier)) 4294 return parseOperand(result); 4295 return llvm::None; 4296 } 4297 4298 /// Parse zero or more SSA comma-separated operand references with a specified 4299 /// surrounding delimiter, and an optional required operand count. 4300 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 4301 int requiredOperandCount = -1, 4302 Delimiter delimiter = Delimiter::None) override { 4303 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 4304 requiredOperandCount, delimiter); 4305 } 4306 4307 /// Parse zero or more SSA comma-separated operand or region arguments with 4308 /// optional surrounding delimiter and required operand count. 4309 ParseResult 4310 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 4311 bool isOperandList, int requiredOperandCount = -1, 4312 Delimiter delimiter = Delimiter::None) { 4313 auto startLoc = parser.getToken().getLoc(); 4314 4315 // Handle delimiters. 4316 switch (delimiter) { 4317 case Delimiter::None: 4318 // Don't check for the absence of a delimiter if the number of operands 4319 // is unknown (and hence the operand list could be empty). 4320 if (requiredOperandCount == -1) 4321 break; 4322 // Token already matches an identifier and so can't be a delimiter. 4323 if (parser.getToken().is(Token::percent_identifier)) 4324 break; 4325 // Test against known delimiters. 4326 if (parser.getToken().is(Token::l_paren) || 4327 parser.getToken().is(Token::l_square)) 4328 return emitError(startLoc, "unexpected delimiter"); 4329 return emitError(startLoc, "invalid operand"); 4330 case Delimiter::OptionalParen: 4331 if (parser.getToken().isNot(Token::l_paren)) 4332 return success(); 4333 LLVM_FALLTHROUGH; 4334 case Delimiter::Paren: 4335 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 4336 return failure(); 4337 break; 4338 case Delimiter::OptionalSquare: 4339 if (parser.getToken().isNot(Token::l_square)) 4340 return success(); 4341 LLVM_FALLTHROUGH; 4342 case Delimiter::Square: 4343 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 4344 return failure(); 4345 break; 4346 } 4347 4348 // Check for zero operands. 4349 if (parser.getToken().is(Token::percent_identifier)) { 4350 do { 4351 OperandType operandOrArg; 4352 if (isOperandList ? parseOperand(operandOrArg) 4353 : parseRegionArgument(operandOrArg)) 4354 return failure(); 4355 result.push_back(operandOrArg); 4356 } while (parser.consumeIf(Token::comma)); 4357 } 4358 4359 // Handle delimiters. If we reach here, the optional delimiters were 4360 // present, so we need to parse their closing one. 4361 switch (delimiter) { 4362 case Delimiter::None: 4363 break; 4364 case Delimiter::OptionalParen: 4365 case Delimiter::Paren: 4366 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 4367 return failure(); 4368 break; 4369 case Delimiter::OptionalSquare: 4370 case Delimiter::Square: 4371 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 4372 return failure(); 4373 break; 4374 } 4375 4376 if (requiredOperandCount != -1 && 4377 result.size() != static_cast<size_t>(requiredOperandCount)) 4378 return emitError(startLoc, "expected ") 4379 << requiredOperandCount << " operands"; 4380 return success(); 4381 } 4382 4383 /// Parse zero or more trailing SSA comma-separated trailing operand 4384 /// references with a specified surrounding delimiter, and an optional 4385 /// required operand count. A leading comma is expected before the operands. 4386 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 4387 int requiredOperandCount, 4388 Delimiter delimiter) override { 4389 if (parser.getToken().is(Token::comma)) { 4390 parseComma(); 4391 return parseOperandList(result, requiredOperandCount, delimiter); 4392 } 4393 if (requiredOperandCount != -1) 4394 return emitError(parser.getToken().getLoc(), "expected ") 4395 << requiredOperandCount << " operands"; 4396 return success(); 4397 } 4398 4399 /// Resolve an operand to an SSA value, emitting an error on failure. 4400 ParseResult resolveOperand(const OperandType &operand, Type type, 4401 SmallVectorImpl<Value> &result) override { 4402 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4403 operand.location}; 4404 if (auto value = parser.resolveSSAUse(operandInfo, type)) { 4405 result.push_back(value); 4406 return success(); 4407 } 4408 return failure(); 4409 } 4410 4411 /// Parse an AffineMap of SSA ids. 4412 ParseResult parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 4413 Attribute &mapAttr, StringRef attrName, 4414 SmallVectorImpl<NamedAttribute> &attrs, 4415 Delimiter delimiter) override { 4416 SmallVector<OperandType, 2> dimOperands; 4417 SmallVector<OperandType, 1> symOperands; 4418 4419 auto parseElement = [&](bool isSymbol) -> ParseResult { 4420 OperandType operand; 4421 if (parseOperand(operand)) 4422 return failure(); 4423 if (isSymbol) 4424 symOperands.push_back(operand); 4425 else 4426 dimOperands.push_back(operand); 4427 return success(); 4428 }; 4429 4430 AffineMap map; 4431 if (parser.parseAffineMapOfSSAIds(map, parseElement, delimiter)) 4432 return failure(); 4433 // Add AffineMap attribute. 4434 if (map) { 4435 mapAttr = AffineMapAttr::get(map); 4436 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 4437 } 4438 4439 // Add dim operands before symbol operands in 'operands'. 4440 operands.assign(dimOperands.begin(), dimOperands.end()); 4441 operands.append(symOperands.begin(), symOperands.end()); 4442 return success(); 4443 } 4444 4445 //===--------------------------------------------------------------------===// 4446 // Region Parsing 4447 //===--------------------------------------------------------------------===// 4448 4449 /// Parse a region that takes `arguments` of `argTypes` types. This 4450 /// effectively defines the SSA values of `arguments` and assigns their type. 4451 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 4452 ArrayRef<Type> argTypes, 4453 bool enableNameShadowing) override { 4454 assert(arguments.size() == argTypes.size() && 4455 "mismatching number of arguments and types"); 4456 4457 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 4458 regionArguments; 4459 for (auto pair : llvm::zip(arguments, argTypes)) { 4460 const OperandType &operand = std::get<0>(pair); 4461 Type type = std::get<1>(pair); 4462 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 4463 operand.location}; 4464 regionArguments.emplace_back(operandInfo, type); 4465 } 4466 4467 // Try to parse the region. 4468 assert((!enableNameShadowing || 4469 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 4470 "name shadowing is only allowed on isolated regions"); 4471 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 4472 return failure(); 4473 return success(); 4474 } 4475 4476 /// Parses a region if present. 4477 ParseResult parseOptionalRegion(Region ®ion, 4478 ArrayRef<OperandType> arguments, 4479 ArrayRef<Type> argTypes, 4480 bool enableNameShadowing) override { 4481 if (parser.getToken().isNot(Token::l_brace)) 4482 return success(); 4483 return parseRegion(region, arguments, argTypes, enableNameShadowing); 4484 } 4485 4486 /// Parse a region argument. The type of the argument will be resolved later 4487 /// by a call to `parseRegion`. 4488 ParseResult parseRegionArgument(OperandType &argument) override { 4489 return parseOperand(argument); 4490 } 4491 4492 /// Parse a region argument if present. 4493 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 4494 if (parser.getToken().isNot(Token::percent_identifier)) 4495 return success(); 4496 return parseRegionArgument(argument); 4497 } 4498 4499 ParseResult 4500 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 4501 int requiredOperandCount = -1, 4502 Delimiter delimiter = Delimiter::None) override { 4503 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 4504 requiredOperandCount, delimiter); 4505 } 4506 4507 //===--------------------------------------------------------------------===// 4508 // Successor Parsing 4509 //===--------------------------------------------------------------------===// 4510 4511 /// Parse a single operation successor. 4512 ParseResult parseSuccessor(Block *&dest) override { 4513 return parser.parseSuccessor(dest); 4514 } 4515 4516 /// Parse an optional operation successor and its operand list. 4517 OptionalParseResult parseOptionalSuccessor(Block *&dest) override { 4518 if (parser.getToken().isNot(Token::caret_identifier)) 4519 return llvm::None; 4520 return parseSuccessor(dest); 4521 } 4522 4523 /// Parse a single operation successor and its operand list. 4524 ParseResult 4525 parseSuccessorAndUseList(Block *&dest, 4526 SmallVectorImpl<Value> &operands) override { 4527 if (parseSuccessor(dest)) 4528 return failure(); 4529 4530 // Handle optional arguments. 4531 if (succeeded(parseOptionalLParen()) && 4532 (parser.parseOptionalSSAUseAndTypeList(operands) || parseRParen())) { 4533 return failure(); 4534 } 4535 return success(); 4536 } 4537 4538 //===--------------------------------------------------------------------===// 4539 // Type Parsing 4540 //===--------------------------------------------------------------------===// 4541 4542 /// Parse a type. 4543 ParseResult parseType(Type &result) override { 4544 return failure(!(result = parser.parseType())); 4545 } 4546 4547 /// Parse an optional type. 4548 OptionalParseResult parseOptionalType(Type &result) override { 4549 return parser.parseOptionalType(result); 4550 } 4551 4552 /// Parse an arrow followed by a type list. 4553 ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override { 4554 if (parseArrow() || parser.parseFunctionResultTypes(result)) 4555 return failure(); 4556 return success(); 4557 } 4558 4559 /// Parse an optional arrow followed by a type list. 4560 ParseResult 4561 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 4562 if (!parser.consumeIf(Token::arrow)) 4563 return success(); 4564 return parser.parseFunctionResultTypes(result); 4565 } 4566 4567 /// Parse a colon followed by a type. 4568 ParseResult parseColonType(Type &result) override { 4569 return failure(parser.parseToken(Token::colon, "expected ':'") || 4570 !(result = parser.parseType())); 4571 } 4572 4573 /// Parse a colon followed by a type list, which must have at least one type. 4574 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 4575 if (parser.parseToken(Token::colon, "expected ':'")) 4576 return failure(); 4577 return parser.parseTypeListNoParens(result); 4578 } 4579 4580 /// Parse an optional colon followed by a type list, which if present must 4581 /// have at least one type. 4582 ParseResult 4583 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 4584 if (!parser.consumeIf(Token::colon)) 4585 return success(); 4586 return parser.parseTypeListNoParens(result); 4587 } 4588 4589 /// Parse a list of assignments of the form 4590 /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...). 4591 /// The list must contain at least one entry 4592 ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs, 4593 SmallVectorImpl<OperandType> &rhs) override { 4594 auto parseElt = [&]() -> ParseResult { 4595 OperandType regionArg, operand; 4596 if (parseRegionArgument(regionArg) || parseEqual() || 4597 parseOperand(operand)) 4598 return failure(); 4599 lhs.push_back(regionArg); 4600 rhs.push_back(operand); 4601 return success(); 4602 }; 4603 if (parseLParen()) 4604 return failure(); 4605 return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt); 4606 } 4607 4608 private: 4609 /// The source location of the operation name. 4610 SMLoc nameLoc; 4611 4612 /// Information about the result name specifiers. 4613 ArrayRef<OperationParser::ResultRecord> resultIDs; 4614 4615 /// The abstract information of the operation. 4616 const AbstractOperation *opDefinition; 4617 4618 /// The main operation parser. 4619 OperationParser &parser; 4620 4621 /// A flag that indicates if any errors were emitted during parsing. 4622 bool emittedError = false; 4623 }; 4624 } // end anonymous namespace. 4625 4626 Operation * 4627 OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) { 4628 auto opLoc = getToken().getLoc(); 4629 auto opName = getTokenSpelling(); 4630 4631 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 4632 if (!opDefinition && !opName.contains('.')) { 4633 // If the operation name has no namespace prefix we treat it as a standard 4634 // operation and prefix it with "std". 4635 // TODO: Would it be better to just build a mapping of the registered 4636 // operations in the standard dialect? 4637 opDefinition = 4638 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 4639 } 4640 4641 if (!opDefinition) { 4642 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 4643 return nullptr; 4644 } 4645 4646 consumeToken(); 4647 4648 // If the custom op parser crashes, produce some indication to help 4649 // debugging. 4650 std::string opNameStr = opName.str(); 4651 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 4652 opNameStr.c_str()); 4653 4654 // Get location information for the operation. 4655 auto srcLocation = getEncodedSourceLocation(opLoc); 4656 4657 // Have the op implementation take a crack and parsing this. 4658 OperationState opState(srcLocation, opDefinition->name); 4659 CleanupOpStateRegions guard{opState}; 4660 CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this); 4661 if (opAsmParser.parseOperation(opState)) 4662 return nullptr; 4663 4664 // If it emitted an error, we failed. 4665 if (opAsmParser.didEmitError()) 4666 return nullptr; 4667 4668 // Parse a location if one is present. 4669 if (parseOptionalTrailingLocation(opState.location)) 4670 return nullptr; 4671 4672 // Otherwise, we succeeded. Use the state it parsed as our op information. 4673 return opBuilder.createOperation(opState); 4674 } 4675 4676 //===----------------------------------------------------------------------===// 4677 // Region Parsing 4678 //===----------------------------------------------------------------------===// 4679 4680 /// Region. 4681 /// 4682 /// region ::= '{' region-body 4683 /// 4684 ParseResult OperationParser::parseRegion( 4685 Region ®ion, 4686 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 4687 bool isIsolatedNameScope) { 4688 // Parse the '{'. 4689 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 4690 return failure(); 4691 4692 // Check for an empty region. 4693 if (entryArguments.empty() && consumeIf(Token::r_brace)) 4694 return success(); 4695 auto currentPt = opBuilder.saveInsertionPoint(); 4696 4697 // Push a new named value scope. 4698 pushSSANameScope(isIsolatedNameScope); 4699 4700 // Parse the first block directly to allow for it to be unnamed. 4701 Block *block = new Block(); 4702 4703 // Add arguments to the entry block. 4704 if (!entryArguments.empty()) { 4705 for (auto &placeholderArgPair : entryArguments) { 4706 auto &argInfo = placeholderArgPair.first; 4707 // Ensure that the argument was not already defined. 4708 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 4709 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 4710 "' is already in use") 4711 .attachNote(getEncodedSourceLocation(*defLoc)) 4712 << "previously referenced here"; 4713 } 4714 if (addDefinition(placeholderArgPair.first, 4715 block->addArgument(placeholderArgPair.second))) { 4716 delete block; 4717 return failure(); 4718 } 4719 } 4720 4721 // If we had named arguments, then don't allow a block name. 4722 if (getToken().is(Token::caret_identifier)) 4723 return emitError("invalid block name in region with named arguments"); 4724 } 4725 4726 if (parseBlock(block)) { 4727 delete block; 4728 return failure(); 4729 } 4730 4731 // Verify that no other arguments were parsed. 4732 if (!entryArguments.empty() && 4733 block->getNumArguments() > entryArguments.size()) { 4734 delete block; 4735 return emitError("entry block arguments were already defined"); 4736 } 4737 4738 // Parse the rest of the region. 4739 region.push_back(block); 4740 if (parseRegionBody(region)) 4741 return failure(); 4742 4743 // Pop the SSA value scope for this region. 4744 if (popSSANameScope()) 4745 return failure(); 4746 4747 // Reset the original insertion point. 4748 opBuilder.restoreInsertionPoint(currentPt); 4749 return success(); 4750 } 4751 4752 /// Region. 4753 /// 4754 /// region-body ::= block* '}' 4755 /// 4756 ParseResult OperationParser::parseRegionBody(Region ®ion) { 4757 // Parse the list of blocks. 4758 while (!consumeIf(Token::r_brace)) { 4759 Block *newBlock = nullptr; 4760 if (parseBlock(newBlock)) 4761 return failure(); 4762 region.push_back(newBlock); 4763 } 4764 return success(); 4765 } 4766 4767 //===----------------------------------------------------------------------===// 4768 // Block Parsing 4769 //===----------------------------------------------------------------------===// 4770 4771 /// Block declaration. 4772 /// 4773 /// block ::= block-label? operation* 4774 /// block-label ::= block-id block-arg-list? `:` 4775 /// block-id ::= caret-id 4776 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 4777 /// 4778 ParseResult OperationParser::parseBlock(Block *&block) { 4779 // The first block of a region may already exist, if it does the caret 4780 // identifier is optional. 4781 if (block && getToken().isNot(Token::caret_identifier)) 4782 return parseBlockBody(block); 4783 4784 SMLoc nameLoc = getToken().getLoc(); 4785 auto name = getTokenSpelling(); 4786 if (parseToken(Token::caret_identifier, "expected block name")) 4787 return failure(); 4788 4789 block = defineBlockNamed(name, nameLoc, block); 4790 4791 // Fail if the block was already defined. 4792 if (!block) 4793 return emitError(nameLoc, "redefinition of block '") << name << "'"; 4794 4795 // If an argument list is present, parse it. 4796 if (consumeIf(Token::l_paren)) { 4797 SmallVector<BlockArgument, 8> bbArgs; 4798 if (parseOptionalBlockArgList(bbArgs, block) || 4799 parseToken(Token::r_paren, "expected ')' to end argument list")) 4800 return failure(); 4801 } 4802 4803 if (parseToken(Token::colon, "expected ':' after block name")) 4804 return failure(); 4805 4806 return parseBlockBody(block); 4807 } 4808 4809 ParseResult OperationParser::parseBlockBody(Block *block) { 4810 // Set the insertion point to the end of the block to parse. 4811 opBuilder.setInsertionPointToEnd(block); 4812 4813 // Parse the list of operations that make up the body of the block. 4814 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 4815 if (parseOperation()) 4816 return failure(); 4817 4818 return success(); 4819 } 4820 4821 /// Get the block with the specified name, creating it if it doesn't already 4822 /// exist. The location specified is the point of use, which allows 4823 /// us to diagnose references to blocks that are not defined precisely. 4824 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 4825 auto &blockAndLoc = getBlockInfoByName(name); 4826 if (!blockAndLoc.first) { 4827 blockAndLoc = {new Block(), loc}; 4828 insertForwardRef(blockAndLoc.first, loc); 4829 } 4830 4831 return blockAndLoc.first; 4832 } 4833 4834 /// Define the block with the specified name. Returns the Block* or nullptr in 4835 /// the case of redefinition. 4836 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 4837 Block *existing) { 4838 auto &blockAndLoc = getBlockInfoByName(name); 4839 if (!blockAndLoc.first) { 4840 // If the caller provided a block, use it. Otherwise create a new one. 4841 if (!existing) 4842 existing = new Block(); 4843 blockAndLoc.first = existing; 4844 blockAndLoc.second = loc; 4845 return blockAndLoc.first; 4846 } 4847 4848 // Forward declarations are removed once defined, so if we are defining a 4849 // existing block and it is not a forward declaration, then it is a 4850 // redeclaration. 4851 if (!eraseForwardRef(blockAndLoc.first)) 4852 return nullptr; 4853 return blockAndLoc.first; 4854 } 4855 4856 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 4857 /// 4858 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 4859 /// 4860 ParseResult OperationParser::parseOptionalBlockArgList( 4861 SmallVectorImpl<BlockArgument> &results, Block *owner) { 4862 if (getToken().is(Token::r_brace)) 4863 return success(); 4864 4865 // If the block already has arguments, then we're handling the entry block. 4866 // Parse and register the names for the arguments, but do not add them. 4867 bool definingExistingArgs = owner->getNumArguments() != 0; 4868 unsigned nextArgument = 0; 4869 4870 return parseCommaSeparatedList([&]() -> ParseResult { 4871 return parseSSADefOrUseAndType( 4872 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 4873 // If this block did not have existing arguments, define a new one. 4874 if (!definingExistingArgs) 4875 return addDefinition(useInfo, owner->addArgument(type)); 4876 4877 // Otherwise, ensure that this argument has already been created. 4878 if (nextArgument >= owner->getNumArguments()) 4879 return emitError("too many arguments specified in argument list"); 4880 4881 // Finally, make sure the existing argument has the correct type. 4882 auto arg = owner->getArgument(nextArgument++); 4883 if (arg.getType() != type) 4884 return emitError("argument and block argument type mismatch"); 4885 return addDefinition(useInfo, arg); 4886 }); 4887 }); 4888 } 4889 4890 //===----------------------------------------------------------------------===// 4891 // Top-level entity parsing. 4892 //===----------------------------------------------------------------------===// 4893 4894 namespace { 4895 /// This parser handles entities that are only valid at the top level of the 4896 /// file. 4897 class ModuleParser : public Parser { 4898 public: 4899 explicit ModuleParser(ParserState &state) : Parser(state) {} 4900 4901 ParseResult parseModule(ModuleOp module); 4902 4903 private: 4904 /// Parse an attribute alias declaration. 4905 ParseResult parseAttributeAliasDef(); 4906 4907 /// Parse an attribute alias declaration. 4908 ParseResult parseTypeAliasDef(); 4909 }; 4910 } // end anonymous namespace 4911 4912 /// Parses an attribute alias declaration. 4913 /// 4914 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 4915 /// 4916 ParseResult ModuleParser::parseAttributeAliasDef() { 4917 assert(getToken().is(Token::hash_identifier)); 4918 StringRef aliasName = getTokenSpelling().drop_front(); 4919 4920 // Check for redefinitions. 4921 if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) 4922 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 4923 4924 // Make sure this isn't invading the dialect attribute namespace. 4925 if (aliasName.contains('.')) 4926 return emitError("attribute names with a '.' are reserved for " 4927 "dialect-defined names"); 4928 4929 consumeToken(Token::hash_identifier); 4930 4931 // Parse the '='. 4932 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 4933 return failure(); 4934 4935 // Parse the attribute value. 4936 Attribute attr = parseAttribute(); 4937 if (!attr) 4938 return failure(); 4939 4940 getState().symbols.attributeAliasDefinitions[aliasName] = attr; 4941 return success(); 4942 } 4943 4944 /// Parse a type alias declaration. 4945 /// 4946 /// type-alias-def ::= '!' alias-name `=` 'type' type 4947 /// 4948 ParseResult ModuleParser::parseTypeAliasDef() { 4949 assert(getToken().is(Token::exclamation_identifier)); 4950 StringRef aliasName = getTokenSpelling().drop_front(); 4951 4952 // Check for redefinitions. 4953 if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) 4954 return emitError("redefinition of type alias id '" + aliasName + "'"); 4955 4956 // Make sure this isn't invading the dialect type namespace. 4957 if (aliasName.contains('.')) 4958 return emitError("type names with a '.' are reserved for " 4959 "dialect-defined names"); 4960 4961 consumeToken(Token::exclamation_identifier); 4962 4963 // Parse the '=' and 'type'. 4964 if (parseToken(Token::equal, "expected '=' in type alias definition") || 4965 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 4966 return failure(); 4967 4968 // Parse the type. 4969 Type aliasedType = parseType(); 4970 if (!aliasedType) 4971 return failure(); 4972 4973 // Register this alias with the parser state. 4974 getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); 4975 return success(); 4976 } 4977 4978 /// This is the top-level module parser. 4979 ParseResult ModuleParser::parseModule(ModuleOp module) { 4980 OperationParser opParser(getState(), module); 4981 4982 // Module itself is a name scope. 4983 opParser.pushSSANameScope(/*isIsolated=*/true); 4984 4985 while (true) { 4986 switch (getToken().getKind()) { 4987 default: 4988 // Parse a top-level operation. 4989 if (opParser.parseOperation()) 4990 return failure(); 4991 break; 4992 4993 // If we got to the end of the file, then we're done. 4994 case Token::eof: { 4995 if (opParser.finalize()) 4996 return failure(); 4997 4998 // Handle the case where the top level module was explicitly defined. 4999 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 5000 auto &operations = bodyBlocks.front().getOperations(); 5001 assert(!operations.empty() && "expected a valid module terminator"); 5002 5003 // Check that the first operation is a module, and it is the only 5004 // non-terminator operation. 5005 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 5006 if (nested && std::next(operations.begin(), 2) == operations.end()) { 5007 // Merge the data of the nested module operation into 'module'. 5008 module.setLoc(nested.getLoc()); 5009 module.setAttrs(nested.getOperation()->getAttrList()); 5010 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 5011 5012 // Erase the original module body. 5013 bodyBlocks.pop_front(); 5014 } 5015 5016 return opParser.popSSANameScope(); 5017 } 5018 5019 // If we got an error token, then the lexer already emitted an error, just 5020 // stop. Someday we could introduce error recovery if there was demand 5021 // for it. 5022 case Token::error: 5023 return failure(); 5024 5025 // Parse an attribute alias. 5026 case Token::hash_identifier: 5027 if (parseAttributeAliasDef()) 5028 return failure(); 5029 break; 5030 5031 // Parse a type alias. 5032 case Token::exclamation_identifier: 5033 if (parseTypeAliasDef()) 5034 return failure(); 5035 break; 5036 } 5037 } 5038 } 5039 5040 //===----------------------------------------------------------------------===// 5041 5042 /// This parses the file specified by the indicated SourceMgr and returns an 5043 /// MLIR module if it was valid. If not, it emits diagnostics and returns 5044 /// null. 5045 OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 5046 MLIRContext *context) { 5047 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 5048 5049 // This is the result module we are parsing into. 5050 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 5051 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 5052 5053 SymbolState aliasState; 5054 ParserState state(sourceMgr, context, aliasState); 5055 if (ModuleParser(state).parseModule(*module)) 5056 return nullptr; 5057 5058 // Make sure the parse module has no other structural problems detected by 5059 // the verifier. 5060 if (failed(verify(*module))) 5061 return nullptr; 5062 5063 return module; 5064 } 5065 5066 /// This parses the file specified by the indicated filename and returns an 5067 /// MLIR module if it was valid. If not, the error message is emitted through 5068 /// the error handler registered in the context, and a null pointer is returned. 5069 OwningModuleRef mlir::parseSourceFile(StringRef filename, 5070 MLIRContext *context) { 5071 llvm::SourceMgr sourceMgr; 5072 return parseSourceFile(filename, sourceMgr, context); 5073 } 5074 5075 /// This parses the file specified by the indicated filename using the provided 5076 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 5077 /// message is emitted through the error handler registered in the context, and 5078 /// a null pointer is returned. 5079 OwningModuleRef mlir::parseSourceFile(StringRef filename, 5080 llvm::SourceMgr &sourceMgr, 5081 MLIRContext *context) { 5082 if (sourceMgr.getNumBuffers() != 0) { 5083 // TODO(b/136086478): Extend to support multiple buffers. 5084 emitError(mlir::UnknownLoc::get(context), 5085 "only main buffer parsed at the moment"); 5086 return nullptr; 5087 } 5088 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 5089 if (std::error_code error = file_or_err.getError()) { 5090 emitError(mlir::UnknownLoc::get(context), 5091 "could not open input file " + filename); 5092 return nullptr; 5093 } 5094 5095 // Load the MLIR module. 5096 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 5097 return parseSourceFile(sourceMgr, context); 5098 } 5099 5100 /// This parses the program string to a MLIR module if it was valid. If not, 5101 /// it emits diagnostics and returns null. 5102 OwningModuleRef mlir::parseSourceString(StringRef moduleStr, 5103 MLIRContext *context) { 5104 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 5105 if (!memBuffer) 5106 return nullptr; 5107 5108 SourceMgr sourceMgr; 5109 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 5110 return parseSourceFile(sourceMgr, context); 5111 } 5112 5113 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 5114 /// parsing failed, nullptr is returned. The number of bytes read from the input 5115 /// string is returned in 'numRead'. 5116 template <typename T, typename ParserFn> 5117 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, 5118 ParserFn &&parserFn) { 5119 SymbolState aliasState; 5120 return parseSymbol<T>( 5121 inputStr, context, aliasState, 5122 [&](Parser &parser) { 5123 SourceMgrDiagnosticHandler handler( 5124 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), 5125 parser.getContext()); 5126 return parserFn(parser); 5127 }, 5128 &numRead); 5129 } 5130 5131 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { 5132 size_t numRead = 0; 5133 return parseAttribute(attrStr, context, numRead); 5134 } 5135 Attribute mlir::parseAttribute(StringRef attrStr, Type type) { 5136 size_t numRead = 0; 5137 return parseAttribute(attrStr, type, numRead); 5138 } 5139 5140 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, 5141 size_t &numRead) { 5142 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { 5143 return parser.parseAttribute(); 5144 }); 5145 } 5146 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { 5147 return parseSymbol<Attribute>( 5148 attrStr, type.getContext(), numRead, 5149 [type](Parser &parser) { return parser.parseAttribute(type); }); 5150 } 5151 5152 Type mlir::parseType(StringRef typeStr, MLIRContext *context) { 5153 size_t numRead = 0; 5154 return parseType(typeStr, context, numRead); 5155 } 5156 5157 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { 5158 return parseSymbol<Type>(typeStr, context, numRead, 5159 [](Parser &parser) { return parser.parseType(); }); 5160 } 5161