1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the parser for the MLIR textual form. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Parser.h" 23 #include "Lexer.h" 24 #include "mlir/Analysis/Verifier.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/IR/Attributes.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/IR/Dialect.h" 30 #include "mlir/IR/IntegerSet.h" 31 #include "mlir/IR/Location.h" 32 #include "mlir/IR/MLIRContext.h" 33 #include "mlir/IR/Module.h" 34 #include "mlir/IR/OpImplementation.h" 35 #include "mlir/IR/StandardTypes.h" 36 #include "mlir/Support/STLExtras.h" 37 #include "llvm/ADT/APInt.h" 38 #include "llvm/ADT/DenseMap.h" 39 #include "llvm/ADT/StringSet.h" 40 #include "llvm/ADT/bit.h" 41 #include "llvm/Support/MemoryBuffer.h" 42 #include "llvm/Support/PrettyStackTrace.h" 43 #include "llvm/Support/SMLoc.h" 44 #include "llvm/Support/SourceMgr.h" 45 #include <algorithm> 46 using namespace mlir; 47 using llvm::MemoryBuffer; 48 using llvm::SMLoc; 49 using llvm::SourceMgr; 50 51 namespace { 52 class Parser; 53 54 //===----------------------------------------------------------------------===// 55 // ParserState 56 //===----------------------------------------------------------------------===// 57 58 /// This class refers to all of the state maintained globally by the parser, 59 /// such as the current lexer position etc. The Parser base class provides 60 /// methods to access this. 61 class ParserState { 62 public: 63 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx) 64 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()) {} 65 66 // A map from attribute alias identifier to Attribute. 67 llvm::StringMap<Attribute> attributeAliasDefinitions; 68 69 // A map from type alias identifier to Type. 70 llvm::StringMap<Type> typeAliasDefinitions; 71 72 private: 73 ParserState(const ParserState &) = delete; 74 void operator=(const ParserState &) = delete; 75 76 friend class Parser; 77 78 // The context we're parsing into. 79 MLIRContext *const context; 80 81 // The lexer for the source file we're parsing. 82 Lexer lex; 83 84 // This is the next token that hasn't been consumed yet. 85 Token curToken; 86 }; 87 88 //===----------------------------------------------------------------------===// 89 // Parser 90 //===----------------------------------------------------------------------===// 91 92 /// This class implement support for parsing global entities like types and 93 /// shared entities like SSA names. It is intended to be subclassed by 94 /// specialized subparsers that include state, e.g. when a local symbol table. 95 class Parser { 96 public: 97 Builder builder; 98 99 Parser(ParserState &state) : builder(state.context), state(state) {} 100 101 // Helper methods to get stuff from the parser-global state. 102 ParserState &getState() const { return state; } 103 MLIRContext *getContext() const { return state.context; } 104 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 105 106 /// Parse a comma-separated list of elements up until the specified end token. 107 ParseResult 108 parseCommaSeparatedListUntil(Token::Kind rightToken, 109 const std::function<ParseResult()> &parseElement, 110 bool allowEmptyList = true); 111 112 /// Parse a comma separated list of elements that must have at least one entry 113 /// in it. 114 ParseResult 115 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 116 117 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 118 119 // We have two forms of parsing methods - those that return a non-null 120 // pointer on success, and those that return a ParseResult to indicate whether 121 // they returned a failure. The second class fills in by-reference arguments 122 // as the results of their action. 123 124 //===--------------------------------------------------------------------===// 125 // Error Handling 126 //===--------------------------------------------------------------------===// 127 128 /// Emit an error and return failure. 129 InFlightDiagnostic emitError(const Twine &message = {}) { 130 return emitError(state.curToken.getLoc(), message); 131 } 132 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 133 134 /// Encode the specified source location information into an attribute for 135 /// attachment to the IR. 136 Location getEncodedSourceLocation(llvm::SMLoc loc) { 137 return state.lex.getEncodedSourceLocation(loc); 138 } 139 140 //===--------------------------------------------------------------------===// 141 // Token Parsing 142 //===--------------------------------------------------------------------===// 143 144 /// Return the current token the parser is inspecting. 145 const Token &getToken() const { return state.curToken; } 146 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 147 148 /// If the current token has the specified kind, consume it and return true. 149 /// If not, return false. 150 bool consumeIf(Token::Kind kind) { 151 if (state.curToken.isNot(kind)) 152 return false; 153 consumeToken(kind); 154 return true; 155 } 156 157 /// Advance the current lexer onto the next token. 158 void consumeToken() { 159 assert(state.curToken.isNot(Token::eof, Token::error) && 160 "shouldn't advance past EOF or errors"); 161 state.curToken = state.lex.lexToken(); 162 } 163 164 /// Advance the current lexer onto the next token, asserting what the expected 165 /// current token is. This is preferred to the above method because it leads 166 /// to more self-documenting code with better checking. 167 void consumeToken(Token::Kind kind) { 168 assert(state.curToken.is(kind) && "consumed an unexpected token"); 169 consumeToken(); 170 } 171 172 /// Consume the specified token if present and return success. On failure, 173 /// output a diagnostic and return failure. 174 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 175 176 //===--------------------------------------------------------------------===// 177 // Type Parsing 178 //===--------------------------------------------------------------------===// 179 180 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 181 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 182 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 183 184 /// Parse an arbitrary type. 185 Type parseType(); 186 187 /// Parse a complex type. 188 Type parseComplexType(); 189 190 /// Parse an extended type. 191 Type parseExtendedType(); 192 193 /// Parse a function type. 194 Type parseFunctionType(); 195 196 /// Parse a memref type. 197 Type parseMemRefType(); 198 199 /// Parse a non function type. 200 Type parseNonFunctionType(); 201 202 /// Parse a tensor type. 203 Type parseTensorType(); 204 205 /// Parse a tuple type. 206 Type parseTupleType(); 207 208 /// Parse a vector type. 209 VectorType parseVectorType(); 210 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 211 bool allowDynamic = true); 212 ParseResult parseXInDimensionList(); 213 214 //===--------------------------------------------------------------------===// 215 // Attribute Parsing 216 //===--------------------------------------------------------------------===// 217 218 /// Parse an arbitrary attribute with an optional type. 219 Attribute parseAttribute(Type type = {}); 220 221 /// Parse an attribute dictionary. 222 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 223 224 /// Parse an extended attribute. 225 Attribute parseExtendedAttr(Type type); 226 227 /// Parse a float attribute. 228 Attribute parseFloatAttr(Type type, bool isNegative); 229 230 /// Parse a decimal or a hexadecimal literal, which can be either an integer 231 /// or a float attribute. 232 Attribute parseDecOrHexAttr(Type type, bool isNegative); 233 234 /// Parse an opaque elements attribute. 235 Attribute parseOpaqueElementsAttr(); 236 237 /// Parse a dense elements attribute. 238 Attribute parseDenseElementsAttr(); 239 ShapedType parseElementsLiteralType(); 240 241 /// Parse a sparse elements attribute. 242 Attribute parseSparseElementsAttr(); 243 244 //===--------------------------------------------------------------------===// 245 // Location Parsing 246 //===--------------------------------------------------------------------===// 247 248 /// Parse an inline location. 249 ParseResult parseLocation(LocationAttr &loc); 250 251 /// Parse a raw location instance. 252 ParseResult parseLocationInstance(LocationAttr &loc); 253 254 /// Parse a callsite location instance. 255 ParseResult parseCallSiteLocation(LocationAttr &loc); 256 257 /// Parse a fused location instance. 258 ParseResult parseFusedLocation(LocationAttr &loc); 259 260 /// Parse a name or FileLineCol location instance. 261 ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); 262 263 /// Parse an optional trailing location. 264 /// 265 /// trailing-location ::= location? 266 /// 267 template <typename Owner> 268 ParseResult parseOptionalTrailingLocation(Owner *owner) { 269 // If there is a 'loc' we parse a trailing location. 270 if (!getToken().is(Token::kw_loc)) 271 return success(); 272 273 // Parse the location. 274 LocationAttr directLoc; 275 if (parseLocation(directLoc)) 276 return failure(); 277 owner->setLoc(directLoc); 278 return success(); 279 } 280 281 //===--------------------------------------------------------------------===// 282 // Affine Parsing 283 //===--------------------------------------------------------------------===// 284 285 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 286 IntegerSet &set); 287 288 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 289 ParseResult 290 parseAffineMapOfSSAIds(AffineMap &map, 291 llvm::function_ref<ParseResult(bool)> parseElement); 292 293 private: 294 /// The Parser is subclassed and reinstantiated. Do not add additional 295 /// non-trivial state here, add it to the ParserState class. 296 ParserState &state; 297 }; 298 } // end anonymous namespace 299 300 //===----------------------------------------------------------------------===// 301 // Helper methods. 302 //===----------------------------------------------------------------------===// 303 304 /// Parse a comma separated list of elements that must have at least one entry 305 /// in it. 306 ParseResult Parser::parseCommaSeparatedList( 307 const std::function<ParseResult()> &parseElement) { 308 // Non-empty case starts with an element. 309 if (parseElement()) 310 return failure(); 311 312 // Otherwise we have a list of comma separated elements. 313 while (consumeIf(Token::comma)) { 314 if (parseElement()) 315 return failure(); 316 } 317 return success(); 318 } 319 320 /// Parse a comma-separated list of elements, terminated with an arbitrary 321 /// token. This allows empty lists if allowEmptyList is true. 322 /// 323 /// abstract-list ::= rightToken // if allowEmptyList == true 324 /// abstract-list ::= element (',' element)* rightToken 325 /// 326 ParseResult Parser::parseCommaSeparatedListUntil( 327 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 328 bool allowEmptyList) { 329 // Handle the empty case. 330 if (getToken().is(rightToken)) { 331 if (!allowEmptyList) 332 return emitError("expected list element"); 333 consumeToken(rightToken); 334 return success(); 335 } 336 337 if (parseCommaSeparatedList(parseElement) || 338 parseToken(rightToken, "expected ',' or '" + 339 Token::getTokenSpelling(rightToken) + "'")) 340 return failure(); 341 342 return success(); 343 } 344 345 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 346 /// and may be recursive. Return with the 'prettyName' StringRef encompasing 347 /// the entire pretty name. 348 /// 349 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 350 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 351 /// | '(' pretty-dialect-sym-contents+ ')' 352 /// | '[' pretty-dialect-sym-contents+ ']' 353 /// | '{' pretty-dialect-sym-contents+ '}' 354 /// | '[^[<({>\])}\0]+' 355 /// 356 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 357 // Pretty symbol names are a relatively unstructured format that contains a 358 // series of properly nested punctuation, with anything else in the middle. 359 // Scan ahead to find it and consume it if successful, otherwise emit an 360 // error. 361 auto *curPtr = getTokenSpelling().data(); 362 363 SmallVector<char, 8> nestedPunctuation; 364 365 // Scan over the nested punctuation, bailing out on error and consuming until 366 // we find the end. We know that we're currently looking at the '<', so we 367 // can go until we find the matching '>' character. 368 assert(*curPtr == '<'); 369 do { 370 char c = *curPtr++; 371 switch (c) { 372 case '\0': 373 // This also handles the EOF case. 374 return emitError("unexpected nul or EOF in pretty dialect name"); 375 case '<': 376 case '[': 377 case '(': 378 case '{': 379 nestedPunctuation.push_back(c); 380 continue; 381 382 case '-': 383 // The sequence `->` is treated as special token. 384 if (*curPtr == '>') 385 ++curPtr; 386 continue; 387 388 case '>': 389 if (nestedPunctuation.pop_back_val() != '<') 390 return emitError("unbalanced '>' character in pretty dialect name"); 391 break; 392 case ']': 393 if (nestedPunctuation.pop_back_val() != '[') 394 return emitError("unbalanced ']' character in pretty dialect name"); 395 break; 396 case ')': 397 if (nestedPunctuation.pop_back_val() != '(') 398 return emitError("unbalanced ')' character in pretty dialect name"); 399 break; 400 case '}': 401 if (nestedPunctuation.pop_back_val() != '{') 402 return emitError("unbalanced '}' character in pretty dialect name"); 403 break; 404 405 default: 406 continue; 407 } 408 } while (!nestedPunctuation.empty()); 409 410 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 411 // consuming all this stuff, and return. 412 state.lex.resetPointer(curPtr); 413 414 unsigned length = curPtr - prettyName.begin(); 415 prettyName = StringRef(prettyName.begin(), length); 416 consumeToken(); 417 return success(); 418 } 419 420 /// Parse an extended dialect symbol. 421 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 422 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 423 SymbolAliasMap &aliases, 424 CreateFn &&createSymbol) { 425 // Parse the dialect namespace. 426 StringRef identifier = p.getTokenSpelling().drop_front(); 427 auto loc = p.getToken().getLoc(); 428 p.consumeToken(identifierTok); 429 430 // If there is no '<' token following this, and if the typename contains no 431 // dot, then we are parsing a symbol alias. 432 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 433 // Check for an alias for this type. 434 auto aliasIt = aliases.find(identifier); 435 if (aliasIt == aliases.end()) 436 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 437 nullptr); 438 return aliasIt->second; 439 } 440 441 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 442 // a dot, then this is the "pretty" form. If not, it is the verbose form that 443 // looks like <"...">. 444 std::string symbolData; 445 auto dialectName = identifier; 446 447 // Handle the verbose form, where "identifier" is a simple dialect name. 448 if (!identifier.contains('.')) { 449 // Consume the '<'. 450 if (p.parseToken(Token::less, "expected '<' in dialect type")) 451 return nullptr; 452 453 // Parse the symbol specific data. 454 if (p.getToken().isNot(Token::string)) 455 return (p.emitError("expected string literal data in dialect symbol"), 456 nullptr); 457 symbolData = p.getToken().getStringValue(); 458 loc = p.getToken().getLoc(); 459 p.consumeToken(Token::string); 460 461 // Consume the '>'. 462 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 463 return nullptr; 464 } else { 465 // Ok, the dialect name is the part of the identifier before the dot, the 466 // part after the dot is the dialect's symbol, or the start thereof. 467 auto dotHalves = identifier.split('.'); 468 dialectName = dotHalves.first; 469 auto prettyName = dotHalves.second; 470 471 // If the dialect's symbol is followed immediately by a <, then lex the body 472 // of it into prettyName. 473 if (p.getToken().is(Token::less) && 474 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 475 if (p.parsePrettyDialectSymbolName(prettyName)) 476 return nullptr; 477 } 478 479 symbolData = prettyName.str(); 480 } 481 482 // Call into the provided symbol construction function. 483 auto encodedLoc = p.getEncodedSourceLocation(loc); 484 return createSymbol(dialectName, symbolData, encodedLoc); 485 } 486 487 //===----------------------------------------------------------------------===// 488 // Error Handling 489 //===----------------------------------------------------------------------===// 490 491 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 492 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 493 494 // If we hit a parse error in response to a lexer error, then the lexer 495 // already reported the error. 496 if (getToken().is(Token::error)) 497 diag.abandon(); 498 return diag; 499 } 500 501 //===----------------------------------------------------------------------===// 502 // Token Parsing 503 //===----------------------------------------------------------------------===// 504 505 /// Consume the specified token if present and return success. On failure, 506 /// output a diagnostic and return failure. 507 ParseResult Parser::parseToken(Token::Kind expectedToken, 508 const Twine &message) { 509 if (consumeIf(expectedToken)) 510 return success(); 511 return emitError(message); 512 } 513 514 //===----------------------------------------------------------------------===// 515 // Type Parsing 516 //===----------------------------------------------------------------------===// 517 518 /// Parse an arbitrary type. 519 /// 520 /// type ::= function-type 521 /// | non-function-type 522 /// 523 Type Parser::parseType() { 524 if (getToken().is(Token::l_paren)) 525 return parseFunctionType(); 526 return parseNonFunctionType(); 527 } 528 529 /// Parse a function result type. 530 /// 531 /// function-result-type ::= type-list-parens 532 /// | non-function-type 533 /// 534 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 535 if (getToken().is(Token::l_paren)) 536 return parseTypeListParens(elements); 537 538 Type t = parseNonFunctionType(); 539 if (!t) 540 return failure(); 541 elements.push_back(t); 542 return success(); 543 } 544 545 /// Parse a list of types without an enclosing parenthesis. The list must have 546 /// at least one member. 547 /// 548 /// type-list-no-parens ::= type (`,` type)* 549 /// 550 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 551 auto parseElt = [&]() -> ParseResult { 552 auto elt = parseType(); 553 elements.push_back(elt); 554 return elt ? success() : failure(); 555 }; 556 557 return parseCommaSeparatedList(parseElt); 558 } 559 560 /// Parse a parenthesized list of types. 561 /// 562 /// type-list-parens ::= `(` `)` 563 /// | `(` type-list-no-parens `)` 564 /// 565 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 566 if (parseToken(Token::l_paren, "expected '('")) 567 return failure(); 568 569 // Handle empty lists. 570 if (getToken().is(Token::r_paren)) 571 return consumeToken(), success(); 572 573 if (parseTypeListNoParens(elements) || 574 parseToken(Token::r_paren, "expected ')'")) 575 return failure(); 576 return success(); 577 } 578 579 /// Parse a complex type. 580 /// 581 /// complex-type ::= `complex` `<` type `>` 582 /// 583 Type Parser::parseComplexType() { 584 consumeToken(Token::kw_complex); 585 586 // Parse the '<'. 587 if (parseToken(Token::less, "expected '<' in complex type")) 588 return nullptr; 589 590 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 591 auto elementType = parseType(); 592 if (!elementType || 593 parseToken(Token::greater, "expected '>' in complex type")) 594 return nullptr; 595 596 return ComplexType::getChecked(elementType, typeLocation); 597 } 598 599 /// Parse an extended type. 600 /// 601 /// extended-type ::= (dialect-type | type-alias) 602 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 603 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 604 /// type-alias ::= `!` alias-name 605 /// 606 Type Parser::parseExtendedType() { 607 return parseExtendedSymbol<Type>( 608 *this, Token::exclamation_identifier, state.typeAliasDefinitions, 609 [&](StringRef dialectName, StringRef symbolData, Location loc) -> Type { 610 // If we found a registered dialect, then ask it to parse the type. 611 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) 612 return dialect->parseType(symbolData, loc); 613 614 // Otherwise, form a new opaque type. 615 return OpaqueType::getChecked( 616 Identifier::get(dialectName, state.context), symbolData, 617 state.context, loc); 618 }); 619 } 620 621 /// Parse a function type. 622 /// 623 /// function-type ::= type-list-parens `->` function-result-type 624 /// 625 Type Parser::parseFunctionType() { 626 assert(getToken().is(Token::l_paren)); 627 628 SmallVector<Type, 4> arguments, results; 629 if (parseTypeListParens(arguments) || 630 parseToken(Token::arrow, "expected '->' in function type") || 631 parseFunctionResultTypes(results)) 632 return nullptr; 633 634 return builder.getFunctionType(arguments, results); 635 } 636 637 /// Parse a memref type. 638 /// 639 /// memref-type ::= `memref` `<` dimension-list-ranked type 640 /// (`,` semi-affine-map-composition)? (`,` memory-space)? `>` 641 /// 642 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 643 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 644 /// 645 Type Parser::parseMemRefType() { 646 consumeToken(Token::kw_memref); 647 648 if (parseToken(Token::less, "expected '<' in memref type")) 649 return nullptr; 650 651 SmallVector<int64_t, 4> dimensions; 652 if (parseDimensionListRanked(dimensions)) 653 return nullptr; 654 655 // Parse the element type. 656 auto typeLoc = getToken().getLoc(); 657 auto elementType = parseType(); 658 if (!elementType) 659 return nullptr; 660 661 // Parse semi-affine-map-composition. 662 SmallVector<AffineMap, 2> affineMapComposition; 663 unsigned memorySpace = 0; 664 bool parsedMemorySpace = false; 665 666 auto parseElt = [&]() -> ParseResult { 667 if (getToken().is(Token::integer)) { 668 // Parse memory space. 669 if (parsedMemorySpace) 670 return emitError("multiple memory spaces specified in memref type"); 671 auto v = getToken().getUnsignedIntegerValue(); 672 if (!v.hasValue()) 673 return emitError("invalid memory space in memref type"); 674 memorySpace = v.getValue(); 675 consumeToken(Token::integer); 676 parsedMemorySpace = true; 677 } else { 678 // Parse affine map. 679 if (parsedMemorySpace) 680 return emitError("affine map after memory space in memref type"); 681 auto affineMap = parseAttribute(); 682 if (!affineMap) 683 return failure(); 684 685 // Verify that the parsed attribute is an affine map. 686 if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>()) 687 affineMapComposition.push_back(affineMapAttr.getValue()); 688 else 689 return emitError("expected affine map in memref type"); 690 } 691 return success(); 692 }; 693 694 // Parse a list of mappings and address space if present. 695 if (consumeIf(Token::comma)) { 696 // Parse comma separated list of affine maps, followed by memory space. 697 if (parseCommaSeparatedListUntil(Token::greater, parseElt, 698 /*allowEmptyList=*/false)) { 699 return nullptr; 700 } 701 } else { 702 if (parseToken(Token::greater, "expected ',' or '>' in memref type")) 703 return nullptr; 704 } 705 706 return MemRefType::getChecked(dimensions, elementType, affineMapComposition, 707 memorySpace, getEncodedSourceLocation(typeLoc)); 708 } 709 710 /// Parse any type except the function type. 711 /// 712 /// non-function-type ::= integer-type 713 /// | index-type 714 /// | float-type 715 /// | extended-type 716 /// | vector-type 717 /// | tensor-type 718 /// | memref-type 719 /// | complex-type 720 /// | tuple-type 721 /// | none-type 722 /// 723 /// index-type ::= `index` 724 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 725 /// none-type ::= `none` 726 /// 727 Type Parser::parseNonFunctionType() { 728 switch (getToken().getKind()) { 729 default: 730 return (emitError("expected non-function type"), nullptr); 731 case Token::kw_memref: 732 return parseMemRefType(); 733 case Token::kw_tensor: 734 return parseTensorType(); 735 case Token::kw_complex: 736 return parseComplexType(); 737 case Token::kw_tuple: 738 return parseTupleType(); 739 case Token::kw_vector: 740 return parseVectorType(); 741 // integer-type 742 case Token::inttype: { 743 auto width = getToken().getIntTypeBitwidth(); 744 if (!width.hasValue()) 745 return (emitError("invalid integer width"), nullptr); 746 auto loc = getEncodedSourceLocation(getToken().getLoc()); 747 consumeToken(Token::inttype); 748 return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); 749 } 750 751 // float-type 752 case Token::kw_bf16: 753 consumeToken(Token::kw_bf16); 754 return builder.getBF16Type(); 755 case Token::kw_f16: 756 consumeToken(Token::kw_f16); 757 return builder.getF16Type(); 758 case Token::kw_f32: 759 consumeToken(Token::kw_f32); 760 return builder.getF32Type(); 761 case Token::kw_f64: 762 consumeToken(Token::kw_f64); 763 return builder.getF64Type(); 764 765 // index-type 766 case Token::kw_index: 767 consumeToken(Token::kw_index); 768 return builder.getIndexType(); 769 770 // none-type 771 case Token::kw_none: 772 consumeToken(Token::kw_none); 773 return builder.getNoneType(); 774 775 // extended type 776 case Token::exclamation_identifier: 777 return parseExtendedType(); 778 } 779 } 780 781 /// Parse a tensor type. 782 /// 783 /// tensor-type ::= `tensor` `<` dimension-list type `>` 784 /// dimension-list ::= dimension-list-ranked | `*x` 785 /// 786 Type Parser::parseTensorType() { 787 consumeToken(Token::kw_tensor); 788 789 if (parseToken(Token::less, "expected '<' in tensor type")) 790 return nullptr; 791 792 bool isUnranked; 793 SmallVector<int64_t, 4> dimensions; 794 795 if (consumeIf(Token::star)) { 796 // This is an unranked tensor type. 797 isUnranked = true; 798 799 if (parseXInDimensionList()) 800 return nullptr; 801 802 } else { 803 isUnranked = false; 804 if (parseDimensionListRanked(dimensions)) 805 return nullptr; 806 } 807 808 // Parse the element type. 809 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 810 auto elementType = parseType(); 811 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 812 return nullptr; 813 814 if (isUnranked) 815 return UnrankedTensorType::getChecked(elementType, typeLocation); 816 return RankedTensorType::getChecked(dimensions, elementType, typeLocation); 817 } 818 819 /// Parse a tuple type. 820 /// 821 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 822 /// 823 Type Parser::parseTupleType() { 824 consumeToken(Token::kw_tuple); 825 826 // Parse the '<'. 827 if (parseToken(Token::less, "expected '<' in tuple type")) 828 return nullptr; 829 830 // Check for an empty tuple by directly parsing '>'. 831 if (consumeIf(Token::greater)) 832 return TupleType::get(getContext()); 833 834 // Parse the element types and the '>'. 835 SmallVector<Type, 4> types; 836 if (parseTypeListNoParens(types) || 837 parseToken(Token::greater, "expected '>' in tuple type")) 838 return nullptr; 839 840 return TupleType::get(types, getContext()); 841 } 842 843 /// Parse a vector type. 844 /// 845 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` 846 /// non-empty-static-dimension-list ::= decimal-literal `x` 847 /// static-dimension-list 848 /// static-dimension-list ::= (decimal-literal `x`)* 849 /// 850 VectorType Parser::parseVectorType() { 851 consumeToken(Token::kw_vector); 852 853 if (parseToken(Token::less, "expected '<' in vector type")) 854 return nullptr; 855 856 SmallVector<int64_t, 4> dimensions; 857 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 858 return nullptr; 859 if (dimensions.empty()) 860 return (emitError("expected dimension size in vector type"), nullptr); 861 862 // Parse the element type. 863 auto typeLoc = getToken().getLoc(); 864 auto elementType = parseType(); 865 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 866 return nullptr; 867 868 return VectorType::getChecked(dimensions, elementType, 869 getEncodedSourceLocation(typeLoc)); 870 } 871 872 /// Parse a dimension list of a tensor or memref type. This populates the 873 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 874 /// errors out on `?` otherwise. 875 /// 876 /// dimension-list-ranked ::= (dimension `x`)* 877 /// dimension ::= `?` | decimal-literal 878 /// 879 /// When `allowDynamic` is not set, this is used to parse: 880 /// 881 /// static-dimension-list ::= (decimal-literal `x`)* 882 ParseResult 883 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 884 bool allowDynamic) { 885 while (getToken().isAny(Token::integer, Token::question)) { 886 if (consumeIf(Token::question)) { 887 if (!allowDynamic) 888 return emitError("expected static shape"); 889 dimensions.push_back(-1); 890 } else { 891 // Hexadecimal integer literals (starting with `0x`) are not allowed in 892 // aggregate type declarations. Therefore, `0xf32` should be processed as 893 // a sequence of separate elements `0`, `x`, `f32`. 894 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 895 // We can get here only if the token is an integer literal. Hexadecimal 896 // integer literals can only start with `0x` (`1x` wouldn't lex as a 897 // literal, just `1` would, at which point we don't get into this 898 // branch). 899 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 900 dimensions.push_back(0); 901 state.lex.resetPointer(getTokenSpelling().data() + 1); 902 consumeToken(); 903 } else { 904 // Make sure this integer value is in bound and valid. 905 auto dimension = getToken().getUnsignedIntegerValue(); 906 if (!dimension.hasValue()) 907 return emitError("invalid dimension"); 908 dimensions.push_back((int64_t)dimension.getValue()); 909 consumeToken(Token::integer); 910 } 911 } 912 913 // Make sure we have an 'x' or something like 'xbf32'. 914 if (parseXInDimensionList()) 915 return failure(); 916 } 917 918 return success(); 919 } 920 921 /// Parse an 'x' token in a dimension list, handling the case where the x is 922 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 923 /// token. 924 ParseResult Parser::parseXInDimensionList() { 925 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 926 return emitError("expected 'x' in dimension list"); 927 928 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 929 if (getTokenSpelling().size() != 1) 930 state.lex.resetPointer(getTokenSpelling().data() + 1); 931 932 // Consume the 'x'. 933 consumeToken(Token::bare_identifier); 934 935 return success(); 936 } 937 938 //===----------------------------------------------------------------------===// 939 // Attribute parsing. 940 //===----------------------------------------------------------------------===// 941 942 /// Parse an arbitrary attribute. 943 /// 944 /// attribute-value ::= `unit` 945 /// | bool-literal 946 /// | integer-literal (`:` (index-type | integer-type))? 947 /// | float-literal (`:` float-type)? 948 /// | string-literal (`:` type)? 949 /// | type 950 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 951 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 952 /// | symbol-ref-id 953 /// | `dense` `<` attribute-value `>` `:` 954 /// (tensor-type | vector-type) 955 /// | `sparse` `<` attribute-value `,` attribute-value `>` 956 /// `:` (tensor-type | vector-type) 957 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 958 /// `>` `:` (tensor-type | vector-type) 959 /// | extended-attribute 960 /// 961 Attribute Parser::parseAttribute(Type type) { 962 switch (getToken().getKind()) { 963 // Parse an AffineMap or IntegerSet attribute. 964 case Token::l_paren: { 965 // Try to parse an affine map or an integer set reference. 966 AffineMap map; 967 IntegerSet set; 968 if (parseAffineMapOrIntegerSetReference(map, set)) 969 return nullptr; 970 if (map) 971 return builder.getAffineMapAttr(map); 972 assert(set); 973 return builder.getIntegerSetAttr(set); 974 } 975 976 // Parse an array attribute. 977 case Token::l_square: { 978 consumeToken(Token::l_square); 979 980 SmallVector<Attribute, 4> elements; 981 auto parseElt = [&]() -> ParseResult { 982 elements.push_back(parseAttribute()); 983 return elements.back() ? success() : failure(); 984 }; 985 986 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 987 return nullptr; 988 return builder.getArrayAttr(elements); 989 } 990 991 // Parse a boolean attribute. 992 case Token::kw_false: 993 consumeToken(Token::kw_false); 994 return builder.getBoolAttr(false); 995 case Token::kw_true: 996 consumeToken(Token::kw_true); 997 return builder.getBoolAttr(true); 998 999 // Parse a dense elements attribute. 1000 case Token::kw_dense: 1001 return parseDenseElementsAttr(); 1002 1003 // Parse a dictionary attribute. 1004 case Token::l_brace: { 1005 SmallVector<NamedAttribute, 4> elements; 1006 if (parseAttributeDict(elements)) 1007 return nullptr; 1008 return builder.getDictionaryAttr(elements); 1009 } 1010 1011 // Parse an extended attribute, i.e. alias or dialect attribute. 1012 case Token::hash_identifier: 1013 return parseExtendedAttr(type); 1014 1015 // Parse floating point and integer attributes. 1016 case Token::floatliteral: 1017 return parseFloatAttr(type, /*isNegative=*/false); 1018 case Token::integer: 1019 return parseDecOrHexAttr(type, /*isNegative=*/false); 1020 case Token::minus: { 1021 consumeToken(Token::minus); 1022 if (getToken().is(Token::integer)) 1023 return parseDecOrHexAttr(type, /*isNegative=*/true); 1024 if (getToken().is(Token::floatliteral)) 1025 return parseFloatAttr(type, /*isNegative=*/true); 1026 1027 return (emitError("expected constant integer or floating point value"), 1028 nullptr); 1029 } 1030 1031 // Parse a location attribute. 1032 case Token::kw_loc: { 1033 LocationAttr attr; 1034 return failed(parseLocation(attr)) ? Attribute() : attr; 1035 } 1036 1037 // Parse an opaque elements attribute. 1038 case Token::kw_opaque: 1039 return parseOpaqueElementsAttr(); 1040 1041 // Parse a sparse elements attribute. 1042 case Token::kw_sparse: 1043 return parseSparseElementsAttr(); 1044 1045 // Parse a string attribute. 1046 case Token::string: { 1047 auto val = getToken().getStringValue(); 1048 consumeToken(Token::string); 1049 // Parse the optional trailing colon type if one wasn't explicitly provided. 1050 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1051 return Attribute(); 1052 1053 return type ? StringAttr::get(val, type) 1054 : StringAttr::get(val, getContext()); 1055 } 1056 1057 // Parse a symbol reference attribute. 1058 case Token::at_identifier: { 1059 auto nameStr = getTokenSpelling(); 1060 consumeToken(Token::at_identifier); 1061 return builder.getSymbolRefAttr(nameStr.drop_front()); 1062 } 1063 1064 // Parse a 'unit' attribute. 1065 case Token::kw_unit: 1066 consumeToken(Token::kw_unit); 1067 return builder.getUnitAttr(); 1068 1069 default: 1070 // Parse a type attribute. 1071 if (Type type = parseType()) 1072 return builder.getTypeAttr(type); 1073 return nullptr; 1074 } 1075 } 1076 1077 /// Attribute dictionary. 1078 /// 1079 /// attribute-dict ::= `{` `}` 1080 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1081 /// attribute-entry ::= bare-id `=` attribute-value 1082 /// 1083 ParseResult 1084 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1085 if (!consumeIf(Token::l_brace)) 1086 return failure(); 1087 1088 auto parseElt = [&]() -> ParseResult { 1089 // We allow keywords as attribute names. 1090 if (getToken().isNot(Token::bare_identifier, Token::inttype) && 1091 !getToken().isKeyword()) 1092 return emitError("expected attribute name"); 1093 Identifier nameId = builder.getIdentifier(getTokenSpelling()); 1094 consumeToken(); 1095 1096 // Try to parse the '=' for the attribute value. 1097 if (!consumeIf(Token::equal)) { 1098 // If there is no '=', we treat this as a unit attribute. 1099 attributes.push_back({nameId, builder.getUnitAttr()}); 1100 return success(); 1101 } 1102 1103 auto attr = parseAttribute(); 1104 if (!attr) 1105 return failure(); 1106 1107 attributes.push_back({nameId, attr}); 1108 return success(); 1109 }; 1110 1111 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1112 return failure(); 1113 1114 return success(); 1115 } 1116 1117 /// Parse an extended attribute. 1118 /// 1119 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1120 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1121 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1122 /// attribute-alias ::= `#` alias-name 1123 /// 1124 Attribute Parser::parseExtendedAttr(Type type) { 1125 Attribute attr = parseExtendedSymbol<Attribute>( 1126 *this, Token::hash_identifier, state.attributeAliasDefinitions, 1127 [&](StringRef dialectName, StringRef symbolData, 1128 Location loc) -> Attribute { 1129 // Parse an optional trailing colon type. 1130 Type attrType = type; 1131 if (consumeIf(Token::colon) && !(attrType = parseType())) 1132 return Attribute(); 1133 1134 // If we found a registered dialect, then ask it to parse the attribute. 1135 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) 1136 return dialect->parseAttribute(symbolData, attrType, loc); 1137 1138 // Otherwise, form a new opaque attribute. 1139 return OpaqueAttr::getChecked( 1140 Identifier::get(dialectName, state.context), symbolData, 1141 attrType ? attrType : NoneType::get(state.context), loc); 1142 }); 1143 1144 // Ensure that the attribute has the same type as requested. 1145 if (attr && type && attr.getType() != type) { 1146 emitError("attribute type different than expected: expected ") 1147 << type << ", but got " << attr.getType(); 1148 return nullptr; 1149 } 1150 return attr; 1151 } 1152 1153 /// Parse a float attribute. 1154 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1155 auto val = getToken().getFloatingPointValue(); 1156 if (!val.hasValue()) 1157 return (emitError("floating point value too large for attribute"), nullptr); 1158 consumeToken(Token::floatliteral); 1159 if (!type) { 1160 // Default to F64 when no type is specified. 1161 if (!consumeIf(Token::colon)) 1162 type = builder.getF64Type(); 1163 else if (!(type = parseType())) 1164 return nullptr; 1165 } 1166 if (!type.isa<FloatType>()) 1167 return (emitError("floating point value not valid for specified type"), 1168 nullptr); 1169 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1170 } 1171 1172 /// Construct a float attribute bitwise equivalent to the integer literal. 1173 static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1174 uint64_t value) { 1175 int width = type.getIntOrFloatBitWidth(); 1176 APInt apInt(width, value); 1177 if (apInt != value) { 1178 p->emitError("hexadecimal float constant out of range for type"); 1179 return nullptr; 1180 } 1181 APFloat apFloat(type.getFloatSemantics(), apInt); 1182 return p->builder.getFloatAttr(type, apFloat); 1183 } 1184 1185 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1186 /// or a float attribute. 1187 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1188 auto val = getToken().getUInt64IntegerValue(); 1189 if (!val.hasValue()) 1190 return (emitError("integer constant out of range for attribute"), nullptr); 1191 1192 // Remember if the literal is hexadecimal. 1193 StringRef spelling = getToken().getSpelling(); 1194 auto loc = state.curToken.getLoc(); 1195 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1196 1197 consumeToken(Token::integer); 1198 if (!type) { 1199 // Default to i64 if not type is specified. 1200 if (!consumeIf(Token::colon)) 1201 type = builder.getIntegerType(64); 1202 else if (!(type = parseType())) 1203 return nullptr; 1204 } 1205 1206 if (auto floatType = type.dyn_cast<FloatType>()) { 1207 // TODO(zinenko): Update once hex format for bfloat16 is supported. 1208 if (type.isBF16()) 1209 return emitError(loc, 1210 "hexadecimal float literal not supported for bfloat16"), 1211 nullptr; 1212 if (isNegative) 1213 return emitError( 1214 loc, 1215 "hexadecimal float literal should not have a leading minus"), 1216 nullptr; 1217 if (!isHex) { 1218 emitError(loc, "unexpected decimal integer literal for a float attribute") 1219 .attachNote() 1220 << "add a trailing dot to make the literal a float"; 1221 return nullptr; 1222 } 1223 1224 // Construct a float attribute bitwise equivalent to the integer literal. 1225 return buildHexadecimalFloatLiteral(this, floatType, *val); 1226 } 1227 1228 if (!type.isIntOrIndex()) 1229 return emitError(loc, "integer literal not valid for specified type"), 1230 nullptr; 1231 1232 // Parse the integer literal. 1233 int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); 1234 APInt apInt(width, *val, isNegative); 1235 if (apInt != *val) 1236 return emitError(loc, "integer constant out of range for attribute"), 1237 nullptr; 1238 1239 // Otherwise construct an integer attribute. 1240 if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0) 1241 return emitError(loc, "integer constant out of range for attribute"), 1242 nullptr; 1243 1244 return builder.getIntegerAttr(type, isNegative ? -apInt : apInt); 1245 } 1246 1247 /// Parse an opaque elements attribute. 1248 Attribute Parser::parseOpaqueElementsAttr() { 1249 consumeToken(Token::kw_opaque); 1250 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1251 return nullptr; 1252 1253 if (getToken().isNot(Token::string)) 1254 return (emitError("expected dialect namespace"), nullptr); 1255 1256 auto name = getToken().getStringValue(); 1257 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1258 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1259 // attribute. Otherwise, it can't be roundtripped without having the dialect 1260 // registered. 1261 if (!dialect) 1262 return (emitError("no registered dialect with namespace '" + name + "'"), 1263 nullptr); 1264 1265 consumeToken(Token::string); 1266 if (parseToken(Token::comma, "expected ','")) 1267 return nullptr; 1268 1269 if (getToken().getKind() != Token::string) 1270 return (emitError("opaque string should start with '0x'"), nullptr); 1271 1272 auto val = getToken().getStringValue(); 1273 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1274 return (emitError("opaque string should start with '0x'"), nullptr); 1275 1276 val = val.substr(2); 1277 if (!llvm::all_of(val, llvm::isHexDigit)) 1278 return (emitError("opaque string only contains hex digits"), nullptr); 1279 1280 consumeToken(Token::string); 1281 if (parseToken(Token::greater, "expected '>'") || 1282 parseToken(Token::colon, "expected ':'")) 1283 return nullptr; 1284 1285 auto type = parseElementsLiteralType(); 1286 if (!type) 1287 return nullptr; 1288 1289 return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val)); 1290 } 1291 1292 namespace { 1293 class TensorLiteralParser { 1294 public: 1295 TensorLiteralParser(Parser &p) : p(p) {} 1296 1297 ParseResult parse() { 1298 if (p.getToken().is(Token::l_square)) 1299 return parseList(shape); 1300 return parseElement(); 1301 } 1302 1303 /// Build a dense attribute instance with the parsed elements and the given 1304 /// shaped type. 1305 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1306 1307 ArrayRef<int64_t> getShape() const { return shape; } 1308 1309 private: 1310 enum class ElementKind { Boolean, Integer, Float }; 1311 1312 /// Return a string to represent the given element kind. 1313 const char *getElementKindStr(ElementKind kind) { 1314 switch (kind) { 1315 case ElementKind::Boolean: 1316 return "'boolean'"; 1317 case ElementKind::Integer: 1318 return "'integer'"; 1319 case ElementKind::Float: 1320 return "'float'"; 1321 } 1322 llvm_unreachable("unknown element kind"); 1323 } 1324 1325 /// Build a Dense Integer attribute for the given type. 1326 DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, 1327 IntegerType eltTy); 1328 1329 /// Build a Dense Float attribute for the given type. 1330 DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, 1331 FloatType eltTy); 1332 1333 /// Parse a single element, returning failure if it isn't a valid element 1334 /// literal. For example: 1335 /// parseElement(1) -> Success, 1 1336 /// parseElement([1]) -> Failure 1337 ParseResult parseElement(); 1338 1339 /// Parse a list of either lists or elements, returning the dimensions of the 1340 /// parsed sub-tensors in dims. For example: 1341 /// parseList([1, 2, 3]) -> Success, [3] 1342 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1343 /// parseList([[1, 2], 3]) -> Failure 1344 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1345 ParseResult parseList(llvm::SmallVectorImpl<int64_t> &dims); 1346 1347 Parser &p; 1348 1349 /// The shape inferred from the parsed elements. 1350 SmallVector<int64_t, 4> shape; 1351 1352 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 1353 std::vector<std::pair<bool, Token>> storage; 1354 1355 /// A flag that indicates the type of elements that have been parsed. 1356 llvm::Optional<ElementKind> knownEltKind; 1357 }; 1358 } // namespace 1359 1360 /// Build a dense attribute instance with the parsed elements and the given 1361 /// shaped type. 1362 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 1363 ShapedType type) { 1364 // Check that the parsed storage size has the same number of elements to the 1365 // type, or is a known splat. 1366 if (!shape.empty() && getShape() != type.getShape()) { 1367 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 1368 << "]) does not match type ([" << type.getShape() << "])"; 1369 return nullptr; 1370 } 1371 1372 // If the type is an integer, build a set of APInt values from the storage 1373 // with the correct bitwidth. 1374 if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) 1375 return getIntAttr(loc, type, intTy); 1376 1377 // Otherwise, this must be a floating point type. 1378 auto floatTy = type.getElementType().dyn_cast<FloatType>(); 1379 if (!floatTy) { 1380 p.emitError(loc) << "expected floating-point or integer element type, got " 1381 << type.getElementType(); 1382 return nullptr; 1383 } 1384 return getFloatAttr(loc, type, floatTy); 1385 } 1386 1387 /// Build a Dense Integer attribute for the given type. 1388 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, 1389 ShapedType type, 1390 IntegerType eltTy) { 1391 std::vector<APInt> intElements; 1392 intElements.reserve(storage.size()); 1393 for (const auto &signAndToken : storage) { 1394 bool isNegative = signAndToken.first; 1395 const Token &token = signAndToken.second; 1396 1397 // Check to see if floating point values were parsed. 1398 if (token.is(Token::floatliteral)) { 1399 p.emitError() << "expected integer elements, but parsed floating-point"; 1400 return nullptr; 1401 } 1402 1403 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 1404 "unexpected token type"); 1405 if (token.isAny(Token::kw_true, Token::kw_false)) { 1406 if (!eltTy.isInteger(1)) 1407 p.emitError() << "expected i1 type for 'true' or 'false' values"; 1408 APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), 1409 /*isSigned=*/false); 1410 intElements.push_back(apInt); 1411 continue; 1412 } 1413 1414 // Create APInt values for each element with the correct bitwidth. 1415 auto val = token.getUInt64IntegerValue(); 1416 if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 1417 : (int64_t)val.getValue() < 0)) { 1418 p.emitError(token.getLoc(), 1419 "integer constant out of range for attribute"); 1420 return nullptr; 1421 } 1422 APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); 1423 if (apInt != val.getValue()) 1424 return (p.emitError("integer constant out of range for type"), nullptr); 1425 intElements.push_back(isNegative ? -apInt : apInt); 1426 } 1427 1428 return DenseElementsAttr::get(type, intElements); 1429 } 1430 1431 /// Build a Dense Float attribute for the given type. 1432 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, 1433 ShapedType type, 1434 FloatType eltTy) { 1435 std::vector<Attribute> floatValues; 1436 floatValues.reserve(storage.size()); 1437 for (const auto &signAndToken : storage) { 1438 bool isNegative = signAndToken.first; 1439 const Token &token = signAndToken.second; 1440 1441 // Handle hexadecimal float literals. 1442 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 1443 if (isNegative) { 1444 p.emitError(token.getLoc()) 1445 << "hexadecimal float literal should not have a leading minus"; 1446 return nullptr; 1447 } 1448 auto val = token.getUInt64IntegerValue(); 1449 if (!val.hasValue()) { 1450 p.emitError("hexadecimal float constant out of range for attribute"); 1451 return nullptr; 1452 } 1453 FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); 1454 if (!attr) 1455 return nullptr; 1456 floatValues.push_back(attr); 1457 continue; 1458 } 1459 1460 // Check to see if any decimal integers or booleans were parsed. 1461 if (!token.is(Token::floatliteral)) { 1462 p.emitError() << "expected floating-point elements, but parsed integer"; 1463 return nullptr; 1464 } 1465 1466 // Build the float values from tokens. 1467 auto val = token.getFloatingPointValue(); 1468 if (!val.hasValue()) { 1469 p.emitError("floating point value too large for attribute"); 1470 return nullptr; 1471 } 1472 floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); 1473 } 1474 1475 return DenseElementsAttr::get(type, floatValues); 1476 } 1477 1478 ParseResult TensorLiteralParser::parseElement() { 1479 switch (p.getToken().getKind()) { 1480 // Parse a boolean element. 1481 case Token::kw_true: 1482 case Token::kw_false: 1483 case Token::floatliteral: 1484 case Token::integer: 1485 storage.emplace_back(/*isNegative=*/false, p.getToken()); 1486 p.consumeToken(); 1487 break; 1488 1489 // Parse a signed integer or a negative floating-point element. 1490 case Token::minus: 1491 p.consumeToken(Token::minus); 1492 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 1493 return p.emitError("expected integer or floating point literal"); 1494 storage.emplace_back(/*isNegative=*/true, p.getToken()); 1495 p.consumeToken(); 1496 break; 1497 1498 default: 1499 return p.emitError("expected element literal of primitive type"); 1500 } 1501 1502 return success(); 1503 } 1504 1505 /// Parse a list of either lists or elements, returning the dimensions of the 1506 /// parsed sub-tensors in dims. For example: 1507 /// parseList([1, 2, 3]) -> Success, [3] 1508 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1509 /// parseList([[1, 2], 3]) -> Failure 1510 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1511 ParseResult 1512 TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) { 1513 p.consumeToken(Token::l_square); 1514 1515 auto checkDims = 1516 [&](const llvm::SmallVectorImpl<int64_t> &prevDims, 1517 const llvm::SmallVectorImpl<int64_t> &newDims) -> ParseResult { 1518 if (prevDims == newDims) 1519 return success(); 1520 return p.emitError("tensor literal is invalid; ranks are not consistent " 1521 "between elements"); 1522 }; 1523 1524 bool first = true; 1525 llvm::SmallVector<int64_t, 4> newDims; 1526 unsigned size = 0; 1527 auto parseCommaSeparatedList = [&]() -> ParseResult { 1528 llvm::SmallVector<int64_t, 4> thisDims; 1529 if (p.getToken().getKind() == Token::l_square) { 1530 if (parseList(thisDims)) 1531 return failure(); 1532 } else if (parseElement()) { 1533 return failure(); 1534 } 1535 ++size; 1536 if (!first) 1537 return checkDims(newDims, thisDims); 1538 newDims = thisDims; 1539 first = false; 1540 return success(); 1541 }; 1542 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 1543 return failure(); 1544 1545 // Return the sublists' dimensions with 'size' prepended. 1546 dims.clear(); 1547 dims.push_back(size); 1548 dims.append(newDims.begin(), newDims.end()); 1549 return success(); 1550 } 1551 1552 /// Parse a dense elements attribute. 1553 Attribute Parser::parseDenseElementsAttr() { 1554 consumeToken(Token::kw_dense); 1555 if (parseToken(Token::less, "expected '<' after 'dense'")) 1556 return nullptr; 1557 1558 // Parse the literal data. 1559 TensorLiteralParser literalParser(*this); 1560 if (literalParser.parse()) 1561 return nullptr; 1562 1563 if (parseToken(Token::greater, "expected '>'") || 1564 parseToken(Token::colon, "expected ':'")) 1565 return nullptr; 1566 1567 auto typeLoc = getToken().getLoc(); 1568 auto type = parseElementsLiteralType(); 1569 if (!type) 1570 return nullptr; 1571 return literalParser.getAttr(typeLoc, type); 1572 } 1573 1574 /// Shaped type for elements attribute. 1575 /// 1576 /// elements-literal-type ::= vector-type | ranked-tensor-type 1577 /// 1578 /// This method also checks the type has static shape. 1579 ShapedType Parser::parseElementsLiteralType() { 1580 auto type = parseType(); 1581 if (!type) 1582 return nullptr; 1583 1584 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 1585 emitError("elements literal must be a ranked tensor or vector type"); 1586 return nullptr; 1587 } 1588 1589 auto sType = type.cast<ShapedType>(); 1590 if (!sType.hasStaticShape()) 1591 return (emitError("elements literal type must have static shape"), nullptr); 1592 1593 return sType; 1594 } 1595 1596 /// Parse a sparse elements attribute. 1597 Attribute Parser::parseSparseElementsAttr() { 1598 consumeToken(Token::kw_sparse); 1599 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 1600 return nullptr; 1601 1602 /// Parse indices 1603 auto indicesLoc = getToken().getLoc(); 1604 TensorLiteralParser indiceParser(*this); 1605 if (indiceParser.parse()) 1606 return nullptr; 1607 1608 if (parseToken(Token::comma, "expected ','")) 1609 return nullptr; 1610 1611 /// Parse values. 1612 auto valuesLoc = getToken().getLoc(); 1613 TensorLiteralParser valuesParser(*this); 1614 if (valuesParser.parse()) 1615 return nullptr; 1616 1617 if (parseToken(Token::greater, "expected '>'") || 1618 parseToken(Token::colon, "expected ':'")) 1619 return nullptr; 1620 1621 auto type = parseElementsLiteralType(); 1622 if (!type) 1623 return nullptr; 1624 1625 // If the indices are a splat, i.e. the literal parser parsed an element and 1626 // not a list, we set the shape explicitly. The indices are represented by a 1627 // 2-dimensional shape where the second dimension is the rank of the type. 1628 // Given that the parsed indices is a splat, we know that we only have one 1629 // indice and thus one for the first dimension. 1630 auto indiceEltType = builder.getIntegerType(64); 1631 ShapedType indicesType; 1632 if (indiceParser.getShape().empty()) { 1633 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 1634 } else { 1635 // Otherwise, set the shape to the one parsed by the literal parser. 1636 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 1637 } 1638 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 1639 1640 // If the values are a splat, set the shape explicitly based on the number of 1641 // indices. The number of indices is encoded in the first dimension of the 1642 // indice shape type. 1643 auto valuesEltType = type.getElementType(); 1644 ShapedType valuesType = 1645 valuesParser.getShape().empty() 1646 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 1647 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 1648 auto values = valuesParser.getAttr(valuesLoc, valuesType); 1649 1650 /// Sanity check. 1651 if (valuesType.getRank() != 1) 1652 return (emitError("expected 1-d tensor for values"), nullptr); 1653 1654 auto sameShape = (indicesType.getRank() == 1) || 1655 (type.getRank() == indicesType.getDimSize(1)); 1656 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 1657 if (!sameShape || !sameElementNum) { 1658 emitError() << "expected shape ([" << type.getShape() 1659 << "]); inferred shape of indices literal ([" 1660 << indicesType.getShape() 1661 << "]); inferred shape of values literal ([" 1662 << valuesType.getShape() << "])"; 1663 return nullptr; 1664 } 1665 1666 // Build the sparse elements attribute by the indices and values. 1667 return SparseElementsAttr::get(type, indices, values); 1668 } 1669 1670 //===----------------------------------------------------------------------===// 1671 // Location parsing. 1672 //===----------------------------------------------------------------------===// 1673 1674 /// Parse a location. 1675 /// 1676 /// location ::= `loc` inline-location 1677 /// inline-location ::= '(' location-inst ')' 1678 /// 1679 ParseResult Parser::parseLocation(LocationAttr &loc) { 1680 // Check for 'loc' identifier. 1681 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 1682 return emitError(); 1683 1684 // Parse the inline-location. 1685 if (parseToken(Token::l_paren, "expected '(' in inline location") || 1686 parseLocationInstance(loc) || 1687 parseToken(Token::r_paren, "expected ')' in inline location")) 1688 return failure(); 1689 return success(); 1690 } 1691 1692 /// Specific location instances. 1693 /// 1694 /// location-inst ::= filelinecol-location | 1695 /// name-location | 1696 /// callsite-location | 1697 /// fused-location | 1698 /// unknown-location 1699 /// filelinecol-location ::= string-literal ':' integer-literal 1700 /// ':' integer-literal 1701 /// name-location ::= string-literal 1702 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 1703 /// fused-location ::= fused ('<' attribute-value '>')? 1704 /// '[' location-inst (location-inst ',')* ']' 1705 /// unknown-location ::= 'unknown' 1706 /// 1707 ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) { 1708 consumeToken(Token::bare_identifier); 1709 1710 // Parse the '('. 1711 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 1712 return failure(); 1713 1714 // Parse the callee location. 1715 LocationAttr calleeLoc; 1716 if (parseLocationInstance(calleeLoc)) 1717 return failure(); 1718 1719 // Parse the 'at'. 1720 if (getToken().isNot(Token::bare_identifier) || 1721 getToken().getSpelling() != "at") 1722 return emitError("expected 'at' in callsite location"); 1723 consumeToken(Token::bare_identifier); 1724 1725 // Parse the caller location. 1726 LocationAttr callerLoc; 1727 if (parseLocationInstance(callerLoc)) 1728 return failure(); 1729 1730 // Parse the ')'. 1731 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 1732 return failure(); 1733 1734 // Return the callsite location. 1735 loc = CallSiteLoc::get(calleeLoc, callerLoc); 1736 return success(); 1737 } 1738 1739 ParseResult Parser::parseFusedLocation(LocationAttr &loc) { 1740 consumeToken(Token::bare_identifier); 1741 1742 // Try to parse the optional metadata. 1743 Attribute metadata; 1744 if (consumeIf(Token::less)) { 1745 metadata = parseAttribute(); 1746 if (!metadata) 1747 return emitError("expected valid attribute metadata"); 1748 // Parse the '>' token. 1749 if (parseToken(Token::greater, 1750 "expected '>' after fused location metadata")) 1751 return failure(); 1752 } 1753 1754 llvm::SmallVector<Location, 4> locations; 1755 auto parseElt = [&] { 1756 LocationAttr newLoc; 1757 if (parseLocationInstance(newLoc)) 1758 return failure(); 1759 locations.push_back(newLoc); 1760 return success(); 1761 }; 1762 1763 if (parseToken(Token::l_square, "expected '[' in fused location") || 1764 parseCommaSeparatedList(parseElt) || 1765 parseToken(Token::r_square, "expected ']' in fused location")) 1766 return failure(); 1767 1768 // Return the fused location. 1769 loc = FusedLoc::get(locations, metadata, getContext()); 1770 return success(); 1771 } 1772 1773 ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) { 1774 auto *ctx = getContext(); 1775 auto str = getToken().getStringValue(); 1776 consumeToken(Token::string); 1777 1778 // If the next token is ':' this is a filelinecol location. 1779 if (consumeIf(Token::colon)) { 1780 // Parse the line number. 1781 if (getToken().isNot(Token::integer)) 1782 return emitError("expected integer line number in FileLineColLoc"); 1783 auto line = getToken().getUnsignedIntegerValue(); 1784 if (!line.hasValue()) 1785 return emitError("expected integer line number in FileLineColLoc"); 1786 consumeToken(Token::integer); 1787 1788 // Parse the ':'. 1789 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 1790 return failure(); 1791 1792 // Parse the column number. 1793 if (getToken().isNot(Token::integer)) 1794 return emitError("expected integer column number in FileLineColLoc"); 1795 auto column = getToken().getUnsignedIntegerValue(); 1796 if (!column.hasValue()) 1797 return emitError("expected integer column number in FileLineColLoc"); 1798 consumeToken(Token::integer); 1799 1800 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 1801 return success(); 1802 } 1803 1804 // Otherwise, this is a NameLoc. 1805 1806 // Check for a child location. 1807 if (consumeIf(Token::l_paren)) { 1808 auto childSourceLoc = getToken().getLoc(); 1809 1810 // Parse the child location. 1811 LocationAttr childLoc; 1812 if (parseLocationInstance(childLoc)) 1813 return failure(); 1814 1815 // The child must not be another NameLoc. 1816 if (childLoc.isa<NameLoc>()) 1817 return emitError(childSourceLoc, 1818 "child of NameLoc cannot be another NameLoc"); 1819 loc = NameLoc::get(Identifier::get(str, ctx), childLoc); 1820 1821 // Parse the closing ')'. 1822 if (parseToken(Token::r_paren, 1823 "expected ')' after child location of NameLoc")) 1824 return failure(); 1825 } else { 1826 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 1827 } 1828 1829 return success(); 1830 } 1831 1832 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 1833 // Handle either name or filelinecol locations. 1834 if (getToken().is(Token::string)) 1835 return parseNameOrFileLineColLocation(loc); 1836 1837 // Bare tokens required for other cases. 1838 if (!getToken().is(Token::bare_identifier)) 1839 return emitError("expected location instance"); 1840 1841 // Check for the 'callsite' signifying a callsite location. 1842 if (getToken().getSpelling() == "callsite") 1843 return parseCallSiteLocation(loc); 1844 1845 // If the token is 'fused', then this is a fused location. 1846 if (getToken().getSpelling() == "fused") 1847 return parseFusedLocation(loc); 1848 1849 // Check for a 'unknown' for an unknown location. 1850 if (getToken().getSpelling() == "unknown") { 1851 consumeToken(Token::bare_identifier); 1852 loc = UnknownLoc::get(getContext()); 1853 return success(); 1854 } 1855 1856 return emitError("expected location instance"); 1857 } 1858 1859 //===----------------------------------------------------------------------===// 1860 // Affine parsing. 1861 //===----------------------------------------------------------------------===// 1862 1863 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 1864 /// the boolean sense. 1865 enum AffineLowPrecOp { 1866 /// Null value. 1867 LNoOp, 1868 Add, 1869 Sub 1870 }; 1871 1872 /// Higher precedence ops - all at the same precedence level. HNoOp is false 1873 /// in the boolean sense. 1874 enum AffineHighPrecOp { 1875 /// Null value. 1876 HNoOp, 1877 Mul, 1878 FloorDiv, 1879 CeilDiv, 1880 Mod 1881 }; 1882 1883 namespace { 1884 /// This is a specialized parser for affine structures (affine maps, affine 1885 /// expressions, and integer sets), maintaining the state transient to their 1886 /// bodies. 1887 class AffineParser : public Parser { 1888 public: 1889 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 1890 llvm::function_ref<ParseResult(bool)> parseElement = nullptr) 1891 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 1892 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 1893 1894 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 1895 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 1896 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 1897 ParseResult parseAffineMapOfSSAIds(AffineMap &map); 1898 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 1899 unsigned &numDims); 1900 1901 private: 1902 // Binary affine op parsing. 1903 AffineLowPrecOp consumeIfLowPrecOp(); 1904 AffineHighPrecOp consumeIfHighPrecOp(); 1905 1906 // Identifier lists for polyhedral structures. 1907 ParseResult parseDimIdList(unsigned &numDims); 1908 ParseResult parseSymbolIdList(unsigned &numSymbols); 1909 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 1910 unsigned &numSymbols); 1911 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 1912 1913 AffineExpr parseAffineExpr(); 1914 AffineExpr parseParentheticalExpr(); 1915 AffineExpr parseNegateExpression(AffineExpr lhs); 1916 AffineExpr parseIntegerExpr(); 1917 AffineExpr parseBareIdExpr(); 1918 AffineExpr parseSSAIdExpr(bool isSymbol); 1919 AffineExpr parseSymbolSSAIdExpr(); 1920 1921 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 1922 AffineExpr rhs, SMLoc opLoc); 1923 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 1924 AffineExpr rhs); 1925 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 1926 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 1927 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 1928 SMLoc llhsOpLoc); 1929 AffineExpr parseAffineConstraint(bool *isEq); 1930 1931 private: 1932 bool allowParsingSSAIds; 1933 llvm::function_ref<ParseResult(bool)> parseElement; 1934 unsigned numDimOperands; 1935 unsigned numSymbolOperands; 1936 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 1937 }; 1938 } // end anonymous namespace 1939 1940 /// Create an affine binary high precedence op expression (mul's, div's, mod). 1941 /// opLoc is the location of the op token to be used to report errors 1942 /// for non-conforming expressions. 1943 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 1944 AffineExpr lhs, AffineExpr rhs, 1945 SMLoc opLoc) { 1946 // TODO: make the error location info accurate. 1947 switch (op) { 1948 case Mul: 1949 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 1950 emitError(opLoc, "non-affine expression: at least one of the multiply " 1951 "operands has to be either a constant or symbolic"); 1952 return nullptr; 1953 } 1954 return lhs * rhs; 1955 case FloorDiv: 1956 if (!rhs.isSymbolicOrConstant()) { 1957 emitError(opLoc, "non-affine expression: right operand of floordiv " 1958 "has to be either a constant or symbolic"); 1959 return nullptr; 1960 } 1961 return lhs.floorDiv(rhs); 1962 case CeilDiv: 1963 if (!rhs.isSymbolicOrConstant()) { 1964 emitError(opLoc, "non-affine expression: right operand of ceildiv " 1965 "has to be either a constant or symbolic"); 1966 return nullptr; 1967 } 1968 return lhs.ceilDiv(rhs); 1969 case Mod: 1970 if (!rhs.isSymbolicOrConstant()) { 1971 emitError(opLoc, "non-affine expression: right operand of mod " 1972 "has to be either a constant or symbolic"); 1973 return nullptr; 1974 } 1975 return lhs % rhs; 1976 case HNoOp: 1977 llvm_unreachable("can't create affine expression for null high prec op"); 1978 return nullptr; 1979 } 1980 llvm_unreachable("Unknown AffineHighPrecOp"); 1981 } 1982 1983 /// Create an affine binary low precedence op expression (add, sub). 1984 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 1985 AffineExpr lhs, AffineExpr rhs) { 1986 switch (op) { 1987 case AffineLowPrecOp::Add: 1988 return lhs + rhs; 1989 case AffineLowPrecOp::Sub: 1990 return lhs - rhs; 1991 case AffineLowPrecOp::LNoOp: 1992 llvm_unreachable("can't create affine expression for null low prec op"); 1993 return nullptr; 1994 } 1995 llvm_unreachable("Unknown AffineLowPrecOp"); 1996 } 1997 1998 /// Consume this token if it is a lower precedence affine op (there are only 1999 /// two precedence levels). 2000 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 2001 switch (getToken().getKind()) { 2002 case Token::plus: 2003 consumeToken(Token::plus); 2004 return AffineLowPrecOp::Add; 2005 case Token::minus: 2006 consumeToken(Token::minus); 2007 return AffineLowPrecOp::Sub; 2008 default: 2009 return AffineLowPrecOp::LNoOp; 2010 } 2011 } 2012 2013 /// Consume this token if it is a higher precedence affine op (there are only 2014 /// two precedence levels) 2015 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 2016 switch (getToken().getKind()) { 2017 case Token::star: 2018 consumeToken(Token::star); 2019 return Mul; 2020 case Token::kw_floordiv: 2021 consumeToken(Token::kw_floordiv); 2022 return FloorDiv; 2023 case Token::kw_ceildiv: 2024 consumeToken(Token::kw_ceildiv); 2025 return CeilDiv; 2026 case Token::kw_mod: 2027 consumeToken(Token::kw_mod); 2028 return Mod; 2029 default: 2030 return HNoOp; 2031 } 2032 } 2033 2034 /// Parse a high precedence op expression list: mul, div, and mod are high 2035 /// precedence binary ops, i.e., parse a 2036 /// expr_1 op_1 expr_2 op_2 ... expr_n 2037 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2038 /// All affine binary ops are left associative. 2039 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2040 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2041 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2042 /// report an error for non-conforming expressions. 2043 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2044 AffineHighPrecOp llhsOp, 2045 SMLoc llhsOpLoc) { 2046 AffineExpr lhs = parseAffineOperandExpr(llhs); 2047 if (!lhs) 2048 return nullptr; 2049 2050 // Found an LHS. Parse the remaining expression. 2051 auto opLoc = getToken().getLoc(); 2052 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2053 if (llhs) { 2054 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2055 if (!expr) 2056 return nullptr; 2057 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2058 } 2059 // No LLHS, get RHS 2060 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2061 } 2062 2063 // This is the last operand in this expression. 2064 if (llhs) 2065 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2066 2067 // No llhs, 'lhs' itself is the expression. 2068 return lhs; 2069 } 2070 2071 /// Parse an affine expression inside parentheses. 2072 /// 2073 /// affine-expr ::= `(` affine-expr `)` 2074 AffineExpr AffineParser::parseParentheticalExpr() { 2075 if (parseToken(Token::l_paren, "expected '('")) 2076 return nullptr; 2077 if (getToken().is(Token::r_paren)) 2078 return (emitError("no expression inside parentheses"), nullptr); 2079 2080 auto expr = parseAffineExpr(); 2081 if (!expr) 2082 return nullptr; 2083 if (parseToken(Token::r_paren, "expected ')'")) 2084 return nullptr; 2085 2086 return expr; 2087 } 2088 2089 /// Parse the negation expression. 2090 /// 2091 /// affine-expr ::= `-` affine-expr 2092 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2093 if (parseToken(Token::minus, "expected '-'")) 2094 return nullptr; 2095 2096 AffineExpr operand = parseAffineOperandExpr(lhs); 2097 // Since negation has the highest precedence of all ops (including high 2098 // precedence ops) but lower than parentheses, we are only going to use 2099 // parseAffineOperandExpr instead of parseAffineExpr here. 2100 if (!operand) 2101 // Extra error message although parseAffineOperandExpr would have 2102 // complained. Leads to a better diagnostic. 2103 return (emitError("missing operand of negation"), nullptr); 2104 return (-1) * operand; 2105 } 2106 2107 /// Parse a bare id that may appear in an affine expression. 2108 /// 2109 /// affine-expr ::= bare-id 2110 AffineExpr AffineParser::parseBareIdExpr() { 2111 if (getToken().isNot(Token::bare_identifier)) 2112 return (emitError("expected bare identifier"), nullptr); 2113 2114 StringRef sRef = getTokenSpelling(); 2115 for (auto entry : dimsAndSymbols) { 2116 if (entry.first == sRef) { 2117 consumeToken(Token::bare_identifier); 2118 return entry.second; 2119 } 2120 } 2121 2122 return (emitError("use of undeclared identifier"), nullptr); 2123 } 2124 2125 /// Parse an SSA id which may appear in an affine expression. 2126 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2127 if (!allowParsingSSAIds) 2128 return (emitError("unexpected ssa identifier"), nullptr); 2129 if (getToken().isNot(Token::percent_identifier)) 2130 return (emitError("expected ssa identifier"), nullptr); 2131 auto name = getTokenSpelling(); 2132 // Check if we already parsed this SSA id. 2133 for (auto entry : dimsAndSymbols) { 2134 if (entry.first == name) { 2135 consumeToken(Token::percent_identifier); 2136 return entry.second; 2137 } 2138 } 2139 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2140 if (parseElement(isSymbol)) 2141 return (emitError("failed to parse ssa identifier"), nullptr); 2142 auto idExpr = isSymbol 2143 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2144 : getAffineDimExpr(numDimOperands++, getContext()); 2145 dimsAndSymbols.push_back({name, idExpr}); 2146 return idExpr; 2147 } 2148 2149 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2150 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2151 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2152 return nullptr; 2153 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2154 if (!symbolExpr) 2155 return nullptr; 2156 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2157 return nullptr; 2158 return symbolExpr; 2159 } 2160 2161 /// Parse a positive integral constant appearing in an affine expression. 2162 /// 2163 /// affine-expr ::= integer-literal 2164 AffineExpr AffineParser::parseIntegerExpr() { 2165 auto val = getToken().getUInt64IntegerValue(); 2166 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2167 return (emitError("constant too large for index"), nullptr); 2168 2169 consumeToken(Token::integer); 2170 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2171 } 2172 2173 /// Parses an expression that can be a valid operand of an affine expression. 2174 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2175 /// operator, the rhs of which is being parsed. This is used to determine 2176 /// whether an error should be emitted for a missing right operand. 2177 // Eg: for an expression without parentheses (like i + j + k + l), each 2178 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2179 // operand expression, it's an op expression and will be parsed via 2180 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2181 // -l are valid operands that will be parsed by this function. 2182 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2183 switch (getToken().getKind()) { 2184 case Token::bare_identifier: 2185 return parseBareIdExpr(); 2186 case Token::kw_symbol: 2187 return parseSymbolSSAIdExpr(); 2188 case Token::percent_identifier: 2189 return parseSSAIdExpr(/*isSymbol=*/false); 2190 case Token::integer: 2191 return parseIntegerExpr(); 2192 case Token::l_paren: 2193 return parseParentheticalExpr(); 2194 case Token::minus: 2195 return parseNegateExpression(lhs); 2196 case Token::kw_ceildiv: 2197 case Token::kw_floordiv: 2198 case Token::kw_mod: 2199 case Token::plus: 2200 case Token::star: 2201 if (lhs) 2202 emitError("missing right operand of binary operator"); 2203 else 2204 emitError("missing left operand of binary operator"); 2205 return nullptr; 2206 default: 2207 if (lhs) 2208 emitError("missing right operand of binary operator"); 2209 else 2210 emitError("expected affine expression"); 2211 return nullptr; 2212 } 2213 } 2214 2215 /// Parse affine expressions that are bare-id's, integer constants, 2216 /// parenthetical affine expressions, and affine op expressions that are a 2217 /// composition of those. 2218 /// 2219 /// All binary op's associate from left to right. 2220 /// 2221 /// {add, sub} have lower precedence than {mul, div, and mod}. 2222 /// 2223 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2224 /// ceildiv, and mod are at the same higher precedence level. Negation has 2225 /// higher precedence than any binary op. 2226 /// 2227 /// llhs: the affine expression appearing on the left of the one being parsed. 2228 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2229 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2230 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2231 /// associativity. 2232 /// 2233 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2234 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2235 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2236 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2237 AffineLowPrecOp llhsOp) { 2238 AffineExpr lhs; 2239 if (!(lhs = parseAffineOperandExpr(llhs))) 2240 return nullptr; 2241 2242 // Found an LHS. Deal with the ops. 2243 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 2244 if (llhs) { 2245 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2246 return parseAffineLowPrecOpExpr(sum, lOp); 2247 } 2248 // No LLHS, get RHS and form the expression. 2249 return parseAffineLowPrecOpExpr(lhs, lOp); 2250 } 2251 auto opLoc = getToken().getLoc(); 2252 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 2253 // We have a higher precedence op here. Get the rhs operand for the llhs 2254 // through parseAffineHighPrecOpExpr. 2255 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 2256 if (!highRes) 2257 return nullptr; 2258 2259 // If llhs is null, the product forms the first operand of the yet to be 2260 // found expression. If non-null, the op to associate with llhs is llhsOp. 2261 AffineExpr expr = 2262 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 2263 2264 // Recurse for subsequent low prec op's after the affine high prec op 2265 // expression. 2266 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 2267 return parseAffineLowPrecOpExpr(expr, nextOp); 2268 return expr; 2269 } 2270 // Last operand in the expression list. 2271 if (llhs) 2272 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2273 // No llhs, 'lhs' itself is the expression. 2274 return lhs; 2275 } 2276 2277 /// Parse an affine expression. 2278 /// affine-expr ::= `(` affine-expr `)` 2279 /// | `-` affine-expr 2280 /// | affine-expr `+` affine-expr 2281 /// | affine-expr `-` affine-expr 2282 /// | affine-expr `*` affine-expr 2283 /// | affine-expr `floordiv` affine-expr 2284 /// | affine-expr `ceildiv` affine-expr 2285 /// | affine-expr `mod` affine-expr 2286 /// | bare-id 2287 /// | integer-literal 2288 /// 2289 /// Additional conditions are checked depending on the production. For eg., 2290 /// one of the operands for `*` has to be either constant/symbolic; the second 2291 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 2292 AffineExpr AffineParser::parseAffineExpr() { 2293 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 2294 } 2295 2296 /// Parse a dim or symbol from the lists appearing before the actual 2297 /// expressions of the affine map. Update our state to store the 2298 /// dimensional/symbolic identifier. 2299 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 2300 if (getToken().isNot(Token::bare_identifier)) 2301 return emitError("expected bare identifier"); 2302 2303 auto name = getTokenSpelling(); 2304 for (auto entry : dimsAndSymbols) { 2305 if (entry.first == name) 2306 return emitError("redefinition of identifier '" + name + "'"); 2307 } 2308 consumeToken(Token::bare_identifier); 2309 2310 dimsAndSymbols.push_back({name, idExpr}); 2311 return success(); 2312 } 2313 2314 /// Parse the list of dimensional identifiers to an affine map. 2315 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 2316 if (parseToken(Token::l_paren, 2317 "expected '(' at start of dimensional identifiers list")) { 2318 return failure(); 2319 } 2320 2321 auto parseElt = [&]() -> ParseResult { 2322 auto dimension = getAffineDimExpr(numDims++, getContext()); 2323 return parseIdentifierDefinition(dimension); 2324 }; 2325 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 2326 } 2327 2328 /// Parse the list of symbolic identifiers to an affine map. 2329 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 2330 consumeToken(Token::l_square); 2331 auto parseElt = [&]() -> ParseResult { 2332 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 2333 return parseIdentifierDefinition(symbol); 2334 }; 2335 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 2336 } 2337 2338 /// Parse the list of symbolic identifiers to an affine map. 2339 ParseResult 2340 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 2341 unsigned &numSymbols) { 2342 if (parseDimIdList(numDims)) { 2343 return failure(); 2344 } 2345 if (!getToken().is(Token::l_square)) { 2346 numSymbols = 0; 2347 return success(); 2348 } 2349 return parseSymbolIdList(numSymbols); 2350 } 2351 2352 /// Parses an ambiguous affine map or integer set definition inline. 2353 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 2354 IntegerSet &set) { 2355 unsigned numDims = 0, numSymbols = 0; 2356 2357 // List of dimensional and optional symbol identifiers. 2358 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 2359 return failure(); 2360 } 2361 2362 // This is needed for parsing attributes as we wouldn't know whether we would 2363 // be parsing an integer set attribute or an affine map attribute. 2364 bool isArrow = getToken().is(Token::arrow); 2365 bool isColon = getToken().is(Token::colon); 2366 if (!isArrow && !isColon) { 2367 return emitError("expected '->' or ':'"); 2368 } else if (isArrow) { 2369 parseToken(Token::arrow, "expected '->' or '['"); 2370 map = parseAffineMapRange(numDims, numSymbols); 2371 return map ? success() : failure(); 2372 } else if (parseToken(Token::colon, "expected ':' or '['")) { 2373 return failure(); 2374 } 2375 2376 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 2377 return success(); 2378 2379 return failure(); 2380 } 2381 2382 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 2383 ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) { 2384 if (parseToken(Token::l_square, "expected '['")) 2385 return failure(); 2386 2387 SmallVector<AffineExpr, 4> exprs; 2388 auto parseElt = [&]() -> ParseResult { 2389 auto elt = parseAffineExpr(); 2390 exprs.push_back(elt); 2391 return elt ? success() : failure(); 2392 }; 2393 2394 // Parse a multi-dimensional affine expression (a comma-separated list of 2395 // 1-d affine expressions); the list cannot be empty. Grammar: 2396 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2397 if (parseCommaSeparatedListUntil(Token::r_square, parseElt, 2398 /*allowEmptyList=*/true)) 2399 return failure(); 2400 // Parsed a valid affine map. 2401 if (exprs.empty()) 2402 map = AffineMap(); 2403 else 2404 map = builder.getAffineMap(numDimOperands, 2405 dimsAndSymbols.size() - numDimOperands, exprs); 2406 return success(); 2407 } 2408 2409 /// Parse the range and sizes affine map definition inline. 2410 /// 2411 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 2412 /// 2413 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2414 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 2415 unsigned numSymbols) { 2416 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 2417 2418 SmallVector<AffineExpr, 4> exprs; 2419 auto parseElt = [&]() -> ParseResult { 2420 auto elt = parseAffineExpr(); 2421 ParseResult res = elt ? success() : failure(); 2422 exprs.push_back(elt); 2423 return res; 2424 }; 2425 2426 // Parse a multi-dimensional affine expression (a comma-separated list of 2427 // 1-d affine expressions); the list cannot be empty. Grammar: 2428 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2429 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false)) 2430 return AffineMap(); 2431 2432 // Parsed a valid affine map. 2433 return builder.getAffineMap(numDims, numSymbols, exprs); 2434 } 2435 2436 /// Parse an affine constraint. 2437 /// affine-constraint ::= affine-expr `>=` `0` 2438 /// | affine-expr `==` `0` 2439 /// 2440 /// isEq is set to true if the parsed constraint is an equality, false if it 2441 /// is an inequality (greater than or equal). 2442 /// 2443 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 2444 AffineExpr expr = parseAffineExpr(); 2445 if (!expr) 2446 return nullptr; 2447 2448 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 2449 getToken().is(Token::integer)) { 2450 auto dim = getToken().getUnsignedIntegerValue(); 2451 if (dim.hasValue() && dim.getValue() == 0) { 2452 consumeToken(Token::integer); 2453 *isEq = false; 2454 return expr; 2455 } 2456 return (emitError("expected '0' after '>='"), nullptr); 2457 } 2458 2459 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 2460 getToken().is(Token::integer)) { 2461 auto dim = getToken().getUnsignedIntegerValue(); 2462 if (dim.hasValue() && dim.getValue() == 0) { 2463 consumeToken(Token::integer); 2464 *isEq = true; 2465 return expr; 2466 } 2467 return (emitError("expected '0' after '=='"), nullptr); 2468 } 2469 2470 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 2471 nullptr); 2472 } 2473 2474 /// Parse the constraints that are part of an integer set definition. 2475 /// integer-set-inline 2476 /// ::= dim-and-symbol-id-lists `:` 2477 /// '(' affine-constraint-conjunction? ')' 2478 /// affine-constraint-conjunction ::= affine-constraint (`,` 2479 /// affine-constraint)* 2480 /// 2481 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 2482 unsigned numSymbols) { 2483 if (parseToken(Token::l_paren, 2484 "expected '(' at start of integer set constraint list")) 2485 return IntegerSet(); 2486 2487 SmallVector<AffineExpr, 4> constraints; 2488 SmallVector<bool, 4> isEqs; 2489 auto parseElt = [&]() -> ParseResult { 2490 bool isEq; 2491 auto elt = parseAffineConstraint(&isEq); 2492 ParseResult res = elt ? success() : failure(); 2493 if (elt) { 2494 constraints.push_back(elt); 2495 isEqs.push_back(isEq); 2496 } 2497 return res; 2498 }; 2499 2500 // Parse a list of affine constraints (comma-separated). 2501 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 2502 return IntegerSet(); 2503 2504 // If no constraints were parsed, then treat this as a degenerate 'true' case. 2505 if (constraints.empty()) { 2506 /* 0 == 0 */ 2507 auto zero = getAffineConstantExpr(0, getContext()); 2508 return builder.getIntegerSet(numDims, numSymbols, zero, true); 2509 } 2510 2511 // Parsed a valid integer set. 2512 return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); 2513 } 2514 2515 /// Parse an ambiguous reference to either and affine map or an integer set. 2516 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 2517 IntegerSet &set) { 2518 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 2519 } 2520 2521 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 2522 /// parse SSA value uses encountered while parsing affine expressions. 2523 ParseResult Parser::parseAffineMapOfSSAIds( 2524 AffineMap &map, llvm::function_ref<ParseResult(bool)> parseElement) { 2525 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 2526 .parseAffineMapOfSSAIds(map); 2527 } 2528 2529 //===----------------------------------------------------------------------===// 2530 // OperationParser 2531 //===----------------------------------------------------------------------===// 2532 2533 namespace { 2534 /// This class provides support for parsing operations and regions of 2535 /// operations. 2536 class OperationParser : public Parser { 2537 public: 2538 OperationParser(ParserState &state, ModuleOp moduleOp) 2539 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 2540 } 2541 2542 ~OperationParser(); 2543 2544 /// After parsing is finished, this function must be called to see if there 2545 /// are any remaining issues. 2546 ParseResult finalize(); 2547 2548 //===--------------------------------------------------------------------===// 2549 // SSA Value Handling 2550 //===--------------------------------------------------------------------===// 2551 2552 /// This represents a use of an SSA value in the program. The first two 2553 /// entries in the tuple are the name and result number of a reference. The 2554 /// third is the location of the reference, which is used in case this ends 2555 /// up being a use of an undefined value. 2556 struct SSAUseInfo { 2557 StringRef name; // Value name, e.g. %42 or %abc 2558 unsigned number; // Number, specified with #12 2559 SMLoc loc; // Location of first definition or use. 2560 }; 2561 2562 /// Push a new SSA name scope to the parser. 2563 void pushSSANameScope(bool isIsolated); 2564 2565 /// Pop the last SSA name scope from the parser. 2566 ParseResult popSSANameScope(); 2567 2568 /// Register a definition of a value with the symbol table. 2569 ParseResult addDefinition(SSAUseInfo useInfo, Value *value); 2570 2571 /// Parse an optional list of SSA uses into 'results'. 2572 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 2573 2574 /// Parse a single SSA use into 'result'. 2575 ParseResult parseSSAUse(SSAUseInfo &result); 2576 2577 /// Given a reference to an SSA value and its type, return a reference. This 2578 /// returns null on failure. 2579 Value *resolveSSAUse(SSAUseInfo useInfo, Type type); 2580 2581 ParseResult parseSSADefOrUseAndType( 2582 const std::function<ParseResult(SSAUseInfo, Type)> &action); 2583 2584 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value *> &results); 2585 2586 /// Return the location of the value identified by its name and number if it 2587 /// has been already reference. 2588 llvm::Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 2589 auto &values = isolatedNameScopes.back().values; 2590 if (!values.count(name) || number >= values[name].size()) 2591 return {}; 2592 if (values[name][number].first) 2593 return values[name][number].second; 2594 return {}; 2595 } 2596 2597 //===--------------------------------------------------------------------===// 2598 // Operation Parsing 2599 //===--------------------------------------------------------------------===// 2600 2601 /// Parse an operation instance. 2602 ParseResult parseOperation(); 2603 2604 /// Parse a single operation successor and its operand list. 2605 ParseResult parseSuccessorAndUseList(Block *&dest, 2606 SmallVectorImpl<Value *> &operands); 2607 2608 /// Parse a comma-separated list of operation successors in brackets. 2609 ParseResult 2610 parseSuccessors(SmallVectorImpl<Block *> &destinations, 2611 SmallVectorImpl<SmallVector<Value *, 4>> &operands); 2612 2613 /// Parse an operation instance that is in the generic form. 2614 Operation *parseGenericOperation(); 2615 2616 /// Parse an operation instance that is in the generic form and insert it at 2617 /// the provided insertion point. 2618 Operation *parseGenericOperation(Block *insertBlock, 2619 Block::iterator insertPt); 2620 2621 /// Parse an operation instance that is in the op-defined custom form. 2622 Operation *parseCustomOperation(); 2623 2624 //===--------------------------------------------------------------------===// 2625 // Region Parsing 2626 //===--------------------------------------------------------------------===// 2627 2628 /// Parse a region into 'region' with the provided entry block arguments. 2629 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 2630 /// isolated from those above. 2631 ParseResult parseRegion(Region ®ion, 2632 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 2633 bool isIsolatedNameScope = false); 2634 2635 /// Parse a region body into 'region'. 2636 ParseResult parseRegionBody(Region ®ion); 2637 2638 //===--------------------------------------------------------------------===// 2639 // Block Parsing 2640 //===--------------------------------------------------------------------===// 2641 2642 /// Parse a new block into 'block'. 2643 ParseResult parseBlock(Block *&block); 2644 2645 /// Parse a list of operations into 'block'. 2646 ParseResult parseBlockBody(Block *block); 2647 2648 /// Parse a (possibly empty) list of block arguments. 2649 ParseResult 2650 parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results, 2651 Block *owner); 2652 2653 /// Get the block with the specified name, creating it if it doesn't 2654 /// already exist. The location specified is the point of use, which allows 2655 /// us to diagnose references to blocks that are not defined precisely. 2656 Block *getBlockNamed(StringRef name, SMLoc loc); 2657 2658 /// Define the block with the specified name. Returns the Block* or nullptr in 2659 /// the case of redefinition. 2660 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 2661 2662 private: 2663 /// Returns the info for a block at the current scope for the given name. 2664 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 2665 return blocksByName.back()[name]; 2666 } 2667 2668 /// Insert a new forward reference to the given block. 2669 void insertForwardRef(Block *block, SMLoc loc) { 2670 forwardRef.back().try_emplace(block, loc); 2671 } 2672 2673 /// Erase any forward reference to the given block. 2674 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 2675 2676 /// Record that a definition was added at the current scope. 2677 void recordDefinition(StringRef def); 2678 2679 /// Get the value entry for the given SSA name. 2680 SmallVectorImpl<std::pair<Value *, SMLoc>> &getSSAValueEntry(StringRef name); 2681 2682 /// Create a forward reference placeholder value with the given location and 2683 /// result type. 2684 Value *createForwardRefPlaceholder(SMLoc loc, Type type); 2685 2686 /// Return true if this is a forward reference. 2687 bool isForwardRefPlaceholder(Value *value) { 2688 return forwardRefPlaceholders.count(value); 2689 } 2690 2691 /// This struct represents an isolated SSA name scope. This scope may contain 2692 /// other nested non-isolated scopes. These scopes are used for operations 2693 /// that are known to be isolated to allow for reusing names within their 2694 /// regions, even if those names are used above. 2695 struct IsolatedSSANameScope { 2696 /// Record that a definition was added at the current scope. 2697 void recordDefinition(StringRef def) { 2698 definitionsPerScope.back().insert(def); 2699 } 2700 2701 /// Push a nested name scope. 2702 void pushSSANameScope() { definitionsPerScope.push_back({}); } 2703 2704 /// Pop a nested name scope. 2705 void popSSANameScope() { 2706 for (auto &def : definitionsPerScope.pop_back_val()) 2707 values.erase(def.getKey()); 2708 } 2709 2710 /// This keeps track of all of the SSA values we are tracking for each name 2711 /// scope, indexed by their name. This has one entry per result number. 2712 llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values; 2713 2714 /// This keeps track of all of the values defined by a specific name scope. 2715 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 2716 }; 2717 2718 /// A list of isolated name scopes. 2719 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 2720 2721 /// This keeps track of the block names as well as the location of the first 2722 /// reference for each nested name scope. This is used to diagnose invalid 2723 /// block references and memoize them. 2724 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 2725 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 2726 2727 /// These are all of the placeholders we've made along with the location of 2728 /// their first reference, to allow checking for use of undefined values. 2729 DenseMap<Value *, SMLoc> forwardRefPlaceholders; 2730 2731 /// The builder used when creating parsed operation instances. 2732 OpBuilder opBuilder; 2733 2734 /// The top level module operation. 2735 ModuleOp moduleOp; 2736 }; 2737 } // end anonymous namespace 2738 2739 OperationParser::~OperationParser() { 2740 for (auto &fwd : forwardRefPlaceholders) { 2741 // Drop all uses of undefined forward declared reference and destroy 2742 // defining operation. 2743 fwd.first->dropAllUses(); 2744 fwd.first->getDefiningOp()->destroy(); 2745 } 2746 } 2747 2748 /// After parsing is finished, this function must be called to see if there are 2749 /// any remaining issues. 2750 ParseResult OperationParser::finalize() { 2751 // Check for any forward references that are left. If we find any, error 2752 // out. 2753 if (!forwardRefPlaceholders.empty()) { 2754 SmallVector<std::pair<const char *, Value *>, 4> errors; 2755 // Iteration over the map isn't deterministic, so sort by source location. 2756 for (auto entry : forwardRefPlaceholders) 2757 errors.push_back({entry.second.getPointer(), entry.first}); 2758 llvm::array_pod_sort(errors.begin(), errors.end()); 2759 2760 for (auto entry : errors) { 2761 auto loc = SMLoc::getFromPointer(entry.first); 2762 emitError(loc, "use of undeclared SSA value name"); 2763 } 2764 return failure(); 2765 } 2766 2767 return success(); 2768 } 2769 2770 //===----------------------------------------------------------------------===// 2771 // SSA Value Handling 2772 //===----------------------------------------------------------------------===// 2773 2774 void OperationParser::pushSSANameScope(bool isIsolated) { 2775 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 2776 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 2777 2778 // Push back a new name definition scope. 2779 if (isIsolated) 2780 isolatedNameScopes.push_back({}); 2781 isolatedNameScopes.back().pushSSANameScope(); 2782 } 2783 2784 ParseResult OperationParser::popSSANameScope() { 2785 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 2786 2787 // Verify that all referenced blocks were defined. 2788 if (!forwardRefInCurrentScope.empty()) { 2789 SmallVector<std::pair<const char *, Block *>, 4> errors; 2790 // Iteration over the map isn't deterministic, so sort by source location. 2791 for (auto entry : forwardRefInCurrentScope) { 2792 errors.push_back({entry.second.getPointer(), entry.first}); 2793 // Add this block to the top-level region to allow for automatic cleanup. 2794 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 2795 } 2796 llvm::array_pod_sort(errors.begin(), errors.end()); 2797 2798 for (auto entry : errors) { 2799 auto loc = SMLoc::getFromPointer(entry.first); 2800 emitError(loc, "reference to an undefined block"); 2801 } 2802 return failure(); 2803 } 2804 2805 // Pop the next nested namescope. If there is only one internal namescope, 2806 // just pop the isolated scope. 2807 auto ¤tNameScope = isolatedNameScopes.back(); 2808 if (currentNameScope.definitionsPerScope.size() == 1) 2809 isolatedNameScopes.pop_back(); 2810 else 2811 currentNameScope.popSSANameScope(); 2812 2813 blocksByName.pop_back(); 2814 return success(); 2815 } 2816 2817 /// Register a definition of a value with the symbol table. 2818 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value *value) { 2819 auto &entries = getSSAValueEntry(useInfo.name); 2820 2821 // Make sure there is a slot for this value. 2822 if (entries.size() <= useInfo.number) 2823 entries.resize(useInfo.number + 1); 2824 2825 // If we already have an entry for this, check to see if it was a definition 2826 // or a forward reference. 2827 if (auto *existing = entries[useInfo.number].first) { 2828 if (!isForwardRefPlaceholder(existing)) { 2829 return emitError(useInfo.loc) 2830 .append("redefinition of SSA value '", useInfo.name, "'") 2831 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 2832 .append("previously defined here"); 2833 } 2834 2835 // If it was a forward reference, update everything that used it to use 2836 // the actual definition instead, delete the forward ref, and remove it 2837 // from our set of forward references we track. 2838 existing->replaceAllUsesWith(value); 2839 existing->getDefiningOp()->destroy(); 2840 forwardRefPlaceholders.erase(existing); 2841 } 2842 2843 /// Record this definition for the current scope. 2844 entries[useInfo.number] = {value, useInfo.loc}; 2845 recordDefinition(useInfo.name); 2846 return success(); 2847 } 2848 2849 /// Parse a (possibly empty) list of SSA operands. 2850 /// 2851 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 2852 /// ssa-use-list-opt ::= ssa-use-list? 2853 /// 2854 ParseResult 2855 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 2856 if (getToken().isNot(Token::percent_identifier)) 2857 return success(); 2858 return parseCommaSeparatedList([&]() -> ParseResult { 2859 SSAUseInfo result; 2860 if (parseSSAUse(result)) 2861 return failure(); 2862 results.push_back(result); 2863 return success(); 2864 }); 2865 } 2866 2867 /// Parse a SSA operand for an operation. 2868 /// 2869 /// ssa-use ::= ssa-id 2870 /// 2871 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 2872 result.name = getTokenSpelling(); 2873 result.number = 0; 2874 result.loc = getToken().getLoc(); 2875 if (parseToken(Token::percent_identifier, "expected SSA operand")) 2876 return failure(); 2877 2878 // If we have an attribute ID, it is a result number. 2879 if (getToken().is(Token::hash_identifier)) { 2880 if (auto value = getToken().getHashIdentifierNumber()) 2881 result.number = value.getValue(); 2882 else 2883 return emitError("invalid SSA value result number"); 2884 consumeToken(Token::hash_identifier); 2885 } 2886 2887 return success(); 2888 } 2889 2890 /// Given an unbound reference to an SSA value and its type, return the value 2891 /// it specifies. This returns null on failure. 2892 Value *OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 2893 auto &entries = getSSAValueEntry(useInfo.name); 2894 2895 // If we have already seen a value of this name, return it. 2896 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 2897 auto *result = entries[useInfo.number].first; 2898 // Check that the type matches the other uses. 2899 if (result->getType() == type) 2900 return result; 2901 2902 emitError(useInfo.loc, "use of value '") 2903 .append(useInfo.name, 2904 "' expects different type than prior uses: ", type, " vs ", 2905 result->getType()) 2906 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 2907 .append("prior use here"); 2908 return nullptr; 2909 } 2910 2911 // Make sure we have enough slots for this. 2912 if (entries.size() <= useInfo.number) 2913 entries.resize(useInfo.number + 1); 2914 2915 // If the value has already been defined and this is an overly large result 2916 // number, diagnose that. 2917 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 2918 return (emitError(useInfo.loc, "reference to invalid result number"), 2919 nullptr); 2920 2921 // Otherwise, this is a forward reference. Create a placeholder and remember 2922 // that we did so. 2923 auto *result = createForwardRefPlaceholder(useInfo.loc, type); 2924 entries[useInfo.number].first = result; 2925 entries[useInfo.number].second = useInfo.loc; 2926 return result; 2927 } 2928 2929 /// Parse an SSA use with an associated type. 2930 /// 2931 /// ssa-use-and-type ::= ssa-use `:` type 2932 ParseResult OperationParser::parseSSADefOrUseAndType( 2933 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 2934 SSAUseInfo useInfo; 2935 if (parseSSAUse(useInfo) || 2936 parseToken(Token::colon, "expected ':' and type for SSA operand")) 2937 return failure(); 2938 2939 auto type = parseType(); 2940 if (!type) 2941 return failure(); 2942 2943 return action(useInfo, type); 2944 } 2945 2946 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 2947 /// followed by a type list. 2948 /// 2949 /// ssa-use-and-type-list 2950 /// ::= ssa-use-list ':' type-list-no-parens 2951 /// 2952 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 2953 SmallVectorImpl<Value *> &results) { 2954 SmallVector<SSAUseInfo, 4> valueIDs; 2955 if (parseOptionalSSAUseList(valueIDs)) 2956 return failure(); 2957 2958 // If there were no operands, then there is no colon or type lists. 2959 if (valueIDs.empty()) 2960 return success(); 2961 2962 SmallVector<Type, 4> types; 2963 if (parseToken(Token::colon, "expected ':' in operand list") || 2964 parseTypeListNoParens(types)) 2965 return failure(); 2966 2967 if (valueIDs.size() != types.size()) 2968 return emitError("expected ") 2969 << valueIDs.size() << " types to match operand list"; 2970 2971 results.reserve(valueIDs.size()); 2972 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 2973 if (auto *value = resolveSSAUse(valueIDs[i], types[i])) 2974 results.push_back(value); 2975 else 2976 return failure(); 2977 } 2978 2979 return success(); 2980 } 2981 2982 /// Record that a definition was added at the current scope. 2983 void OperationParser::recordDefinition(StringRef def) { 2984 isolatedNameScopes.back().recordDefinition(def); 2985 } 2986 2987 /// Get the value entry for the given SSA name. 2988 SmallVectorImpl<std::pair<Value *, SMLoc>> & 2989 OperationParser::getSSAValueEntry(StringRef name) { 2990 return isolatedNameScopes.back().values[name]; 2991 } 2992 2993 /// Create and remember a new placeholder for a forward reference. 2994 Value *OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 2995 // Forward references are always created as operations, because we just need 2996 // something with a def/use chain. 2997 // 2998 // We create these placeholders as having an empty name, which we know 2999 // cannot be created through normal user input, allowing us to distinguish 3000 // them. 3001 auto name = OperationName("placeholder", getContext()); 3002 auto *op = Operation::create( 3003 getEncodedSourceLocation(loc), name, /*operands=*/{}, type, 3004 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, 3005 /*resizableOperandList=*/false); 3006 forwardRefPlaceholders[op->getResult(0)] = loc; 3007 return op->getResult(0); 3008 } 3009 3010 //===----------------------------------------------------------------------===// 3011 // Operation Parsing 3012 //===----------------------------------------------------------------------===// 3013 3014 /// Parse an operation. 3015 /// 3016 /// operation ::= 3017 /// operation-result? string '(' ssa-use-list? ')' attribute-dict? 3018 /// `:` function-type trailing-location? 3019 /// operation-result ::= ssa-id ((`:` integer-literal) | (`,` ssa-id)*) `=` 3020 /// 3021 ParseResult OperationParser::parseOperation() { 3022 auto loc = getToken().getLoc(); 3023 SmallVector<std::pair<StringRef, SMLoc>, 1> resultIDs; 3024 size_t numExpectedResults; 3025 if (getToken().is(Token::percent_identifier)) { 3026 // Parse the first result id. 3027 resultIDs.emplace_back(getTokenSpelling(), loc); 3028 consumeToken(Token::percent_identifier); 3029 3030 // If the next token is a ':', we parse the expected result count. 3031 if (consumeIf(Token::colon)) { 3032 // Check that the next token is an integer. 3033 if (!getToken().is(Token::integer)) 3034 return emitError("expected integer number of results"); 3035 3036 // Check that number of results is > 0. 3037 auto val = getToken().getUInt64IntegerValue(); 3038 if (!val.hasValue() || val.getValue() < 1) 3039 return emitError("expected named operation to have atleast 1 result"); 3040 consumeToken(Token::integer); 3041 numExpectedResults = *val; 3042 } else { 3043 // Otherwise, this is a comma separated list of result ids. 3044 if (consumeIf(Token::comma)) { 3045 auto parseNextResult = [&]() -> ParseResult { 3046 // Parse the next result id. 3047 if (!getToken().is(Token::percent_identifier)) 3048 return emitError("expected valid ssa identifier"); 3049 3050 resultIDs.emplace_back(getTokenSpelling(), getToken().getLoc()); 3051 consumeToken(Token::percent_identifier); 3052 return success(); 3053 }; 3054 3055 if (parseCommaSeparatedList(parseNextResult)) 3056 return failure(); 3057 } 3058 numExpectedResults = resultIDs.size(); 3059 } 3060 3061 if (parseToken(Token::equal, "expected '=' after SSA name")) 3062 return failure(); 3063 } 3064 3065 Operation *op; 3066 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3067 op = parseCustomOperation(); 3068 else if (getToken().is(Token::string)) 3069 op = parseGenericOperation(); 3070 else 3071 return emitError("expected operation name in quotes"); 3072 3073 // If parsing of the basic operation failed, then this whole thing fails. 3074 if (!op) 3075 return failure(); 3076 3077 // If the operation had a name, register it. 3078 if (!resultIDs.empty()) { 3079 if (op->getNumResults() == 0) 3080 return emitError(loc, "cannot name an operation with no results"); 3081 if (numExpectedResults != op->getNumResults()) 3082 return emitError(loc, "operation defines ") 3083 << op->getNumResults() << " results but was provided " 3084 << numExpectedResults << " to bind"; 3085 3086 // If the number of result names matches the number of operation results, we 3087 // can directly use the provided names. 3088 if (resultIDs.size() == op->getNumResults()) { 3089 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) 3090 if (addDefinition({resultIDs[i].first, 0, resultIDs[i].second}, 3091 op->getResult(i))) 3092 return failure(); 3093 } else { 3094 // Otherwise, we use the same name for all results. 3095 StringRef name = resultIDs.front().first; 3096 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) 3097 if (addDefinition({name, i, loc}, op->getResult(i))) 3098 return failure(); 3099 } 3100 } 3101 3102 // Try to parse the optional trailing location. 3103 if (parseOptionalTrailingLocation(op)) 3104 return failure(); 3105 3106 return success(); 3107 } 3108 3109 /// Parse a single operation successor and its operand list. 3110 /// 3111 /// successor ::= block-id branch-use-list? 3112 /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` 3113 /// 3114 ParseResult 3115 OperationParser::parseSuccessorAndUseList(Block *&dest, 3116 SmallVectorImpl<Value *> &operands) { 3117 // Verify branch is identifier and get the matching block. 3118 if (!getToken().is(Token::caret_identifier)) 3119 return emitError("expected block name"); 3120 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3121 consumeToken(); 3122 3123 // Handle optional arguments. 3124 if (consumeIf(Token::l_paren) && 3125 (parseOptionalSSAUseAndTypeList(operands) || 3126 parseToken(Token::r_paren, "expected ')' to close argument list"))) { 3127 return failure(); 3128 } 3129 3130 return success(); 3131 } 3132 3133 /// Parse a comma-separated list of operation successors in brackets. 3134 /// 3135 /// successor-list ::= `[` successor (`,` successor )* `]` 3136 /// 3137 ParseResult OperationParser::parseSuccessors( 3138 SmallVectorImpl<Block *> &destinations, 3139 SmallVectorImpl<SmallVector<Value *, 4>> &operands) { 3140 if (parseToken(Token::l_square, "expected '['")) 3141 return failure(); 3142 3143 auto parseElt = [this, &destinations, &operands]() { 3144 Block *dest; 3145 SmallVector<Value *, 4> destOperands; 3146 auto res = parseSuccessorAndUseList(dest, destOperands); 3147 destinations.push_back(dest); 3148 operands.push_back(destOperands); 3149 return res; 3150 }; 3151 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3152 /*allowEmptyList=*/false); 3153 } 3154 3155 namespace { 3156 // RAII-style guard for cleaning up the regions in the operation state before 3157 // deleting them. Within the parser, regions may get deleted if parsing failed, 3158 // and other errors may be present, in praticular undominated uses. This makes 3159 // sure such uses are deleted. 3160 struct CleanupOpStateRegions { 3161 ~CleanupOpStateRegions() { 3162 SmallVector<Region *, 4> regionsToClean; 3163 regionsToClean.reserve(state.regions.size()); 3164 for (auto ®ion : state.regions) 3165 if (region) 3166 for (auto &block : *region) 3167 block.dropAllDefinedValueUses(); 3168 } 3169 OperationState &state; 3170 }; 3171 } // namespace 3172 3173 Operation *OperationParser::parseGenericOperation() { 3174 // Get location information for the operation. 3175 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3176 3177 auto name = getToken().getStringValue(); 3178 if (name.empty()) 3179 return (emitError("empty operation name is invalid"), nullptr); 3180 if (name.find('\0') != StringRef::npos) 3181 return (emitError("null character not allowed in operation name"), nullptr); 3182 3183 consumeToken(Token::string); 3184 3185 OperationState result(srcLocation, name); 3186 3187 // Generic operations have a resizable operation list. 3188 result.setOperandListToResizable(); 3189 3190 // Parse the operand list. 3191 SmallVector<SSAUseInfo, 8> operandInfos; 3192 3193 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3194 parseOptionalSSAUseList(operandInfos) || 3195 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3196 return nullptr; 3197 } 3198 3199 // Parse the successor list but don't add successors to the result yet to 3200 // avoid messing up with the argument order. 3201 SmallVector<Block *, 2> successors; 3202 SmallVector<SmallVector<Value *, 4>, 2> successorOperands; 3203 if (getToken().is(Token::l_square)) { 3204 // Check if the operation is a known terminator. 3205 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3206 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3207 return emitError("successors in non-terminator"), nullptr; 3208 if (parseSuccessors(successors, successorOperands)) 3209 return nullptr; 3210 } 3211 3212 // Parse the region list. 3213 CleanupOpStateRegions guard{result}; 3214 if (consumeIf(Token::l_paren)) { 3215 do { 3216 // Create temporary regions with the top level region as parent. 3217 result.regions.emplace_back(new Region(moduleOp)); 3218 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3219 return nullptr; 3220 } while (consumeIf(Token::comma)); 3221 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3222 return nullptr; 3223 } 3224 3225 if (getToken().is(Token::l_brace)) { 3226 if (parseAttributeDict(result.attributes)) 3227 return nullptr; 3228 } 3229 3230 if (parseToken(Token::colon, "expected ':' followed by operation type")) 3231 return nullptr; 3232 3233 auto typeLoc = getToken().getLoc(); 3234 auto type = parseType(); 3235 if (!type) 3236 return nullptr; 3237 auto fnType = type.dyn_cast<FunctionType>(); 3238 if (!fnType) 3239 return (emitError(typeLoc, "expected function type"), nullptr); 3240 3241 result.addTypes(fnType.getResults()); 3242 3243 // Check that we have the right number of types for the operands. 3244 auto operandTypes = fnType.getInputs(); 3245 if (operandTypes.size() != operandInfos.size()) { 3246 auto plural = "s"[operandInfos.size() == 1]; 3247 return (emitError(typeLoc, "expected ") 3248 << operandInfos.size() << " operand type" << plural 3249 << " but had " << operandTypes.size(), 3250 nullptr); 3251 } 3252 3253 // Resolve all of the operands. 3254 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 3255 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 3256 if (!result.operands.back()) 3257 return nullptr; 3258 } 3259 3260 // Add the sucessors, and their operands after the proper operands. 3261 for (const auto &succ : llvm::zip(successors, successorOperands)) { 3262 Block *successor = std::get<0>(succ); 3263 const SmallVector<Value *, 4> &operands = std::get<1>(succ); 3264 result.addSuccessor(successor, operands); 3265 } 3266 3267 return opBuilder.createOperation(result); 3268 } 3269 3270 Operation *OperationParser::parseGenericOperation(Block *insertBlock, 3271 Block::iterator insertPt) { 3272 OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); 3273 opBuilder.setInsertionPoint(insertBlock, insertPt); 3274 return parseGenericOperation(); 3275 } 3276 3277 namespace { 3278 class CustomOpAsmParser : public OpAsmParser { 3279 public: 3280 CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, 3281 OperationParser &parser) 3282 : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} 3283 3284 /// Parse an instance of the operation described by 'opDefinition' into the 3285 /// provided operation state. 3286 ParseResult parseOperation(OperationState &opState) { 3287 if (opDefinition->parseAssembly(*this, opState)) 3288 return failure(); 3289 return success(); 3290 } 3291 3292 Operation *parseGenericOperation(Block *insertBlock, 3293 Block::iterator insertPt) final { 3294 return parser.parseGenericOperation(insertBlock, insertPt); 3295 } 3296 3297 //===--------------------------------------------------------------------===// 3298 // Utilities 3299 //===--------------------------------------------------------------------===// 3300 3301 /// Return if any errors were emitted during parsing. 3302 bool didEmitError() const { return emittedError; } 3303 3304 /// Emit a diagnostic at the specified location and return failure. 3305 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 3306 emittedError = true; 3307 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 3308 message); 3309 } 3310 3311 llvm::SMLoc getCurrentLocation() override { 3312 return parser.getToken().getLoc(); 3313 } 3314 3315 Builder &getBuilder() const override { return parser.builder; } 3316 3317 llvm::SMLoc getNameLoc() const override { return nameLoc; } 3318 3319 //===--------------------------------------------------------------------===// 3320 // Token Parsing 3321 //===--------------------------------------------------------------------===// 3322 3323 /// Parse a `->` token. 3324 ParseResult parseArrow() override { 3325 return parser.parseToken(Token::arrow, "expected '->'"); 3326 } 3327 3328 /// Parses a `->` if present. 3329 ParseResult parseOptionalArrow() override { 3330 return success(parser.consumeIf(Token::arrow)); 3331 } 3332 3333 /// Parse a `:` token. 3334 ParseResult parseColon() override { 3335 return parser.parseToken(Token::colon, "expected ':'"); 3336 } 3337 3338 /// Parse a `:` token if present. 3339 ParseResult parseOptionalColon() override { 3340 return success(parser.consumeIf(Token::colon)); 3341 } 3342 3343 /// Parse a `,` token. 3344 ParseResult parseComma() override { 3345 return parser.parseToken(Token::comma, "expected ','"); 3346 } 3347 3348 /// Parse a `,` token if present. 3349 ParseResult parseOptionalComma() override { 3350 return success(parser.consumeIf(Token::comma)); 3351 } 3352 3353 /// Parses a `...` if present. 3354 ParseResult parseOptionalEllipsis() override { 3355 return success(parser.consumeIf(Token::ellipsis)); 3356 } 3357 3358 /// Parse a `=` token. 3359 ParseResult parseEqual() override { 3360 return parser.parseToken(Token::equal, "expected '='"); 3361 } 3362 3363 /// Parse a `(` token. 3364 ParseResult parseLParen() override { 3365 return parser.parseToken(Token::l_paren, "expected '('"); 3366 } 3367 3368 /// Parses a '(' if present. 3369 ParseResult parseOptionalLParen() override { 3370 return success(parser.consumeIf(Token::l_paren)); 3371 } 3372 3373 /// Parse a `)` token. 3374 ParseResult parseRParen() override { 3375 return parser.parseToken(Token::r_paren, "expected ')'"); 3376 } 3377 3378 /// Parses a ')' if present. 3379 ParseResult parseOptionalRParen() override { 3380 return success(parser.consumeIf(Token::r_paren)); 3381 } 3382 3383 /// Parse a `[` token. 3384 ParseResult parseLSquare() override { 3385 return parser.parseToken(Token::l_square, "expected '['"); 3386 } 3387 3388 /// Parses a '[' if present. 3389 ParseResult parseOptionalLSquare() override { 3390 return success(parser.consumeIf(Token::l_square)); 3391 } 3392 3393 /// Parse a `]` token. 3394 ParseResult parseRSquare() override { 3395 return parser.parseToken(Token::r_square, "expected ']'"); 3396 } 3397 3398 /// Parses a ']' if present. 3399 ParseResult parseOptionalRSquare() override { 3400 return success(parser.consumeIf(Token::r_square)); 3401 } 3402 3403 //===--------------------------------------------------------------------===// 3404 // Attribute Parsing 3405 //===--------------------------------------------------------------------===// 3406 3407 /// Parse an arbitrary attribute of a given type and return it in result. This 3408 /// also adds the attribute to the specified attribute list with the specified 3409 /// name. 3410 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 3411 SmallVectorImpl<NamedAttribute> &attrs) override { 3412 result = parser.parseAttribute(type); 3413 if (!result) 3414 return failure(); 3415 3416 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 3417 return success(); 3418 } 3419 3420 /// Parse a named dictionary into 'result' if it is present. 3421 ParseResult 3422 parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) override { 3423 if (parser.getToken().isNot(Token::l_brace)) 3424 return success(); 3425 return parser.parseAttributeDict(result); 3426 } 3427 3428 //===--------------------------------------------------------------------===// 3429 // Identifier Parsing 3430 //===--------------------------------------------------------------------===// 3431 3432 /// Returns if the current token corresponds to a keyword. 3433 bool isCurrentTokenAKeyword() const { 3434 return parser.getToken().is(Token::bare_identifier) || 3435 parser.getToken().isKeyword(); 3436 } 3437 3438 /// Parse the given keyword if present. 3439 ParseResult parseOptionalKeyword(StringRef keyword) override { 3440 // Check that the current token has the same spelling. 3441 if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) 3442 return failure(); 3443 parser.consumeToken(); 3444 return success(); 3445 } 3446 3447 /// Parse a keyword, if present, into 'keyword'. 3448 ParseResult parseOptionalKeyword(StringRef *keyword) override { 3449 // Check that the current token is a keyword. 3450 if (!isCurrentTokenAKeyword()) 3451 return failure(); 3452 3453 *keyword = parser.getTokenSpelling(); 3454 parser.consumeToken(); 3455 return success(); 3456 } 3457 3458 /// Parse an @-identifier and store it (without the '@' symbol) in a string 3459 /// attribute named 'attrName'. 3460 ParseResult parseSymbolName(StringAttr &result, StringRef attrName, 3461 SmallVectorImpl<NamedAttribute> &attrs) override { 3462 if (parser.getToken().isNot(Token::at_identifier)) 3463 return failure(); 3464 result = getBuilder().getStringAttr(parser.getTokenSpelling().drop_front()); 3465 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 3466 parser.consumeToken(); 3467 return success(); 3468 } 3469 3470 //===--------------------------------------------------------------------===// 3471 // Operand Parsing 3472 //===--------------------------------------------------------------------===// 3473 3474 /// Parse a single operand. 3475 ParseResult parseOperand(OperandType &result) override { 3476 OperationParser::SSAUseInfo useInfo; 3477 if (parser.parseSSAUse(useInfo)) 3478 return failure(); 3479 3480 result = {useInfo.loc, useInfo.name, useInfo.number}; 3481 return success(); 3482 } 3483 3484 /// Parse zero or more SSA comma-separated operand references with a specified 3485 /// surrounding delimiter, and an optional required operand count. 3486 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 3487 int requiredOperandCount = -1, 3488 Delimiter delimiter = Delimiter::None) override { 3489 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 3490 requiredOperandCount, delimiter); 3491 } 3492 3493 /// Parse zero or more SSA comma-separated operand or region arguments with 3494 /// optional surrounding delimiter and required operand count. 3495 ParseResult 3496 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 3497 bool isOperandList, int requiredOperandCount = -1, 3498 Delimiter delimiter = Delimiter::None) { 3499 auto startLoc = parser.getToken().getLoc(); 3500 3501 // Handle delimiters. 3502 switch (delimiter) { 3503 case Delimiter::None: 3504 // Don't check for the absence of a delimiter if the number of operands 3505 // is unknown (and hence the operand list could be empty). 3506 if (requiredOperandCount == -1) 3507 break; 3508 // Token already matches an identifier and so can't be a delimiter. 3509 if (parser.getToken().is(Token::percent_identifier)) 3510 break; 3511 // Test against known delimiters. 3512 if (parser.getToken().is(Token::l_paren) || 3513 parser.getToken().is(Token::l_square)) 3514 return emitError(startLoc, "unexpected delimiter"); 3515 return emitError(startLoc, "invalid operand"); 3516 case Delimiter::OptionalParen: 3517 if (parser.getToken().isNot(Token::l_paren)) 3518 return success(); 3519 LLVM_FALLTHROUGH; 3520 case Delimiter::Paren: 3521 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 3522 return failure(); 3523 break; 3524 case Delimiter::OptionalSquare: 3525 if (parser.getToken().isNot(Token::l_square)) 3526 return success(); 3527 LLVM_FALLTHROUGH; 3528 case Delimiter::Square: 3529 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 3530 return failure(); 3531 break; 3532 } 3533 3534 // Check for zero operands. 3535 if (parser.getToken().is(Token::percent_identifier)) { 3536 do { 3537 OperandType operandOrArg; 3538 if (isOperandList ? parseOperand(operandOrArg) 3539 : parseRegionArgument(operandOrArg)) 3540 return failure(); 3541 result.push_back(operandOrArg); 3542 } while (parser.consumeIf(Token::comma)); 3543 } 3544 3545 // Handle delimiters. If we reach here, the optional delimiters were 3546 // present, so we need to parse their closing one. 3547 switch (delimiter) { 3548 case Delimiter::None: 3549 break; 3550 case Delimiter::OptionalParen: 3551 case Delimiter::Paren: 3552 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 3553 return failure(); 3554 break; 3555 case Delimiter::OptionalSquare: 3556 case Delimiter::Square: 3557 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 3558 return failure(); 3559 break; 3560 } 3561 3562 if (requiredOperandCount != -1 && 3563 result.size() != static_cast<size_t>(requiredOperandCount)) 3564 return emitError(startLoc, "expected ") 3565 << requiredOperandCount << " operands"; 3566 return success(); 3567 } 3568 3569 /// Parse zero or more trailing SSA comma-separated trailing operand 3570 /// references with a specified surrounding delimiter, and an optional 3571 /// required operand count. A leading comma is expected before the operands. 3572 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 3573 int requiredOperandCount, 3574 Delimiter delimiter) override { 3575 if (parser.getToken().is(Token::comma)) { 3576 parseComma(); 3577 return parseOperandList(result, requiredOperandCount, delimiter); 3578 } 3579 if (requiredOperandCount != -1) 3580 return emitError(parser.getToken().getLoc(), "expected ") 3581 << requiredOperandCount << " operands"; 3582 return success(); 3583 } 3584 3585 /// Resolve an operand to an SSA value, emitting an error on failure. 3586 ParseResult resolveOperand(const OperandType &operand, Type type, 3587 SmallVectorImpl<Value *> &result) override { 3588 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 3589 operand.location}; 3590 if (auto *value = parser.resolveSSAUse(operandInfo, type)) { 3591 result.push_back(value); 3592 return success(); 3593 } 3594 return failure(); 3595 } 3596 3597 /// Parse an AffineMap of SSA ids. 3598 ParseResult 3599 parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 3600 Attribute &mapAttr, StringRef attrName, 3601 SmallVectorImpl<NamedAttribute> &attrs) override { 3602 SmallVector<OperandType, 2> dimOperands; 3603 SmallVector<OperandType, 1> symOperands; 3604 3605 auto parseElement = [&](bool isSymbol) -> ParseResult { 3606 OperandType operand; 3607 if (parseOperand(operand)) 3608 return failure(); 3609 if (isSymbol) 3610 symOperands.push_back(operand); 3611 else 3612 dimOperands.push_back(operand); 3613 return success(); 3614 }; 3615 3616 AffineMap map; 3617 if (parser.parseAffineMapOfSSAIds(map, parseElement)) 3618 return failure(); 3619 // Add AffineMap attribute. 3620 if (map) { 3621 mapAttr = parser.builder.getAffineMapAttr(map); 3622 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 3623 } 3624 3625 // Add dim operands before symbol operands in 'operands'. 3626 operands.assign(dimOperands.begin(), dimOperands.end()); 3627 operands.append(symOperands.begin(), symOperands.end()); 3628 return success(); 3629 } 3630 3631 //===--------------------------------------------------------------------===// 3632 // Region Parsing 3633 //===--------------------------------------------------------------------===// 3634 3635 /// Parse a region that takes `arguments` of `argTypes` types. This 3636 /// effectively defines the SSA values of `arguments` and assignes their type. 3637 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 3638 ArrayRef<Type> argTypes, 3639 bool enableNameShadowing) override { 3640 assert(arguments.size() == argTypes.size() && 3641 "mismatching number of arguments and types"); 3642 3643 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 3644 regionArguments; 3645 for (const auto &pair : llvm::zip(arguments, argTypes)) { 3646 const OperandType &operand = std::get<0>(pair); 3647 Type type = std::get<1>(pair); 3648 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 3649 operand.location}; 3650 regionArguments.emplace_back(operandInfo, type); 3651 } 3652 3653 // Try to parse the region. 3654 assert((!enableNameShadowing || 3655 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 3656 "name shadowing is only allowed on isolated regions"); 3657 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 3658 return failure(); 3659 return success(); 3660 } 3661 3662 /// Parses a region if present. 3663 ParseResult parseOptionalRegion(Region ®ion, 3664 ArrayRef<OperandType> arguments, 3665 ArrayRef<Type> argTypes, 3666 bool enableNameShadowing) override { 3667 if (parser.getToken().isNot(Token::l_brace)) 3668 return success(); 3669 return parseRegion(region, arguments, argTypes, enableNameShadowing); 3670 } 3671 3672 /// Parse a region argument. The type of the argument will be resolved later 3673 /// by a call to `parseRegion`. 3674 ParseResult parseRegionArgument(OperandType &argument) override { 3675 return parseOperand(argument); 3676 } 3677 3678 /// Parse a region argument if present. 3679 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 3680 if (parser.getToken().isNot(Token::percent_identifier)) 3681 return success(); 3682 return parseRegionArgument(argument); 3683 } 3684 3685 ParseResult 3686 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 3687 int requiredOperandCount = -1, 3688 Delimiter delimiter = Delimiter::None) override { 3689 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 3690 requiredOperandCount, delimiter); 3691 } 3692 3693 //===--------------------------------------------------------------------===// 3694 // Successor Parsing 3695 //===--------------------------------------------------------------------===// 3696 3697 /// Parse a single operation successor and its operand list. 3698 ParseResult 3699 parseSuccessorAndUseList(Block *&dest, 3700 SmallVectorImpl<Value *> &operands) override { 3701 return parser.parseSuccessorAndUseList(dest, operands); 3702 } 3703 3704 //===--------------------------------------------------------------------===// 3705 // Type Parsing 3706 //===--------------------------------------------------------------------===// 3707 3708 /// Parse a type. 3709 ParseResult parseType(Type &result) override { 3710 return failure(!(result = parser.parseType())); 3711 } 3712 3713 /// Parse an optional arrow followed by a type list. 3714 ParseResult 3715 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 3716 if (!parser.consumeIf(Token::arrow)) 3717 return success(); 3718 return parser.parseFunctionResultTypes(result); 3719 } 3720 3721 /// Parse a colon followed by a type. 3722 ParseResult parseColonType(Type &result) override { 3723 return failure(parser.parseToken(Token::colon, "expected ':'") || 3724 !(result = parser.parseType())); 3725 } 3726 3727 /// Parse a colon followed by a type list, which must have at least one type. 3728 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 3729 if (parser.parseToken(Token::colon, "expected ':'")) 3730 return failure(); 3731 return parser.parseTypeListNoParens(result); 3732 } 3733 3734 /// Parse an optional colon followed by a type list, which if present must 3735 /// have at least one type. 3736 ParseResult 3737 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 3738 if (!parser.consumeIf(Token::colon)) 3739 return success(); 3740 return parser.parseTypeListNoParens(result); 3741 } 3742 3743 private: 3744 /// The source location of the operation name. 3745 SMLoc nameLoc; 3746 3747 /// The abstract information of the operation. 3748 const AbstractOperation *opDefinition; 3749 3750 /// The main operation parser. 3751 OperationParser &parser; 3752 3753 /// A flag that indicates if any errors were emitted during parsing. 3754 bool emittedError = false; 3755 }; 3756 } // end anonymous namespace. 3757 3758 Operation *OperationParser::parseCustomOperation() { 3759 auto opLoc = getToken().getLoc(); 3760 auto opName = getTokenSpelling(); 3761 3762 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 3763 if (!opDefinition && !opName.contains('.')) { 3764 // If the operation name has no namespace prefix we treat it as a standard 3765 // operation and prefix it with "std". 3766 // TODO: Would it be better to just build a mapping of the registered 3767 // operations in the standard dialect? 3768 opDefinition = 3769 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 3770 } 3771 3772 if (!opDefinition) { 3773 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 3774 return nullptr; 3775 } 3776 3777 consumeToken(); 3778 3779 // If the custom op parser crashes, produce some indication to help 3780 // debugging. 3781 std::string opNameStr = opName.str(); 3782 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 3783 opNameStr.c_str()); 3784 3785 // Get location information for the operation. 3786 auto srcLocation = getEncodedSourceLocation(opLoc); 3787 3788 // Have the op implementation take a crack and parsing this. 3789 OperationState opState(srcLocation, opDefinition->name); 3790 CleanupOpStateRegions guard{opState}; 3791 CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); 3792 if (opAsmParser.parseOperation(opState)) 3793 return nullptr; 3794 3795 // If it emitted an error, we failed. 3796 if (opAsmParser.didEmitError()) 3797 return nullptr; 3798 3799 // Otherwise, we succeeded. Use the state it parsed as our op information. 3800 return opBuilder.createOperation(opState); 3801 } 3802 3803 //===----------------------------------------------------------------------===// 3804 // Region Parsing 3805 //===----------------------------------------------------------------------===// 3806 3807 /// Region. 3808 /// 3809 /// region ::= '{' region-body 3810 /// 3811 ParseResult OperationParser::parseRegion( 3812 Region ®ion, 3813 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 3814 bool isIsolatedNameScope) { 3815 // Parse the '{'. 3816 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 3817 return failure(); 3818 3819 // Check for an empty region. 3820 if (entryArguments.empty() && consumeIf(Token::r_brace)) 3821 return success(); 3822 auto currentPt = opBuilder.saveInsertionPoint(); 3823 3824 // Push a new named value scope. 3825 pushSSANameScope(isIsolatedNameScope); 3826 3827 // Parse the first block directly to allow for it to be unnamed. 3828 Block *block = new Block(); 3829 3830 // Add arguments to the entry block. 3831 if (!entryArguments.empty()) { 3832 for (auto &placeholderArgPair : entryArguments) { 3833 auto &argInfo = placeholderArgPair.first; 3834 // Ensure that the argument was not already defined. 3835 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 3836 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 3837 "' is already in use") 3838 .attachNote(getEncodedSourceLocation(*defLoc)) 3839 << "previously referenced here"; 3840 } 3841 if (addDefinition(placeholderArgPair.first, 3842 block->addArgument(placeholderArgPair.second))) { 3843 delete block; 3844 return failure(); 3845 } 3846 } 3847 3848 // If we had named arguments, then don't allow a block name. 3849 if (getToken().is(Token::caret_identifier)) 3850 return emitError("invalid block name in region with named arguments"); 3851 } 3852 3853 if (parseBlock(block)) { 3854 delete block; 3855 return failure(); 3856 } 3857 3858 // Verify that no other arguments were parsed. 3859 if (!entryArguments.empty() && 3860 block->getNumArguments() > entryArguments.size()) { 3861 delete block; 3862 return emitError("entry block arguments were already defined"); 3863 } 3864 3865 // Parse the rest of the region. 3866 region.push_back(block); 3867 if (parseRegionBody(region)) 3868 return failure(); 3869 3870 // Pop the SSA value scope for this region. 3871 if (popSSANameScope()) 3872 return failure(); 3873 3874 // Reset the original insertion point. 3875 opBuilder.restoreInsertionPoint(currentPt); 3876 return success(); 3877 } 3878 3879 /// Region. 3880 /// 3881 /// region-body ::= block* '}' 3882 /// 3883 ParseResult OperationParser::parseRegionBody(Region ®ion) { 3884 // Parse the list of blocks. 3885 while (!consumeIf(Token::r_brace)) { 3886 Block *newBlock = nullptr; 3887 if (parseBlock(newBlock)) 3888 return failure(); 3889 region.push_back(newBlock); 3890 } 3891 return success(); 3892 } 3893 3894 //===----------------------------------------------------------------------===// 3895 // Block Parsing 3896 //===----------------------------------------------------------------------===// 3897 3898 /// Block declaration. 3899 /// 3900 /// block ::= block-label? operation* 3901 /// block-label ::= block-id block-arg-list? `:` 3902 /// block-id ::= caret-id 3903 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 3904 /// 3905 ParseResult OperationParser::parseBlock(Block *&block) { 3906 // The first block of a region may already exist, if it does the caret 3907 // identifier is optional. 3908 if (block && getToken().isNot(Token::caret_identifier)) 3909 return parseBlockBody(block); 3910 3911 SMLoc nameLoc = getToken().getLoc(); 3912 auto name = getTokenSpelling(); 3913 if (parseToken(Token::caret_identifier, "expected block name")) 3914 return failure(); 3915 3916 block = defineBlockNamed(name, nameLoc, block); 3917 3918 // Fail if the block was already defined. 3919 if (!block) 3920 return emitError(nameLoc, "redefinition of block '") << name << "'"; 3921 3922 // If an argument list is present, parse it. 3923 if (consumeIf(Token::l_paren)) { 3924 SmallVector<BlockArgument *, 8> bbArgs; 3925 if (parseOptionalBlockArgList(bbArgs, block) || 3926 parseToken(Token::r_paren, "expected ')' to end argument list")) 3927 return failure(); 3928 } 3929 3930 if (parseToken(Token::colon, "expected ':' after block name")) 3931 return failure(); 3932 3933 return parseBlockBody(block); 3934 } 3935 3936 ParseResult OperationParser::parseBlockBody(Block *block) { 3937 // Set the insertion point to the end of the block to parse. 3938 opBuilder.setInsertionPointToEnd(block); 3939 3940 // Parse the list of operations that make up the body of the block. 3941 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 3942 if (parseOperation()) 3943 return failure(); 3944 3945 return success(); 3946 } 3947 3948 /// Get the block with the specified name, creating it if it doesn't already 3949 /// exist. The location specified is the point of use, which allows 3950 /// us to diagnose references to blocks that are not defined precisely. 3951 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 3952 auto &blockAndLoc = getBlockInfoByName(name); 3953 if (!blockAndLoc.first) { 3954 blockAndLoc = {new Block(), loc}; 3955 insertForwardRef(blockAndLoc.first, loc); 3956 } 3957 3958 return blockAndLoc.first; 3959 } 3960 3961 /// Define the block with the specified name. Returns the Block* or nullptr in 3962 /// the case of redefinition. 3963 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 3964 Block *existing) { 3965 auto &blockAndLoc = getBlockInfoByName(name); 3966 if (!blockAndLoc.first) { 3967 // If the caller provided a block, use it. Otherwise create a new one. 3968 if (!existing) 3969 existing = new Block(); 3970 blockAndLoc.first = existing; 3971 blockAndLoc.second = loc; 3972 return blockAndLoc.first; 3973 } 3974 3975 // Forward declarations are removed once defined, so if we are defining a 3976 // existing block and it is not a forward declaration, then it is a 3977 // redeclaration. 3978 if (!eraseForwardRef(blockAndLoc.first)) 3979 return nullptr; 3980 return blockAndLoc.first; 3981 } 3982 3983 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 3984 /// 3985 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 3986 /// 3987 ParseResult OperationParser::parseOptionalBlockArgList( 3988 SmallVectorImpl<BlockArgument *> &results, Block *owner) { 3989 if (getToken().is(Token::r_brace)) 3990 return success(); 3991 3992 // If the block already has arguments, then we're handling the entry block. 3993 // Parse and register the names for the arguments, but do not add them. 3994 bool definingExistingArgs = owner->getNumArguments() != 0; 3995 unsigned nextArgument = 0; 3996 3997 return parseCommaSeparatedList([&]() -> ParseResult { 3998 return parseSSADefOrUseAndType( 3999 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 4000 // If this block did not have existing arguments, define a new one. 4001 if (!definingExistingArgs) 4002 return addDefinition(useInfo, owner->addArgument(type)); 4003 4004 // Otherwise, ensure that this argument has already been created. 4005 if (nextArgument >= owner->getNumArguments()) 4006 return emitError("too many arguments specified in argument list"); 4007 4008 // Finally, make sure the existing argument has the correct type. 4009 auto *arg = owner->getArgument(nextArgument++); 4010 if (arg->getType() != type) 4011 return emitError("argument and block argument type mismatch"); 4012 return addDefinition(useInfo, arg); 4013 }); 4014 }); 4015 } 4016 4017 //===----------------------------------------------------------------------===// 4018 // Top-level entity parsing. 4019 //===----------------------------------------------------------------------===// 4020 4021 namespace { 4022 /// This parser handles entities that are only valid at the top level of the 4023 /// file. 4024 class ModuleParser : public Parser { 4025 public: 4026 explicit ModuleParser(ParserState &state) : Parser(state) {} 4027 4028 ParseResult parseModule(ModuleOp module); 4029 4030 private: 4031 /// Parse an attribute alias declaration. 4032 ParseResult parseAttributeAliasDef(); 4033 4034 /// Parse an attribute alias declaration. 4035 ParseResult parseTypeAliasDef(); 4036 }; 4037 } // end anonymous namespace 4038 4039 /// Parses an attribute alias declaration. 4040 /// 4041 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 4042 /// 4043 ParseResult ModuleParser::parseAttributeAliasDef() { 4044 assert(getToken().is(Token::hash_identifier)); 4045 StringRef aliasName = getTokenSpelling().drop_front(); 4046 4047 // Check for redefinitions. 4048 if (getState().attributeAliasDefinitions.count(aliasName) > 0) 4049 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 4050 4051 // Make sure this isn't invading the dialect attribute namespace. 4052 if (aliasName.contains('.')) 4053 return emitError("attribute names with a '.' are reserved for " 4054 "dialect-defined names"); 4055 4056 consumeToken(Token::hash_identifier); 4057 4058 // Parse the '='. 4059 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 4060 return failure(); 4061 4062 // Parse the attribute value. 4063 Attribute attr = parseAttribute(); 4064 if (!attr) 4065 return failure(); 4066 4067 getState().attributeAliasDefinitions[aliasName] = attr; 4068 return success(); 4069 } 4070 4071 /// Parse a type alias declaration. 4072 /// 4073 /// type-alias-def ::= '!' alias-name `=` 'type' type 4074 /// 4075 ParseResult ModuleParser::parseTypeAliasDef() { 4076 assert(getToken().is(Token::exclamation_identifier)); 4077 StringRef aliasName = getTokenSpelling().drop_front(); 4078 4079 // Check for redefinitions. 4080 if (getState().typeAliasDefinitions.count(aliasName) > 0) 4081 return emitError("redefinition of type alias id '" + aliasName + "'"); 4082 4083 // Make sure this isn't invading the dialect type namespace. 4084 if (aliasName.contains('.')) 4085 return emitError("type names with a '.' are reserved for " 4086 "dialect-defined names"); 4087 4088 consumeToken(Token::exclamation_identifier); 4089 4090 // Parse the '=' and 'type'. 4091 if (parseToken(Token::equal, "expected '=' in type alias definition") || 4092 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 4093 return failure(); 4094 4095 // Parse the type. 4096 Type aliasedType = parseType(); 4097 if (!aliasedType) 4098 return failure(); 4099 4100 // Register this alias with the parser state. 4101 getState().typeAliasDefinitions.try_emplace(aliasName, aliasedType); 4102 return success(); 4103 } 4104 4105 /// This is the top-level module parser. 4106 ParseResult ModuleParser::parseModule(ModuleOp module) { 4107 OperationParser opParser(getState(), module); 4108 4109 // Module itself is a name scope. 4110 opParser.pushSSANameScope(/*isIsolated=*/true); 4111 4112 while (true) { 4113 switch (getToken().getKind()) { 4114 default: 4115 // Parse a top-level operation. 4116 if (opParser.parseOperation()) 4117 return failure(); 4118 break; 4119 4120 // If we got to the end of the file, then we're done. 4121 case Token::eof: { 4122 if (opParser.finalize()) 4123 return failure(); 4124 4125 // Handle the case where the top level module was explicitly defined. 4126 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 4127 auto &operations = bodyBlocks.front().getOperations(); 4128 assert(!operations.empty() && "expected a valid module terminator"); 4129 4130 // Check that the first operation is a module, and it is the only 4131 // non-terminator operation. 4132 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 4133 if (nested && std::next(operations.begin(), 2) == operations.end()) { 4134 // Merge the data of the nested module operation into 'module'. 4135 module.setLoc(nested.getLoc()); 4136 module.setAttrs(nested.getOperation()->getAttrList()); 4137 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 4138 4139 // Erase the original module body. 4140 bodyBlocks.pop_front(); 4141 } 4142 4143 return opParser.popSSANameScope(); 4144 } 4145 4146 // If we got an error token, then the lexer already emitted an error, just 4147 // stop. Someday we could introduce error recovery if there was demand 4148 // for it. 4149 case Token::error: 4150 return failure(); 4151 4152 // Parse an attribute alias. 4153 case Token::hash_identifier: 4154 if (parseAttributeAliasDef()) 4155 return failure(); 4156 break; 4157 4158 // Parse a type alias. 4159 case Token::exclamation_identifier: 4160 if (parseTypeAliasDef()) 4161 return failure(); 4162 break; 4163 } 4164 } 4165 } 4166 4167 //===----------------------------------------------------------------------===// 4168 4169 /// This parses the file specified by the indicated SourceMgr and returns an 4170 /// MLIR module if it was valid. If not, it emits diagnostics and returns 4171 /// null. 4172 OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 4173 MLIRContext *context) { 4174 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 4175 4176 // This is the result module we are parsing into. 4177 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 4178 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 4179 4180 ParserState state(sourceMgr, context); 4181 if (ModuleParser(state).parseModule(*module)) 4182 return nullptr; 4183 4184 // Make sure the parse module has no other structural problems detected by 4185 // the verifier. 4186 if (failed(verify(*module))) 4187 return nullptr; 4188 4189 return module; 4190 } 4191 4192 /// This parses the file specified by the indicated filename and returns an 4193 /// MLIR module if it was valid. If not, the error message is emitted through 4194 /// the error handler registered in the context, and a null pointer is returned. 4195 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4196 MLIRContext *context) { 4197 llvm::SourceMgr sourceMgr; 4198 return parseSourceFile(filename, sourceMgr, context); 4199 } 4200 4201 /// This parses the file specified by the indicated filename using the provided 4202 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 4203 /// message is emitted through the error handler registered in the context, and 4204 /// a null pointer is returned. 4205 OwningModuleRef mlir::parseSourceFile(StringRef filename, 4206 llvm::SourceMgr &sourceMgr, 4207 MLIRContext *context) { 4208 if (sourceMgr.getNumBuffers() != 0) { 4209 // TODO(b/136086478): Extend to support multiple buffers. 4210 emitError(mlir::UnknownLoc::get(context), 4211 "only main buffer parsed at the moment"); 4212 return nullptr; 4213 } 4214 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 4215 if (std::error_code error = file_or_err.getError()) { 4216 emitError(mlir::UnknownLoc::get(context), 4217 "could not open input file " + filename); 4218 return nullptr; 4219 } 4220 4221 // Load the MLIR module. 4222 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 4223 return parseSourceFile(sourceMgr, context); 4224 } 4225 4226 /// This parses the program string to a MLIR module if it was valid. If not, 4227 /// it emits diagnostics and returns null. 4228 OwningModuleRef mlir::parseSourceString(StringRef moduleStr, 4229 MLIRContext *context) { 4230 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 4231 if (!memBuffer) 4232 return nullptr; 4233 4234 SourceMgr sourceMgr; 4235 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 4236 return parseSourceFile(sourceMgr, context); 4237 } 4238 4239 Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context, 4240 size_t &numRead) { 4241 SourceMgr sourceMgr; 4242 auto memBuffer = 4243 MemoryBuffer::getMemBuffer(typeStr, /*BufferName=*/"<mlir_type_buffer>", 4244 /*RequiresNullTerminator=*/false); 4245 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 4246 SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context); 4247 ParserState state(sourceMgr, context); 4248 Parser parser(state); 4249 4250 auto start = parser.getToken().getLoc(); 4251 auto ty = parser.parseType(); 4252 if (!ty) 4253 return Type(); 4254 4255 auto end = parser.getToken().getLoc(); 4256 numRead = static_cast<size_t>(end.getPointer() - start.getPointer()); 4257 return ty; 4258 } 4259 4260 Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context) { 4261 size_t numRead = 0; 4262 return parseType(typeStr, context, numRead); 4263 } 4264