1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the parser for the MLIR textual form. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Parser.h" 23 #include "Lexer.h" 24 #include "mlir/Analysis/Verifier.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/IR/Attributes.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/IR/Dialect.h" 30 #include "mlir/IR/IntegerSet.h" 31 #include "mlir/IR/Location.h" 32 #include "mlir/IR/MLIRContext.h" 33 #include "mlir/IR/Module.h" 34 #include "mlir/IR/OpImplementation.h" 35 #include "mlir/IR/StandardTypes.h" 36 #include "mlir/Support/STLExtras.h" 37 #include "llvm/ADT/APInt.h" 38 #include "llvm/ADT/DenseMap.h" 39 #include "llvm/ADT/StringSet.h" 40 #include "llvm/ADT/bit.h" 41 #include "llvm/Support/PrettyStackTrace.h" 42 #include "llvm/Support/SMLoc.h" 43 #include "llvm/Support/SourceMgr.h" 44 #include <algorithm> 45 using namespace mlir; 46 using llvm::MemoryBuffer; 47 using llvm::SMLoc; 48 using llvm::SourceMgr; 49 50 namespace { 51 class Parser; 52 53 //===----------------------------------------------------------------------===// 54 // ParserState 55 //===----------------------------------------------------------------------===// 56 57 /// This class refers to all of the state maintained globally by the parser, 58 /// such as the current lexer position etc. The Parser base class provides 59 /// methods to access this. 60 class ParserState { 61 public: 62 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx) 63 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()) {} 64 65 // A map from attribute alias identifier to Attribute. 66 llvm::StringMap<Attribute> attributeAliasDefinitions; 67 68 // A map from type alias identifier to Type. 69 llvm::StringMap<Type> typeAliasDefinitions; 70 71 private: 72 ParserState(const ParserState &) = delete; 73 void operator=(const ParserState &) = delete; 74 75 friend class Parser; 76 77 // The context we're parsing into. 78 MLIRContext *const context; 79 80 // The lexer for the source file we're parsing. 81 Lexer lex; 82 83 // This is the next token that hasn't been consumed yet. 84 Token curToken; 85 }; 86 87 //===----------------------------------------------------------------------===// 88 // Parser 89 //===----------------------------------------------------------------------===// 90 91 /// This class implement support for parsing global entities like types and 92 /// shared entities like SSA names. It is intended to be subclassed by 93 /// specialized subparsers that include state, e.g. when a local symbol table. 94 class Parser { 95 public: 96 Builder builder; 97 98 Parser(ParserState &state) : builder(state.context), state(state) {} 99 100 // Helper methods to get stuff from the parser-global state. 101 ParserState &getState() const { return state; } 102 MLIRContext *getContext() const { return state.context; } 103 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 104 105 /// Parse a comma-separated list of elements up until the specified end token. 106 ParseResult 107 parseCommaSeparatedListUntil(Token::Kind rightToken, 108 const std::function<ParseResult()> &parseElement, 109 bool allowEmptyList = true); 110 111 /// Parse a comma separated list of elements that must have at least one entry 112 /// in it. 113 ParseResult 114 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 115 116 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 117 118 // We have two forms of parsing methods - those that return a non-null 119 // pointer on success, and those that return a ParseResult to indicate whether 120 // they returned a failure. The second class fills in by-reference arguments 121 // as the results of their action. 122 123 //===--------------------------------------------------------------------===// 124 // Error Handling 125 //===--------------------------------------------------------------------===// 126 127 /// Emit an error and return failure. 128 InFlightDiagnostic emitError(const Twine &message = {}) { 129 return emitError(state.curToken.getLoc(), message); 130 } 131 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 132 133 /// Encode the specified source location information into an attribute for 134 /// attachment to the IR. 135 Location getEncodedSourceLocation(llvm::SMLoc loc) { 136 return state.lex.getEncodedSourceLocation(loc); 137 } 138 139 //===--------------------------------------------------------------------===// 140 // Token Parsing 141 //===--------------------------------------------------------------------===// 142 143 /// Return the current token the parser is inspecting. 144 const Token &getToken() const { return state.curToken; } 145 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 146 147 /// If the current token has the specified kind, consume it and return true. 148 /// If not, return false. 149 bool consumeIf(Token::Kind kind) { 150 if (state.curToken.isNot(kind)) 151 return false; 152 consumeToken(kind); 153 return true; 154 } 155 156 /// Advance the current lexer onto the next token. 157 void consumeToken() { 158 assert(state.curToken.isNot(Token::eof, Token::error) && 159 "shouldn't advance past EOF or errors"); 160 state.curToken = state.lex.lexToken(); 161 } 162 163 /// Advance the current lexer onto the next token, asserting what the expected 164 /// current token is. This is preferred to the above method because it leads 165 /// to more self-documenting code with better checking. 166 void consumeToken(Token::Kind kind) { 167 assert(state.curToken.is(kind) && "consumed an unexpected token"); 168 consumeToken(); 169 } 170 171 /// Consume the specified token if present and return success. On failure, 172 /// output a diagnostic and return failure. 173 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 174 175 //===--------------------------------------------------------------------===// 176 // Type Parsing 177 //===--------------------------------------------------------------------===// 178 179 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 180 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 181 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 182 183 /// Parse an arbitrary type. 184 Type parseType(); 185 186 /// Parse a complex type. 187 Type parseComplexType(); 188 189 /// Parse an extended type. 190 Type parseExtendedType(); 191 192 /// Parse a function type. 193 Type parseFunctionType(); 194 195 /// Parse a memref type. 196 Type parseMemRefType(); 197 198 /// Parse a non function type. 199 Type parseNonFunctionType(); 200 201 /// Parse a tensor type. 202 Type parseTensorType(); 203 204 /// Parse a tuple type. 205 Type parseTupleType(); 206 207 /// Parse a vector type. 208 VectorType parseVectorType(); 209 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 210 bool allowDynamic = true); 211 ParseResult parseXInDimensionList(); 212 213 /// Parse strided layout specification. 214 ParseResult parseStridedLayout(int64_t &offset, 215 SmallVectorImpl<int64_t> &strides); 216 217 // Parse a brace-delimiter list of comma-separated integers with `?` as an 218 // unknown marker. 219 ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions); 220 221 //===--------------------------------------------------------------------===// 222 // Attribute Parsing 223 //===--------------------------------------------------------------------===// 224 225 /// Parse an arbitrary attribute with an optional type. 226 Attribute parseAttribute(Type type = {}); 227 228 /// Parse an attribute dictionary. 229 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 230 231 /// Parse an extended attribute. 232 Attribute parseExtendedAttr(Type type); 233 234 /// Parse a float attribute. 235 Attribute parseFloatAttr(Type type, bool isNegative); 236 237 /// Parse a decimal or a hexadecimal literal, which can be either an integer 238 /// or a float attribute. 239 Attribute parseDecOrHexAttr(Type type, bool isNegative); 240 241 /// Parse an opaque elements attribute. 242 Attribute parseOpaqueElementsAttr(); 243 244 /// Parse a dense elements attribute. 245 Attribute parseDenseElementsAttr(); 246 ShapedType parseElementsLiteralType(); 247 248 /// Parse a sparse elements attribute. 249 Attribute parseSparseElementsAttr(); 250 251 //===--------------------------------------------------------------------===// 252 // Location Parsing 253 //===--------------------------------------------------------------------===// 254 255 /// Parse an inline location. 256 ParseResult parseLocation(LocationAttr &loc); 257 258 /// Parse a raw location instance. 259 ParseResult parseLocationInstance(LocationAttr &loc); 260 261 /// Parse a callsite location instance. 262 ParseResult parseCallSiteLocation(LocationAttr &loc); 263 264 /// Parse a fused location instance. 265 ParseResult parseFusedLocation(LocationAttr &loc); 266 267 /// Parse a name or FileLineCol location instance. 268 ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); 269 270 /// Parse an optional trailing location. 271 /// 272 /// trailing-location ::= location? 273 /// 274 template <typename Owner> 275 ParseResult parseOptionalTrailingLocation(Owner *owner) { 276 // If there is a 'loc' we parse a trailing location. 277 if (!getToken().is(Token::kw_loc)) 278 return success(); 279 280 // Parse the location. 281 LocationAttr directLoc; 282 if (parseLocation(directLoc)) 283 return failure(); 284 owner->setLoc(directLoc); 285 return success(); 286 } 287 288 //===--------------------------------------------------------------------===// 289 // Affine Parsing 290 //===--------------------------------------------------------------------===// 291 292 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 293 IntegerSet &set); 294 295 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 296 ParseResult 297 parseAffineMapOfSSAIds(AffineMap &map, 298 llvm::function_ref<ParseResult(bool)> parseElement); 299 300 private: 301 /// The Parser is subclassed and reinstantiated. Do not add additional 302 /// non-trivial state here, add it to the ParserState class. 303 ParserState &state; 304 }; 305 } // end anonymous namespace 306 307 //===----------------------------------------------------------------------===// 308 // Helper methods. 309 //===----------------------------------------------------------------------===// 310 311 /// Parse a comma separated list of elements that must have at least one entry 312 /// in it. 313 ParseResult Parser::parseCommaSeparatedList( 314 const std::function<ParseResult()> &parseElement) { 315 // Non-empty case starts with an element. 316 if (parseElement()) 317 return failure(); 318 319 // Otherwise we have a list of comma separated elements. 320 while (consumeIf(Token::comma)) { 321 if (parseElement()) 322 return failure(); 323 } 324 return success(); 325 } 326 327 /// Parse a comma-separated list of elements, terminated with an arbitrary 328 /// token. This allows empty lists if allowEmptyList is true. 329 /// 330 /// abstract-list ::= rightToken // if allowEmptyList == true 331 /// abstract-list ::= element (',' element)* rightToken 332 /// 333 ParseResult Parser::parseCommaSeparatedListUntil( 334 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 335 bool allowEmptyList) { 336 // Handle the empty case. 337 if (getToken().is(rightToken)) { 338 if (!allowEmptyList) 339 return emitError("expected list element"); 340 consumeToken(rightToken); 341 return success(); 342 } 343 344 if (parseCommaSeparatedList(parseElement) || 345 parseToken(rightToken, "expected ',' or '" + 346 Token::getTokenSpelling(rightToken) + "'")) 347 return failure(); 348 349 return success(); 350 } 351 352 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 353 /// and may be recursive. Return with the 'prettyName' StringRef encompassing 354 /// the entire pretty name. 355 /// 356 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 357 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 358 /// | '(' pretty-dialect-sym-contents+ ')' 359 /// | '[' pretty-dialect-sym-contents+ ']' 360 /// | '{' pretty-dialect-sym-contents+ '}' 361 /// | '[^[<({>\])}\0]+' 362 /// 363 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 364 // Pretty symbol names are a relatively unstructured format that contains a 365 // series of properly nested punctuation, with anything else in the middle. 366 // Scan ahead to find it and consume it if successful, otherwise emit an 367 // error. 368 auto *curPtr = getTokenSpelling().data(); 369 370 SmallVector<char, 8> nestedPunctuation; 371 372 // Scan over the nested punctuation, bailing out on error and consuming until 373 // we find the end. We know that we're currently looking at the '<', so we 374 // can go until we find the matching '>' character. 375 assert(*curPtr == '<'); 376 do { 377 char c = *curPtr++; 378 switch (c) { 379 case '\0': 380 // This also handles the EOF case. 381 return emitError("unexpected nul or EOF in pretty dialect name"); 382 case '<': 383 case '[': 384 case '(': 385 case '{': 386 nestedPunctuation.push_back(c); 387 continue; 388 389 case '-': 390 // The sequence `->` is treated as special token. 391 if (*curPtr == '>') 392 ++curPtr; 393 continue; 394 395 case '>': 396 if (nestedPunctuation.pop_back_val() != '<') 397 return emitError("unbalanced '>' character in pretty dialect name"); 398 break; 399 case ']': 400 if (nestedPunctuation.pop_back_val() != '[') 401 return emitError("unbalanced ']' character in pretty dialect name"); 402 break; 403 case ')': 404 if (nestedPunctuation.pop_back_val() != '(') 405 return emitError("unbalanced ')' character in pretty dialect name"); 406 break; 407 case '}': 408 if (nestedPunctuation.pop_back_val() != '{') 409 return emitError("unbalanced '}' character in pretty dialect name"); 410 break; 411 412 default: 413 continue; 414 } 415 } while (!nestedPunctuation.empty()); 416 417 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 418 // consuming all this stuff, and return. 419 state.lex.resetPointer(curPtr); 420 421 unsigned length = curPtr - prettyName.begin(); 422 prettyName = StringRef(prettyName.begin(), length); 423 consumeToken(); 424 return success(); 425 } 426 427 /// Parse an extended dialect symbol. 428 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 429 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 430 SymbolAliasMap &aliases, 431 CreateFn &&createSymbol) { 432 // Parse the dialect namespace. 433 StringRef identifier = p.getTokenSpelling().drop_front(); 434 auto loc = p.getToken().getLoc(); 435 p.consumeToken(identifierTok); 436 437 // If there is no '<' token following this, and if the typename contains no 438 // dot, then we are parsing a symbol alias. 439 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 440 // Check for an alias for this type. 441 auto aliasIt = aliases.find(identifier); 442 if (aliasIt == aliases.end()) 443 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 444 nullptr); 445 return aliasIt->second; 446 } 447 448 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 449 // a dot, then this is the "pretty" form. If not, it is the verbose form that 450 // looks like <"...">. 451 std::string symbolData; 452 auto dialectName = identifier; 453 454 // Handle the verbose form, where "identifier" is a simple dialect name. 455 if (!identifier.contains('.')) { 456 // Consume the '<'. 457 if (p.parseToken(Token::less, "expected '<' in dialect type")) 458 return nullptr; 459 460 // Parse the symbol specific data. 461 if (p.getToken().isNot(Token::string)) 462 return (p.emitError("expected string literal data in dialect symbol"), 463 nullptr); 464 symbolData = p.getToken().getStringValue(); 465 loc = p.getToken().getLoc(); 466 p.consumeToken(Token::string); 467 468 // Consume the '>'. 469 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 470 return nullptr; 471 } else { 472 // Ok, the dialect name is the part of the identifier before the dot, the 473 // part after the dot is the dialect's symbol, or the start thereof. 474 auto dotHalves = identifier.split('.'); 475 dialectName = dotHalves.first; 476 auto prettyName = dotHalves.second; 477 478 // If the dialect's symbol is followed immediately by a <, then lex the body 479 // of it into prettyName. 480 if (p.getToken().is(Token::less) && 481 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 482 if (p.parsePrettyDialectSymbolName(prettyName)) 483 return nullptr; 484 } 485 486 symbolData = prettyName.str(); 487 } 488 489 // Call into the provided symbol construction function. 490 auto encodedLoc = p.getEncodedSourceLocation(loc); 491 return createSymbol(dialectName, symbolData, encodedLoc); 492 } 493 494 //===----------------------------------------------------------------------===// 495 // Error Handling 496 //===----------------------------------------------------------------------===// 497 498 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 499 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 500 501 // If we hit a parse error in response to a lexer error, then the lexer 502 // already reported the error. 503 if (getToken().is(Token::error)) 504 diag.abandon(); 505 return diag; 506 } 507 508 //===----------------------------------------------------------------------===// 509 // Token Parsing 510 //===----------------------------------------------------------------------===// 511 512 /// Consume the specified token if present and return success. On failure, 513 /// output a diagnostic and return failure. 514 ParseResult Parser::parseToken(Token::Kind expectedToken, 515 const Twine &message) { 516 if (consumeIf(expectedToken)) 517 return success(); 518 return emitError(message); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // Type Parsing 523 //===----------------------------------------------------------------------===// 524 525 /// Parse an arbitrary type. 526 /// 527 /// type ::= function-type 528 /// | non-function-type 529 /// 530 Type Parser::parseType() { 531 if (getToken().is(Token::l_paren)) 532 return parseFunctionType(); 533 return parseNonFunctionType(); 534 } 535 536 /// Parse a function result type. 537 /// 538 /// function-result-type ::= type-list-parens 539 /// | non-function-type 540 /// 541 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 542 if (getToken().is(Token::l_paren)) 543 return parseTypeListParens(elements); 544 545 Type t = parseNonFunctionType(); 546 if (!t) 547 return failure(); 548 elements.push_back(t); 549 return success(); 550 } 551 552 /// Parse a list of types without an enclosing parenthesis. The list must have 553 /// at least one member. 554 /// 555 /// type-list-no-parens ::= type (`,` type)* 556 /// 557 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 558 auto parseElt = [&]() -> ParseResult { 559 auto elt = parseType(); 560 elements.push_back(elt); 561 return elt ? success() : failure(); 562 }; 563 564 return parseCommaSeparatedList(parseElt); 565 } 566 567 /// Parse a parenthesized list of types. 568 /// 569 /// type-list-parens ::= `(` `)` 570 /// | `(` type-list-no-parens `)` 571 /// 572 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 573 if (parseToken(Token::l_paren, "expected '('")) 574 return failure(); 575 576 // Handle empty lists. 577 if (getToken().is(Token::r_paren)) 578 return consumeToken(), success(); 579 580 if (parseTypeListNoParens(elements) || 581 parseToken(Token::r_paren, "expected ')'")) 582 return failure(); 583 return success(); 584 } 585 586 /// Parse a complex type. 587 /// 588 /// complex-type ::= `complex` `<` type `>` 589 /// 590 Type Parser::parseComplexType() { 591 consumeToken(Token::kw_complex); 592 593 // Parse the '<'. 594 if (parseToken(Token::less, "expected '<' in complex type")) 595 return nullptr; 596 597 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 598 auto elementType = parseType(); 599 if (!elementType || 600 parseToken(Token::greater, "expected '>' in complex type")) 601 return nullptr; 602 603 return ComplexType::getChecked(elementType, typeLocation); 604 } 605 606 /// Parse an extended type. 607 /// 608 /// extended-type ::= (dialect-type | type-alias) 609 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 610 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 611 /// type-alias ::= `!` alias-name 612 /// 613 Type Parser::parseExtendedType() { 614 return parseExtendedSymbol<Type>( 615 *this, Token::exclamation_identifier, state.typeAliasDefinitions, 616 [&](StringRef dialectName, StringRef symbolData, Location loc) -> Type { 617 // If we found a registered dialect, then ask it to parse the type. 618 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) 619 return dialect->parseType(symbolData, loc); 620 621 // Otherwise, form a new opaque type. 622 return OpaqueType::getChecked( 623 Identifier::get(dialectName, state.context), symbolData, 624 state.context, loc); 625 }); 626 } 627 628 /// Parse a function type. 629 /// 630 /// function-type ::= type-list-parens `->` function-result-type 631 /// 632 Type Parser::parseFunctionType() { 633 assert(getToken().is(Token::l_paren)); 634 635 SmallVector<Type, 4> arguments, results; 636 if (parseTypeListParens(arguments) || 637 parseToken(Token::arrow, "expected '->' in function type") || 638 parseFunctionResultTypes(results)) 639 return nullptr; 640 641 return builder.getFunctionType(arguments, results); 642 } 643 644 /// Parse the offset and strides from a strided layout specification. 645 /// 646 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 647 /// 648 ParseResult Parser::parseStridedLayout(int64_t &offset, 649 SmallVectorImpl<int64_t> &strides) { 650 // Parse offset. 651 consumeToken(Token::kw_offset); 652 if (!consumeIf(Token::colon)) 653 return emitError("expected colon after `offset` keyword"); 654 auto maybeOffset = getToken().getUnsignedIntegerValue(); 655 bool question = getToken().is(Token::question); 656 if (!maybeOffset && !question) 657 return emitError("invalid offset"); 658 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue()) 659 : MemRefType::getDynamicStrideOrOffset(); 660 consumeToken(); 661 662 if (!consumeIf(Token::comma)) 663 return emitError("expected comma after offset value"); 664 665 // Parse stride list. 666 if (!consumeIf(Token::kw_strides)) 667 return emitError("expected `strides` keyword after offset specification"); 668 if (!consumeIf(Token::colon)) 669 return emitError("expected colon after `strides` keyword"); 670 if (failed(parseStrideList(strides))) 671 return emitError("invalid braces-enclosed stride list"); 672 if (llvm::any_of(strides, [](int64_t st) { return st == 0; })) 673 return emitError("invalid memref stride"); 674 675 return success(); 676 } 677 678 /// Parse a memref type. 679 /// 680 /// memref-type ::= `memref` `<` dimension-list-ranked type 681 /// (`,` semi-affine-map-composition)? (`,` memory-space)? `>` 682 /// 683 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 684 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 685 /// 686 Type Parser::parseMemRefType() { 687 consumeToken(Token::kw_memref); 688 689 if (parseToken(Token::less, "expected '<' in memref type")) 690 return nullptr; 691 692 SmallVector<int64_t, 4> dimensions; 693 if (parseDimensionListRanked(dimensions)) 694 return nullptr; 695 696 // Parse the element type. 697 auto typeLoc = getToken().getLoc(); 698 auto elementType = parseType(); 699 if (!elementType) 700 return nullptr; 701 702 // Parse semi-affine-map-composition. 703 SmallVector<AffineMap, 2> affineMapComposition; 704 unsigned memorySpace = 0; 705 bool parsedMemorySpace = false; 706 707 auto parseElt = [&]() -> ParseResult { 708 if (getToken().is(Token::integer)) { 709 // Parse memory space. 710 if (parsedMemorySpace) 711 return emitError("multiple memory spaces specified in memref type"); 712 auto v = getToken().getUnsignedIntegerValue(); 713 if (!v.hasValue()) 714 return emitError("invalid memory space in memref type"); 715 memorySpace = v.getValue(); 716 consumeToken(Token::integer); 717 parsedMemorySpace = true; 718 } else { 719 if (parsedMemorySpace) 720 return emitError("expected memory space to be last in memref type"); 721 if (getToken().is(Token::kw_offset)) { 722 int64_t offset; 723 SmallVector<int64_t, 4> strides; 724 if (failed(parseStridedLayout(offset, strides))) 725 return failure(); 726 // Construct strided affine map. 727 auto map = makeStridedLinearLayoutMap(strides, offset, 728 elementType.getContext()); 729 affineMapComposition.push_back(map); 730 } else { 731 // Parse affine map. 732 auto affineMap = parseAttribute(); 733 if (!affineMap) 734 return failure(); 735 // Verify that the parsed attribute is an affine map. 736 if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>()) 737 affineMapComposition.push_back(affineMapAttr.getValue()); 738 else 739 return emitError("expected affine map in memref type"); 740 } 741 } 742 return success(); 743 }; 744 745 // Parse a list of mappings and address space if present. 746 if (consumeIf(Token::comma)) { 747 // Parse comma separated list of affine maps, followed by memory space. 748 if (parseCommaSeparatedListUntil(Token::greater, parseElt, 749 /*allowEmptyList=*/false)) { 750 return nullptr; 751 } 752 } else { 753 if (parseToken(Token::greater, "expected ',' or '>' in memref type")) 754 return nullptr; 755 } 756 757 return MemRefType::getChecked(dimensions, elementType, affineMapComposition, 758 memorySpace, getEncodedSourceLocation(typeLoc)); 759 } 760 761 /// Parse any type except the function type. 762 /// 763 /// non-function-type ::= integer-type 764 /// | index-type 765 /// | float-type 766 /// | extended-type 767 /// | vector-type 768 /// | tensor-type 769 /// | memref-type 770 /// | complex-type 771 /// | tuple-type 772 /// | none-type 773 /// 774 /// index-type ::= `index` 775 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 776 /// none-type ::= `none` 777 /// 778 Type Parser::parseNonFunctionType() { 779 switch (getToken().getKind()) { 780 default: 781 return (emitError("expected non-function type"), nullptr); 782 case Token::kw_memref: 783 return parseMemRefType(); 784 case Token::kw_tensor: 785 return parseTensorType(); 786 case Token::kw_complex: 787 return parseComplexType(); 788 case Token::kw_tuple: 789 return parseTupleType(); 790 case Token::kw_vector: 791 return parseVectorType(); 792 // integer-type 793 case Token::inttype: { 794 auto width = getToken().getIntTypeBitwidth(); 795 if (!width.hasValue()) 796 return (emitError("invalid integer width"), nullptr); 797 auto loc = getEncodedSourceLocation(getToken().getLoc()); 798 consumeToken(Token::inttype); 799 return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); 800 } 801 802 // float-type 803 case Token::kw_bf16: 804 consumeToken(Token::kw_bf16); 805 return builder.getBF16Type(); 806 case Token::kw_f16: 807 consumeToken(Token::kw_f16); 808 return builder.getF16Type(); 809 case Token::kw_f32: 810 consumeToken(Token::kw_f32); 811 return builder.getF32Type(); 812 case Token::kw_f64: 813 consumeToken(Token::kw_f64); 814 return builder.getF64Type(); 815 816 // index-type 817 case Token::kw_index: 818 consumeToken(Token::kw_index); 819 return builder.getIndexType(); 820 821 // none-type 822 case Token::kw_none: 823 consumeToken(Token::kw_none); 824 return builder.getNoneType(); 825 826 // extended type 827 case Token::exclamation_identifier: 828 return parseExtendedType(); 829 } 830 } 831 832 /// Parse a tensor type. 833 /// 834 /// tensor-type ::= `tensor` `<` dimension-list type `>` 835 /// dimension-list ::= dimension-list-ranked | `*x` 836 /// 837 Type Parser::parseTensorType() { 838 consumeToken(Token::kw_tensor); 839 840 if (parseToken(Token::less, "expected '<' in tensor type")) 841 return nullptr; 842 843 bool isUnranked; 844 SmallVector<int64_t, 4> dimensions; 845 846 if (consumeIf(Token::star)) { 847 // This is an unranked tensor type. 848 isUnranked = true; 849 850 if (parseXInDimensionList()) 851 return nullptr; 852 853 } else { 854 isUnranked = false; 855 if (parseDimensionListRanked(dimensions)) 856 return nullptr; 857 } 858 859 // Parse the element type. 860 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 861 auto elementType = parseType(); 862 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 863 return nullptr; 864 865 if (isUnranked) 866 return UnrankedTensorType::getChecked(elementType, typeLocation); 867 return RankedTensorType::getChecked(dimensions, elementType, typeLocation); 868 } 869 870 /// Parse a tuple type. 871 /// 872 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 873 /// 874 Type Parser::parseTupleType() { 875 consumeToken(Token::kw_tuple); 876 877 // Parse the '<'. 878 if (parseToken(Token::less, "expected '<' in tuple type")) 879 return nullptr; 880 881 // Check for an empty tuple by directly parsing '>'. 882 if (consumeIf(Token::greater)) 883 return TupleType::get(getContext()); 884 885 // Parse the element types and the '>'. 886 SmallVector<Type, 4> types; 887 if (parseTypeListNoParens(types) || 888 parseToken(Token::greater, "expected '>' in tuple type")) 889 return nullptr; 890 891 return TupleType::get(types, getContext()); 892 } 893 894 /// Parse a vector type. 895 /// 896 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` 897 /// non-empty-static-dimension-list ::= decimal-literal `x` 898 /// static-dimension-list 899 /// static-dimension-list ::= (decimal-literal `x`)* 900 /// 901 VectorType Parser::parseVectorType() { 902 consumeToken(Token::kw_vector); 903 904 if (parseToken(Token::less, "expected '<' in vector type")) 905 return nullptr; 906 907 SmallVector<int64_t, 4> dimensions; 908 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 909 return nullptr; 910 if (dimensions.empty()) 911 return (emitError("expected dimension size in vector type"), nullptr); 912 913 // Parse the element type. 914 auto typeLoc = getToken().getLoc(); 915 auto elementType = parseType(); 916 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 917 return nullptr; 918 919 return VectorType::getChecked(dimensions, elementType, 920 getEncodedSourceLocation(typeLoc)); 921 } 922 923 /// Parse a dimension list of a tensor or memref type. This populates the 924 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 925 /// errors out on `?` otherwise. 926 /// 927 /// dimension-list-ranked ::= (dimension `x`)* 928 /// dimension ::= `?` | decimal-literal 929 /// 930 /// When `allowDynamic` is not set, this is used to parse: 931 /// 932 /// static-dimension-list ::= (decimal-literal `x`)* 933 ParseResult 934 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 935 bool allowDynamic) { 936 while (getToken().isAny(Token::integer, Token::question)) { 937 if (consumeIf(Token::question)) { 938 if (!allowDynamic) 939 return emitError("expected static shape"); 940 dimensions.push_back(-1); 941 } else { 942 // Hexadecimal integer literals (starting with `0x`) are not allowed in 943 // aggregate type declarations. Therefore, `0xf32` should be processed as 944 // a sequence of separate elements `0`, `x`, `f32`. 945 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 946 // We can get here only if the token is an integer literal. Hexadecimal 947 // integer literals can only start with `0x` (`1x` wouldn't lex as a 948 // literal, just `1` would, at which point we don't get into this 949 // branch). 950 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 951 dimensions.push_back(0); 952 state.lex.resetPointer(getTokenSpelling().data() + 1); 953 consumeToken(); 954 } else { 955 // Make sure this integer value is in bound and valid. 956 auto dimension = getToken().getUnsignedIntegerValue(); 957 if (!dimension.hasValue()) 958 return emitError("invalid dimension"); 959 dimensions.push_back((int64_t)dimension.getValue()); 960 consumeToken(Token::integer); 961 } 962 } 963 964 // Make sure we have an 'x' or something like 'xbf32'. 965 if (parseXInDimensionList()) 966 return failure(); 967 } 968 969 return success(); 970 } 971 972 /// Parse an 'x' token in a dimension list, handling the case where the x is 973 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 974 /// token. 975 ParseResult Parser::parseXInDimensionList() { 976 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 977 return emitError("expected 'x' in dimension list"); 978 979 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 980 if (getTokenSpelling().size() != 1) 981 state.lex.resetPointer(getTokenSpelling().data() + 1); 982 983 // Consume the 'x'. 984 consumeToken(Token::bare_identifier); 985 986 return success(); 987 } 988 989 // Parse a comma-separated list of dimensions, possibly empty: 990 // stride-list ::= `[` (dimension (`,` dimension)*)? `]` 991 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { 992 if (!consumeIf(Token::l_square)) 993 return failure(); 994 // Empty list early exit. 995 if (consumeIf(Token::r_square)) 996 return success(); 997 while (true) { 998 if (consumeIf(Token::question)) { 999 dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); 1000 } else { 1001 // This must be an integer value. 1002 int64_t val; 1003 if (getToken().getSpelling().getAsInteger(10, val)) 1004 return emitError("invalid integer value: ") << getToken().getSpelling(); 1005 // Make sure it is not the one value for `?`. 1006 if (ShapedType::isDynamic(val)) 1007 return emitError("invalid integer value: ") 1008 << getToken().getSpelling() 1009 << ", use `?` to specify a dynamic dimension"; 1010 dimensions.push_back(val); 1011 consumeToken(Token::integer); 1012 } 1013 if (!consumeIf(Token::comma)) 1014 break; 1015 } 1016 if (!consumeIf(Token::r_square)) 1017 return failure(); 1018 return success(); 1019 } 1020 1021 //===----------------------------------------------------------------------===// 1022 // Attribute parsing. 1023 //===----------------------------------------------------------------------===// 1024 1025 /// Return the symbol reference referred to by the given token, that is known to 1026 /// be an @-identifier. 1027 static std::string extractSymbolReference(Token tok) { 1028 assert(tok.is(Token::at_identifier) && "expected valid @-identifier"); 1029 StringRef nameStr = tok.getSpelling().drop_front(); 1030 1031 // Check to see if the reference is a string literal, or a bare identifier. 1032 if (nameStr.front() == '"') 1033 return tok.getStringValue(); 1034 return nameStr; 1035 } 1036 1037 /// Parse an arbitrary attribute. 1038 /// 1039 /// attribute-value ::= `unit` 1040 /// | bool-literal 1041 /// | integer-literal (`:` (index-type | integer-type))? 1042 /// | float-literal (`:` float-type)? 1043 /// | string-literal (`:` type)? 1044 /// | type 1045 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 1046 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 1047 /// | symbol-ref-id 1048 /// | `dense` `<` attribute-value `>` `:` 1049 /// (tensor-type | vector-type) 1050 /// | `sparse` `<` attribute-value `,` attribute-value `>` 1051 /// `:` (tensor-type | vector-type) 1052 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 1053 /// `>` `:` (tensor-type | vector-type) 1054 /// | extended-attribute 1055 /// 1056 Attribute Parser::parseAttribute(Type type) { 1057 switch (getToken().getKind()) { 1058 // Parse an AffineMap or IntegerSet attribute. 1059 case Token::l_paren: { 1060 // Try to parse an affine map or an integer set reference. 1061 AffineMap map; 1062 IntegerSet set; 1063 if (parseAffineMapOrIntegerSetReference(map, set)) 1064 return nullptr; 1065 if (map) 1066 return AffineMapAttr::get(map); 1067 assert(set); 1068 return IntegerSetAttr::get(set); 1069 } 1070 1071 // Parse an array attribute. 1072 case Token::l_square: { 1073 consumeToken(Token::l_square); 1074 1075 SmallVector<Attribute, 4> elements; 1076 auto parseElt = [&]() -> ParseResult { 1077 elements.push_back(parseAttribute()); 1078 return elements.back() ? success() : failure(); 1079 }; 1080 1081 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 1082 return nullptr; 1083 return builder.getArrayAttr(elements); 1084 } 1085 1086 // Parse a boolean attribute. 1087 case Token::kw_false: 1088 consumeToken(Token::kw_false); 1089 return builder.getBoolAttr(false); 1090 case Token::kw_true: 1091 consumeToken(Token::kw_true); 1092 return builder.getBoolAttr(true); 1093 1094 // Parse a dense elements attribute. 1095 case Token::kw_dense: 1096 return parseDenseElementsAttr(); 1097 1098 // Parse a dictionary attribute. 1099 case Token::l_brace: { 1100 SmallVector<NamedAttribute, 4> elements; 1101 if (parseAttributeDict(elements)) 1102 return nullptr; 1103 return builder.getDictionaryAttr(elements); 1104 } 1105 1106 // Parse an extended attribute, i.e. alias or dialect attribute. 1107 case Token::hash_identifier: 1108 return parseExtendedAttr(type); 1109 1110 // Parse floating point and integer attributes. 1111 case Token::floatliteral: 1112 return parseFloatAttr(type, /*isNegative=*/false); 1113 case Token::integer: 1114 return parseDecOrHexAttr(type, /*isNegative=*/false); 1115 case Token::minus: { 1116 consumeToken(Token::minus); 1117 if (getToken().is(Token::integer)) 1118 return parseDecOrHexAttr(type, /*isNegative=*/true); 1119 if (getToken().is(Token::floatliteral)) 1120 return parseFloatAttr(type, /*isNegative=*/true); 1121 1122 return (emitError("expected constant integer or floating point value"), 1123 nullptr); 1124 } 1125 1126 // Parse a location attribute. 1127 case Token::kw_loc: { 1128 LocationAttr attr; 1129 return failed(parseLocation(attr)) ? Attribute() : attr; 1130 } 1131 1132 // Parse an opaque elements attribute. 1133 case Token::kw_opaque: 1134 return parseOpaqueElementsAttr(); 1135 1136 // Parse a sparse elements attribute. 1137 case Token::kw_sparse: 1138 return parseSparseElementsAttr(); 1139 1140 // Parse a string attribute. 1141 case Token::string: { 1142 auto val = getToken().getStringValue(); 1143 consumeToken(Token::string); 1144 // Parse the optional trailing colon type if one wasn't explicitly provided. 1145 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1146 return Attribute(); 1147 1148 return type ? StringAttr::get(val, type) 1149 : StringAttr::get(val, getContext()); 1150 } 1151 1152 // Parse a symbol reference attribute. 1153 case Token::at_identifier: { 1154 std::string nameStr = extractSymbolReference(getToken()); 1155 consumeToken(Token::at_identifier); 1156 return builder.getSymbolRefAttr(nameStr); 1157 } 1158 1159 // Parse a 'unit' attribute. 1160 case Token::kw_unit: 1161 consumeToken(Token::kw_unit); 1162 return builder.getUnitAttr(); 1163 1164 default: 1165 // Parse a type attribute. 1166 if (Type type = parseType()) 1167 return TypeAttr::get(type); 1168 return nullptr; 1169 } 1170 } 1171 1172 /// Attribute dictionary. 1173 /// 1174 /// attribute-dict ::= `{` `}` 1175 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1176 /// attribute-entry ::= bare-id `=` attribute-value 1177 /// 1178 ParseResult 1179 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1180 if (!consumeIf(Token::l_brace)) 1181 return failure(); 1182 1183 auto parseElt = [&]() -> ParseResult { 1184 // We allow keywords as attribute names. 1185 if (getToken().isNot(Token::bare_identifier, Token::inttype) && 1186 !getToken().isKeyword()) 1187 return emitError("expected attribute name"); 1188 Identifier nameId = builder.getIdentifier(getTokenSpelling()); 1189 consumeToken(); 1190 1191 // Try to parse the '=' for the attribute value. 1192 if (!consumeIf(Token::equal)) { 1193 // If there is no '=', we treat this as a unit attribute. 1194 attributes.push_back({nameId, builder.getUnitAttr()}); 1195 return success(); 1196 } 1197 1198 auto attr = parseAttribute(); 1199 if (!attr) 1200 return failure(); 1201 1202 attributes.push_back({nameId, attr}); 1203 return success(); 1204 }; 1205 1206 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1207 return failure(); 1208 1209 return success(); 1210 } 1211 1212 /// Parse an extended attribute. 1213 /// 1214 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1215 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1216 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1217 /// attribute-alias ::= `#` alias-name 1218 /// 1219 Attribute Parser::parseExtendedAttr(Type type) { 1220 Attribute attr = parseExtendedSymbol<Attribute>( 1221 *this, Token::hash_identifier, state.attributeAliasDefinitions, 1222 [&](StringRef dialectName, StringRef symbolData, 1223 Location loc) -> Attribute { 1224 // Parse an optional trailing colon type. 1225 Type attrType = type; 1226 if (consumeIf(Token::colon) && !(attrType = parseType())) 1227 return Attribute(); 1228 1229 // If we found a registered dialect, then ask it to parse the attribute. 1230 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) 1231 return dialect->parseAttribute(symbolData, attrType, loc); 1232 1233 // Otherwise, form a new opaque attribute. 1234 return OpaqueAttr::getChecked( 1235 Identifier::get(dialectName, state.context), symbolData, 1236 attrType ? attrType : NoneType::get(state.context), loc); 1237 }); 1238 1239 // Ensure that the attribute has the same type as requested. 1240 if (attr && type && attr.getType() != type) { 1241 emitError("attribute type different than expected: expected ") 1242 << type << ", but got " << attr.getType(); 1243 return nullptr; 1244 } 1245 return attr; 1246 } 1247 1248 /// Parse a float attribute. 1249 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1250 auto val = getToken().getFloatingPointValue(); 1251 if (!val.hasValue()) 1252 return (emitError("floating point value too large for attribute"), nullptr); 1253 consumeToken(Token::floatliteral); 1254 if (!type) { 1255 // Default to F64 when no type is specified. 1256 if (!consumeIf(Token::colon)) 1257 type = builder.getF64Type(); 1258 else if (!(type = parseType())) 1259 return nullptr; 1260 } 1261 if (!type.isa<FloatType>()) 1262 return (emitError("floating point value not valid for specified type"), 1263 nullptr); 1264 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1265 } 1266 1267 /// Construct a float attribute bitwise equivalent to the integer literal. 1268 static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1269 uint64_t value) { 1270 int width = type.getIntOrFloatBitWidth(); 1271 APInt apInt(width, value); 1272 if (apInt != value) { 1273 p->emitError("hexadecimal float constant out of range for type"); 1274 return nullptr; 1275 } 1276 APFloat apFloat(type.getFloatSemantics(), apInt); 1277 return p->builder.getFloatAttr(type, apFloat); 1278 } 1279 1280 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1281 /// or a float attribute. 1282 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1283 auto val = getToken().getUInt64IntegerValue(); 1284 if (!val.hasValue()) 1285 return (emitError("integer constant out of range for attribute"), nullptr); 1286 1287 // Remember if the literal is hexadecimal. 1288 StringRef spelling = getToken().getSpelling(); 1289 auto loc = state.curToken.getLoc(); 1290 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1291 1292 consumeToken(Token::integer); 1293 if (!type) { 1294 // Default to i64 if not type is specified. 1295 if (!consumeIf(Token::colon)) 1296 type = builder.getIntegerType(64); 1297 else if (!(type = parseType())) 1298 return nullptr; 1299 } 1300 1301 if (auto floatType = type.dyn_cast<FloatType>()) { 1302 // TODO(zinenko): Update once hex format for bfloat16 is supported. 1303 if (type.isBF16()) 1304 return emitError(loc, 1305 "hexadecimal float literal not supported for bfloat16"), 1306 nullptr; 1307 if (isNegative) 1308 return emitError( 1309 loc, 1310 "hexadecimal float literal should not have a leading minus"), 1311 nullptr; 1312 if (!isHex) { 1313 emitError(loc, "unexpected decimal integer literal for a float attribute") 1314 .attachNote() 1315 << "add a trailing dot to make the literal a float"; 1316 return nullptr; 1317 } 1318 1319 // Construct a float attribute bitwise equivalent to the integer literal. 1320 return buildHexadecimalFloatLiteral(this, floatType, *val); 1321 } 1322 1323 if (!type.isIntOrIndex()) 1324 return emitError(loc, "integer literal not valid for specified type"), 1325 nullptr; 1326 1327 // Parse the integer literal. 1328 int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); 1329 APInt apInt(width, *val, isNegative); 1330 if (apInt != *val) 1331 return emitError(loc, "integer constant out of range for attribute"), 1332 nullptr; 1333 1334 // Otherwise construct an integer attribute. 1335 if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0) 1336 return emitError(loc, "integer constant out of range for attribute"), 1337 nullptr; 1338 1339 return builder.getIntegerAttr(type, isNegative ? -apInt : apInt); 1340 } 1341 1342 /// Parse an opaque elements attribute. 1343 Attribute Parser::parseOpaqueElementsAttr() { 1344 consumeToken(Token::kw_opaque); 1345 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1346 return nullptr; 1347 1348 if (getToken().isNot(Token::string)) 1349 return (emitError("expected dialect namespace"), nullptr); 1350 1351 auto name = getToken().getStringValue(); 1352 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1353 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1354 // attribute. Otherwise, it can't be roundtripped without having the dialect 1355 // registered. 1356 if (!dialect) 1357 return (emitError("no registered dialect with namespace '" + name + "'"), 1358 nullptr); 1359 1360 consumeToken(Token::string); 1361 if (parseToken(Token::comma, "expected ','")) 1362 return nullptr; 1363 1364 if (getToken().getKind() != Token::string) 1365 return (emitError("opaque string should start with '0x'"), nullptr); 1366 1367 auto val = getToken().getStringValue(); 1368 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1369 return (emitError("opaque string should start with '0x'"), nullptr); 1370 1371 val = val.substr(2); 1372 if (!llvm::all_of(val, llvm::isHexDigit)) 1373 return (emitError("opaque string only contains hex digits"), nullptr); 1374 1375 consumeToken(Token::string); 1376 if (parseToken(Token::greater, "expected '>'") || 1377 parseToken(Token::colon, "expected ':'")) 1378 return nullptr; 1379 1380 auto type = parseElementsLiteralType(); 1381 if (!type) 1382 return nullptr; 1383 1384 return OpaqueElementsAttr::get(dialect, type, llvm::fromHex(val)); 1385 } 1386 1387 namespace { 1388 class TensorLiteralParser { 1389 public: 1390 TensorLiteralParser(Parser &p) : p(p) {} 1391 1392 ParseResult parse() { 1393 if (p.getToken().is(Token::l_square)) 1394 return parseList(shape); 1395 return parseElement(); 1396 } 1397 1398 /// Build a dense attribute instance with the parsed elements and the given 1399 /// shaped type. 1400 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1401 1402 ArrayRef<int64_t> getShape() const { return shape; } 1403 1404 private: 1405 enum class ElementKind { Boolean, Integer, Float }; 1406 1407 /// Return a string to represent the given element kind. 1408 const char *getElementKindStr(ElementKind kind) { 1409 switch (kind) { 1410 case ElementKind::Boolean: 1411 return "'boolean'"; 1412 case ElementKind::Integer: 1413 return "'integer'"; 1414 case ElementKind::Float: 1415 return "'float'"; 1416 } 1417 llvm_unreachable("unknown element kind"); 1418 } 1419 1420 /// Build a Dense Integer attribute for the given type. 1421 DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, 1422 IntegerType eltTy); 1423 1424 /// Build a Dense Float attribute for the given type. 1425 DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, 1426 FloatType eltTy); 1427 1428 /// Parse a single element, returning failure if it isn't a valid element 1429 /// literal. For example: 1430 /// parseElement(1) -> Success, 1 1431 /// parseElement([1]) -> Failure 1432 ParseResult parseElement(); 1433 1434 /// Parse a list of either lists or elements, returning the dimensions of the 1435 /// parsed sub-tensors in dims. For example: 1436 /// parseList([1, 2, 3]) -> Success, [3] 1437 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1438 /// parseList([[1, 2], 3]) -> Failure 1439 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1440 ParseResult parseList(llvm::SmallVectorImpl<int64_t> &dims); 1441 1442 Parser &p; 1443 1444 /// The shape inferred from the parsed elements. 1445 SmallVector<int64_t, 4> shape; 1446 1447 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 1448 std::vector<std::pair<bool, Token>> storage; 1449 1450 /// A flag that indicates the type of elements that have been parsed. 1451 llvm::Optional<ElementKind> knownEltKind; 1452 }; 1453 } // namespace 1454 1455 /// Build a dense attribute instance with the parsed elements and the given 1456 /// shaped type. 1457 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 1458 ShapedType type) { 1459 // Check that the parsed storage size has the same number of elements to the 1460 // type, or is a known splat. 1461 if (!shape.empty() && getShape() != type.getShape()) { 1462 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 1463 << "]) does not match type ([" << type.getShape() << "])"; 1464 return nullptr; 1465 } 1466 1467 // If the type is an integer, build a set of APInt values from the storage 1468 // with the correct bitwidth. 1469 if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) 1470 return getIntAttr(loc, type, intTy); 1471 1472 // Otherwise, this must be a floating point type. 1473 auto floatTy = type.getElementType().dyn_cast<FloatType>(); 1474 if (!floatTy) { 1475 p.emitError(loc) << "expected floating-point or integer element type, got " 1476 << type.getElementType(); 1477 return nullptr; 1478 } 1479 return getFloatAttr(loc, type, floatTy); 1480 } 1481 1482 /// Build a Dense Integer attribute for the given type. 1483 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, 1484 ShapedType type, 1485 IntegerType eltTy) { 1486 std::vector<APInt> intElements; 1487 intElements.reserve(storage.size()); 1488 for (const auto &signAndToken : storage) { 1489 bool isNegative = signAndToken.first; 1490 const Token &token = signAndToken.second; 1491 1492 // Check to see if floating point values were parsed. 1493 if (token.is(Token::floatliteral)) { 1494 p.emitError() << "expected integer elements, but parsed floating-point"; 1495 return nullptr; 1496 } 1497 1498 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 1499 "unexpected token type"); 1500 if (token.isAny(Token::kw_true, Token::kw_false)) { 1501 if (!eltTy.isInteger(1)) 1502 p.emitError() << "expected i1 type for 'true' or 'false' values"; 1503 APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), 1504 /*isSigned=*/false); 1505 intElements.push_back(apInt); 1506 continue; 1507 } 1508 1509 // Create APInt values for each element with the correct bitwidth. 1510 auto val = token.getUInt64IntegerValue(); 1511 if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 1512 : (int64_t)val.getValue() < 0)) { 1513 p.emitError(token.getLoc(), 1514 "integer constant out of range for attribute"); 1515 return nullptr; 1516 } 1517 APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); 1518 if (apInt != val.getValue()) 1519 return (p.emitError("integer constant out of range for type"), nullptr); 1520 intElements.push_back(isNegative ? -apInt : apInt); 1521 } 1522 1523 return DenseElementsAttr::get(type, intElements); 1524 } 1525 1526 /// Build a Dense Float attribute for the given type. 1527 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, 1528 ShapedType type, 1529 FloatType eltTy) { 1530 std::vector<Attribute> floatValues; 1531 floatValues.reserve(storage.size()); 1532 for (const auto &signAndToken : storage) { 1533 bool isNegative = signAndToken.first; 1534 const Token &token = signAndToken.second; 1535 1536 // Handle hexadecimal float literals. 1537 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 1538 if (isNegative) { 1539 p.emitError(token.getLoc()) 1540 << "hexadecimal float literal should not have a leading minus"; 1541 return nullptr; 1542 } 1543 auto val = token.getUInt64IntegerValue(); 1544 if (!val.hasValue()) { 1545 p.emitError("hexadecimal float constant out of range for attribute"); 1546 return nullptr; 1547 } 1548 FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); 1549 if (!attr) 1550 return nullptr; 1551 floatValues.push_back(attr); 1552 continue; 1553 } 1554 1555 // Check to see if any decimal integers or booleans were parsed. 1556 if (!token.is(Token::floatliteral)) { 1557 p.emitError() << "expected floating-point elements, but parsed integer"; 1558 return nullptr; 1559 } 1560 1561 // Build the float values from tokens. 1562 auto val = token.getFloatingPointValue(); 1563 if (!val.hasValue()) { 1564 p.emitError("floating point value too large for attribute"); 1565 return nullptr; 1566 } 1567 floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); 1568 } 1569 1570 return DenseElementsAttr::get(type, floatValues); 1571 } 1572 1573 ParseResult TensorLiteralParser::parseElement() { 1574 switch (p.getToken().getKind()) { 1575 // Parse a boolean element. 1576 case Token::kw_true: 1577 case Token::kw_false: 1578 case Token::floatliteral: 1579 case Token::integer: 1580 storage.emplace_back(/*isNegative=*/false, p.getToken()); 1581 p.consumeToken(); 1582 break; 1583 1584 // Parse a signed integer or a negative floating-point element. 1585 case Token::minus: 1586 p.consumeToken(Token::minus); 1587 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 1588 return p.emitError("expected integer or floating point literal"); 1589 storage.emplace_back(/*isNegative=*/true, p.getToken()); 1590 p.consumeToken(); 1591 break; 1592 1593 default: 1594 return p.emitError("expected element literal of primitive type"); 1595 } 1596 1597 return success(); 1598 } 1599 1600 /// Parse a list of either lists or elements, returning the dimensions of the 1601 /// parsed sub-tensors in dims. For example: 1602 /// parseList([1, 2, 3]) -> Success, [3] 1603 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1604 /// parseList([[1, 2], 3]) -> Failure 1605 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1606 ParseResult 1607 TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) { 1608 p.consumeToken(Token::l_square); 1609 1610 auto checkDims = 1611 [&](const llvm::SmallVectorImpl<int64_t> &prevDims, 1612 const llvm::SmallVectorImpl<int64_t> &newDims) -> ParseResult { 1613 if (prevDims == newDims) 1614 return success(); 1615 return p.emitError("tensor literal is invalid; ranks are not consistent " 1616 "between elements"); 1617 }; 1618 1619 bool first = true; 1620 llvm::SmallVector<int64_t, 4> newDims; 1621 unsigned size = 0; 1622 auto parseCommaSeparatedList = [&]() -> ParseResult { 1623 llvm::SmallVector<int64_t, 4> thisDims; 1624 if (p.getToken().getKind() == Token::l_square) { 1625 if (parseList(thisDims)) 1626 return failure(); 1627 } else if (parseElement()) { 1628 return failure(); 1629 } 1630 ++size; 1631 if (!first) 1632 return checkDims(newDims, thisDims); 1633 newDims = thisDims; 1634 first = false; 1635 return success(); 1636 }; 1637 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 1638 return failure(); 1639 1640 // Return the sublists' dimensions with 'size' prepended. 1641 dims.clear(); 1642 dims.push_back(size); 1643 dims.append(newDims.begin(), newDims.end()); 1644 return success(); 1645 } 1646 1647 /// Parse a dense elements attribute. 1648 Attribute Parser::parseDenseElementsAttr() { 1649 consumeToken(Token::kw_dense); 1650 if (parseToken(Token::less, "expected '<' after 'dense'")) 1651 return nullptr; 1652 1653 // Parse the literal data. 1654 TensorLiteralParser literalParser(*this); 1655 if (literalParser.parse()) 1656 return nullptr; 1657 1658 if (parseToken(Token::greater, "expected '>'") || 1659 parseToken(Token::colon, "expected ':'")) 1660 return nullptr; 1661 1662 auto typeLoc = getToken().getLoc(); 1663 auto type = parseElementsLiteralType(); 1664 if (!type) 1665 return nullptr; 1666 return literalParser.getAttr(typeLoc, type); 1667 } 1668 1669 /// Shaped type for elements attribute. 1670 /// 1671 /// elements-literal-type ::= vector-type | ranked-tensor-type 1672 /// 1673 /// This method also checks the type has static shape. 1674 ShapedType Parser::parseElementsLiteralType() { 1675 auto type = parseType(); 1676 if (!type) 1677 return nullptr; 1678 1679 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 1680 emitError("elements literal must be a ranked tensor or vector type"); 1681 return nullptr; 1682 } 1683 1684 auto sType = type.cast<ShapedType>(); 1685 if (!sType.hasStaticShape()) 1686 return (emitError("elements literal type must have static shape"), nullptr); 1687 1688 return sType; 1689 } 1690 1691 /// Parse a sparse elements attribute. 1692 Attribute Parser::parseSparseElementsAttr() { 1693 consumeToken(Token::kw_sparse); 1694 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 1695 return nullptr; 1696 1697 /// Parse indices 1698 auto indicesLoc = getToken().getLoc(); 1699 TensorLiteralParser indiceParser(*this); 1700 if (indiceParser.parse()) 1701 return nullptr; 1702 1703 if (parseToken(Token::comma, "expected ','")) 1704 return nullptr; 1705 1706 /// Parse values. 1707 auto valuesLoc = getToken().getLoc(); 1708 TensorLiteralParser valuesParser(*this); 1709 if (valuesParser.parse()) 1710 return nullptr; 1711 1712 if (parseToken(Token::greater, "expected '>'") || 1713 parseToken(Token::colon, "expected ':'")) 1714 return nullptr; 1715 1716 auto type = parseElementsLiteralType(); 1717 if (!type) 1718 return nullptr; 1719 1720 // If the indices are a splat, i.e. the literal parser parsed an element and 1721 // not a list, we set the shape explicitly. The indices are represented by a 1722 // 2-dimensional shape where the second dimension is the rank of the type. 1723 // Given that the parsed indices is a splat, we know that we only have one 1724 // indice and thus one for the first dimension. 1725 auto indiceEltType = builder.getIntegerType(64); 1726 ShapedType indicesType; 1727 if (indiceParser.getShape().empty()) { 1728 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 1729 } else { 1730 // Otherwise, set the shape to the one parsed by the literal parser. 1731 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 1732 } 1733 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 1734 1735 // If the values are a splat, set the shape explicitly based on the number of 1736 // indices. The number of indices is encoded in the first dimension of the 1737 // indice shape type. 1738 auto valuesEltType = type.getElementType(); 1739 ShapedType valuesType = 1740 valuesParser.getShape().empty() 1741 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 1742 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 1743 auto values = valuesParser.getAttr(valuesLoc, valuesType); 1744 1745 /// Sanity check. 1746 if (valuesType.getRank() != 1) 1747 return (emitError("expected 1-d tensor for values"), nullptr); 1748 1749 auto sameShape = (indicesType.getRank() == 1) || 1750 (type.getRank() == indicesType.getDimSize(1)); 1751 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 1752 if (!sameShape || !sameElementNum) { 1753 emitError() << "expected shape ([" << type.getShape() 1754 << "]); inferred shape of indices literal ([" 1755 << indicesType.getShape() 1756 << "]); inferred shape of values literal ([" 1757 << valuesType.getShape() << "])"; 1758 return nullptr; 1759 } 1760 1761 // Build the sparse elements attribute by the indices and values. 1762 return SparseElementsAttr::get(type, indices, values); 1763 } 1764 1765 //===----------------------------------------------------------------------===// 1766 // Location parsing. 1767 //===----------------------------------------------------------------------===// 1768 1769 /// Parse a location. 1770 /// 1771 /// location ::= `loc` inline-location 1772 /// inline-location ::= '(' location-inst ')' 1773 /// 1774 ParseResult Parser::parseLocation(LocationAttr &loc) { 1775 // Check for 'loc' identifier. 1776 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 1777 return emitError(); 1778 1779 // Parse the inline-location. 1780 if (parseToken(Token::l_paren, "expected '(' in inline location") || 1781 parseLocationInstance(loc) || 1782 parseToken(Token::r_paren, "expected ')' in inline location")) 1783 return failure(); 1784 return success(); 1785 } 1786 1787 /// Specific location instances. 1788 /// 1789 /// location-inst ::= filelinecol-location | 1790 /// name-location | 1791 /// callsite-location | 1792 /// fused-location | 1793 /// unknown-location 1794 /// filelinecol-location ::= string-literal ':' integer-literal 1795 /// ':' integer-literal 1796 /// name-location ::= string-literal 1797 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 1798 /// fused-location ::= fused ('<' attribute-value '>')? 1799 /// '[' location-inst (location-inst ',')* ']' 1800 /// unknown-location ::= 'unknown' 1801 /// 1802 ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { 1803 consumeToken(Token::bare_identifier); 1804 1805 // Parse the '('. 1806 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 1807 return failure(); 1808 1809 // Parse the callee location. 1810 LocationAttr calleeLoc; 1811 if (parseLocationInstance(calleeLoc)) 1812 return failure(); 1813 1814 // Parse the 'at'. 1815 if (getToken().isNot(Token::bare_identifier) || 1816 getToken().getSpelling() != "at") 1817 return emitError("expected 'at' in callsite location"); 1818 consumeToken(Token::bare_identifier); 1819 1820 // Parse the caller location. 1821 LocationAttr callerLoc; 1822 if (parseLocationInstance(callerLoc)) 1823 return failure(); 1824 1825 // Parse the ')'. 1826 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 1827 return failure(); 1828 1829 // Return the callsite location. 1830 loc = CallSiteLoc::get(calleeLoc, callerLoc); 1831 return success(); 1832 } 1833 1834 ParseResult Parser::parseFusedLocation(LocationAttr &loc) { 1835 consumeToken(Token::bare_identifier); 1836 1837 // Try to parse the optional metadata. 1838 Attribute metadata; 1839 if (consumeIf(Token::less)) { 1840 metadata = parseAttribute(); 1841 if (!metadata) 1842 return emitError("expected valid attribute metadata"); 1843 // Parse the '>' token. 1844 if (parseToken(Token::greater, 1845 "expected '>' after fused location metadata")) 1846 return failure(); 1847 } 1848 1849 llvm::SmallVector<Location, 4> locations; 1850 auto parseElt = [&] { 1851 LocationAttr newLoc; 1852 if (parseLocationInstance(newLoc)) 1853 return failure(); 1854 locations.push_back(newLoc); 1855 return success(); 1856 }; 1857 1858 if (parseToken(Token::l_square, "expected '[' in fused location") || 1859 parseCommaSeparatedList(parseElt) || 1860 parseToken(Token::r_square, "expected ']' in fused location")) 1861 return failure(); 1862 1863 // Return the fused location. 1864 loc = FusedLoc::get(locations, metadata, getContext()); 1865 return success(); 1866 } 1867 1868 ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { 1869 auto *ctx = getContext(); 1870 auto str = getToken().getStringValue(); 1871 consumeToken(Token::string); 1872 1873 // If the next token is ':' this is a filelinecol location. 1874 if (consumeIf(Token::colon)) { 1875 // Parse the line number. 1876 if (getToken().isNot(Token::integer)) 1877 return emitError("expected integer line number in FileLineColLoc"); 1878 auto line = getToken().getUnsignedIntegerValue(); 1879 if (!line.hasValue()) 1880 return emitError("expected integer line number in FileLineColLoc"); 1881 consumeToken(Token::integer); 1882 1883 // Parse the ':'. 1884 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 1885 return failure(); 1886 1887 // Parse the column number. 1888 if (getToken().isNot(Token::integer)) 1889 return emitError("expected integer column number in FileLineColLoc"); 1890 auto column = getToken().getUnsignedIntegerValue(); 1891 if (!column.hasValue()) 1892 return emitError("expected integer column number in FileLineColLoc"); 1893 consumeToken(Token::integer); 1894 1895 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 1896 return success(); 1897 } 1898 1899 // Otherwise, this is a NameLoc. 1900 1901 // Check for a child location. 1902 if (consumeIf(Token::l_paren)) { 1903 auto childSourceLoc = getToken().getLoc(); 1904 1905 // Parse the child location. 1906 LocationAttr childLoc; 1907 if (parseLocationInstance(childLoc)) 1908 return failure(); 1909 1910 // The child must not be another NameLoc. 1911 if (childLoc.isa<NameLoc>()) 1912 return emitError(childSourceLoc, 1913 "child of NameLoc cannot be another NameLoc"); 1914 loc = NameLoc::get(Identifier::get(str, ctx), childLoc); 1915 1916 // Parse the closing ')'. 1917 if (parseToken(Token::r_paren, 1918 "expected ')' after child location of NameLoc")) 1919 return failure(); 1920 } else { 1921 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 1922 } 1923 1924 return success(); 1925 } 1926 1927 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 1928 // Handle either name or filelinecol locations. 1929 if (getToken().is(Token::string)) 1930 return parseNameOrFileLineColLocation(loc); 1931 1932 // Bare tokens required for other cases. 1933 if (!getToken().is(Token::bare_identifier)) 1934 return emitError("expected location instance"); 1935 1936 // Check for the 'callsite' signifying a callsite location. 1937 if (getToken().getSpelling() == "callsite") 1938 return parseCallSiteLocation(loc); 1939 1940 // If the token is 'fused', then this is a fused location. 1941 if (getToken().getSpelling() == "fused") 1942 return parseFusedLocation(loc); 1943 1944 // Check for a 'unknown' for an unknown location. 1945 if (getToken().getSpelling() == "unknown") { 1946 consumeToken(Token::bare_identifier); 1947 loc = UnknownLoc::get(getContext()); 1948 return success(); 1949 } 1950 1951 return emitError("expected location instance"); 1952 } 1953 1954 //===----------------------------------------------------------------------===// 1955 // Affine parsing. 1956 //===----------------------------------------------------------------------===// 1957 1958 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 1959 /// the boolean sense. 1960 enum AffineLowPrecOp { 1961 /// Null value. 1962 LNoOp, 1963 Add, 1964 Sub 1965 }; 1966 1967 /// Higher precedence ops - all at the same precedence level. HNoOp is false 1968 /// in the boolean sense. 1969 enum AffineHighPrecOp { 1970 /// Null value. 1971 HNoOp, 1972 Mul, 1973 FloorDiv, 1974 CeilDiv, 1975 Mod 1976 }; 1977 1978 namespace { 1979 /// This is a specialized parser for affine structures (affine maps, affine 1980 /// expressions, and integer sets), maintaining the state transient to their 1981 /// bodies. 1982 class AffineParser : public Parser { 1983 public: 1984 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 1985 llvm::function_ref<ParseResult(bool)> parseElement = nullptr) 1986 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 1987 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 1988 1989 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 1990 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 1991 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 1992 ParseResult parseAffineMapOfSSAIds(AffineMap &map); 1993 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 1994 unsigned &numDims); 1995 1996 private: 1997 // Binary affine op parsing. 1998 AffineLowPrecOp consumeIfLowPrecOp(); 1999 AffineHighPrecOp consumeIfHighPrecOp(); 2000 2001 // Identifier lists for polyhedral structures. 2002 ParseResult parseDimIdList(unsigned &numDims); 2003 ParseResult parseSymbolIdList(unsigned &numSymbols); 2004 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 2005 unsigned &numSymbols); 2006 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 2007 2008 AffineExpr parseAffineExpr(); 2009 AffineExpr parseParentheticalExpr(); 2010 AffineExpr parseNegateExpression(AffineExpr lhs); 2011 AffineExpr parseIntegerExpr(); 2012 AffineExpr parseBareIdExpr(); 2013 AffineExpr parseSSAIdExpr(bool isSymbol); 2014 AffineExpr parseSymbolSSAIdExpr(); 2015 2016 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 2017 AffineExpr rhs, SMLoc opLoc); 2018 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 2019 AffineExpr rhs); 2020 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 2021 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 2022 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 2023 SMLoc llhsOpLoc); 2024 AffineExpr parseAffineConstraint(bool *isEq); 2025 2026 private: 2027 bool allowParsingSSAIds; 2028 llvm::function_ref<ParseResult(bool)> parseElement; 2029 unsigned numDimOperands; 2030 unsigned numSymbolOperands; 2031 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 2032 }; 2033 } // end anonymous namespace 2034 2035 /// Create an affine binary high precedence op expression (mul's, div's, mod). 2036 /// opLoc is the location of the op token to be used to report errors 2037 /// for non-conforming expressions. 2038 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 2039 AffineExpr lhs, AffineExpr rhs, 2040 SMLoc opLoc) { 2041 // TODO: make the error location info accurate. 2042 switch (op) { 2043 case Mul: 2044 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 2045 emitError(opLoc, "non-affine expression: at least one of the multiply " 2046 "operands has to be either a constant or symbolic"); 2047 return nullptr; 2048 } 2049 return lhs * rhs; 2050 case FloorDiv: 2051 if (!rhs.isSymbolicOrConstant()) { 2052 emitError(opLoc, "non-affine expression: right operand of floordiv " 2053 "has to be either a constant or symbolic"); 2054 return nullptr; 2055 } 2056 return lhs.floorDiv(rhs); 2057 case CeilDiv: 2058 if (!rhs.isSymbolicOrConstant()) { 2059 emitError(opLoc, "non-affine expression: right operand of ceildiv " 2060 "has to be either a constant or symbolic"); 2061 return nullptr; 2062 } 2063 return lhs.ceilDiv(rhs); 2064 case Mod: 2065 if (!rhs.isSymbolicOrConstant()) { 2066 emitError(opLoc, "non-affine expression: right operand of mod " 2067 "has to be either a constant or symbolic"); 2068 return nullptr; 2069 } 2070 return lhs % rhs; 2071 case HNoOp: 2072 llvm_unreachable("can't create affine expression for null high prec op"); 2073 return nullptr; 2074 } 2075 llvm_unreachable("Unknown AffineHighPrecOp"); 2076 } 2077 2078 /// Create an affine binary low precedence op expression (add, sub). 2079 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 2080 AffineExpr lhs, AffineExpr rhs) { 2081 switch (op) { 2082 case AffineLowPrecOp::Add: 2083 return lhs + rhs; 2084 case AffineLowPrecOp::Sub: 2085 return lhs - rhs; 2086 case AffineLowPrecOp::LNoOp: 2087 llvm_unreachable("can't create affine expression for null low prec op"); 2088 return nullptr; 2089 } 2090 llvm_unreachable("Unknown AffineLowPrecOp"); 2091 } 2092 2093 /// Consume this token if it is a lower precedence affine op (there are only 2094 /// two precedence levels). 2095 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 2096 switch (getToken().getKind()) { 2097 case Token::plus: 2098 consumeToken(Token::plus); 2099 return AffineLowPrecOp::Add; 2100 case Token::minus: 2101 consumeToken(Token::minus); 2102 return AffineLowPrecOp::Sub; 2103 default: 2104 return AffineLowPrecOp::LNoOp; 2105 } 2106 } 2107 2108 /// Consume this token if it is a higher precedence affine op (there are only 2109 /// two precedence levels) 2110 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 2111 switch (getToken().getKind()) { 2112 case Token::star: 2113 consumeToken(Token::star); 2114 return Mul; 2115 case Token::kw_floordiv: 2116 consumeToken(Token::kw_floordiv); 2117 return FloorDiv; 2118 case Token::kw_ceildiv: 2119 consumeToken(Token::kw_ceildiv); 2120 return CeilDiv; 2121 case Token::kw_mod: 2122 consumeToken(Token::kw_mod); 2123 return Mod; 2124 default: 2125 return HNoOp; 2126 } 2127 } 2128 2129 /// Parse a high precedence op expression list: mul, div, and mod are high 2130 /// precedence binary ops, i.e., parse a 2131 /// expr_1 op_1 expr_2 op_2 ... expr_n 2132 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2133 /// All affine binary ops are left associative. 2134 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2135 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2136 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2137 /// report an error for non-conforming expressions. 2138 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2139 AffineHighPrecOp llhsOp, 2140 SMLoc llhsOpLoc) { 2141 AffineExpr lhs = parseAffineOperandExpr(llhs); 2142 if (!lhs) 2143 return nullptr; 2144 2145 // Found an LHS. Parse the remaining expression. 2146 auto opLoc = getToken().getLoc(); 2147 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2148 if (llhs) { 2149 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2150 if (!expr) 2151 return nullptr; 2152 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2153 } 2154 // No LLHS, get RHS 2155 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2156 } 2157 2158 // This is the last operand in this expression. 2159 if (llhs) 2160 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2161 2162 // No llhs, 'lhs' itself is the expression. 2163 return lhs; 2164 } 2165 2166 /// Parse an affine expression inside parentheses. 2167 /// 2168 /// affine-expr ::= `(` affine-expr `)` 2169 AffineExpr AffineParser::parseParentheticalExpr() { 2170 if (parseToken(Token::l_paren, "expected '('")) 2171 return nullptr; 2172 if (getToken().is(Token::r_paren)) 2173 return (emitError("no expression inside parentheses"), nullptr); 2174 2175 auto expr = parseAffineExpr(); 2176 if (!expr) 2177 return nullptr; 2178 if (parseToken(Token::r_paren, "expected ')'")) 2179 return nullptr; 2180 2181 return expr; 2182 } 2183 2184 /// Parse the negation expression. 2185 /// 2186 /// affine-expr ::= `-` affine-expr 2187 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2188 if (parseToken(Token::minus, "expected '-'")) 2189 return nullptr; 2190 2191 AffineExpr operand = parseAffineOperandExpr(lhs); 2192 // Since negation has the highest precedence of all ops (including high 2193 // precedence ops) but lower than parentheses, we are only going to use 2194 // parseAffineOperandExpr instead of parseAffineExpr here. 2195 if (!operand) 2196 // Extra error message although parseAffineOperandExpr would have 2197 // complained. Leads to a better diagnostic. 2198 return (emitError("missing operand of negation"), nullptr); 2199 return (-1) * operand; 2200 } 2201 2202 /// Parse a bare id that may appear in an affine expression. 2203 /// 2204 /// affine-expr ::= bare-id 2205 AffineExpr AffineParser::parseBareIdExpr() { 2206 if (getToken().isNot(Token::bare_identifier)) 2207 return (emitError("expected bare identifier"), nullptr); 2208 2209 StringRef sRef = getTokenSpelling(); 2210 for (auto entry : dimsAndSymbols) { 2211 if (entry.first == sRef) { 2212 consumeToken(Token::bare_identifier); 2213 return entry.second; 2214 } 2215 } 2216 2217 return (emitError("use of undeclared identifier"), nullptr); 2218 } 2219 2220 /// Parse an SSA id which may appear in an affine expression. 2221 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2222 if (!allowParsingSSAIds) 2223 return (emitError("unexpected ssa identifier"), nullptr); 2224 if (getToken().isNot(Token::percent_identifier)) 2225 return (emitError("expected ssa identifier"), nullptr); 2226 auto name = getTokenSpelling(); 2227 // Check if we already parsed this SSA id. 2228 for (auto entry : dimsAndSymbols) { 2229 if (entry.first == name) { 2230 consumeToken(Token::percent_identifier); 2231 return entry.second; 2232 } 2233 } 2234 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2235 if (parseElement(isSymbol)) 2236 return (emitError("failed to parse ssa identifier"), nullptr); 2237 auto idExpr = isSymbol 2238 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2239 : getAffineDimExpr(numDimOperands++, getContext()); 2240 dimsAndSymbols.push_back({name, idExpr}); 2241 return idExpr; 2242 } 2243 2244 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2245 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2246 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2247 return nullptr; 2248 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2249 if (!symbolExpr) 2250 return nullptr; 2251 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2252 return nullptr; 2253 return symbolExpr; 2254 } 2255 2256 /// Parse a positive integral constant appearing in an affine expression. 2257 /// 2258 /// affine-expr ::= integer-literal 2259 AffineExpr AffineParser::parseIntegerExpr() { 2260 auto val = getToken().getUInt64IntegerValue(); 2261 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2262 return (emitError("constant too large for index"), nullptr); 2263 2264 consumeToken(Token::integer); 2265 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2266 } 2267 2268 /// Parses an expression that can be a valid operand of an affine expression. 2269 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2270 /// operator, the rhs of which is being parsed. This is used to determine 2271 /// whether an error should be emitted for a missing right operand. 2272 // Eg: for an expression without parentheses (like i + j + k + l), each 2273 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2274 // operand expression, it's an op expression and will be parsed via 2275 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2276 // -l are valid operands that will be parsed by this function. 2277 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2278 switch (getToken().getKind()) { 2279 case Token::bare_identifier: 2280 return parseBareIdExpr(); 2281 case Token::kw_symbol: 2282 return parseSymbolSSAIdExpr(); 2283 case Token::percent_identifier: 2284 return parseSSAIdExpr(/*isSymbol=*/false); 2285 case Token::integer: 2286 return parseIntegerExpr(); 2287 case Token::l_paren: 2288 return parseParentheticalExpr(); 2289 case Token::minus: 2290 return parseNegateExpression(lhs); 2291 case Token::kw_ceildiv: 2292 case Token::kw_floordiv: 2293 case Token::kw_mod: 2294 case Token::plus: 2295 case Token::star: 2296 if (lhs) 2297 emitError("missing right operand of binary operator"); 2298 else 2299 emitError("missing left operand of binary operator"); 2300 return nullptr; 2301 default: 2302 if (lhs) 2303 emitError("missing right operand of binary operator"); 2304 else 2305 emitError("expected affine expression"); 2306 return nullptr; 2307 } 2308 } 2309 2310 /// Parse affine expressions that are bare-id's, integer constants, 2311 /// parenthetical affine expressions, and affine op expressions that are a 2312 /// composition of those. 2313 /// 2314 /// All binary op's associate from left to right. 2315 /// 2316 /// {add, sub} have lower precedence than {mul, div, and mod}. 2317 /// 2318 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2319 /// ceildiv, and mod are at the same higher precedence level. Negation has 2320 /// higher precedence than any binary op. 2321 /// 2322 /// llhs: the affine expression appearing on the left of the one being parsed. 2323 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2324 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2325 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2326 /// associativity. 2327 /// 2328 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2329 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2330 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2331 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2332 AffineLowPrecOp llhsOp) { 2333 AffineExpr lhs; 2334 if (!(lhs = parseAffineOperandExpr(llhs))) 2335 return nullptr; 2336 2337 // Found an LHS. Deal with the ops. 2338 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 2339 if (llhs) { 2340 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2341 return parseAffineLowPrecOpExpr(sum, lOp); 2342 } 2343 // No LLHS, get RHS and form the expression. 2344 return parseAffineLowPrecOpExpr(lhs, lOp); 2345 } 2346 auto opLoc = getToken().getLoc(); 2347 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 2348 // We have a higher precedence op here. Get the rhs operand for the llhs 2349 // through parseAffineHighPrecOpExpr. 2350 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 2351 if (!highRes) 2352 return nullptr; 2353 2354 // If llhs is null, the product forms the first operand of the yet to be 2355 // found expression. If non-null, the op to associate with llhs is llhsOp. 2356 AffineExpr expr = 2357 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 2358 2359 // Recurse for subsequent low prec op's after the affine high prec op 2360 // expression. 2361 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 2362 return parseAffineLowPrecOpExpr(expr, nextOp); 2363 return expr; 2364 } 2365 // Last operand in the expression list. 2366 if (llhs) 2367 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2368 // No llhs, 'lhs' itself is the expression. 2369 return lhs; 2370 } 2371 2372 /// Parse an affine expression. 2373 /// affine-expr ::= `(` affine-expr `)` 2374 /// | `-` affine-expr 2375 /// | affine-expr `+` affine-expr 2376 /// | affine-expr `-` affine-expr 2377 /// | affine-expr `*` affine-expr 2378 /// | affine-expr `floordiv` affine-expr 2379 /// | affine-expr `ceildiv` affine-expr 2380 /// | affine-expr `mod` affine-expr 2381 /// | bare-id 2382 /// | integer-literal 2383 /// 2384 /// Additional conditions are checked depending on the production. For eg., 2385 /// one of the operands for `*` has to be either constant/symbolic; the second 2386 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 2387 AffineExpr AffineParser::parseAffineExpr() { 2388 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 2389 } 2390 2391 /// Parse a dim or symbol from the lists appearing before the actual 2392 /// expressions of the affine map. Update our state to store the 2393 /// dimensional/symbolic identifier. 2394 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 2395 if (getToken().isNot(Token::bare_identifier)) 2396 return emitError("expected bare identifier"); 2397 2398 auto name = getTokenSpelling(); 2399 for (auto entry : dimsAndSymbols) { 2400 if (entry.first == name) 2401 return emitError("redefinition of identifier '" + name + "'"); 2402 } 2403 consumeToken(Token::bare_identifier); 2404 2405 dimsAndSymbols.push_back({name, idExpr}); 2406 return success(); 2407 } 2408 2409 /// Parse the list of dimensional identifiers to an affine map. 2410 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 2411 if (parseToken(Token::l_paren, 2412 "expected '(' at start of dimensional identifiers list")) { 2413 return failure(); 2414 } 2415 2416 auto parseElt = [&]() -> ParseResult { 2417 auto dimension = getAffineDimExpr(numDims++, getContext()); 2418 return parseIdentifierDefinition(dimension); 2419 }; 2420 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 2421 } 2422 2423 /// Parse the list of symbolic identifiers to an affine map. 2424 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 2425 consumeToken(Token::l_square); 2426 auto parseElt = [&]() -> ParseResult { 2427 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 2428 return parseIdentifierDefinition(symbol); 2429 }; 2430 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 2431 } 2432 2433 /// Parse the list of symbolic identifiers to an affine map. 2434 ParseResult 2435 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 2436 unsigned &numSymbols) { 2437 if (parseDimIdList(numDims)) { 2438 return failure(); 2439 } 2440 if (!getToken().is(Token::l_square)) { 2441 numSymbols = 0; 2442 return success(); 2443 } 2444 return parseSymbolIdList(numSymbols); 2445 } 2446 2447 /// Parses an ambiguous affine map or integer set definition inline. 2448 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 2449 IntegerSet &set) { 2450 unsigned numDims = 0, numSymbols = 0; 2451 2452 // List of dimensional and optional symbol identifiers. 2453 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 2454 return failure(); 2455 } 2456 2457 // This is needed for parsing attributes as we wouldn't know whether we would 2458 // be parsing an integer set attribute or an affine map attribute. 2459 bool isArrow = getToken().is(Token::arrow); 2460 bool isColon = getToken().is(Token::colon); 2461 if (!isArrow && !isColon) { 2462 return emitError("expected '->' or ':'"); 2463 } else if (isArrow) { 2464 parseToken(Token::arrow, "expected '->' or '['"); 2465 map = parseAffineMapRange(numDims, numSymbols); 2466 return map ? success() : failure(); 2467 } else if (parseToken(Token::colon, "expected ':' or '['")) { 2468 return failure(); 2469 } 2470 2471 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 2472 return success(); 2473 2474 return failure(); 2475 } 2476 2477 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 2478 ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) { 2479 if (parseToken(Token::l_square, "expected '['")) 2480 return failure(); 2481 2482 SmallVector<AffineExpr, 4> exprs; 2483 auto parseElt = [&]() -> ParseResult { 2484 auto elt = parseAffineExpr(); 2485 exprs.push_back(elt); 2486 return elt ? success() : failure(); 2487 }; 2488 2489 // Parse a multi-dimensional affine expression (a comma-separated list of 2490 // 1-d affine expressions); the list cannot be empty. Grammar: 2491 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2492 if (parseCommaSeparatedListUntil(Token::r_square, parseElt, 2493 /*allowEmptyList=*/true)) 2494 return failure(); 2495 // Parsed a valid affine map. 2496 if (exprs.empty()) 2497 map = AffineMap(); 2498 else 2499 map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands, 2500 exprs); 2501 return success(); 2502 } 2503 2504 /// Parse the range and sizes affine map definition inline. 2505 /// 2506 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 2507 /// 2508 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2509 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 2510 unsigned numSymbols) { 2511 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 2512 2513 SmallVector<AffineExpr, 4> exprs; 2514 auto parseElt = [&]() -> ParseResult { 2515 auto elt = parseAffineExpr(); 2516 ParseResult res = elt ? success() : failure(); 2517 exprs.push_back(elt); 2518 return res; 2519 }; 2520 2521 // Parse a multi-dimensional affine expression (a comma-separated list of 2522 // 1-d affine expressions); the list cannot be empty. Grammar: 2523 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2524 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false)) 2525 return AffineMap(); 2526 2527 // Parsed a valid affine map. 2528 return AffineMap::get(numDims, numSymbols, exprs); 2529 } 2530 2531 /// Parse an affine constraint. 2532 /// affine-constraint ::= affine-expr `>=` `0` 2533 /// | affine-expr `==` `0` 2534 /// 2535 /// isEq is set to true if the parsed constraint is an equality, false if it 2536 /// is an inequality (greater than or equal). 2537 /// 2538 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 2539 AffineExpr expr = parseAffineExpr(); 2540 if (!expr) 2541 return nullptr; 2542 2543 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 2544 getToken().is(Token::integer)) { 2545 auto dim = getToken().getUnsignedIntegerValue(); 2546 if (dim.hasValue() && dim.getValue() == 0) { 2547 consumeToken(Token::integer); 2548 *isEq = false; 2549 return expr; 2550 } 2551 return (emitError("expected '0' after '>='"), nullptr); 2552 } 2553 2554 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 2555 getToken().is(Token::integer)) { 2556 auto dim = getToken().getUnsignedIntegerValue(); 2557 if (dim.hasValue() && dim.getValue() == 0) { 2558 consumeToken(Token::integer); 2559 *isEq = true; 2560 return expr; 2561 } 2562 return (emitError("expected '0' after '=='"), nullptr); 2563 } 2564 2565 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 2566 nullptr); 2567 } 2568 2569 /// Parse the constraints that are part of an integer set definition. 2570 /// integer-set-inline 2571 /// ::= dim-and-symbol-id-lists `:` 2572 /// '(' affine-constraint-conjunction? ')' 2573 /// affine-constraint-conjunction ::= affine-constraint (`,` 2574 /// affine-constraint)* 2575 /// 2576 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 2577 unsigned numSymbols) { 2578 if (parseToken(Token::l_paren, 2579 "expected '(' at start of integer set constraint list")) 2580 return IntegerSet(); 2581 2582 SmallVector<AffineExpr, 4> constraints; 2583 SmallVector<bool, 4> isEqs; 2584 auto parseElt = [&]() -> ParseResult { 2585 bool isEq; 2586 auto elt = parseAffineConstraint(&isEq); 2587 ParseResult res = elt ? success() : failure(); 2588 if (elt) { 2589 constraints.push_back(elt); 2590 isEqs.push_back(isEq); 2591 } 2592 return res; 2593 }; 2594 2595 // Parse a list of affine constraints (comma-separated). 2596 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 2597 return IntegerSet(); 2598 2599 // If no constraints were parsed, then treat this as a degenerate 'true' case. 2600 if (constraints.empty()) { 2601 /* 0 == 0 */ 2602 auto zero = getAffineConstantExpr(0, getContext()); 2603 return IntegerSet::get(numDims, numSymbols, zero, true); 2604 } 2605 2606 // Parsed a valid integer set. 2607 return IntegerSet::get(numDims, numSymbols, constraints, isEqs); 2608 } 2609 2610 /// Parse an ambiguous reference to either and affine map or an integer set. 2611 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 2612 IntegerSet &set) { 2613 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 2614 } 2615 2616 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 2617 /// parse SSA value uses encountered while parsing affine expressions. 2618 ParseResult Parser::parseAffineMapOfSSAIds( 2619 AffineMap &map, llvm::function_ref<ParseResult(bool)> parseElement) { 2620 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 2621 .parseAffineMapOfSSAIds(map); 2622 } 2623 2624 //===----------------------------------------------------------------------===// 2625 // OperationParser 2626 //===----------------------------------------------------------------------===// 2627 2628 namespace { 2629 /// This class provides support for parsing operations and regions of 2630 /// operations. 2631 class OperationParser : public Parser { 2632 public: 2633 OperationParser(ParserState &state, ModuleOp moduleOp) 2634 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 2635 } 2636 2637 ~OperationParser(); 2638 2639 /// After parsing is finished, this function must be called to see if there 2640 /// are any remaining issues. 2641 ParseResult finalize(); 2642 2643 //===--------------------------------------------------------------------===// 2644 // SSA Value Handling 2645 //===--------------------------------------------------------------------===// 2646 2647 /// This represents a use of an SSA value in the program. The first two 2648 /// entries in the tuple are the name and result number of a reference. The 2649 /// third is the location of the reference, which is used in case this ends 2650 /// up being a use of an undefined value. 2651 struct SSAUseInfo { 2652 StringRef name; // Value name, e.g. %42 or %abc 2653 unsigned number; // Number, specified with #12 2654 SMLoc loc; // Location of first definition or use. 2655 }; 2656 2657 /// Push a new SSA name scope to the parser. 2658 void pushSSANameScope(bool isIsolated); 2659 2660 /// Pop the last SSA name scope from the parser. 2661 ParseResult popSSANameScope(); 2662 2663 /// Register a definition of a value with the symbol table. 2664 ParseResult addDefinition(SSAUseInfo useInfo, Value *value); 2665 2666 /// Parse an optional list of SSA uses into 'results'. 2667 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 2668 2669 /// Parse a single SSA use into 'result'. 2670 ParseResult parseSSAUse(SSAUseInfo &result); 2671 2672 /// Given a reference to an SSA value and its type, return a reference. This 2673 /// returns null on failure. 2674 Value *resolveSSAUse(SSAUseInfo useInfo, Type type); 2675 2676 ParseResult parseSSADefOrUseAndType( 2677 const std::function<ParseResult(SSAUseInfo, Type)> &action); 2678 2679 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value *> &results); 2680 2681 /// Return the location of the value identified by its name and number if it 2682 /// has been already reference. 2683 llvm::Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 2684 auto &values = isolatedNameScopes.back().values; 2685 if (!values.count(name) || number >= values[name].size()) 2686 return {}; 2687 if (values[name][number].first) 2688 return values[name][number].second; 2689 return {}; 2690 } 2691 2692 //===--------------------------------------------------------------------===// 2693 // Operation Parsing 2694 //===--------------------------------------------------------------------===// 2695 2696 /// Parse an operation instance. 2697 ParseResult parseOperation(); 2698 2699 /// Parse a single operation successor and its operand list. 2700 ParseResult parseSuccessorAndUseList(Block *&dest, 2701 SmallVectorImpl<Value *> &operands); 2702 2703 /// Parse a comma-separated list of operation successors in brackets. 2704 ParseResult 2705 parseSuccessors(SmallVectorImpl<Block *> &destinations, 2706 SmallVectorImpl<SmallVector<Value *, 4>> &operands); 2707 2708 /// Parse an operation instance that is in the generic form. 2709 Operation *parseGenericOperation(); 2710 2711 /// Parse an operation instance that is in the generic form and insert it at 2712 /// the provided insertion point. 2713 Operation *parseGenericOperation(Block *insertBlock, 2714 Block::iterator insertPt); 2715 2716 /// Parse an operation instance that is in the op-defined custom form. 2717 Operation *parseCustomOperation(); 2718 2719 //===--------------------------------------------------------------------===// 2720 // Region Parsing 2721 //===--------------------------------------------------------------------===// 2722 2723 /// Parse a region into 'region' with the provided entry block arguments. 2724 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 2725 /// isolated from those above. 2726 ParseResult parseRegion(Region ®ion, 2727 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 2728 bool isIsolatedNameScope = false); 2729 2730 /// Parse a region body into 'region'. 2731 ParseResult parseRegionBody(Region ®ion); 2732 2733 //===--------------------------------------------------------------------===// 2734 // Block Parsing 2735 //===--------------------------------------------------------------------===// 2736 2737 /// Parse a new block into 'block'. 2738 ParseResult parseBlock(Block *&block); 2739 2740 /// Parse a list of operations into 'block'. 2741 ParseResult parseBlockBody(Block *block); 2742 2743 /// Parse a (possibly empty) list of block arguments. 2744 ParseResult 2745 parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results, 2746 Block *owner); 2747 2748 /// Get the block with the specified name, creating it if it doesn't 2749 /// already exist. The location specified is the point of use, which allows 2750 /// us to diagnose references to blocks that are not defined precisely. 2751 Block *getBlockNamed(StringRef name, SMLoc loc); 2752 2753 /// Define the block with the specified name. Returns the Block* or nullptr in 2754 /// the case of redefinition. 2755 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 2756 2757 private: 2758 /// Returns the info for a block at the current scope for the given name. 2759 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 2760 return blocksByName.back()[name]; 2761 } 2762 2763 /// Insert a new forward reference to the given block. 2764 void insertForwardRef(Block *block, SMLoc loc) { 2765 forwardRef.back().try_emplace(block, loc); 2766 } 2767 2768 /// Erase any forward reference to the given block. 2769 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 2770 2771 /// Record that a definition was added at the current scope. 2772 void recordDefinition(StringRef def); 2773 2774 /// Get the value entry for the given SSA name. 2775 SmallVectorImpl<std::pair<Value *, SMLoc>> &getSSAValueEntry(StringRef name); 2776 2777 /// Create a forward reference placeholder value with the given location and 2778 /// result type. 2779 Value *createForwardRefPlaceholder(SMLoc loc, Type type); 2780 2781 /// Return true if this is a forward reference. 2782 bool isForwardRefPlaceholder(Value *value) { 2783 return forwardRefPlaceholders.count(value); 2784 } 2785 2786 /// This struct represents an isolated SSA name scope. This scope may contain 2787 /// other nested non-isolated scopes. These scopes are used for operations 2788 /// that are known to be isolated to allow for reusing names within their 2789 /// regions, even if those names are used above. 2790 struct IsolatedSSANameScope { 2791 /// Record that a definition was added at the current scope. 2792 void recordDefinition(StringRef def) { 2793 definitionsPerScope.back().insert(def); 2794 } 2795 2796 /// Push a nested name scope. 2797 void pushSSANameScope() { definitionsPerScope.push_back({}); } 2798 2799 /// Pop a nested name scope. 2800 void popSSANameScope() { 2801 for (auto &def : definitionsPerScope.pop_back_val()) 2802 values.erase(def.getKey()); 2803 } 2804 2805 /// This keeps track of all of the SSA values we are tracking for each name 2806 /// scope, indexed by their name. This has one entry per result number. 2807 llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values; 2808 2809 /// This keeps track of all of the values defined by a specific name scope. 2810 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 2811 }; 2812 2813 /// A list of isolated name scopes. 2814 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 2815 2816 /// This keeps track of the block names as well as the location of the first 2817 /// reference for each nested name scope. This is used to diagnose invalid 2818 /// block references and memorize them. 2819 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 2820 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 2821 2822 /// These are all of the placeholders we've made along with the location of 2823 /// their first reference, to allow checking for use of undefined values. 2824 DenseMap<Value *, SMLoc> forwardRefPlaceholders; 2825 2826 /// The builder used when creating parsed operation instances. 2827 OpBuilder opBuilder; 2828 2829 /// The top level module operation. 2830 ModuleOp moduleOp; 2831 }; 2832 } // end anonymous namespace 2833 2834 OperationParser::~OperationParser() { 2835 for (auto &fwd : forwardRefPlaceholders) { 2836 // Drop all uses of undefined forward declared reference and destroy 2837 // defining operation. 2838 fwd.first->dropAllUses(); 2839 fwd.first->getDefiningOp()->destroy(); 2840 } 2841 } 2842 2843 /// After parsing is finished, this function must be called to see if there are 2844 /// any remaining issues. 2845 ParseResult OperationParser::finalize() { 2846 // Check for any forward references that are left. If we find any, error 2847 // out. 2848 if (!forwardRefPlaceholders.empty()) { 2849 SmallVector<std::pair<const char *, Value *>, 4> errors; 2850 // Iteration over the map isn't deterministic, so sort by source location. 2851 for (auto entry : forwardRefPlaceholders) 2852 errors.push_back({entry.second.getPointer(), entry.first}); 2853 llvm::array_pod_sort(errors.begin(), errors.end()); 2854 2855 for (auto entry : errors) { 2856 auto loc = SMLoc::getFromPointer(entry.first); 2857 emitError(loc, "use of undeclared SSA value name"); 2858 } 2859 return failure(); 2860 } 2861 2862 return success(); 2863 } 2864 2865 //===----------------------------------------------------------------------===// 2866 // SSA Value Handling 2867 //===----------------------------------------------------------------------===// 2868 2869 void OperationParser::pushSSANameScope(bool isIsolated) { 2870 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 2871 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 2872 2873 // Push back a new name definition scope. 2874 if (isIsolated) 2875 isolatedNameScopes.push_back({}); 2876 isolatedNameScopes.back().pushSSANameScope(); 2877 } 2878 2879 ParseResult OperationParser::popSSANameScope() { 2880 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 2881 2882 // Verify that all referenced blocks were defined. 2883 if (!forwardRefInCurrentScope.empty()) { 2884 SmallVector<std::pair<const char *, Block *>, 4> errors; 2885 // Iteration over the map isn't deterministic, so sort by source location. 2886 for (auto entry : forwardRefInCurrentScope) { 2887 errors.push_back({entry.second.getPointer(), entry.first}); 2888 // Add this block to the top-level region to allow for automatic cleanup. 2889 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 2890 } 2891 llvm::array_pod_sort(errors.begin(), errors.end()); 2892 2893 for (auto entry : errors) { 2894 auto loc = SMLoc::getFromPointer(entry.first); 2895 emitError(loc, "reference to an undefined block"); 2896 } 2897 return failure(); 2898 } 2899 2900 // Pop the next nested namescope. If there is only one internal namescope, 2901 // just pop the isolated scope. 2902 auto ¤tNameScope = isolatedNameScopes.back(); 2903 if (currentNameScope.definitionsPerScope.size() == 1) 2904 isolatedNameScopes.pop_back(); 2905 else 2906 currentNameScope.popSSANameScope(); 2907 2908 blocksByName.pop_back(); 2909 return success(); 2910 } 2911 2912 /// Register a definition of a value with the symbol table. 2913 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value *value) { 2914 auto &entries = getSSAValueEntry(useInfo.name); 2915 2916 // Make sure there is a slot for this value. 2917 if (entries.size() <= useInfo.number) 2918 entries.resize(useInfo.number + 1); 2919 2920 // If we already have an entry for this, check to see if it was a definition 2921 // or a forward reference. 2922 if (auto *existing = entries[useInfo.number].first) { 2923 if (!isForwardRefPlaceholder(existing)) { 2924 return emitError(useInfo.loc) 2925 .append("redefinition of SSA value '", useInfo.name, "'") 2926 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 2927 .append("previously defined here"); 2928 } 2929 2930 // If it was a forward reference, update everything that used it to use 2931 // the actual definition instead, delete the forward ref, and remove it 2932 // from our set of forward references we track. 2933 existing->replaceAllUsesWith(value); 2934 existing->getDefiningOp()->destroy(); 2935 forwardRefPlaceholders.erase(existing); 2936 } 2937 2938 /// Record this definition for the current scope. 2939 entries[useInfo.number] = {value, useInfo.loc}; 2940 recordDefinition(useInfo.name); 2941 return success(); 2942 } 2943 2944 /// Parse a (possibly empty) list of SSA operands. 2945 /// 2946 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 2947 /// ssa-use-list-opt ::= ssa-use-list? 2948 /// 2949 ParseResult 2950 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 2951 if (getToken().isNot(Token::percent_identifier)) 2952 return success(); 2953 return parseCommaSeparatedList([&]() -> ParseResult { 2954 SSAUseInfo result; 2955 if (parseSSAUse(result)) 2956 return failure(); 2957 results.push_back(result); 2958 return success(); 2959 }); 2960 } 2961 2962 /// Parse a SSA operand for an operation. 2963 /// 2964 /// ssa-use ::= ssa-id 2965 /// 2966 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 2967 result.name = getTokenSpelling(); 2968 result.number = 0; 2969 result.loc = getToken().getLoc(); 2970 if (parseToken(Token::percent_identifier, "expected SSA operand")) 2971 return failure(); 2972 2973 // If we have an attribute ID, it is a result number. 2974 if (getToken().is(Token::hash_identifier)) { 2975 if (auto value = getToken().getHashIdentifierNumber()) 2976 result.number = value.getValue(); 2977 else 2978 return emitError("invalid SSA value result number"); 2979 consumeToken(Token::hash_identifier); 2980 } 2981 2982 return success(); 2983 } 2984 2985 /// Given an unbound reference to an SSA value and its type, return the value 2986 /// it specifies. This returns null on failure. 2987 Value *OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 2988 auto &entries = getSSAValueEntry(useInfo.name); 2989 2990 // If we have already seen a value of this name, return it. 2991 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 2992 auto *result = entries[useInfo.number].first; 2993 // Check that the type matches the other uses. 2994 if (result->getType() == type) 2995 return result; 2996 2997 emitError(useInfo.loc, "use of value '") 2998 .append(useInfo.name, 2999 "' expects different type than prior uses: ", type, " vs ", 3000 result->getType()) 3001 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 3002 .append("prior use here"); 3003 return nullptr; 3004 } 3005 3006 // Make sure we have enough slots for this. 3007 if (entries.size() <= useInfo.number) 3008 entries.resize(useInfo.number + 1); 3009 3010 // If the value has already been defined and this is an overly large result 3011 // number, diagnose that. 3012 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 3013 return (emitError(useInfo.loc, "reference to invalid result number"), 3014 nullptr); 3015 3016 // Otherwise, this is a forward reference. Create a placeholder and remember 3017 // that we did so. 3018 auto *result = createForwardRefPlaceholder(useInfo.loc, type); 3019 entries[useInfo.number].first = result; 3020 entries[useInfo.number].second = useInfo.loc; 3021 return result; 3022 } 3023 3024 /// Parse an SSA use with an associated type. 3025 /// 3026 /// ssa-use-and-type ::= ssa-use `:` type 3027 ParseResult OperationParser::parseSSADefOrUseAndType( 3028 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 3029 SSAUseInfo useInfo; 3030 if (parseSSAUse(useInfo) || 3031 parseToken(Token::colon, "expected ':' and type for SSA operand")) 3032 return failure(); 3033 3034 auto type = parseType(); 3035 if (!type) 3036 return failure(); 3037 3038 return action(useInfo, type); 3039 } 3040 3041 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 3042 /// followed by a type list. 3043 /// 3044 /// ssa-use-and-type-list 3045 /// ::= ssa-use-list ':' type-list-no-parens 3046 /// 3047 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 3048 SmallVectorImpl<Value *> &results) { 3049 SmallVector<SSAUseInfo, 4> valueIDs; 3050 if (parseOptionalSSAUseList(valueIDs)) 3051 return failure(); 3052 3053 // If there were no operands, then there is no colon or type lists. 3054 if (valueIDs.empty()) 3055 return success(); 3056 3057 SmallVector<Type, 4> types; 3058 if (parseToken(Token::colon, "expected ':' in operand list") || 3059 parseTypeListNoParens(types)) 3060 return failure(); 3061 3062 if (valueIDs.size() != types.size()) 3063 return emitError("expected ") 3064 << valueIDs.size() << " types to match operand list"; 3065 3066 results.reserve(valueIDs.size()); 3067 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 3068 if (auto *value = resolveSSAUse(valueIDs[i], types[i])) 3069 results.push_back(value); 3070 else 3071 return failure(); 3072 } 3073 3074 return success(); 3075 } 3076 3077 /// Record that a definition was added at the current scope. 3078 void OperationParser::recordDefinition(StringRef def) { 3079 isolatedNameScopes.back().recordDefinition(def); 3080 } 3081 3082 /// Get the value entry for the given SSA name. 3083 SmallVectorImpl<std::pair<Value *, SMLoc>> & 3084 OperationParser::getSSAValueEntry(StringRef name) { 3085 return isolatedNameScopes.back().values[name]; 3086 } 3087 3088 /// Create and remember a new placeholder for a forward reference. 3089 Value *OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 3090 // Forward references are always created as operations, because we just need 3091 // something with a def/use chain. 3092 // 3093 // We create these placeholders as having an empty name, which we know 3094 // cannot be created through normal user input, allowing us to distinguish 3095 // them. 3096 auto name = OperationName("placeholder", getContext()); 3097 auto *op = Operation::create( 3098 getEncodedSourceLocation(loc), name, type, /*operands=*/{}, 3099 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, 3100 /*resizableOperandList=*/false); 3101 forwardRefPlaceholders[op->getResult(0)] = loc; 3102 return op->getResult(0); 3103 } 3104 3105 //===----------------------------------------------------------------------===// 3106 // Operation Parsing 3107 //===----------------------------------------------------------------------===// 3108 3109 /// Parse an operation. 3110 /// 3111 /// operation ::= 3112 /// operation-result? string '(' ssa-use-list? ')' attribute-dict? 3113 /// `:` function-type trailing-location? 3114 /// operation-result ::= ssa-id ((`:` integer-literal) | (`,` ssa-id)*) `=` 3115 /// 3116 ParseResult OperationParser::parseOperation() { 3117 auto loc = getToken().getLoc(); 3118 SmallVector<std::pair<StringRef, SMLoc>, 1> resultIDs; 3119 size_t numExpectedResults; 3120 if (getToken().is(Token::percent_identifier)) { 3121 // Parse the first result id. 3122 resultIDs.emplace_back(getTokenSpelling(), loc); 3123 consumeToken(Token::percent_identifier); 3124 3125 // If the next token is a ':', we parse the expected result count. 3126 if (consumeIf(Token::colon)) { 3127 // Check that the next token is an integer. 3128 if (!getToken().is(Token::integer)) 3129 return emitError("expected integer number of results"); 3130 3131 // Check that number of results is > 0. 3132 auto val = getToken().getUInt64IntegerValue(); 3133 if (!val.hasValue() || val.getValue() < 1) 3134 return emitError("expected named operation to have atleast 1 result"); 3135 consumeToken(Token::integer); 3136 numExpectedResults = *val; 3137 } else { 3138 // Otherwise, this is a comma separated list of result ids. 3139 if (consumeIf(Token::comma)) { 3140 auto parseNextResult = [&]() -> ParseResult { 3141 // Parse the next result id. 3142 if (!getToken().is(Token::percent_identifier)) 3143 return emitError("expected valid ssa identifier"); 3144 3145 resultIDs.emplace_back(getTokenSpelling(), getToken().getLoc()); 3146 consumeToken(Token::percent_identifier); 3147 return success(); 3148 }; 3149 3150 if (parseCommaSeparatedList(parseNextResult)) 3151 return failure(); 3152 } 3153 numExpectedResults = resultIDs.size(); 3154 } 3155 3156 if (parseToken(Token::equal, "expected '=' after SSA name")) 3157 return failure(); 3158 } 3159 3160 Operation *op; 3161 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3162 op = parseCustomOperation(); 3163 else if (getToken().is(Token::string)) 3164 op = parseGenericOperation(); 3165 else 3166 return emitError("expected operation name in quotes"); 3167 3168 // If parsing of the basic operation failed, then this whole thing fails. 3169 if (!op) 3170 return failure(); 3171 3172 // If the operation had a name, register it. 3173 if (!resultIDs.empty()) { 3174 if (op->getNumResults() == 0) 3175 return emitError(loc, "cannot name an operation with no results"); 3176 if (numExpectedResults != op->getNumResults()) 3177 return emitError(loc, "operation defines ") 3178 << op->getNumResults() << " results but was provided " 3179 << numExpectedResults << " to bind"; 3180 3181 // If the number of result names matches the number of operation results, we 3182 // can directly use the provided names. 3183 if (resultIDs.size() == op->getNumResults()) { 3184 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) 3185 if (addDefinition({resultIDs[i].first, 0, resultIDs[i].second}, 3186 op->getResult(i))) 3187 return failure(); 3188 } else { 3189 // Otherwise, we use the same name for all results. 3190 StringRef name = resultIDs.front().first; 3191 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) 3192 if (addDefinition({name, i, loc}, op->getResult(i))) 3193 return failure(); 3194 } 3195 } 3196 3197 // Try to parse the optional trailing location. 3198 if (parseOptionalTrailingLocation(op)) 3199 return failure(); 3200 3201 return success(); 3202 } 3203 3204 /// Parse a single operation successor and its operand list. 3205 /// 3206 /// successor ::= block-id branch-use-list? 3207 /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` 3208 /// 3209 ParseResult 3210 OperationParser::parseSuccessorAndUseList(Block *&dest, 3211 SmallVectorImpl<Value *> &operands) { 3212 // Verify branch is identifier and get the matching block. 3213 if (!getToken().is(Token::caret_identifier)) 3214 return emitError("expected block name"); 3215 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3216 consumeToken(); 3217 3218 // Handle optional arguments. 3219 if (consumeIf(Token::l_paren) && 3220 (parseOptionalSSAUseAndTypeList(operands) || 3221 parseToken(Token::r_paren, "expected ')' to close argument list"))) { 3222 return failure(); 3223 } 3224 3225 return success(); 3226 } 3227 3228 /// Parse a comma-separated list of operation successors in brackets. 3229 /// 3230 /// successor-list ::= `[` successor (`,` successor )* `]` 3231 /// 3232 ParseResult OperationParser::parseSuccessors( 3233 SmallVectorImpl<Block *> &destinations, 3234 SmallVectorImpl<SmallVector<Value *, 4>> &operands) { 3235 if (parseToken(Token::l_square, "expected '['")) 3236 return failure(); 3237 3238 auto parseElt = [this, &destinations, &operands]() { 3239 Block *dest; 3240 SmallVector<Value *, 4> destOperands; 3241 auto res = parseSuccessorAndUseList(dest, destOperands); 3242 destinations.push_back(dest); 3243 operands.push_back(destOperands); 3244 return res; 3245 }; 3246 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3247 /*allowEmptyList=*/false); 3248 } 3249 3250 namespace { 3251 // RAII-style guard for cleaning up the regions in the operation state before 3252 // deleting them. Within the parser, regions may get deleted if parsing failed, 3253 // and other errors may be present, in particular undominated uses. This makes 3254 // sure such uses are deleted. 3255 struct CleanupOpStateRegions { 3256 ~CleanupOpStateRegions() { 3257 SmallVector<Region *, 4> regionsToClean; 3258 regionsToClean.reserve(state.regions.size()); 3259 for (auto ®ion : state.regions) 3260 if (region) 3261 for (auto &block : *region) 3262 block.dropAllDefinedValueUses(); 3263 } 3264 OperationState &state; 3265 }; 3266 } // namespace 3267 3268 Operation *OperationParser::parseGenericOperation() { 3269 // Get location information for the operation. 3270 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3271 3272 auto name = getToken().getStringValue(); 3273 if (name.empty()) 3274 return (emitError("empty operation name is invalid"), nullptr); 3275 if (name.find('\0') != StringRef::npos) 3276 return (emitError("null character not allowed in operation name"), nullptr); 3277 3278 consumeToken(Token::string); 3279 3280 OperationState result(srcLocation, name); 3281 3282 // Generic operations have a resizable operation list. 3283 result.setOperandListToResizable(); 3284 3285 // Parse the operand list. 3286 SmallVector<SSAUseInfo, 8> operandInfos; 3287 3288 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3289 parseOptionalSSAUseList(operandInfos) || 3290 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3291 return nullptr; 3292 } 3293 3294 // Parse the successor list but don't add successors to the result yet to 3295 // avoid messing up with the argument order. 3296 SmallVector<Block *, 2> successors; 3297 SmallVector<SmallVector<Value *, 4>, 2> successorOperands; 3298 if (getToken().is(Token::l_square)) { 3299 // Check if the operation is a known terminator. 3300 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3301 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3302 return emitError("successors in non-terminator"), nullptr; 3303 if (parseSuccessors(successors, successorOperands)) 3304 return nullptr; 3305 } 3306 3307 // Parse the region list. 3308 CleanupOpStateRegions guard{result}; 3309 if (consumeIf(Token::l_paren)) { 3310 do { 3311 // Create temporary regions with the top level region as parent. 3312 result.regions.emplace_back(new Region(moduleOp)); 3313 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3314 return nullptr; 3315 } while (consumeIf(Token::comma)); 3316 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3317 return nullptr; 3318 } 3319 3320 if (getToken().is(Token::l_brace)) { 3321 if (parseAttributeDict(result.attributes)) 3322 return nullptr; 3323 } 3324 3325 if (parseToken(Token::colon, "expected ':' followed by operation type")) 3326 return nullptr; 3327 3328 auto typeLoc = getToken().getLoc(); 3329 auto type = parseType(); 3330 if (!type) 3331 return nullptr; 3332 auto fnType = type.dyn_cast<FunctionType>(); 3333 if (!fnType) 3334 return (emitError(typeLoc, "expected function type"), nullptr); 3335 3336 result.addTypes(fnType.getResults()); 3337 3338 // Check that we have the right number of types for the operands. 3339 auto operandTypes = fnType.getInputs(); 3340 if (operandTypes.size() != operandInfos.size()) { 3341 auto plural = "s"[operandInfos.size() == 1]; 3342 return (emitError(typeLoc, "expected ") 3343 << operandInfos.size() << " operand type" << plural 3344 << " but had " << operandTypes.size(), 3345 nullptr); 3346 } 3347 3348 // Resolve all of the operands. 3349 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 3350 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 3351 if (!result.operands.back()) 3352 return nullptr; 3353 } 3354 3355 // Add the successors, and their operands after the proper operands. 3356 for (const auto &succ : llvm::zip(successors, successorOperands)) { 3357 Block *successor = std::get<0>(succ); 3358 const SmallVector<Value *, 4> &operands = std::get<1>(succ); 3359 result.addSuccessor(successor, operands); 3360 } 3361 3362 return opBuilder.createOperation(result); 3363 } 3364 3365 Operation *OperationParser::parseGenericOperation(Block *insertBlock, 3366 Block::iterator insertPt) { 3367 OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); 3368 opBuilder.setInsertionPoint(insertBlock, insertPt); 3369 return parseGenericOperation(); 3370 } 3371 3372 namespace { 3373 class CustomOpAsmParser : public OpAsmParser { 3374 public: 3375 CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, 3376 OperationParser &parser) 3377 : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} 3378 3379 /// Parse an instance of the operation described by 'opDefinition' into the 3380 /// provided operation state. 3381 ParseResult parseOperation(OperationState &opState) { 3382 if (opDefinition->parseAssembly(*this, opState)) 3383 return failure(); 3384 return success(); 3385 } 3386 3387 Operation *parseGenericOperation(Block *insertBlock, 3388 Block::iterator insertPt) final { 3389 return parser.parseGenericOperation(insertBlock, insertPt); 3390 } 3391 3392 //===--------------------------------------------------------------------===// 3393 // Utilities 3394 //===--------------------------------------------------------------------===// 3395 3396 /// Return if any errors were emitted during parsing. 3397 bool didEmitError() const { return emittedError; } 3398 3399 /// Emit a diagnostic at the specified location and return failure. 3400 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 3401 emittedError = true; 3402 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 3403 message); 3404 } 3405 3406 llvm::SMLoc getCurrentLocation() override { 3407 return parser.getToken().getLoc(); 3408 } 3409 3410 Builder &getBuilder() const override { return parser.builder; } 3411 3412 llvm::SMLoc getNameLoc() const override { return nameLoc; } 3413 3414 //===--------------------------------------------------------------------===// 3415 // Token Parsing 3416 //===--------------------------------------------------------------------===// 3417 3418 /// Parse a `->` token. 3419 ParseResult parseArrow() override { 3420 return parser.parseToken(Token::arrow, "expected '->'"); 3421 } 3422 3423 /// Parses a `->` if present. 3424 ParseResult parseOptionalArrow() override { 3425 return success(parser.consumeIf(Token::arrow)); 3426 } 3427 3428 /// Parse a `:` token. 3429 ParseResult parseColon() override { 3430 return parser.parseToken(Token::colon, "expected ':'"); 3431 } 3432 3433 /// Parse a `:` token if present. 3434 ParseResult parseOptionalColon() override { 3435 return success(parser.consumeIf(Token::colon)); 3436 } 3437 3438 /// Parse a `,` token. 3439 ParseResult parseComma() override { 3440 return parser.parseToken(Token::comma, "expected ','"); 3441 } 3442 3443 /// Parse a `,` token if present. 3444 ParseResult parseOptionalComma() override { 3445 return success(parser.consumeIf(Token::comma)); 3446 } 3447 3448 /// Parses a `...` if present. 3449 ParseResult parseOptionalEllipsis() override { 3450 return success(parser.consumeIf(Token::ellipsis)); 3451 } 3452 3453 /// Parse a `=` token. 3454 ParseResult parseEqual() override { 3455 return parser.parseToken(Token::equal, "expected '='"); 3456 } 3457 3458 /// Parse a `(` token. 3459 ParseResult parseLParen() override { 3460 return parser.parseToken(Token::l_paren, "expected '('"); 3461 } 3462 3463 /// Parses a '(' if present. 3464 ParseResult parseOptionalLParen() override { 3465 return success(parser.consumeIf(Token::l_paren)); 3466 } 3467 3468 /// Parse a `)` token. 3469 ParseResult parseRParen() override { 3470 return parser.parseToken(Token::r_paren, "expected ')'"); 3471 } 3472 3473 /// Parses a ')' if present. 3474 ParseResult parseOptionalRParen() override { 3475 return success(parser.consumeIf(Token::r_paren)); 3476 } 3477 3478 /// Parse a `[` token. 3479 ParseResult parseLSquare() override { 3480 return parser.parseToken(Token::l_square, "expected '['"); 3481 } 3482 3483 /// Parses a '[' if present. 3484 ParseResult parseOptionalLSquare() override { 3485 return success(parser.consumeIf(Token::l_square)); 3486 } 3487 3488 /// Parse a `]` token. 3489 ParseResult parseRSquare() override { 3490 return parser.parseToken(Token::r_square, "expected ']'"); 3491 } 3492 3493 /// Parses a ']' if present. 3494 ParseResult parseOptionalRSquare() override { 3495 return success(parser.consumeIf(Token::r_square)); 3496 } 3497 3498 //===--------------------------------------------------------------------===// 3499 // Attribute Parsing 3500 //===--------------------------------------------------------------------===// 3501 3502 /// Parse an arbitrary attribute of a given type and return it in result. This 3503 /// also adds the attribute to the specified attribute list with the specified 3504 /// name. 3505 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 3506 SmallVectorImpl<NamedAttribute> &attrs) override { 3507 result = parser.parseAttribute(type); 3508 if (!result) 3509 return failure(); 3510 3511 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 3512 return success(); 3513 } 3514 3515 /// Parse a named dictionary into 'result' if it is present. 3516 ParseResult 3517 parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) override { 3518 if (parser.getToken().isNot(Token::l_brace)) 3519 return success(); 3520 return parser.parseAttributeDict(result); 3521 } 3522 3523 //===--------------------------------------------------------------------===// 3524 // Identifier Parsing 3525 //===--------------------------------------------------------------------===// 3526 3527 /// Returns if the current token corresponds to a keyword. 3528 bool isCurrentTokenAKeyword() const { 3529 return parser.getToken().is(Token::bare_identifier) || 3530 parser.getToken().isKeyword(); 3531 } 3532 3533 /// Parse the given keyword if present. 3534 ParseResult parseOptionalKeyword(StringRef keyword) override { 3535 // Check that the current token has the same spelling. 3536 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 3537 return failure(); 3538 parser.consumeToken(); 3539 return success(); 3540 } 3541 3542 /// Parse a keyword, if present, into 'keyword'. 3543 ParseResult parseOptionalKeyword(StringRef *keyword) override { 3544 // Check that the current token is a keyword. 3545 if (!isCurrentTokenAKeyword()) 3546 return failure(); 3547 3548 *keyword = parser.getTokenSpelling(); 3549 parser.consumeToken(); 3550 return success(); 3551 } 3552 3553 /// Parse an @-identifier and store it (without the '@' symbol) in a string 3554 /// attribute named 'attrName'. 3555 ParseResult parseSymbolName(StringAttr &result, StringRef attrName, 3556 SmallVectorImpl<NamedAttribute> &attrs) override { 3557 Token atToken = parser.getToken(); 3558 if (atToken.isNot(Token::at_identifier)) 3559 return failure(); 3560 3561 result = getBuilder().getStringAttr(extractSymbolReference(atToken)); 3562 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 3563 parser.consumeToken(); 3564 return success(); 3565 } 3566 3567 //===--------------------------------------------------------------------===// 3568 // Operand Parsing 3569 //===--------------------------------------------------------------------===// 3570 3571 /// Parse a single operand. 3572 ParseResult parseOperand(OperandType &result) override { 3573 OperationParser::SSAUseInfo useInfo; 3574 if (parser.parseSSAUse(useInfo)) 3575 return failure(); 3576 3577 result = {useInfo.loc, useInfo.name, useInfo.number}; 3578 return success(); 3579 } 3580 3581 /// Parse zero or more SSA comma-separated operand references with a specified 3582 /// surrounding delimiter, and an optional required operand count. 3583 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 3584 int requiredOperandCount = -1, 3585 Delimiter delimiter = Delimiter::None) override { 3586 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 3587 requiredOperandCount, delimiter); 3588 } 3589 3590 /// Parse zero or more SSA comma-separated operand or region arguments with 3591 /// optional surrounding delimiter and required operand count. 3592 ParseResult 3593 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 3594 bool isOperandList, int requiredOperandCount = -1, 3595 Delimiter delimiter = Delimiter::None) { 3596 auto startLoc = parser.getToken().getLoc(); 3597 3598 // Handle delimiters. 3599 switch (delimiter) { 3600 case Delimiter::None: 3601 // Don't check for the absence of a delimiter if the number of operands 3602 // is unknown (and hence the operand list could be empty). 3603 if (requiredOperandCount == -1) 3604 break; 3605 // Token already matches an identifier and so can't be a delimiter. 3606 if (parser.getToken().is(Token::percent_identifier)) 3607 break; 3608 // Test against known delimiters. 3609 if (parser.getToken().is(Token::l_paren) || 3610 parser.getToken().is(Token::l_square)) 3611 return emitError(startLoc, "unexpected delimiter"); 3612 return emitError(startLoc, "invalid operand"); 3613 case Delimiter::OptionalParen: 3614 if (parser.getToken().isNot(Token::l_paren)) 3615 return success(); 3616 LLVM_FALLTHROUGH; 3617 case Delimiter::Paren: 3618 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 3619 return failure(); 3620 break; 3621 case Delimiter::OptionalSquare: 3622 if (parser.getToken().isNot(Token::l_square)) 3623 return success(); 3624 LLVM_FALLTHROUGH; 3625 case Delimiter::Square: 3626 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 3627 return failure(); 3628 break; 3629 } 3630 3631 // Check for zero operands. 3632 if (parser.getToken().is(Token::percent_identifier)) { 3633 do { 3634 OperandType operandOrArg; 3635 if (isOperandList ? parseOperand(operandOrArg) 3636 : parseRegionArgument(operandOrArg)) 3637 return failure(); 3638 result.push_back(operandOrArg); 3639 } while (parser.consumeIf(Token::comma)); 3640 } 3641 3642 // Handle delimiters. If we reach here, the optional delimiters were 3643 // present, so we need to parse their closing one. 3644 switch (delimiter) { 3645 case Delimiter::None: 3646 break; 3647 case Delimiter::OptionalParen: 3648 case Delimiter::Paren: 3649 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 3650 return failure(); 3651 break; 3652 case Delimiter::OptionalSquare: 3653 case Delimiter::Square: 3654 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 3655 return failure(); 3656 break; 3657 } 3658 3659 if (requiredOperandCount != -1 && 3660 result.size() != static_cast<size_t>(requiredOperandCount)) 3661 return emitError(startLoc, "expected ") 3662 << requiredOperandCount << " operands"; 3663 return success(); 3664 } 3665 3666 /// Parse zero or more trailing SSA comma-separated trailing operand 3667 /// references with a specified surrounding delimiter, and an optional 3668 /// required operand count. A leading comma is expected before the operands. 3669 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 3670 int requiredOperandCount, 3671 Delimiter delimiter) override { 3672 if (parser.getToken().is(Token::comma)) { 3673 parseComma(); 3674 return parseOperandList(result, requiredOperandCount, delimiter); 3675 } 3676 if (requiredOperandCount != -1) 3677 return emitError(parser.getToken().getLoc(), "expected ") 3678 << requiredOperandCount << " operands"; 3679 return success(); 3680 } 3681 3682 /// Resolve an operand to an SSA value, emitting an error on failure. 3683 ParseResult resolveOperand(const OperandType &operand, Type type, 3684 SmallVectorImpl<Value *> &result) override { 3685 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 3686 operand.location}; 3687 if (auto *value = parser.resolveSSAUse(operandInfo, type)) { 3688 result.push_back(value); 3689 return success(); 3690 } 3691 return failure(); 3692 } 3693 3694 /// Parse an AffineMap of SSA ids. 3695 ParseResult 3696 parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 3697 Attribute &mapAttr, StringRef attrName, 3698 SmallVectorImpl<NamedAttribute> &attrs) override { 3699 SmallVector<OperandType, 2> dimOperands; 3700 SmallVector<OperandType, 1> symOperands; 3701 3702 auto parseElement = [&](bool isSymbol) -> ParseResult { 3703 OperandType operand; 3704 if (parseOperand(operand)) 3705 return failure(); 3706 if (isSymbol) 3707 symOperands.push_back(operand); 3708 else 3709 dimOperands.push_back(operand); 3710 return success(); 3711 }; 3712 3713 AffineMap map; 3714 if (parser.parseAffineMapOfSSAIds(map, parseElement)) 3715 return failure(); 3716 // Add AffineMap attribute. 3717 if (map) { 3718 mapAttr = AffineMapAttr::get(map); 3719 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 3720 } 3721 3722 // Add dim operands before symbol operands in 'operands'. 3723 operands.assign(dimOperands.begin(), dimOperands.end()); 3724 operands.append(symOperands.begin(), symOperands.end()); 3725 return success(); 3726 } 3727 3728 //===--------------------------------------------------------------------===// 3729 // Region Parsing 3730 //===--------------------------------------------------------------------===// 3731 3732 /// Parse a region that takes `arguments` of `argTypes` types. This 3733 /// effectively defines the SSA values of `arguments` and assigns their type. 3734 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 3735 ArrayRef<Type> argTypes, 3736 bool enableNameShadowing) override { 3737 assert(arguments.size() == argTypes.size() && 3738 "mismatching number of arguments and types"); 3739 3740 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 3741 regionArguments; 3742 for (const auto &pair : llvm::zip(arguments, argTypes)) { 3743 const OperandType &operand = std::get<0>(pair); 3744 Type type = std::get<1>(pair); 3745 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 3746 operand.location}; 3747 regionArguments.emplace_back(operandInfo, type); 3748 } 3749 3750 // Try to parse the region. 3751 assert((!enableNameShadowing || 3752 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 3753 "name shadowing is only allowed on isolated regions"); 3754 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 3755 return failure(); 3756 return success(); 3757 } 3758 3759 /// Parses a region if present. 3760 ParseResult parseOptionalRegion(Region ®ion, 3761 ArrayRef<OperandType> arguments, 3762 ArrayRef<Type> argTypes, 3763 bool enableNameShadowing) override { 3764 if (parser.getToken().isNot(Token::l_brace)) 3765 return success(); 3766 return parseRegion(region, arguments, argTypes, enableNameShadowing); 3767 } 3768 3769 /// Parse a region argument. The type of the argument will be resolved later 3770 /// by a call to `parseRegion`. 3771 ParseResult parseRegionArgument(OperandType &argument) override { 3772 return parseOperand(argument); 3773 } 3774 3775 /// Parse a region argument if present. 3776 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 3777 if (parser.getToken().isNot(Token::percent_identifier)) 3778 return success(); 3779 return parseRegionArgument(argument); 3780 } 3781 3782 ParseResult 3783 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 3784 int requiredOperandCount = -1, 3785 Delimiter delimiter = Delimiter::None) override { 3786 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 3787 requiredOperandCount, delimiter); 3788 } 3789 3790 //===--------------------------------------------------------------------===// 3791 // Successor Parsing 3792 //===--------------------------------------------------------------------===// 3793 3794 /// Parse a single operation successor and its operand list. 3795 ParseResult 3796 parseSuccessorAndUseList(Block *&dest, 3797 SmallVectorImpl<Value *> &operands) override { 3798 return parser.parseSuccessorAndUseList(dest, operands); 3799 } 3800 3801 //===--------------------------------------------------------------------===// 3802 // Type Parsing 3803 //===--------------------------------------------------------------------===// 3804 3805 /// Parse a type. 3806 ParseResult parseType(Type &result) override { 3807 return failure(!(result = parser.parseType())); 3808 } 3809 3810 /// Parse an optional arrow followed by a type list. 3811 ParseResult 3812 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 3813 if (!parser.consumeIf(Token::arrow)) 3814 return success(); 3815 return parser.parseFunctionResultTypes(result); 3816 } 3817 3818 /// Parse a colon followed by a type. 3819 ParseResult parseColonType(Type &result) override { 3820 return failure(parser.parseToken(Token::colon, "expected ':'") || 3821 !(result = parser.parseType())); 3822 } 3823 3824 /// Parse a colon followed by a type list, which must have at least one type. 3825 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 3826 if (parser.parseToken(Token::colon, "expected ':'")) 3827 return failure(); 3828 return parser.parseTypeListNoParens(result); 3829 } 3830 3831 /// Parse an optional colon followed by a type list, which if present must 3832 /// have at least one type. 3833 ParseResult 3834 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 3835 if (!parser.consumeIf(Token::colon)) 3836 return success(); 3837 return parser.parseTypeListNoParens(result); 3838 } 3839 3840 private: 3841 /// The source location of the operation name. 3842 SMLoc nameLoc; 3843 3844 /// The abstract information of the operation. 3845 const AbstractOperation *opDefinition; 3846 3847 /// The main operation parser. 3848 OperationParser &parser; 3849 3850 /// A flag that indicates if any errors were emitted during parsing. 3851 bool emittedError = false; 3852 }; 3853 } // end anonymous namespace. 3854 3855 Operation *OperationParser::parseCustomOperation() { 3856 auto opLoc = getToken().getLoc(); 3857 auto opName = getTokenSpelling(); 3858 3859 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 3860 if (!opDefinition && !opName.contains('.')) { 3861 // If the operation name has no namespace prefix we treat it as a standard 3862 // operation and prefix it with "std". 3863 // TODO: Would it be better to just build a mapping of the registered 3864 // operations in the standard dialect? 3865 opDefinition = 3866 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 3867 } 3868 3869 if (!opDefinition) { 3870 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 3871 return nullptr; 3872 } 3873 3874 consumeToken(); 3875 3876 // If the custom op parser crashes, produce some indication to help 3877 // debugging. 3878 std::string opNameStr = opName.str(); 3879 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 3880 opNameStr.c_str()); 3881 3882 // Get location information for the operation. 3883 auto srcLocation = getEncodedSourceLocation(opLoc); 3884 3885 // Have the op implementation take a crack and parsing this. 3886 OperationState opState(srcLocation, opDefinition->name); 3887 CleanupOpStateRegions guard{opState}; 3888 CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); 3889 if (opAsmParser.parseOperation(opState)) 3890 return nullptr; 3891 3892 // If it emitted an error, we failed. 3893 if (opAsmParser.didEmitError()) 3894 return nullptr; 3895 3896 // Otherwise, we succeeded. Use the state it parsed as our op information. 3897 return opBuilder.createOperation(opState); 3898 } 3899 3900 //===----------------------------------------------------------------------===// 3901 // Region Parsing 3902 //===----------------------------------------------------------------------===// 3903 3904 /// Region. 3905 /// 3906 /// region ::= '{' region-body 3907 /// 3908 ParseResult OperationParser::parseRegion( 3909 Region ®ion, 3910 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 3911 bool isIsolatedNameScope) { 3912 // Parse the '{'. 3913 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 3914 return failure(); 3915 3916 // Check for an empty region. 3917 if (entryArguments.empty() && consumeIf(Token::r_brace)) 3918 return success(); 3919 auto currentPt = opBuilder.saveInsertionPoint(); 3920 3921 // Push a new named value scope. 3922 pushSSANameScope(isIsolatedNameScope); 3923 3924 // Parse the first block directly to allow for it to be unnamed. 3925 Block *block = new Block(); 3926 3927 // Add arguments to the entry block. 3928 if (!entryArguments.empty()) { 3929 for (auto &placeholderArgPair : entryArguments) { 3930 auto &argInfo = placeholderArgPair.first; 3931 // Ensure that the argument was not already defined. 3932 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 3933 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 3934 "' is already in use") 3935 .attachNote(getEncodedSourceLocation(*defLoc)) 3936 << "previously referenced here"; 3937 } 3938 if (addDefinition(placeholderArgPair.first, 3939 block->addArgument(placeholderArgPair.second))) { 3940 delete block; 3941 return failure(); 3942 } 3943 } 3944 3945 // If we had named arguments, then don't allow a block name. 3946 if (getToken().is(Token::caret_identifier)) 3947 return emitError("invalid block name in region with named arguments"); 3948 } 3949 3950 if (parseBlock(block)) { 3951 delete block; 3952 return failure(); 3953 } 3954 3955 // Verify that no other arguments were parsed. 3956 if (!entryArguments.empty() && 3957 block->getNumArguments() > entryArguments.size()) { 3958 delete block; 3959 return emitError("entry block arguments were already defined"); 3960 } 3961 3962 // Parse the rest of the region. 3963 region.push_back(block); 3964 if (parseRegionBody(region)) 3965 return failure(); 3966 3967 // Pop the SSA value scope for this region. 3968 if (popSSANameScope()) 3969 return failure(); 3970 3971 // Reset the original insertion point. 3972 opBuilder.restoreInsertionPoint(currentPt); 3973 return success(); 3974 } 3975 3976 /// Region. 3977 /// 3978 /// region-body ::= block* '}' 3979 /// 3980 ParseResult OperationParser::parseRegionBody(Region ®ion) { 3981 // Parse the list of blocks. 3982 while (!consumeIf(Token::r_brace)) { 3983 Block *newBlock = nullptr; 3984 if (parseBlock(newBlock)) 3985 return failure(); 3986 region.push_back(newBlock); 3987 } 3988 return success(); 3989 } 3990 3991 //===----------------------------------------------------------------------===// 3992 // Block Parsing 3993 //===----------------------------------------------------------------------===// 3994 3995 /// Block declaration. 3996 /// 3997 /// block ::= block-label? operation* 3998 /// block-label ::= block-id block-arg-list? `:` 3999 /// block-id ::= caret-id 4000 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 4001 /// 4002 ParseResult OperationParser::parseBlock(Block *&block) { 4003 // The first block of a region may already exist, if it does the caret 4004 // identifier is optional. 4005 if (block && getToken().isNot(Token::caret_identifier)) 4006 return parseBlockBody(block); 4007 4008 SMLoc nameLoc = getToken().getLoc(); 4009 auto name = getTokenSpelling(); 4010 if (parseToken(Token::caret_identifier, "expected block name")) 4011 return failure(); 4012 4013 block = defineBlockNamed(name, nameLoc, block); 4014 4015 // Fail if the block was already defined. 4016 if (!block) 4017 return emitError(nameLoc, "redefinition of block '") << name << "'"; 4018 4019 // If an argument list is present, parse it. 4020 if (consumeIf(Token::l_paren)) { 4021 SmallVector<BlockArgument *, 8> bbArgs; 4022 if (parseOptionalBlockArgList(bbArgs, block) || 4023 parseToken(Token::r_paren, "expected ')' to end argument list")) 4024 return failure(); 4025 } 4026 4027 if (parseToken(Token::colon, "expected ':' after block name")) 4028 return failure(); 4029 4030 return parseBlockBody(block); 4031 } 4032 4033 ParseResult OperationParser::parseBlockBody(Block *block) { 4034 // Set the insertion point to the end of the block to parse. 4035 opBuilder.setInsertionPointToEnd(block); 4036 4037 // Parse the list of operations that make up the body of the block. 4038 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 4039 if (parseOperation()) 4040 return failure(); 4041 4042 return success(); 4043 } 4044 4045 /// Get the block with the specified name, creating it if it doesn't already 4046 /// exist. The location specified is the point of use, which allows 4047 /// us to diagnose references to blocks that are not defined precisely. 4048 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 4049 auto &blockAndLoc = getBlockInfoByName(name); 4050 if (!blockAndLoc.first) { 4051 blockAndLoc = {new Block(), loc}; 4052 insertForwardRef(blockAndLoc.first, loc); 4053 } 4054 4055 return blockAndLoc.first; 4056 } 4057 4058 /// Define the block with the specified name. Returns the Block* or nullptr in 4059 /// the case of redefinition. 4060 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 4061 Block *existing) { 4062 auto &blockAndLoc = getBlockInfoByName(name); 4063 if (!blockAndLoc.first) { 4064 // If the caller provided a block, use it. Otherwise create a new one. 4065 if (!existing) 4066 existing = new Block(); 4067 blockAndLoc.first = existing; 4068 blockAndLoc.second = loc; 4069 return blockAndLoc.first; 4070 } 4071 4072 // Forward declarations are removed once defined, so if we are defining a 4073 // existing block and it is not a forward declaration, then it is a 4074 // redeclaration. 4075 if (!eraseForwardRef(blockAndLoc.first)) 4076 return nullptr; 4077 return blockAndLoc.first; 4078 } 4079 4080 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 4081 /// 4082 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 4083 /// 4084 ParseResult OperationParser::parseOptionalBlockArgList( 4085 SmallVectorImpl<BlockArgument *> &results, Block *owner) { 4086 if (getToken().is(Token::r_brace)) 4087 return success(); 4088 4089 // If the block already has arguments, then we're handling the entry block. 4090 // Parse and register the names for the arguments, but do not add them. 4091 bool definingExistingArgs = owner->getNumArguments() != 0; 4092 unsigned nextArgument = 0; 4093 4094 return parseCommaSeparatedList([&]() -> ParseResult { 4095 return parseSSADefOrUseAndType( 4096 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 4097 // If this block did not have existing arguments, define a new one. 4098 if (!definingExistingArgs) 4099 return addDefinition(useInfo, owner->addArgument(type)); 4100 4101 // Otherwise, ensure that this argument has already been created. 4102 if (nextArgument >= owner->getNumArguments()) 4103 return emitError("too many arguments specified in argument list"); 4104 4105 // Finally, make sure the existing argument has the correct type. 4106 auto *arg = owner->getArgument(nextArgument++); 4107 if (arg->getType() != type) 4108 return emitError("argument and block argument type mismatch"); 4109 return addDefinition(useInfo, arg); 4110 }); 4111 }); 4112 } 4113 4114 //===----------------------------------------------------------------------===// 4115 // Top-level entity parsing. 4116 //===----------------------------------------------------------------------===// 4117 4118 namespace { 4119 /// This parser handles entities that are only valid at the top level of the 4120 /// file. 4121 class ModuleParser : public Parser { 4122 public: 4123 explicit ModuleParser(ParserState &state) : Parser(state) {} 4124 4125 ParseResult parseModule(ModuleOp module); 4126 4127 private: 4128 /// Parse an attribute alias declaration. 4129 ParseResult parseAttributeAliasDef(); 4130 4131 /// Parse an attribute alias declaration. 4132 ParseResult parseTypeAliasDef(); 4133 }; 4134 } // end anonymous namespace 4135 4136 /// Parses an attribute alias declaration. 4137 /// 4138 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 4139 /// 4140 ParseResult ModuleParser::parseAttributeAliasDef() { 4141 assert(getToken().is(Token::hash_identifier)); 4142 StringRef aliasName = getTokenSpelling().drop_front(); 4143 4144 // Check for redefinitions. 4145 if (getState().attributeAliasDefinitions.count(aliasName) > 0) 4146 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 4147 4148 // Make sure this isn't invading the dialect attribute namespace. 4149 if (aliasName.contains('.')) 4150 return emitError("attribute names with a '.' are reserved for " 4151 "dialect-defined names"); 4152 4153 consumeToken(Token::hash_identifier); 4154 4155 // Parse the '='. 4156 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 4157 return failure(); 4158 4159 // Parse the attribute value. 4160 Attribute attr = parseAttribute(); 4161 if (!attr) 4162 return failure(); 4163 4164 getState().attributeAliasDefinitions[aliasName] = attr; 4165 return success(); 4166 } 4167 4168 /// Parse a type alias declaration. 4169 /// 4170 /// type-alias-def ::= '!' alias-name `=` 'type' type 4171 /// 4172 ParseResult ModuleParser::parseTypeAliasDef() { 4173 assert(getToken().is(Token::exclamation_identifier)); 4174 StringRef aliasName = getTokenSpelling().drop_front(); 4175 4176 // Check for redefinitions. 4177 if (getState().typeAliasDefinitions.count(aliasName) > 0) 4178 return emitError("redefinition of type alias id '" + aliasName + "'"); 4179 4180 // Make sure this isn't invading the dialect type namespace. 4181 if (aliasName.contains('.')) 4182 return emitError("type names with a '.' are reserved for " 4183 "dialect-defined names"); 4184 4185 consumeToken(Token::exclamation_identifier); 4186 4187 // Parse the '=' and 'type'. 4188 if (parseToken(Token::equal, "expected '=' in type alias definition") || 4189 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 4190 return failure(); 4191 4192 // Parse the type. 4193 Type aliasedType = parseType(); 4194 if (!aliasedType) 4195 return failure(); 4196 4197 // Register this alias with the parser state. 4198 getState().typeAliasDefinitions.try_emplace(aliasName, aliasedType); 4199 return success(); 4200 } 4201 4202 /// This is the top-level module parser. 4203 ParseResult ModuleParser::parseModule(ModuleOp module) { 4204 OperationParser opParser(getState(), module); 4205 4206 // Module itself is a name scope. 4207 opParser.pushSSANameScope(/*isIsolated=*/true); 4208 4209 while (true) { 4210 switch (getToken().getKind()) { 4211 default: 4212 // Parse a top-level operation. 4213 if (opParser.parseOperation()) 4214 return failure(); 4215 break; 4216 4217 // If we got to the end of the file, then we're done. 4218 case Token::eof: { 4219 if (opParser.finalize()) 4220 return failure(); 4221 4222 // Handle the case where the top level module was explicitly defined. 4223 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 4224 auto &operations = bodyBlocks.front().getOperations(); 4225 assert(!operations.empty() && "expected a valid module terminator"); 4226 4227 // Check that the first operation is a module, and it is the only 4228 // non-terminator operation. 4229 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 4230 if (nested && std::next(operations.begin(), 2) == operations.end()) { 4231 // Merge the data of the nested module operation into 'module'. 4232 module.setLoc(nested.getLoc()); 4233 module.setAttrs(nested.getOperation()->getAttrList()); 4234 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 4235 4236 // Erase the original module body. 4237 bodyBlocks.pop_front(); 4238 } 4239 4240 return opParser.popSSANameScope(); 4241 } 4242 4243 // If we got an error token, then the lexer already emitted an error, just 4244 // stop. Someday we could introduce error recovery if there was demand 4245 // for it. 4246 case Token::error: 4247 return failure(); 4248 4249 // Parse an attribute alias. 4250 case Token::hash_identifier: 4251 if (parseAttributeAliasDef()) 4252 return failure(); 4253 break; 4254 4255 // Parse a type alias. 4256 case Token::exclamation_identifier: 4257 if (parseTypeAliasDef()) 4258 return failure(); 4259 break; 4260 } 4261 } 4262 } 4263 4264 //===----------------------------------------------------------------------===// 4265 4266 /// This parses the file specified by the indicated SourceMgr and returns an 4267 /// MLIR module if it was valid. If not, it emits diagnostics and returns 4268 /// null. 4269 OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 4270 MLIRContext *context) { 4271 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 4272 4273 // This is the result module we are parsing into. 4274 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 4275 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 4276 4277 ParserState state(sourceMgr, context); 4278 if (ModuleParser(state).parseModule(*module)) 4279 return nullptr; 4280 4281 // Make sure the parse module has no other structural problems detected by 4282 // the verifier. 4283 if (failed(verify(*module))) 4284 return nullptr; 4285 4286 return module; 4287 } 4288 4289 /// This parses the file specified by the indicated filename and returns an 4290 /// MLIR module if it was valid. If not, the error message is emitted through 4291 /// the error handler registered in the context, and a null pointer is returned. 4292 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4293 MLIRContext *context) { 4294 llvm::SourceMgr sourceMgr; 4295 return parseSourceFile(filename, sourceMgr, context); 4296 } 4297 4298 /// This parses the file specified by the indicated filename using the provided 4299 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 4300 /// message is emitted through the error handler registered in the context, and 4301 /// a null pointer is returned. 4302 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4303 llvm::SourceMgr &sourceMgr, 4304 MLIRContext *context) { 4305 if (sourceMgr.getNumBuffers() != 0) { 4306 // TODO(b/136086478): Extend to support multiple buffers. 4307 emitError(mlir::UnknownLoc::get(context), 4308 "only main buffer parsed at the moment"); 4309 return nullptr; 4310 } 4311 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 4312 if (std::error_code error = file_or_err.getError()) { 4313 emitError(mlir::UnknownLoc::get(context), 4314 "could not open input file " + filename); 4315 return nullptr; 4316 } 4317 4318 // Load the MLIR module. 4319 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 4320 return parseSourceFile(sourceMgr, context); 4321 } 4322 4323 /// This parses the program string to a MLIR module if it was valid. If not, 4324 /// it emits diagnostics and returns null. 4325 OwningModuleRef mlir::parseSourceString(StringRef moduleStr, 4326 MLIRContext *context) { 4327 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 4328 if (!memBuffer) 4329 return nullptr; 4330 4331 SourceMgr sourceMgr; 4332 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 4333 return parseSourceFile(sourceMgr, context); 4334 } 4335 4336 Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context, 4337 size_t &numRead) { 4338 SourceMgr sourceMgr; 4339 auto memBuffer = 4340 MemoryBuffer::getMemBuffer(typeStr, /*BufferName=*/"<mlir_type_buffer>", 4341 /*RequiresNullTerminator=*/false); 4342 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 4343 SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context); 4344 ParserState state(sourceMgr, context); 4345 Parser parser(state); 4346 4347 auto start = parser.getToken().getLoc(); 4348 auto ty = parser.parseType(); 4349 if (!ty) 4350 return Type(); 4351 4352 auto end = parser.getToken().getLoc(); 4353 numRead = static_cast<size_t>(end.getPointer() - start.getPointer()); 4354 return ty; 4355 } 4356 4357 Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context) { 4358 size_t numRead = 0; 4359 return parseType(typeStr, context, numRead); 4360 } 4361