1 //===- Parser.cpp - MLIR Parser Implementation ----------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the parser for the MLIR textual form. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Parser.h" 23 #include "Lexer.h" 24 #include "mlir/Analysis/Verifier.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/IR/Attributes.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/IR/Dialect.h" 30 #include "mlir/IR/IntegerSet.h" 31 #include "mlir/IR/Location.h" 32 #include "mlir/IR/MLIRContext.h" 33 #include "mlir/IR/Module.h" 34 #include "mlir/IR/OpImplementation.h" 35 #include "mlir/IR/StandardTypes.h" 36 #include "mlir/Support/STLExtras.h" 37 #include "llvm/ADT/APInt.h" 38 #include "llvm/ADT/DenseMap.h" 39 #include "llvm/ADT/StringSet.h" 40 #include "llvm/ADT/bit.h" 41 #include "llvm/Support/MemoryBuffer.h" 42 #include "llvm/Support/PrettyStackTrace.h" 43 #include "llvm/Support/SMLoc.h" 44 #include "llvm/Support/SourceMgr.h" 45 #include <algorithm> 46 using namespace mlir; 47 using llvm::MemoryBuffer; 48 using llvm::SMLoc; 49 using llvm::SourceMgr; 50 51 namespace { 52 class Parser; 53 54 //===----------------------------------------------------------------------===// 55 // ParserState 56 //===----------------------------------------------------------------------===// 57 58 /// This class refers to all of the state maintained globally by the parser, 59 /// such as the current lexer position etc. The Parser base class provides 60 /// methods to access this. 61 class ParserState { 62 public: 63 ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx) 64 : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()) {} 65 66 // A map from attribute alias identifier to Attribute. 67 llvm::StringMap<Attribute> attributeAliasDefinitions; 68 69 // A map from type alias identifier to Type. 70 llvm::StringMap<Type> typeAliasDefinitions; 71 72 private: 73 ParserState(const ParserState &) = delete; 74 void operator=(const ParserState &) = delete; 75 76 friend class Parser; 77 78 // The context we're parsing into. 79 MLIRContext *const context; 80 81 // The lexer for the source file we're parsing. 82 Lexer lex; 83 84 // This is the next token that hasn't been consumed yet. 85 Token curToken; 86 }; 87 88 //===----------------------------------------------------------------------===// 89 // Parser 90 //===----------------------------------------------------------------------===// 91 92 /// This class implement support for parsing global entities like types and 93 /// shared entities like SSA names. It is intended to be subclassed by 94 /// specialized subparsers that include state, e.g. when a local symbol table. 95 class Parser { 96 public: 97 Builder builder; 98 99 Parser(ParserState &state) : builder(state.context), state(state) {} 100 101 // Helper methods to get stuff from the parser-global state. 102 ParserState &getState() const { return state; } 103 MLIRContext *getContext() const { return state.context; } 104 const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } 105 106 /// Parse a comma-separated list of elements up until the specified end token. 107 ParseResult 108 parseCommaSeparatedListUntil(Token::Kind rightToken, 109 const std::function<ParseResult()> &parseElement, 110 bool allowEmptyList = true); 111 112 /// Parse a comma separated list of elements that must have at least one entry 113 /// in it. 114 ParseResult 115 parseCommaSeparatedList(const std::function<ParseResult()> &parseElement); 116 117 ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); 118 119 // We have two forms of parsing methods - those that return a non-null 120 // pointer on success, and those that return a ParseResult to indicate whether 121 // they returned a failure. The second class fills in by-reference arguments 122 // as the results of their action. 123 124 //===--------------------------------------------------------------------===// 125 // Error Handling 126 //===--------------------------------------------------------------------===// 127 128 /// Emit an error and return failure. 129 InFlightDiagnostic emitError(const Twine &message = {}) { 130 return emitError(state.curToken.getLoc(), message); 131 } 132 InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); 133 134 /// Encode the specified source location information into an attribute for 135 /// attachment to the IR. 136 Location getEncodedSourceLocation(llvm::SMLoc loc) { 137 return state.lex.getEncodedSourceLocation(loc); 138 } 139 140 //===--------------------------------------------------------------------===// 141 // Token Parsing 142 //===--------------------------------------------------------------------===// 143 144 /// Return the current token the parser is inspecting. 145 const Token &getToken() const { return state.curToken; } 146 StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } 147 148 /// If the current token has the specified kind, consume it and return true. 149 /// If not, return false. 150 bool consumeIf(Token::Kind kind) { 151 if (state.curToken.isNot(kind)) 152 return false; 153 consumeToken(kind); 154 return true; 155 } 156 157 /// Advance the current lexer onto the next token. 158 void consumeToken() { 159 assert(state.curToken.isNot(Token::eof, Token::error) && 160 "shouldn't advance past EOF or errors"); 161 state.curToken = state.lex.lexToken(); 162 } 163 164 /// Advance the current lexer onto the next token, asserting what the expected 165 /// current token is. This is preferred to the above method because it leads 166 /// to more self-documenting code with better checking. 167 void consumeToken(Token::Kind kind) { 168 assert(state.curToken.is(kind) && "consumed an unexpected token"); 169 consumeToken(); 170 } 171 172 /// Consume the specified token if present and return success. On failure, 173 /// output a diagnostic and return failure. 174 ParseResult parseToken(Token::Kind expectedToken, const Twine &message); 175 176 //===--------------------------------------------------------------------===// 177 // Type Parsing 178 //===--------------------------------------------------------------------===// 179 180 ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements); 181 ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements); 182 ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements); 183 184 /// Parse an arbitrary type. 185 Type parseType(); 186 187 /// Parse a complex type. 188 Type parseComplexType(); 189 190 /// Parse an extended type. 191 Type parseExtendedType(); 192 193 /// Parse a function type. 194 Type parseFunctionType(); 195 196 /// Parse a memref type. 197 Type parseMemRefType(); 198 199 /// Parse a non function type. 200 Type parseNonFunctionType(); 201 202 /// Parse a tensor type. 203 Type parseTensorType(); 204 205 /// Parse a tuple type. 206 Type parseTupleType(); 207 208 /// Parse a vector type. 209 VectorType parseVectorType(); 210 ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 211 bool allowDynamic = true); 212 ParseResult parseXInDimensionList(); 213 214 //===--------------------------------------------------------------------===// 215 // Attribute Parsing 216 //===--------------------------------------------------------------------===// 217 218 /// Parse an arbitrary attribute with an optional type. 219 Attribute parseAttribute(Type type = {}); 220 221 /// Parse an attribute dictionary. 222 ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); 223 224 /// Parse an extended attribute. 225 Attribute parseExtendedAttr(Type type); 226 227 /// Parse a float attribute. 228 Attribute parseFloatAttr(Type type, bool isNegative); 229 230 /// Parse a decimal or a hexadecimal literal, which can be either an integer 231 /// or a float attribute. 232 Attribute parseDecOrHexAttr(Type type, bool isNegative); 233 234 /// Parse an opaque elements attribute. 235 Attribute parseOpaqueElementsAttr(); 236 237 /// Parse a dense elements attribute. 238 Attribute parseDenseElementsAttr(); 239 ShapedType parseElementsLiteralType(); 240 241 /// Parse a sparse elements attribute. 242 Attribute parseSparseElementsAttr(); 243 244 //===--------------------------------------------------------------------===// 245 // Location Parsing 246 //===--------------------------------------------------------------------===// 247 248 /// Parse an inline location. 249 ParseResult parseLocation(LocationAttr &loc); 250 251 /// Parse a raw location instance. 252 ParseResult parseLocationInstance(LocationAttr &loc); 253 254 /// Parse an optional trailing location. 255 /// 256 /// trailing-location ::= location? 257 /// 258 template <typename Owner> 259 ParseResult parseOptionalTrailingLocation(Owner *owner) { 260 // If there is a 'loc' we parse a trailing location. 261 if (!getToken().is(Token::kw_loc)) 262 return success(); 263 264 // Parse the location. 265 LocationAttr directLoc; 266 if (parseLocation(directLoc)) 267 return failure(); 268 owner->setLoc(directLoc); 269 return success(); 270 } 271 272 //===--------------------------------------------------------------------===// 273 // Affine Parsing 274 //===--------------------------------------------------------------------===// 275 276 ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, 277 IntegerSet &set); 278 279 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 280 ParseResult 281 parseAffineMapOfSSAIds(AffineMap &map, 282 llvm::function_ref<ParseResult(bool)> parseElement); 283 284 private: 285 /// The Parser is subclassed and reinstantiated. Do not add additional 286 /// non-trivial state here, add it to the ParserState class. 287 ParserState &state; 288 }; 289 } // end anonymous namespace 290 291 //===----------------------------------------------------------------------===// 292 // Helper methods. 293 //===----------------------------------------------------------------------===// 294 295 /// Parse a comma separated list of elements that must have at least one entry 296 /// in it. 297 ParseResult Parser::parseCommaSeparatedList( 298 const std::function<ParseResult()> &parseElement) { 299 // Non-empty case starts with an element. 300 if (parseElement()) 301 return failure(); 302 303 // Otherwise we have a list of comma separated elements. 304 while (consumeIf(Token::comma)) { 305 if (parseElement()) 306 return failure(); 307 } 308 return success(); 309 } 310 311 /// Parse a comma-separated list of elements, terminated with an arbitrary 312 /// token. This allows empty lists if allowEmptyList is true. 313 /// 314 /// abstract-list ::= rightToken // if allowEmptyList == true 315 /// abstract-list ::= element (',' element)* rightToken 316 /// 317 ParseResult Parser::parseCommaSeparatedListUntil( 318 Token::Kind rightToken, const std::function<ParseResult()> &parseElement, 319 bool allowEmptyList) { 320 // Handle the empty case. 321 if (getToken().is(rightToken)) { 322 if (!allowEmptyList) 323 return emitError("expected list element"); 324 consumeToken(rightToken); 325 return success(); 326 } 327 328 if (parseCommaSeparatedList(parseElement) || 329 parseToken(rightToken, "expected ',' or '" + 330 Token::getTokenSpelling(rightToken) + "'")) 331 return failure(); 332 333 return success(); 334 } 335 336 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, 337 /// and may be recursive. Return with the 'prettyName' StringRef encompasing 338 /// the entire pretty name. 339 /// 340 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 341 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 342 /// | '(' pretty-dialect-sym-contents+ ')' 343 /// | '[' pretty-dialect-sym-contents+ ']' 344 /// | '{' pretty-dialect-sym-contents+ '}' 345 /// | '[^[<({>\])}\0]+' 346 /// 347 ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { 348 // Pretty symbol names are a relatively unstructured format that contains a 349 // series of properly nested punctuation, with anything else in the middle. 350 // Scan ahead to find it and consume it if successful, otherwise emit an 351 // error. 352 auto *curPtr = getTokenSpelling().data(); 353 354 SmallVector<char, 8> nestedPunctuation; 355 356 // Scan over the nested punctuation, bailing out on error and consuming until 357 // we find the end. We know that we're currently looking at the '<', so we 358 // can go until we find the matching '>' character. 359 assert(*curPtr == '<'); 360 do { 361 char c = *curPtr++; 362 switch (c) { 363 case '\0': 364 // This also handles the EOF case. 365 return emitError("unexpected nul or EOF in pretty dialect name"); 366 case '<': 367 case '[': 368 case '(': 369 case '{': 370 nestedPunctuation.push_back(c); 371 continue; 372 373 case '>': 374 if (nestedPunctuation.pop_back_val() != '<') 375 return emitError("unbalanced '>' character in pretty dialect name"); 376 break; 377 case ']': 378 if (nestedPunctuation.pop_back_val() != '[') 379 return emitError("unbalanced ']' character in pretty dialect name"); 380 break; 381 case ')': 382 if (nestedPunctuation.pop_back_val() != '(') 383 return emitError("unbalanced ')' character in pretty dialect name"); 384 break; 385 case '}': 386 if (nestedPunctuation.pop_back_val() != '{') 387 return emitError("unbalanced '}' character in pretty dialect name"); 388 break; 389 390 default: 391 continue; 392 } 393 } while (!nestedPunctuation.empty()); 394 395 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 396 // consuming all this stuff, and return. 397 state.lex.resetPointer(curPtr); 398 399 unsigned length = curPtr - prettyName.begin(); 400 prettyName = StringRef(prettyName.begin(), length); 401 consumeToken(); 402 return success(); 403 } 404 405 /// Parse an extended dialect symbol. 406 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 407 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, 408 SymbolAliasMap &aliases, 409 CreateFn &&createSymbol) { 410 // Parse the dialect namespace. 411 StringRef identifier = p.getTokenSpelling().drop_front(); 412 auto loc = p.getToken().getLoc(); 413 p.consumeToken(identifierTok); 414 415 // If there is no '<' token following this, and if the typename contains no 416 // dot, then we are parsing a symbol alias. 417 if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { 418 // Check for an alias for this type. 419 auto aliasIt = aliases.find(identifier); 420 if (aliasIt == aliases.end()) 421 return (p.emitError("undefined symbol alias id '" + identifier + "'"), 422 nullptr); 423 return aliasIt->second; 424 } 425 426 // Otherwise, we are parsing a dialect-specific symbol. If the name contains 427 // a dot, then this is the "pretty" form. If not, it is the verbose form that 428 // looks like <"...">. 429 std::string symbolData; 430 auto dialectName = identifier; 431 432 // Handle the verbose form, where "identifier" is a simple dialect name. 433 if (!identifier.contains('.')) { 434 // Consume the '<'. 435 if (p.parseToken(Token::less, "expected '<' in dialect type")) 436 return nullptr; 437 438 // Parse the symbol specific data. 439 if (p.getToken().isNot(Token::string)) 440 return (p.emitError("expected string literal data in dialect symbol"), 441 nullptr); 442 symbolData = p.getToken().getStringValue(); 443 loc = p.getToken().getLoc(); 444 p.consumeToken(Token::string); 445 446 // Consume the '>'. 447 if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) 448 return nullptr; 449 } else { 450 // Ok, the dialect name is the part of the identifier before the dot, the 451 // part after the dot is the dialect's symbol, or the start thereof. 452 auto dotHalves = identifier.split('.'); 453 dialectName = dotHalves.first; 454 auto prettyName = dotHalves.second; 455 456 // If the dialect's symbol is followed immediately by a <, then lex the body 457 // of it into prettyName. 458 if (p.getToken().is(Token::less) && 459 prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { 460 if (p.parsePrettyDialectSymbolName(prettyName)) 461 return nullptr; 462 } 463 464 symbolData = prettyName.str(); 465 } 466 467 // Call into the provided symbol construction function. 468 auto encodedLoc = p.getEncodedSourceLocation(loc); 469 return createSymbol(dialectName, symbolData, encodedLoc); 470 } 471 472 //===----------------------------------------------------------------------===// 473 // Error Handling 474 //===----------------------------------------------------------------------===// 475 476 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { 477 auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); 478 479 // If we hit a parse error in response to a lexer error, then the lexer 480 // already reported the error. 481 if (getToken().is(Token::error)) 482 diag.abandon(); 483 return diag; 484 } 485 486 //===----------------------------------------------------------------------===// 487 // Token Parsing 488 //===----------------------------------------------------------------------===// 489 490 /// Consume the specified token if present and return success. On failure, 491 /// output a diagnostic and return failure. 492 ParseResult Parser::parseToken(Token::Kind expectedToken, 493 const Twine &message) { 494 if (consumeIf(expectedToken)) 495 return success(); 496 return emitError(message); 497 } 498 499 //===----------------------------------------------------------------------===// 500 // Type Parsing 501 //===----------------------------------------------------------------------===// 502 503 /// Parse an arbitrary type. 504 /// 505 /// type ::= function-type 506 /// | non-function-type 507 /// 508 Type Parser::parseType() { 509 if (getToken().is(Token::l_paren)) 510 return parseFunctionType(); 511 return parseNonFunctionType(); 512 } 513 514 /// Parse a function result type. 515 /// 516 /// function-result-type ::= type-list-parens 517 /// | non-function-type 518 /// 519 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 520 if (getToken().is(Token::l_paren)) 521 return parseTypeListParens(elements); 522 523 Type t = parseNonFunctionType(); 524 if (!t) 525 return failure(); 526 elements.push_back(t); 527 return success(); 528 } 529 530 /// Parse a list of types without an enclosing parenthesis. The list must have 531 /// at least one member. 532 /// 533 /// type-list-no-parens ::= type (`,` type)* 534 /// 535 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 536 auto parseElt = [&]() -> ParseResult { 537 auto elt = parseType(); 538 elements.push_back(elt); 539 return elt ? success() : failure(); 540 }; 541 542 return parseCommaSeparatedList(parseElt); 543 } 544 545 /// Parse a parenthesized list of types. 546 /// 547 /// type-list-parens ::= `(` `)` 548 /// | `(` type-list-no-parens `)` 549 /// 550 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 551 if (parseToken(Token::l_paren, "expected '('")) 552 return failure(); 553 554 // Handle empty lists. 555 if (getToken().is(Token::r_paren)) 556 return consumeToken(), success(); 557 558 if (parseTypeListNoParens(elements) || 559 parseToken(Token::r_paren, "expected ')'")) 560 return failure(); 561 return success(); 562 } 563 564 /// Parse a complex type. 565 /// 566 /// complex-type ::= `complex` `<` type `>` 567 /// 568 Type Parser::parseComplexType() { 569 consumeToken(Token::kw_complex); 570 571 // Parse the '<'. 572 if (parseToken(Token::less, "expected '<' in complex type")) 573 return nullptr; 574 575 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 576 auto elementType = parseType(); 577 if (!elementType || 578 parseToken(Token::greater, "expected '>' in complex type")) 579 return nullptr; 580 581 return ComplexType::getChecked(elementType, typeLocation); 582 } 583 584 /// Parse an extended type. 585 /// 586 /// extended-type ::= (dialect-type | type-alias) 587 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 588 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 589 /// type-alias ::= `!` alias-name 590 /// 591 Type Parser::parseExtendedType() { 592 return parseExtendedSymbol<Type>( 593 *this, Token::exclamation_identifier, state.typeAliasDefinitions, 594 [&](StringRef dialectName, StringRef symbolData, Location loc) -> Type { 595 // If we found a registered dialect, then ask it to parse the type. 596 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) 597 return dialect->parseType(symbolData, loc); 598 599 // Otherwise, form a new opaque type. 600 return OpaqueType::getChecked( 601 Identifier::get(dialectName, state.context), symbolData, 602 state.context, loc); 603 }); 604 } 605 606 /// Parse a function type. 607 /// 608 /// function-type ::= type-list-parens `->` type-list 609 /// 610 Type Parser::parseFunctionType() { 611 assert(getToken().is(Token::l_paren)); 612 613 SmallVector<Type, 4> arguments, results; 614 if (parseTypeListParens(arguments) || 615 parseToken(Token::arrow, "expected '->' in function type") || 616 parseFunctionResultTypes(results)) 617 return nullptr; 618 619 return builder.getFunctionType(arguments, results); 620 } 621 622 /// Parse a memref type. 623 /// 624 /// memref-type ::= `memref` `<` dimension-list-ranked element-type 625 /// (`,` semi-affine-map-composition)? (`,` memory-space)? `>` 626 /// 627 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map 628 /// memory-space ::= integer-literal /* | TODO: address-space-id */ 629 /// 630 Type Parser::parseMemRefType() { 631 consumeToken(Token::kw_memref); 632 633 if (parseToken(Token::less, "expected '<' in memref type")) 634 return nullptr; 635 636 SmallVector<int64_t, 4> dimensions; 637 if (parseDimensionListRanked(dimensions)) 638 return nullptr; 639 640 // Parse the element type. 641 auto typeLoc = getToken().getLoc(); 642 auto elementType = parseType(); 643 if (!elementType) 644 return nullptr; 645 646 // Parse semi-affine-map-composition. 647 SmallVector<AffineMap, 2> affineMapComposition; 648 unsigned memorySpace = 0; 649 bool parsedMemorySpace = false; 650 651 auto parseElt = [&]() -> ParseResult { 652 if (getToken().is(Token::integer)) { 653 // Parse memory space. 654 if (parsedMemorySpace) 655 return emitError("multiple memory spaces specified in memref type"); 656 auto v = getToken().getUnsignedIntegerValue(); 657 if (!v.hasValue()) 658 return emitError("invalid memory space in memref type"); 659 memorySpace = v.getValue(); 660 consumeToken(Token::integer); 661 parsedMemorySpace = true; 662 } else { 663 // Parse affine map. 664 if (parsedMemorySpace) 665 return emitError("affine map after memory space in memref type"); 666 auto affineMap = parseAttribute(); 667 if (!affineMap) 668 return failure(); 669 670 // Verify that the parsed attribute is an affine map. 671 if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>()) 672 affineMapComposition.push_back(affineMapAttr.getValue()); 673 else 674 return emitError("expected affine map in memref type"); 675 } 676 return success(); 677 }; 678 679 // Parse a list of mappings and address space if present. 680 if (consumeIf(Token::comma)) { 681 // Parse comma separated list of affine maps, followed by memory space. 682 if (parseCommaSeparatedListUntil(Token::greater, parseElt, 683 /*allowEmptyList=*/false)) { 684 return nullptr; 685 } 686 } else { 687 if (parseToken(Token::greater, "expected ',' or '>' in memref type")) 688 return nullptr; 689 } 690 691 return MemRefType::getChecked(dimensions, elementType, affineMapComposition, 692 memorySpace, getEncodedSourceLocation(typeLoc)); 693 } 694 695 /// Parse any type except the function type. 696 /// 697 /// non-function-type ::= integer-type 698 /// | index-type 699 /// | float-type 700 /// | extended-type 701 /// | vector-type 702 /// | tensor-type 703 /// | memref-type 704 /// | complex-type 705 /// | tuple-type 706 /// | none-type 707 /// 708 /// index-type ::= `index` 709 /// float-type ::= `f16` | `bf16` | `f32` | `f64` 710 /// none-type ::= `none` 711 /// 712 Type Parser::parseNonFunctionType() { 713 switch (getToken().getKind()) { 714 default: 715 return (emitError("expected non-function type"), nullptr); 716 case Token::kw_memref: 717 return parseMemRefType(); 718 case Token::kw_tensor: 719 return parseTensorType(); 720 case Token::kw_complex: 721 return parseComplexType(); 722 case Token::kw_tuple: 723 return parseTupleType(); 724 case Token::kw_vector: 725 return parseVectorType(); 726 // integer-type 727 case Token::inttype: { 728 auto width = getToken().getIntTypeBitwidth(); 729 if (!width.hasValue()) 730 return (emitError("invalid integer width"), nullptr); 731 auto loc = getEncodedSourceLocation(getToken().getLoc()); 732 consumeToken(Token::inttype); 733 return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); 734 } 735 736 // float-type 737 case Token::kw_bf16: 738 consumeToken(Token::kw_bf16); 739 return builder.getBF16Type(); 740 case Token::kw_f16: 741 consumeToken(Token::kw_f16); 742 return builder.getF16Type(); 743 case Token::kw_f32: 744 consumeToken(Token::kw_f32); 745 return builder.getF32Type(); 746 case Token::kw_f64: 747 consumeToken(Token::kw_f64); 748 return builder.getF64Type(); 749 750 // index-type 751 case Token::kw_index: 752 consumeToken(Token::kw_index); 753 return builder.getIndexType(); 754 755 // none-type 756 case Token::kw_none: 757 consumeToken(Token::kw_none); 758 return builder.getNoneType(); 759 760 // extended type 761 case Token::exclamation_identifier: 762 return parseExtendedType(); 763 } 764 } 765 766 /// Parse a tensor type. 767 /// 768 /// tensor-type ::= `tensor` `<` dimension-list element-type `>` 769 /// dimension-list ::= dimension-list-ranked | `*x` 770 /// 771 Type Parser::parseTensorType() { 772 consumeToken(Token::kw_tensor); 773 774 if (parseToken(Token::less, "expected '<' in tensor type")) 775 return nullptr; 776 777 bool isUnranked; 778 SmallVector<int64_t, 4> dimensions; 779 780 if (consumeIf(Token::star)) { 781 // This is an unranked tensor type. 782 isUnranked = true; 783 784 if (parseXInDimensionList()) 785 return nullptr; 786 787 } else { 788 isUnranked = false; 789 if (parseDimensionListRanked(dimensions)) 790 return nullptr; 791 } 792 793 // Parse the element type. 794 auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); 795 auto elementType = parseType(); 796 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 797 return nullptr; 798 799 if (isUnranked) 800 return UnrankedTensorType::getChecked(elementType, typeLocation); 801 return RankedTensorType::getChecked(dimensions, elementType, typeLocation); 802 } 803 804 /// Parse a tuple type. 805 /// 806 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 807 /// 808 Type Parser::parseTupleType() { 809 consumeToken(Token::kw_tuple); 810 811 // Parse the '<'. 812 if (parseToken(Token::less, "expected '<' in tuple type")) 813 return nullptr; 814 815 // Check for an empty tuple by directly parsing '>'. 816 if (consumeIf(Token::greater)) 817 return TupleType::get(getContext()); 818 819 // Parse the element types and the '>'. 820 SmallVector<Type, 4> types; 821 if (parseTypeListNoParens(types) || 822 parseToken(Token::greater, "expected '>' in tuple type")) 823 return nullptr; 824 825 return TupleType::get(types, getContext()); 826 } 827 828 /// Parse a vector type. 829 /// 830 /// vector-type ::= `vector` `<` static-dimension-list primitive-type `>` 831 /// static-dimension-list ::= (decimal-literal `x`)+ 832 /// 833 VectorType Parser::parseVectorType() { 834 consumeToken(Token::kw_vector); 835 836 if (parseToken(Token::less, "expected '<' in vector type")) 837 return nullptr; 838 839 SmallVector<int64_t, 4> dimensions; 840 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) 841 return nullptr; 842 if (dimensions.empty()) 843 return (emitError("expected dimension size in vector type"), nullptr); 844 845 // Parse the element type. 846 auto typeLoc = getToken().getLoc(); 847 auto elementType = parseType(); 848 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 849 return nullptr; 850 851 return VectorType::getChecked(dimensions, elementType, 852 getEncodedSourceLocation(typeLoc)); 853 } 854 855 /// Parse a dimension list of a tensor or memref type. This populates the 856 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 857 /// errors out on `?` otherwise. 858 /// 859 /// dimension-list-ranked ::= (dimension `x`)* 860 /// dimension ::= `?` | decimal-literal 861 /// 862 /// When `allowDynamic` is not set, this can be also used to parse 863 /// 864 /// static-dimension-list ::= (decimal-literal `x`)* 865 ParseResult 866 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 867 bool allowDynamic) { 868 while (getToken().isAny(Token::integer, Token::question)) { 869 if (consumeIf(Token::question)) { 870 if (!allowDynamic) 871 return emitError("expected static shape"); 872 dimensions.push_back(-1); 873 } else { 874 // Hexadecimal integer literals (starting with `0x`) are not allowed in 875 // aggregate type declarations. Therefore, `0xf32` should be processed as 876 // a sequence of separate elements `0`, `x`, `f32`. 877 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 878 // We can get here only if the token is an integer literal. Hexadecimal 879 // integer literals can only start with `0x` (`1x` wouldn't lex as a 880 // literal, just `1` would, at which point we don't get into this 881 // branch). 882 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 883 dimensions.push_back(0); 884 state.lex.resetPointer(getTokenSpelling().data() + 1); 885 consumeToken(); 886 } else { 887 // Make sure this integer value is in bound and valid. 888 auto dimension = getToken().getUnsignedIntegerValue(); 889 if (!dimension.hasValue()) 890 return emitError("invalid dimension"); 891 dimensions.push_back((int64_t)dimension.getValue()); 892 consumeToken(Token::integer); 893 } 894 } 895 896 // Make sure we have an 'x' or something like 'xbf32'. 897 if (parseXInDimensionList()) 898 return failure(); 899 } 900 901 return success(); 902 } 903 904 /// Parse an 'x' token in a dimension list, handling the case where the x is 905 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 906 /// token. 907 ParseResult Parser::parseXInDimensionList() { 908 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 909 return emitError("expected 'x' in dimension list"); 910 911 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 912 if (getTokenSpelling().size() != 1) 913 state.lex.resetPointer(getTokenSpelling().data() + 1); 914 915 // Consume the 'x'. 916 consumeToken(Token::bare_identifier); 917 918 return success(); 919 } 920 921 //===----------------------------------------------------------------------===// 922 // Attribute parsing. 923 //===----------------------------------------------------------------------===// 924 925 /// Parse an arbitrary attribute. 926 /// 927 /// attribute-value ::= `unit` 928 /// | bool-literal 929 /// | integer-literal (`:` (index-type | integer-type))? 930 /// | float-literal (`:` float-type)? 931 /// | string-literal (`:` type)? 932 /// | type 933 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 934 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 935 /// | symbol-ref-id 936 /// | `dense` `<` attribute-value `>` `:` 937 /// (tensor-type | vector-type) 938 /// | `sparse` `<` attribute-value `,` attribute-value `>` 939 /// `:` (tensor-type | vector-type) 940 /// | `opaque` `<` dialect-namespace `,` hex-string-literal 941 /// `>` `:` (tensor-type | vector-type) 942 /// | extended-attribute 943 /// 944 Attribute Parser::parseAttribute(Type type) { 945 switch (getToken().getKind()) { 946 // Parse an AffineMap or IntegerSet attribute. 947 case Token::l_paren: { 948 // Try to parse an affine map or an integer set reference. 949 AffineMap map; 950 IntegerSet set; 951 if (parseAffineMapOrIntegerSetReference(map, set)) 952 return nullptr; 953 if (map) 954 return builder.getAffineMapAttr(map); 955 assert(set); 956 return builder.getIntegerSetAttr(set); 957 } 958 959 // Parse an array attribute. 960 case Token::l_square: { 961 consumeToken(Token::l_square); 962 963 SmallVector<Attribute, 4> elements; 964 auto parseElt = [&]() -> ParseResult { 965 elements.push_back(parseAttribute()); 966 return elements.back() ? success() : failure(); 967 }; 968 969 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 970 return nullptr; 971 return builder.getArrayAttr(elements); 972 } 973 974 // Parse a boolean attribute. 975 case Token::kw_false: 976 consumeToken(Token::kw_false); 977 return builder.getBoolAttr(false); 978 case Token::kw_true: 979 consumeToken(Token::kw_true); 980 return builder.getBoolAttr(true); 981 982 // Parse a dense elements attribute. 983 case Token::kw_dense: 984 return parseDenseElementsAttr(); 985 986 // Parse a dictionary attribute. 987 case Token::l_brace: { 988 SmallVector<NamedAttribute, 4> elements; 989 if (parseAttributeDict(elements)) 990 return nullptr; 991 return builder.getDictionaryAttr(elements); 992 } 993 994 // Parse an extended attribute, i.e. alias or dialect attribute. 995 case Token::hash_identifier: 996 return parseExtendedAttr(type); 997 998 // Parse floating point and integer attributes. 999 case Token::floatliteral: 1000 return parseFloatAttr(type, /*isNegative=*/false); 1001 case Token::integer: 1002 return parseDecOrHexAttr(type, /*isNegative=*/false); 1003 case Token::minus: { 1004 consumeToken(Token::minus); 1005 if (getToken().is(Token::integer)) 1006 return parseDecOrHexAttr(type, /*isNegative=*/true); 1007 if (getToken().is(Token::floatliteral)) 1008 return parseFloatAttr(type, /*isNegative=*/true); 1009 1010 return (emitError("expected constant integer or floating point value"), 1011 nullptr); 1012 } 1013 1014 // Parse a location attribute. 1015 case Token::kw_loc: { 1016 LocationAttr attr; 1017 return failed(parseLocation(attr)) ? Attribute() : attr; 1018 } 1019 1020 // Parse an opaque elements attribute. 1021 case Token::kw_opaque: 1022 return parseOpaqueElementsAttr(); 1023 1024 // Parse a sparse elements attribute. 1025 case Token::kw_sparse: 1026 return parseSparseElementsAttr(); 1027 1028 // Parse a string attribute. 1029 case Token::string: { 1030 auto val = getToken().getStringValue(); 1031 consumeToken(Token::string); 1032 // Parse the optional trailing colon type if one wasn't explicitly provided. 1033 if (!type && consumeIf(Token::colon) && !(type = parseType())) 1034 return Attribute(); 1035 1036 return type ? StringAttr::get(val, type) 1037 : StringAttr::get(val, getContext()); 1038 } 1039 1040 // Parse a symbol reference attribute. 1041 case Token::at_identifier: { 1042 auto nameStr = getTokenSpelling(); 1043 consumeToken(Token::at_identifier); 1044 return builder.getSymbolRefAttr(nameStr.drop_front()); 1045 } 1046 1047 // Parse a 'unit' attribute. 1048 case Token::kw_unit: 1049 consumeToken(Token::kw_unit); 1050 return builder.getUnitAttr(); 1051 1052 default: 1053 // Parse a type attribute. 1054 if (Type type = parseType()) 1055 return builder.getTypeAttr(type); 1056 return nullptr; 1057 } 1058 } 1059 1060 /// Attribute dictionary. 1061 /// 1062 /// attribute-dict ::= `{` `}` 1063 /// | `{` attribute-entry (`,` attribute-entry)* `}` 1064 /// attribute-entry ::= bare-id `=` attribute-value 1065 /// 1066 ParseResult 1067 Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) { 1068 if (!consumeIf(Token::l_brace)) 1069 return failure(); 1070 1071 auto parseElt = [&]() -> ParseResult { 1072 // We allow keywords as attribute names. 1073 if (getToken().isNot(Token::bare_identifier, Token::inttype) && 1074 !getToken().isKeyword()) 1075 return emitError("expected attribute name"); 1076 Identifier nameId = builder.getIdentifier(getTokenSpelling()); 1077 consumeToken(); 1078 1079 // Try to parse the '=' for the attribute value. 1080 if (!consumeIf(Token::equal)) { 1081 // If there is no '=', we treat this as a unit attribute. 1082 attributes.push_back({nameId, builder.getUnitAttr()}); 1083 return success(); 1084 } 1085 1086 auto attr = parseAttribute(); 1087 if (!attr) 1088 return failure(); 1089 1090 attributes.push_back({nameId, attr}); 1091 return success(); 1092 }; 1093 1094 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) 1095 return failure(); 1096 1097 return success(); 1098 } 1099 1100 /// Parse an extended attribute. 1101 /// 1102 /// extended-attribute ::= (dialect-attribute | attribute-alias) 1103 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 1104 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 1105 /// attribute-alias ::= `#` alias-name 1106 /// 1107 Attribute Parser::parseExtendedAttr(Type type) { 1108 Attribute attr = parseExtendedSymbol<Attribute>( 1109 *this, Token::hash_identifier, state.attributeAliasDefinitions, 1110 [&](StringRef dialectName, StringRef symbolData, 1111 Location loc) -> Attribute { 1112 // Parse an optional trailing colon type. 1113 Type attrType = type; 1114 if (consumeIf(Token::colon) && !(attrType = parseType())) 1115 return Attribute(); 1116 1117 // If we found a registered dialect, then ask it to parse the attribute. 1118 if (auto *dialect = state.context->getRegisteredDialect(dialectName)) 1119 return dialect->parseAttribute(symbolData, attrType, loc); 1120 1121 // Otherwise, form a new opaque attribute. 1122 return OpaqueAttr::getChecked( 1123 Identifier::get(dialectName, state.context), symbolData, 1124 attrType ? attrType : NoneType::get(state.context), loc); 1125 }); 1126 1127 // Ensure that the attribute has the same type as requested. 1128 if (attr && type && attr.getType() != type) { 1129 emitError("attribute type different than expected: expected ") 1130 << type << ", but got " << attr.getType(); 1131 return nullptr; 1132 } 1133 return attr; 1134 } 1135 1136 /// Parse a float attribute. 1137 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 1138 auto val = getToken().getFloatingPointValue(); 1139 if (!val.hasValue()) 1140 return (emitError("floating point value too large for attribute"), nullptr); 1141 consumeToken(Token::floatliteral); 1142 if (!type) { 1143 // Default to F64 when no type is specified. 1144 if (!consumeIf(Token::colon)) 1145 type = builder.getF64Type(); 1146 else if (!(type = parseType())) 1147 return nullptr; 1148 } 1149 if (!type.isa<FloatType>()) 1150 return (emitError("floating point value not valid for specified type"), 1151 nullptr); 1152 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue()); 1153 } 1154 1155 /// Construct a float attribute bitwise equivalent to the integer literal. 1156 static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type, 1157 uint64_t value) { 1158 int width = type.getIntOrFloatBitWidth(); 1159 APInt apInt(width, value); 1160 if (apInt != value) { 1161 p->emitError("hexadecimal float constant out of range for type"); 1162 return nullptr; 1163 } 1164 APFloat apFloat(type.getFloatSemantics(), apInt); 1165 return p->builder.getFloatAttr(type, apFloat); 1166 } 1167 1168 /// Parse a decimal or a hexadecimal literal, which can be either an integer 1169 /// or a float attribute. 1170 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 1171 auto val = getToken().getUInt64IntegerValue(); 1172 if (!val.hasValue()) 1173 return (emitError("integer constant out of range for attribute"), nullptr); 1174 1175 // Remember if the literal is hexadecimal. 1176 StringRef spelling = getToken().getSpelling(); 1177 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 1178 1179 consumeToken(Token::integer); 1180 if (!type) { 1181 // Default to i64 if not type is specified. 1182 if (!consumeIf(Token::colon)) 1183 type = builder.getIntegerType(64); 1184 else if (!(type = parseType())) 1185 return nullptr; 1186 } 1187 1188 // Hexadecimal representation of float literals is not supported for bfloat16. 1189 // When supported, the literal should be unsigned. 1190 auto floatType = type.dyn_cast<FloatType>(); 1191 if (floatType && !type.isBF16()) { 1192 if (isNegative) { 1193 emitError("hexadecimal float literal should not have a leading minus"); 1194 return nullptr; 1195 } 1196 if (!isHex) { 1197 emitError("unexpected decimal integer literal for a float attribute") 1198 .attachNote() 1199 << "add a trailing dot to make the literal a float"; 1200 return nullptr; 1201 } 1202 1203 // Construct a float attribute bitwise equivalent to the integer literal. 1204 return buildHexadecimalFloatLiteral(this, floatType, *val); 1205 } 1206 1207 if (!type.isIntOrIndex()) 1208 return (emitError("integer literal not valid for specified type"), nullptr); 1209 1210 // Parse the integer literal. 1211 int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); 1212 APInt apInt(width, *val, isNegative); 1213 if (apInt != *val) 1214 return (emitError("integer constant out of range for attribute"), nullptr); 1215 1216 // Otherwise construct an integer attribute. 1217 if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0) 1218 return (emitError("integer constant out of range for attribute"), nullptr); 1219 1220 return builder.getIntegerAttr(type, isNegative ? -apInt : apInt); 1221 } 1222 1223 /// Parse an opaque elements attribute. 1224 Attribute Parser::parseOpaqueElementsAttr() { 1225 consumeToken(Token::kw_opaque); 1226 if (parseToken(Token::less, "expected '<' after 'opaque'")) 1227 return nullptr; 1228 1229 if (getToken().isNot(Token::string)) 1230 return (emitError("expected dialect namespace"), nullptr); 1231 1232 auto name = getToken().getStringValue(); 1233 auto *dialect = builder.getContext()->getRegisteredDialect(name); 1234 // TODO(shpeisman): Allow for having an unknown dialect on an opaque 1235 // attribute. Otherwise, it can't be roundtripped without having the dialect 1236 // registered. 1237 if (!dialect) 1238 return (emitError("no registered dialect with namespace '" + name + "'"), 1239 nullptr); 1240 1241 consumeToken(Token::string); 1242 if (parseToken(Token::comma, "expected ','")) 1243 return nullptr; 1244 1245 if (getToken().getKind() != Token::string) 1246 return (emitError("opaque string should start with '0x'"), nullptr); 1247 1248 auto val = getToken().getStringValue(); 1249 if (val.size() < 2 || val[0] != '0' || val[1] != 'x') 1250 return (emitError("opaque string should start with '0x'"), nullptr); 1251 1252 val = val.substr(2); 1253 if (!llvm::all_of(val, llvm::isHexDigit)) 1254 return (emitError("opaque string only contains hex digits"), nullptr); 1255 1256 consumeToken(Token::string); 1257 if (parseToken(Token::greater, "expected '>'") || 1258 parseToken(Token::colon, "expected ':'")) 1259 return nullptr; 1260 1261 auto type = parseElementsLiteralType(); 1262 if (!type) 1263 return nullptr; 1264 1265 return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val)); 1266 } 1267 1268 namespace { 1269 class TensorLiteralParser { 1270 public: 1271 TensorLiteralParser(Parser &p) : p(p) {} 1272 1273 ParseResult parse() { 1274 if (p.getToken().is(Token::l_square)) 1275 return parseList(shape); 1276 return parseElement(); 1277 } 1278 1279 /// Build a dense attribute instance with the parsed elements and the given 1280 /// shaped type. 1281 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type); 1282 1283 ArrayRef<int64_t> getShape() const { return shape; } 1284 1285 private: 1286 enum class ElementKind { Boolean, Integer, Float }; 1287 1288 /// Return a string to represent the given element kind. 1289 const char *getElementKindStr(ElementKind kind) { 1290 switch (kind) { 1291 case ElementKind::Boolean: 1292 return "'boolean'"; 1293 case ElementKind::Integer: 1294 return "'integer'"; 1295 case ElementKind::Float: 1296 return "'float'"; 1297 } 1298 llvm_unreachable("unknown element kind"); 1299 } 1300 1301 /// Build a Dense Integer attribute for the given type. 1302 DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, 1303 IntegerType eltTy); 1304 1305 /// Build a Dense Float attribute for the given type. 1306 DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, 1307 FloatType eltTy); 1308 1309 /// Parse a single element, returning failure if it isn't a valid element 1310 /// literal. For example: 1311 /// parseElement(1) -> Success, 1 1312 /// parseElement([1]) -> Failure 1313 ParseResult parseElement(); 1314 1315 /// Parse a list of either lists or elements, returning the dimensions of the 1316 /// parsed sub-tensors in dims. For example: 1317 /// parseList([1, 2, 3]) -> Success, [3] 1318 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1319 /// parseList([[1, 2], 3]) -> Failure 1320 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1321 ParseResult parseList(llvm::SmallVectorImpl<int64_t> &dims); 1322 1323 Parser &p; 1324 1325 /// The shape inferred from the parsed elements. 1326 SmallVector<int64_t, 4> shape; 1327 1328 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 1329 std::vector<std::pair<bool, Token>> storage; 1330 1331 /// A flag that indicates the type of elements that have been parsed. 1332 llvm::Optional<ElementKind> knownEltKind; 1333 }; 1334 } // namespace 1335 1336 /// Build a dense attribute instance with the parsed elements and the given 1337 /// shaped type. 1338 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, 1339 ShapedType type) { 1340 // Check that the parsed storage size has the same number of elements to the 1341 // type, or is a known splat. 1342 if (!shape.empty() && getShape() != type.getShape()) { 1343 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 1344 << "]) does not match type ([" << type.getShape() << "])"; 1345 return nullptr; 1346 } 1347 1348 // If the type is an integer, build a set of APInt values from the storage 1349 // with the correct bitwidth. 1350 if (auto intTy = type.getElementType().dyn_cast<IntegerType>()) 1351 return getIntAttr(loc, type, intTy); 1352 1353 // Otherwise, this must be a floating point type. 1354 auto floatTy = type.getElementType().dyn_cast<FloatType>(); 1355 if (!floatTy) { 1356 p.emitError(loc) << "expected floating-point or integer element type, got " 1357 << type.getElementType(); 1358 return nullptr; 1359 } 1360 return getFloatAttr(loc, type, floatTy); 1361 } 1362 1363 /// Build a Dense Integer attribute for the given type. 1364 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, 1365 ShapedType type, 1366 IntegerType eltTy) { 1367 std::vector<APInt> intElements; 1368 intElements.reserve(storage.size()); 1369 for (const auto &signAndToken : storage) { 1370 bool isNegative = signAndToken.first; 1371 const Token &token = signAndToken.second; 1372 1373 // Check to see if floating point values were parsed. 1374 if (token.is(Token::floatliteral)) { 1375 p.emitError() << "expected integer elements, but parsed floating-point"; 1376 return nullptr; 1377 } 1378 1379 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 1380 "unexpected token type"); 1381 if (token.isAny(Token::kw_true, Token::kw_false)) { 1382 if (!eltTy.isInteger(1)) 1383 p.emitError() << "expected i1 type for 'true' or 'false' values"; 1384 APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), 1385 /*isSigned=*/false); 1386 intElements.push_back(apInt); 1387 continue; 1388 } 1389 1390 // Create APInt values for each element with the correct bitwidth. 1391 auto val = token.getUInt64IntegerValue(); 1392 if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 1393 : (int64_t)val.getValue() < 0)) { 1394 p.emitError(token.getLoc(), 1395 "integer constant out of range for attribute"); 1396 return nullptr; 1397 } 1398 APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); 1399 if (apInt != val.getValue()) 1400 return (p.emitError("integer constant out of range for type"), nullptr); 1401 intElements.push_back(isNegative ? -apInt : apInt); 1402 } 1403 1404 return DenseElementsAttr::get(type, intElements); 1405 } 1406 1407 /// Build a Dense Float attribute for the given type. 1408 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, 1409 ShapedType type, 1410 FloatType eltTy) { 1411 std::vector<Attribute> floatValues; 1412 floatValues.reserve(storage.size()); 1413 for (const auto &signAndToken : storage) { 1414 bool isNegative = signAndToken.first; 1415 const Token &token = signAndToken.second; 1416 1417 // Handle hexadecimal float literals. 1418 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { 1419 if (isNegative) { 1420 p.emitError(token.getLoc()) 1421 << "hexadecimal float literal should not have a leading minus"; 1422 return nullptr; 1423 } 1424 auto val = token.getUInt64IntegerValue(); 1425 if (!val.hasValue()) { 1426 p.emitError("hexadecimal float constant out of range for attribute"); 1427 return nullptr; 1428 } 1429 FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val); 1430 if (!attr) 1431 return nullptr; 1432 floatValues.push_back(attr); 1433 continue; 1434 } 1435 1436 // Check to see if any decimal integers or booleans were parsed. 1437 if (!token.is(Token::floatliteral)) { 1438 p.emitError() << "expected floating-point elements, but parsed integer"; 1439 return nullptr; 1440 } 1441 1442 // Build the float values from tokens. 1443 auto val = token.getFloatingPointValue(); 1444 if (!val.hasValue()) { 1445 p.emitError("floating point value too large for attribute"); 1446 return nullptr; 1447 } 1448 floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val)); 1449 } 1450 1451 return DenseElementsAttr::get(type, floatValues); 1452 } 1453 1454 ParseResult TensorLiteralParser::parseElement() { 1455 switch (p.getToken().getKind()) { 1456 // Parse a boolean element. 1457 case Token::kw_true: 1458 case Token::kw_false: 1459 case Token::floatliteral: 1460 case Token::integer: 1461 storage.emplace_back(/*isNegative=*/false, p.getToken()); 1462 p.consumeToken(); 1463 break; 1464 1465 // Parse a signed integer or a negative floating-point element. 1466 case Token::minus: 1467 p.consumeToken(Token::minus); 1468 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 1469 return p.emitError("expected integer or floating point literal"); 1470 storage.emplace_back(/*isNegative=*/true, p.getToken()); 1471 p.consumeToken(); 1472 break; 1473 1474 default: 1475 return p.emitError("expected element literal of primitive type"); 1476 } 1477 1478 return success(); 1479 } 1480 1481 /// Parse a list of either lists or elements, returning the dimensions of the 1482 /// parsed sub-tensors in dims. For example: 1483 /// parseList([1, 2, 3]) -> Success, [3] 1484 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 1485 /// parseList([[1, 2], 3]) -> Failure 1486 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 1487 ParseResult 1488 TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) { 1489 p.consumeToken(Token::l_square); 1490 1491 auto checkDims = 1492 [&](const llvm::SmallVectorImpl<int64_t> &prevDims, 1493 const llvm::SmallVectorImpl<int64_t> &newDims) -> ParseResult { 1494 if (prevDims == newDims) 1495 return success(); 1496 return p.emitError("tensor literal is invalid; ranks are not consistent " 1497 "between elements"); 1498 }; 1499 1500 bool first = true; 1501 llvm::SmallVector<int64_t, 4> newDims; 1502 unsigned size = 0; 1503 auto parseCommaSeparatedList = [&]() -> ParseResult { 1504 llvm::SmallVector<int64_t, 4> thisDims; 1505 if (p.getToken().getKind() == Token::l_square) { 1506 if (parseList(thisDims)) 1507 return failure(); 1508 } else if (parseElement()) { 1509 return failure(); 1510 } 1511 ++size; 1512 if (!first) 1513 return checkDims(newDims, thisDims); 1514 newDims = thisDims; 1515 first = false; 1516 return success(); 1517 }; 1518 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) 1519 return failure(); 1520 1521 // Return the sublists' dimensions with 'size' prepended. 1522 dims.clear(); 1523 dims.push_back(size); 1524 dims.append(newDims.begin(), newDims.end()); 1525 return success(); 1526 } 1527 1528 /// Parse a dense elements attribute. 1529 Attribute Parser::parseDenseElementsAttr() { 1530 consumeToken(Token::kw_dense); 1531 if (parseToken(Token::less, "expected '<' after 'dense'")) 1532 return nullptr; 1533 1534 // Parse the literal data. 1535 TensorLiteralParser literalParser(*this); 1536 if (literalParser.parse()) 1537 return nullptr; 1538 1539 if (parseToken(Token::greater, "expected '>'") || 1540 parseToken(Token::colon, "expected ':'")) 1541 return nullptr; 1542 1543 auto typeLoc = getToken().getLoc(); 1544 auto type = parseElementsLiteralType(); 1545 if (!type) 1546 return nullptr; 1547 return literalParser.getAttr(typeLoc, type); 1548 } 1549 1550 /// Shaped type for elements attribute. 1551 /// 1552 /// elements-literal-type ::= vector-type | ranked-tensor-type 1553 /// 1554 /// This method also checks the type has static shape. 1555 ShapedType Parser::parseElementsLiteralType() { 1556 auto type = parseType(); 1557 if (!type) 1558 return nullptr; 1559 1560 if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) { 1561 emitError("elements literal must be a ranked tensor or vector type"); 1562 return nullptr; 1563 } 1564 1565 auto sType = type.cast<ShapedType>(); 1566 if (!sType.hasStaticShape()) 1567 return (emitError("elements literal type must have static shape"), nullptr); 1568 1569 return sType; 1570 } 1571 1572 /// Parse a sparse elements attribute. 1573 Attribute Parser::parseSparseElementsAttr() { 1574 consumeToken(Token::kw_sparse); 1575 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 1576 return nullptr; 1577 1578 /// Parse indices 1579 auto indicesLoc = getToken().getLoc(); 1580 TensorLiteralParser indiceParser(*this); 1581 if (indiceParser.parse()) 1582 return nullptr; 1583 1584 if (parseToken(Token::comma, "expected ','")) 1585 return nullptr; 1586 1587 /// Parse values. 1588 auto valuesLoc = getToken().getLoc(); 1589 TensorLiteralParser valuesParser(*this); 1590 if (valuesParser.parse()) 1591 return nullptr; 1592 1593 if (parseToken(Token::greater, "expected '>'") || 1594 parseToken(Token::colon, "expected ':'")) 1595 return nullptr; 1596 1597 auto type = parseElementsLiteralType(); 1598 if (!type) 1599 return nullptr; 1600 1601 // If the indices are a splat, i.e. the literal parser parsed an element and 1602 // not a list, we set the shape explicitly. The indices are represented by a 1603 // 2-dimensional shape where the second dimension is the rank of the type. 1604 // Given that the parsed indices is a splat, we know that we only have one 1605 // indice and thus one for the first dimension. 1606 auto indiceEltType = builder.getIntegerType(64); 1607 ShapedType indicesType; 1608 if (indiceParser.getShape().empty()) { 1609 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 1610 } else { 1611 // Otherwise, set the shape to the one parsed by the literal parser. 1612 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 1613 } 1614 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 1615 1616 // If the values are a splat, set the shape explicitly based on the number of 1617 // indices. The number of indices is encoded in the first dimension of the 1618 // indice shape type. 1619 auto valuesEltType = type.getElementType(); 1620 ShapedType valuesType = 1621 valuesParser.getShape().empty() 1622 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 1623 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 1624 auto values = valuesParser.getAttr(valuesLoc, valuesType); 1625 1626 /// Sanity check. 1627 if (valuesType.getRank() != 1) 1628 return (emitError("expected 1-d tensor for values"), nullptr); 1629 1630 auto sameShape = (indicesType.getRank() == 1) || 1631 (type.getRank() == indicesType.getDimSize(1)); 1632 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0); 1633 if (!sameShape || !sameElementNum) { 1634 emitError() << "expected shape ([" << type.getShape() 1635 << "]); inferred shape of indices literal ([" 1636 << indicesType.getShape() 1637 << "]); inferred shape of values literal ([" 1638 << valuesType.getShape() << "])"; 1639 return nullptr; 1640 } 1641 1642 // Build the sparse elements attribute by the indices and values. 1643 return SparseElementsAttr::get(type, indices, values); 1644 } 1645 1646 //===----------------------------------------------------------------------===// 1647 // Location parsing. 1648 //===----------------------------------------------------------------------===// 1649 1650 /// Parse a location. 1651 /// 1652 /// location ::= `loc` inline-location 1653 /// inline-location ::= '(' location-inst ')' 1654 /// 1655 ParseResult Parser::parseLocation(LocationAttr &loc) { 1656 // Check for 'loc' identifier. 1657 if (parseToken(Token::kw_loc, "expected 'loc' keyword")) 1658 return emitError(); 1659 1660 // Parse the inline-location. 1661 if (parseToken(Token::l_paren, "expected '(' in inline location") || 1662 parseLocationInstance(loc) || 1663 parseToken(Token::r_paren, "expected ')' in inline location")) 1664 return failure(); 1665 return success(); 1666 } 1667 1668 /// Specific location instances. 1669 /// 1670 /// location-inst ::= filelinecol-location | 1671 /// name-location | 1672 /// callsite-location | 1673 /// fused-location | 1674 /// unknown-location 1675 /// filelinecol-location ::= string-literal ':' integer-literal 1676 /// ':' integer-literal 1677 /// name-location ::= string-literal 1678 /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')' 1679 /// fused-location ::= fused ('<' attribute-value '>')? 1680 /// '[' location-inst (location-inst ',')* ']' 1681 /// unknown-location ::= 'unknown' 1682 /// 1683 ParseResult Parser::parseLocationInstance(LocationAttr &loc) { 1684 auto *ctx = getContext(); 1685 1686 // Handle either name or filelinecol locations. 1687 if (getToken().is(Token::string)) { 1688 auto str = getToken().getStringValue(); 1689 consumeToken(Token::string); 1690 1691 // If the next token is ':' this is a filelinecol location. 1692 if (consumeIf(Token::colon)) { 1693 // Parse the line number. 1694 if (getToken().isNot(Token::integer)) 1695 return emitError("expected integer line number in FileLineColLoc"); 1696 auto line = getToken().getUnsignedIntegerValue(); 1697 if (!line.hasValue()) 1698 return emitError("expected integer line number in FileLineColLoc"); 1699 consumeToken(Token::integer); 1700 1701 // Parse the ':'. 1702 if (parseToken(Token::colon, "expected ':' in FileLineColLoc")) 1703 return failure(); 1704 1705 // Parse the column number. 1706 if (getToken().isNot(Token::integer)) 1707 return emitError("expected integer column number in FileLineColLoc"); 1708 auto column = getToken().getUnsignedIntegerValue(); 1709 if (!column.hasValue()) 1710 return emitError("expected integer column number in FileLineColLoc"); 1711 consumeToken(Token::integer); 1712 1713 loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx); 1714 return success(); 1715 } 1716 1717 // Otherwise, this is a NameLoc. 1718 1719 // Check for a child location. 1720 if (consumeIf(Token::l_paren)) { 1721 auto childSourceLoc = getToken().getLoc(); 1722 1723 // Parse the child location. 1724 LocationAttr childLoc; 1725 if (parseLocationInstance(childLoc)) 1726 return failure(); 1727 1728 // The child must not be another NameLoc. 1729 if (childLoc.isa<NameLoc>()) 1730 return emitError(childSourceLoc, 1731 "child of NameLoc cannot be another NameLoc"); 1732 loc = NameLoc::get(Identifier::get(str, ctx), childLoc, ctx); 1733 1734 // Parse the closing ')'. 1735 if (parseToken(Token::r_paren, 1736 "expected ')' after child location of NameLoc")) 1737 return failure(); 1738 } else { 1739 loc = NameLoc::get(Identifier::get(str, ctx), ctx); 1740 } 1741 1742 return success(); 1743 } 1744 1745 // Check for a 'unknown' for an unknown location. 1746 if (getToken().is(Token::bare_identifier) && 1747 getToken().getSpelling() == "unknown") { 1748 consumeToken(Token::bare_identifier); 1749 loc = UnknownLoc::get(ctx); 1750 return success(); 1751 } 1752 1753 // If the token is 'fused', then this is a fused location. 1754 if (getToken().is(Token::bare_identifier) && 1755 getToken().getSpelling() == "fused") { 1756 consumeToken(Token::bare_identifier); 1757 1758 // Try to parse the optional metadata. 1759 Attribute metadata; 1760 if (consumeIf(Token::less)) { 1761 metadata = parseAttribute(); 1762 if (!metadata) 1763 return emitError("expected valid attribute metadata"); 1764 // Parse the '>' token. 1765 if (parseToken(Token::greater, 1766 "expected '>' after fused location metadata")) 1767 return failure(); 1768 } 1769 1770 llvm::SmallVector<Location, 4> locations; 1771 auto parseElt = [&] { 1772 LocationAttr newLoc; 1773 if (parseLocationInstance(newLoc)) 1774 return failure(); 1775 locations.push_back(newLoc); 1776 return success(); 1777 }; 1778 1779 if (parseToken(Token::l_square, "expected '[' in fused location") || 1780 parseCommaSeparatedList(parseElt) || 1781 parseToken(Token::r_square, "expected ']' in fused location")) 1782 return failure(); 1783 1784 // Return the fused location. 1785 loc = FusedLoc::get(locations, metadata, getContext()); 1786 return success(); 1787 } 1788 1789 // Check for the 'callsite' signifying a callsite location. 1790 if (getToken().is(Token::bare_identifier) && 1791 getToken().getSpelling() == "callsite") { 1792 consumeToken(Token::bare_identifier); 1793 1794 // Parse the '('. 1795 if (parseToken(Token::l_paren, "expected '(' in callsite location")) 1796 return failure(); 1797 1798 // Parse the callee location. 1799 LocationAttr calleeLoc; 1800 if (parseLocationInstance(calleeLoc)) 1801 return failure(); 1802 1803 // Parse the 'at'. 1804 if (getToken().isNot(Token::bare_identifier) || 1805 getToken().getSpelling() != "at") 1806 return emitError("expected 'at' in callsite location"); 1807 consumeToken(Token::bare_identifier); 1808 1809 // Parse the caller location. 1810 LocationAttr callerLoc; 1811 if (parseLocationInstance(callerLoc)) 1812 return failure(); 1813 1814 // Parse the ')'. 1815 if (parseToken(Token::r_paren, "expected ')' in callsite location")) 1816 return failure(); 1817 1818 // Return the callsite location. 1819 loc = CallSiteLoc::get(calleeLoc, callerLoc, ctx); 1820 return success(); 1821 } 1822 1823 return emitError("expected location instance"); 1824 } 1825 1826 //===----------------------------------------------------------------------===// 1827 // Affine parsing. 1828 //===----------------------------------------------------------------------===// 1829 1830 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 1831 /// the boolean sense. 1832 enum AffineLowPrecOp { 1833 /// Null value. 1834 LNoOp, 1835 Add, 1836 Sub 1837 }; 1838 1839 /// Higher precedence ops - all at the same precedence level. HNoOp is false 1840 /// in the boolean sense. 1841 enum AffineHighPrecOp { 1842 /// Null value. 1843 HNoOp, 1844 Mul, 1845 FloorDiv, 1846 CeilDiv, 1847 Mod 1848 }; 1849 1850 namespace { 1851 /// This is a specialized parser for affine structures (affine maps, affine 1852 /// expressions, and integer sets), maintaining the state transient to their 1853 /// bodies. 1854 class AffineParser : public Parser { 1855 public: 1856 AffineParser(ParserState &state, bool allowParsingSSAIds = false, 1857 llvm::function_ref<ParseResult(bool)> parseElement = nullptr) 1858 : Parser(state), allowParsingSSAIds(allowParsingSSAIds), 1859 parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} 1860 1861 AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); 1862 ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); 1863 IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); 1864 ParseResult parseAffineMapOfSSAIds(AffineMap &map); 1865 void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds, 1866 unsigned &numDims); 1867 1868 private: 1869 // Binary affine op parsing. 1870 AffineLowPrecOp consumeIfLowPrecOp(); 1871 AffineHighPrecOp consumeIfHighPrecOp(); 1872 1873 // Identifier lists for polyhedral structures. 1874 ParseResult parseDimIdList(unsigned &numDims); 1875 ParseResult parseSymbolIdList(unsigned &numSymbols); 1876 ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims, 1877 unsigned &numSymbols); 1878 ParseResult parseIdentifierDefinition(AffineExpr idExpr); 1879 1880 AffineExpr parseAffineExpr(); 1881 AffineExpr parseParentheticalExpr(); 1882 AffineExpr parseNegateExpression(AffineExpr lhs); 1883 AffineExpr parseIntegerExpr(); 1884 AffineExpr parseBareIdExpr(); 1885 AffineExpr parseSSAIdExpr(bool isSymbol); 1886 AffineExpr parseSymbolSSAIdExpr(); 1887 1888 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 1889 AffineExpr rhs, SMLoc opLoc); 1890 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 1891 AffineExpr rhs); 1892 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 1893 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 1894 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 1895 SMLoc llhsOpLoc); 1896 AffineExpr parseAffineConstraint(bool *isEq); 1897 1898 private: 1899 bool allowParsingSSAIds; 1900 llvm::function_ref<ParseResult(bool)> parseElement; 1901 unsigned numDimOperands; 1902 unsigned numSymbolOperands; 1903 SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols; 1904 }; 1905 } // end anonymous namespace 1906 1907 /// Create an affine binary high precedence op expression (mul's, div's, mod). 1908 /// opLoc is the location of the op token to be used to report errors 1909 /// for non-conforming expressions. 1910 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 1911 AffineExpr lhs, AffineExpr rhs, 1912 SMLoc opLoc) { 1913 // TODO: make the error location info accurate. 1914 switch (op) { 1915 case Mul: 1916 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 1917 emitError(opLoc, "non-affine expression: at least one of the multiply " 1918 "operands has to be either a constant or symbolic"); 1919 return nullptr; 1920 } 1921 return lhs * rhs; 1922 case FloorDiv: 1923 if (!rhs.isSymbolicOrConstant()) { 1924 emitError(opLoc, "non-affine expression: right operand of floordiv " 1925 "has to be either a constant or symbolic"); 1926 return nullptr; 1927 } 1928 return lhs.floorDiv(rhs); 1929 case CeilDiv: 1930 if (!rhs.isSymbolicOrConstant()) { 1931 emitError(opLoc, "non-affine expression: right operand of ceildiv " 1932 "has to be either a constant or symbolic"); 1933 return nullptr; 1934 } 1935 return lhs.ceilDiv(rhs); 1936 case Mod: 1937 if (!rhs.isSymbolicOrConstant()) { 1938 emitError(opLoc, "non-affine expression: right operand of mod " 1939 "has to be either a constant or symbolic"); 1940 return nullptr; 1941 } 1942 return lhs % rhs; 1943 case HNoOp: 1944 llvm_unreachable("can't create affine expression for null high prec op"); 1945 return nullptr; 1946 } 1947 llvm_unreachable("Unknown AffineHighPrecOp"); 1948 } 1949 1950 /// Create an affine binary low precedence op expression (add, sub). 1951 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 1952 AffineExpr lhs, AffineExpr rhs) { 1953 switch (op) { 1954 case AffineLowPrecOp::Add: 1955 return lhs + rhs; 1956 case AffineLowPrecOp::Sub: 1957 return lhs - rhs; 1958 case AffineLowPrecOp::LNoOp: 1959 llvm_unreachable("can't create affine expression for null low prec op"); 1960 return nullptr; 1961 } 1962 llvm_unreachable("Unknown AffineLowPrecOp"); 1963 } 1964 1965 /// Consume this token if it is a lower precedence affine op (there are only 1966 /// two precedence levels). 1967 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 1968 switch (getToken().getKind()) { 1969 case Token::plus: 1970 consumeToken(Token::plus); 1971 return AffineLowPrecOp::Add; 1972 case Token::minus: 1973 consumeToken(Token::minus); 1974 return AffineLowPrecOp::Sub; 1975 default: 1976 return AffineLowPrecOp::LNoOp; 1977 } 1978 } 1979 1980 /// Consume this token if it is a higher precedence affine op (there are only 1981 /// two precedence levels) 1982 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 1983 switch (getToken().getKind()) { 1984 case Token::star: 1985 consumeToken(Token::star); 1986 return Mul; 1987 case Token::kw_floordiv: 1988 consumeToken(Token::kw_floordiv); 1989 return FloorDiv; 1990 case Token::kw_ceildiv: 1991 consumeToken(Token::kw_ceildiv); 1992 return CeilDiv; 1993 case Token::kw_mod: 1994 consumeToken(Token::kw_mod); 1995 return Mod; 1996 default: 1997 return HNoOp; 1998 } 1999 } 2000 2001 /// Parse a high precedence op expression list: mul, div, and mod are high 2002 /// precedence binary ops, i.e., parse a 2003 /// expr_1 op_1 expr_2 op_2 ... expr_n 2004 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 2005 /// All affine binary ops are left associative. 2006 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 2007 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 2008 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 2009 /// report an error for non-conforming expressions. 2010 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 2011 AffineHighPrecOp llhsOp, 2012 SMLoc llhsOpLoc) { 2013 AffineExpr lhs = parseAffineOperandExpr(llhs); 2014 if (!lhs) 2015 return nullptr; 2016 2017 // Found an LHS. Parse the remaining expression. 2018 auto opLoc = getToken().getLoc(); 2019 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 2020 if (llhs) { 2021 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 2022 if (!expr) 2023 return nullptr; 2024 return parseAffineHighPrecOpExpr(expr, op, opLoc); 2025 } 2026 // No LLHS, get RHS 2027 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 2028 } 2029 2030 // This is the last operand in this expression. 2031 if (llhs) 2032 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 2033 2034 // No llhs, 'lhs' itself is the expression. 2035 return lhs; 2036 } 2037 2038 /// Parse an affine expression inside parentheses. 2039 /// 2040 /// affine-expr ::= `(` affine-expr `)` 2041 AffineExpr AffineParser::parseParentheticalExpr() { 2042 if (parseToken(Token::l_paren, "expected '('")) 2043 return nullptr; 2044 if (getToken().is(Token::r_paren)) 2045 return (emitError("no expression inside parentheses"), nullptr); 2046 2047 auto expr = parseAffineExpr(); 2048 if (!expr) 2049 return nullptr; 2050 if (parseToken(Token::r_paren, "expected ')'")) 2051 return nullptr; 2052 2053 return expr; 2054 } 2055 2056 /// Parse the negation expression. 2057 /// 2058 /// affine-expr ::= `-` affine-expr 2059 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 2060 if (parseToken(Token::minus, "expected '-'")) 2061 return nullptr; 2062 2063 AffineExpr operand = parseAffineOperandExpr(lhs); 2064 // Since negation has the highest precedence of all ops (including high 2065 // precedence ops) but lower than parentheses, we are only going to use 2066 // parseAffineOperandExpr instead of parseAffineExpr here. 2067 if (!operand) 2068 // Extra error message although parseAffineOperandExpr would have 2069 // complained. Leads to a better diagnostic. 2070 return (emitError("missing operand of negation"), nullptr); 2071 return (-1) * operand; 2072 } 2073 2074 /// Parse a bare id that may appear in an affine expression. 2075 /// 2076 /// affine-expr ::= bare-id 2077 AffineExpr AffineParser::parseBareIdExpr() { 2078 if (getToken().isNot(Token::bare_identifier)) 2079 return (emitError("expected bare identifier"), nullptr); 2080 2081 StringRef sRef = getTokenSpelling(); 2082 for (auto entry : dimsAndSymbols) { 2083 if (entry.first == sRef) { 2084 consumeToken(Token::bare_identifier); 2085 return entry.second; 2086 } 2087 } 2088 2089 return (emitError("use of undeclared identifier"), nullptr); 2090 } 2091 2092 /// Parse an SSA id which may appear in an affine expression. 2093 AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { 2094 if (!allowParsingSSAIds) 2095 return (emitError("unexpected ssa identifier"), nullptr); 2096 if (getToken().isNot(Token::percent_identifier)) 2097 return (emitError("expected ssa identifier"), nullptr); 2098 auto name = getTokenSpelling(); 2099 // Check if we already parsed this SSA id. 2100 for (auto entry : dimsAndSymbols) { 2101 if (entry.first == name) { 2102 consumeToken(Token::percent_identifier); 2103 return entry.second; 2104 } 2105 } 2106 // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. 2107 if (parseElement(isSymbol)) 2108 return (emitError("failed to parse ssa identifier"), nullptr); 2109 auto idExpr = isSymbol 2110 ? getAffineSymbolExpr(numSymbolOperands++, getContext()) 2111 : getAffineDimExpr(numDimOperands++, getContext()); 2112 dimsAndSymbols.push_back({name, idExpr}); 2113 return idExpr; 2114 } 2115 2116 AffineExpr AffineParser::parseSymbolSSAIdExpr() { 2117 if (parseToken(Token::kw_symbol, "expected symbol keyword") || 2118 parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) 2119 return nullptr; 2120 AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); 2121 if (!symbolExpr) 2122 return nullptr; 2123 if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) 2124 return nullptr; 2125 return symbolExpr; 2126 } 2127 2128 /// Parse a positive integral constant appearing in an affine expression. 2129 /// 2130 /// affine-expr ::= integer-literal 2131 AffineExpr AffineParser::parseIntegerExpr() { 2132 auto val = getToken().getUInt64IntegerValue(); 2133 if (!val.hasValue() || (int64_t)val.getValue() < 0) 2134 return (emitError("constant too large for index"), nullptr); 2135 2136 consumeToken(Token::integer); 2137 return builder.getAffineConstantExpr((int64_t)val.getValue()); 2138 } 2139 2140 /// Parses an expression that can be a valid operand of an affine expression. 2141 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 2142 /// operator, the rhs of which is being parsed. This is used to determine 2143 /// whether an error should be emitted for a missing right operand. 2144 // Eg: for an expression without parentheses (like i + j + k + l), each 2145 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 2146 // operand expression, it's an op expression and will be parsed via 2147 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 2148 // -l are valid operands that will be parsed by this function. 2149 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 2150 switch (getToken().getKind()) { 2151 case Token::bare_identifier: 2152 return parseBareIdExpr(); 2153 case Token::kw_symbol: 2154 return parseSymbolSSAIdExpr(); 2155 case Token::percent_identifier: 2156 return parseSSAIdExpr(/*isSymbol=*/false); 2157 case Token::integer: 2158 return parseIntegerExpr(); 2159 case Token::l_paren: 2160 return parseParentheticalExpr(); 2161 case Token::minus: 2162 return parseNegateExpression(lhs); 2163 case Token::kw_ceildiv: 2164 case Token::kw_floordiv: 2165 case Token::kw_mod: 2166 case Token::plus: 2167 case Token::star: 2168 if (lhs) 2169 emitError("missing right operand of binary operator"); 2170 else 2171 emitError("missing left operand of binary operator"); 2172 return nullptr; 2173 default: 2174 if (lhs) 2175 emitError("missing right operand of binary operator"); 2176 else 2177 emitError("expected affine expression"); 2178 return nullptr; 2179 } 2180 } 2181 2182 /// Parse affine expressions that are bare-id's, integer constants, 2183 /// parenthetical affine expressions, and affine op expressions that are a 2184 /// composition of those. 2185 /// 2186 /// All binary op's associate from left to right. 2187 /// 2188 /// {add, sub} have lower precedence than {mul, div, and mod}. 2189 /// 2190 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 2191 /// ceildiv, and mod are at the same higher precedence level. Negation has 2192 /// higher precedence than any binary op. 2193 /// 2194 /// llhs: the affine expression appearing on the left of the one being parsed. 2195 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 2196 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 2197 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 2198 /// associativity. 2199 /// 2200 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 2201 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 2202 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 2203 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 2204 AffineLowPrecOp llhsOp) { 2205 AffineExpr lhs; 2206 if (!(lhs = parseAffineOperandExpr(llhs))) 2207 return nullptr; 2208 2209 // Found an LHS. Deal with the ops. 2210 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 2211 if (llhs) { 2212 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2213 return parseAffineLowPrecOpExpr(sum, lOp); 2214 } 2215 // No LLHS, get RHS and form the expression. 2216 return parseAffineLowPrecOpExpr(lhs, lOp); 2217 } 2218 auto opLoc = getToken().getLoc(); 2219 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 2220 // We have a higher precedence op here. Get the rhs operand for the llhs 2221 // through parseAffineHighPrecOpExpr. 2222 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 2223 if (!highRes) 2224 return nullptr; 2225 2226 // If llhs is null, the product forms the first operand of the yet to be 2227 // found expression. If non-null, the op to associate with llhs is llhsOp. 2228 AffineExpr expr = 2229 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 2230 2231 // Recurse for subsequent low prec op's after the affine high prec op 2232 // expression. 2233 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 2234 return parseAffineLowPrecOpExpr(expr, nextOp); 2235 return expr; 2236 } 2237 // Last operand in the expression list. 2238 if (llhs) 2239 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 2240 // No llhs, 'lhs' itself is the expression. 2241 return lhs; 2242 } 2243 2244 /// Parse an affine expression. 2245 /// affine-expr ::= `(` affine-expr `)` 2246 /// | `-` affine-expr 2247 /// | affine-expr `+` affine-expr 2248 /// | affine-expr `-` affine-expr 2249 /// | affine-expr `*` affine-expr 2250 /// | affine-expr `floordiv` affine-expr 2251 /// | affine-expr `ceildiv` affine-expr 2252 /// | affine-expr `mod` affine-expr 2253 /// | bare-id 2254 /// | integer-literal 2255 /// 2256 /// Additional conditions are checked depending on the production. For eg., 2257 /// one of the operands for `*` has to be either constant/symbolic; the second 2258 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 2259 AffineExpr AffineParser::parseAffineExpr() { 2260 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 2261 } 2262 2263 /// Parse a dim or symbol from the lists appearing before the actual 2264 /// expressions of the affine map. Update our state to store the 2265 /// dimensional/symbolic identifier. 2266 ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { 2267 if (getToken().isNot(Token::bare_identifier)) 2268 return emitError("expected bare identifier"); 2269 2270 auto name = getTokenSpelling(); 2271 for (auto entry : dimsAndSymbols) { 2272 if (entry.first == name) 2273 return emitError("redefinition of identifier '" + name + "'"); 2274 } 2275 consumeToken(Token::bare_identifier); 2276 2277 dimsAndSymbols.push_back({name, idExpr}); 2278 return success(); 2279 } 2280 2281 /// Parse the list of dimensional identifiers to an affine map. 2282 ParseResult AffineParser::parseDimIdList(unsigned &numDims) { 2283 if (parseToken(Token::l_paren, 2284 "expected '(' at start of dimensional identifiers list")) { 2285 return failure(); 2286 } 2287 2288 auto parseElt = [&]() -> ParseResult { 2289 auto dimension = getAffineDimExpr(numDims++, getContext()); 2290 return parseIdentifierDefinition(dimension); 2291 }; 2292 return parseCommaSeparatedListUntil(Token::r_paren, parseElt); 2293 } 2294 2295 /// Parse the list of symbolic identifiers to an affine map. 2296 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) { 2297 consumeToken(Token::l_square); 2298 auto parseElt = [&]() -> ParseResult { 2299 auto symbol = getAffineSymbolExpr(numSymbols++, getContext()); 2300 return parseIdentifierDefinition(symbol); 2301 }; 2302 return parseCommaSeparatedListUntil(Token::r_square, parseElt); 2303 } 2304 2305 /// Parse the list of symbolic identifiers to an affine map. 2306 ParseResult 2307 AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims, 2308 unsigned &numSymbols) { 2309 if (parseDimIdList(numDims)) { 2310 return failure(); 2311 } 2312 if (!getToken().is(Token::l_square)) { 2313 numSymbols = 0; 2314 return success(); 2315 } 2316 return parseSymbolIdList(numSymbols); 2317 } 2318 2319 /// Parses an ambiguous affine map or integer set definition inline. 2320 ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, 2321 IntegerSet &set) { 2322 unsigned numDims = 0, numSymbols = 0; 2323 2324 // List of dimensional and optional symbol identifiers. 2325 if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) { 2326 return failure(); 2327 } 2328 2329 // This is needed for parsing attributes as we wouldn't know whether we would 2330 // be parsing an integer set attribute or an affine map attribute. 2331 bool isArrow = getToken().is(Token::arrow); 2332 bool isColon = getToken().is(Token::colon); 2333 if (!isArrow && !isColon) { 2334 return emitError("expected '->' or ':'"); 2335 } else if (isArrow) { 2336 parseToken(Token::arrow, "expected '->' or '['"); 2337 map = parseAffineMapRange(numDims, numSymbols); 2338 return map ? success() : failure(); 2339 } else if (parseToken(Token::colon, "expected ':' or '['")) { 2340 return failure(); 2341 } 2342 2343 if ((set = parseIntegerSetConstraints(numDims, numSymbols))) 2344 return success(); 2345 2346 return failure(); 2347 } 2348 2349 /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. 2350 ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) { 2351 if (parseToken(Token::l_square, "expected '['")) 2352 return failure(); 2353 2354 SmallVector<AffineExpr, 4> exprs; 2355 auto parseElt = [&]() -> ParseResult { 2356 auto elt = parseAffineExpr(); 2357 exprs.push_back(elt); 2358 return elt ? success() : failure(); 2359 }; 2360 2361 // Parse a multi-dimensional affine expression (a comma-separated list of 2362 // 1-d affine expressions); the list cannot be empty. Grammar: 2363 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2364 if (parseCommaSeparatedListUntil(Token::r_square, parseElt, 2365 /*allowEmptyList=*/true)) 2366 return failure(); 2367 // Parsed a valid affine map. 2368 if (exprs.empty()) 2369 map = AffineMap(); 2370 else 2371 map = builder.getAffineMap(numDimOperands, 2372 dimsAndSymbols.size() - numDimOperands, exprs); 2373 return success(); 2374 } 2375 2376 /// Parse the range and sizes affine map definition inline. 2377 /// 2378 /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr 2379 /// 2380 /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2381 AffineMap AffineParser::parseAffineMapRange(unsigned numDims, 2382 unsigned numSymbols) { 2383 parseToken(Token::l_paren, "expected '(' at start of affine map range"); 2384 2385 SmallVector<AffineExpr, 4> exprs; 2386 auto parseElt = [&]() -> ParseResult { 2387 auto elt = parseAffineExpr(); 2388 ParseResult res = elt ? success() : failure(); 2389 exprs.push_back(elt); 2390 return res; 2391 }; 2392 2393 // Parse a multi-dimensional affine expression (a comma-separated list of 2394 // 1-d affine expressions); the list cannot be empty. Grammar: 2395 // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) 2396 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false)) 2397 return AffineMap(); 2398 2399 // Parsed a valid affine map. 2400 return builder.getAffineMap(numDims, numSymbols, exprs); 2401 } 2402 2403 /// Parse an affine constraint. 2404 /// affine-constraint ::= affine-expr `>=` `0` 2405 /// | affine-expr `==` `0` 2406 /// 2407 /// isEq is set to true if the parsed constraint is an equality, false if it 2408 /// is an inequality (greater than or equal). 2409 /// 2410 AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { 2411 AffineExpr expr = parseAffineExpr(); 2412 if (!expr) 2413 return nullptr; 2414 2415 if (consumeIf(Token::greater) && consumeIf(Token::equal) && 2416 getToken().is(Token::integer)) { 2417 auto dim = getToken().getUnsignedIntegerValue(); 2418 if (dim.hasValue() && dim.getValue() == 0) { 2419 consumeToken(Token::integer); 2420 *isEq = false; 2421 return expr; 2422 } 2423 return (emitError("expected '0' after '>='"), nullptr); 2424 } 2425 2426 if (consumeIf(Token::equal) && consumeIf(Token::equal) && 2427 getToken().is(Token::integer)) { 2428 auto dim = getToken().getUnsignedIntegerValue(); 2429 if (dim.hasValue() && dim.getValue() == 0) { 2430 consumeToken(Token::integer); 2431 *isEq = true; 2432 return expr; 2433 } 2434 return (emitError("expected '0' after '=='"), nullptr); 2435 } 2436 2437 return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), 2438 nullptr); 2439 } 2440 2441 /// Parse the constraints that are part of an integer set definition. 2442 /// integer-set-inline 2443 /// ::= dim-and-symbol-id-lists `:` 2444 /// '(' affine-constraint-conjunction? ')' 2445 /// affine-constraint-conjunction ::= affine-constraint (`,` 2446 /// affine-constraint)* 2447 /// 2448 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, 2449 unsigned numSymbols) { 2450 if (parseToken(Token::l_paren, 2451 "expected '(' at start of integer set constraint list")) 2452 return IntegerSet(); 2453 2454 SmallVector<AffineExpr, 4> constraints; 2455 SmallVector<bool, 4> isEqs; 2456 auto parseElt = [&]() -> ParseResult { 2457 bool isEq; 2458 auto elt = parseAffineConstraint(&isEq); 2459 ParseResult res = elt ? success() : failure(); 2460 if (elt) { 2461 constraints.push_back(elt); 2462 isEqs.push_back(isEq); 2463 } 2464 return res; 2465 }; 2466 2467 // Parse a list of affine constraints (comma-separated). 2468 if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) 2469 return IntegerSet(); 2470 2471 // If no constraints were parsed, then treat this as a degenerate 'true' case. 2472 if (constraints.empty()) { 2473 /* 0 == 0 */ 2474 auto zero = getAffineConstantExpr(0, getContext()); 2475 return builder.getIntegerSet(numDims, numSymbols, zero, true); 2476 } 2477 2478 // Parsed a valid integer set. 2479 return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); 2480 } 2481 2482 /// Parse an ambiguous reference to either and affine map or an integer set. 2483 ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, 2484 IntegerSet &set) { 2485 return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); 2486 } 2487 2488 /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to 2489 /// parse SSA value uses encountered while parsing affine expressions. 2490 ParseResult Parser::parseAffineMapOfSSAIds( 2491 AffineMap &map, llvm::function_ref<ParseResult(bool)> parseElement) { 2492 return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) 2493 .parseAffineMapOfSSAIds(map); 2494 } 2495 2496 //===----------------------------------------------------------------------===// 2497 // OperationParser 2498 //===----------------------------------------------------------------------===// 2499 2500 namespace { 2501 /// This class provides support for parsing operations and regions of 2502 /// operations. 2503 class OperationParser : public Parser { 2504 public: 2505 OperationParser(ParserState &state, ModuleOp moduleOp) 2506 : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) { 2507 } 2508 2509 ~OperationParser(); 2510 2511 /// After parsing is finished, this function must be called to see if there 2512 /// are any remaining issues. 2513 ParseResult finalize(); 2514 2515 //===--------------------------------------------------------------------===// 2516 // SSA Value Handling 2517 //===--------------------------------------------------------------------===// 2518 2519 /// This represents a use of an SSA value in the program. The first two 2520 /// entries in the tuple are the name and result number of a reference. The 2521 /// third is the location of the reference, which is used in case this ends 2522 /// up being a use of an undefined value. 2523 struct SSAUseInfo { 2524 StringRef name; // Value name, e.g. %42 or %abc 2525 unsigned number; // Number, specified with #12 2526 SMLoc loc; // Location of first definition or use. 2527 }; 2528 2529 /// Push a new SSA name scope to the parser. 2530 void pushSSANameScope(bool isIsolated); 2531 2532 /// Pop the last SSA name scope from the parser. 2533 ParseResult popSSANameScope(); 2534 2535 /// Register a definition of a value with the symbol table. 2536 ParseResult addDefinition(SSAUseInfo useInfo, Value *value); 2537 2538 /// Parse an optional list of SSA uses into 'results'. 2539 ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results); 2540 2541 /// Parse a single SSA use into 'result'. 2542 ParseResult parseSSAUse(SSAUseInfo &result); 2543 2544 /// Given a reference to an SSA value and its type, return a reference. This 2545 /// returns null on failure. 2546 Value *resolveSSAUse(SSAUseInfo useInfo, Type type); 2547 2548 ParseResult parseSSADefOrUseAndType( 2549 const std::function<ParseResult(SSAUseInfo, Type)> &action); 2550 2551 ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value *> &results); 2552 2553 /// Return the location of the value identified by its name and number if it 2554 /// has been already reference. 2555 llvm::Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) { 2556 auto &values = isolatedNameScopes.back().values; 2557 if (!values.count(name) || number >= values[name].size()) 2558 return {}; 2559 if (values[name][number].first) 2560 return values[name][number].second; 2561 return {}; 2562 } 2563 2564 //===--------------------------------------------------------------------===// 2565 // Operation Parsing 2566 //===--------------------------------------------------------------------===// 2567 2568 /// Parse an operation instance. 2569 ParseResult parseOperation(); 2570 2571 /// Parse a single operation successor and its operand list. 2572 ParseResult parseSuccessorAndUseList(Block *&dest, 2573 SmallVectorImpl<Value *> &operands); 2574 2575 /// Parse a comma-separated list of operation successors in brackets. 2576 ParseResult 2577 parseSuccessors(SmallVectorImpl<Block *> &destinations, 2578 SmallVectorImpl<SmallVector<Value *, 4>> &operands); 2579 2580 /// Parse an operation instance that is in the generic form. 2581 Operation *parseGenericOperation(); 2582 2583 /// Parse an operation instance that is in the op-defined custom form. 2584 Operation *parseCustomOperation(); 2585 2586 //===--------------------------------------------------------------------===// 2587 // Region Parsing 2588 //===--------------------------------------------------------------------===// 2589 2590 /// Parse a region into 'region' with the provided entry block arguments. 2591 /// 'isIsolatedNameScope' indicates if the naming scope of this region is 2592 /// isolated from those above. 2593 ParseResult parseRegion(Region ®ion, 2594 ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments, 2595 bool isIsolatedNameScope = false); 2596 2597 /// Parse a region body into 'region'. 2598 ParseResult parseRegionBody(Region ®ion); 2599 2600 //===--------------------------------------------------------------------===// 2601 // Block Parsing 2602 //===--------------------------------------------------------------------===// 2603 2604 /// Parse a new block into 'block'. 2605 ParseResult parseBlock(Block *&block); 2606 2607 /// Parse a list of operations into 'block'. 2608 ParseResult parseBlockBody(Block *block); 2609 2610 /// Parse a (possibly empty) list of block arguments. 2611 ParseResult 2612 parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results, 2613 Block *owner); 2614 2615 /// Get the block with the specified name, creating it if it doesn't 2616 /// already exist. The location specified is the point of use, which allows 2617 /// us to diagnose references to blocks that are not defined precisely. 2618 Block *getBlockNamed(StringRef name, SMLoc loc); 2619 2620 /// Define the block with the specified name. Returns the Block* or nullptr in 2621 /// the case of redefinition. 2622 Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); 2623 2624 private: 2625 /// Returns the info for a block at the current scope for the given name. 2626 std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { 2627 return blocksByName.back()[name]; 2628 } 2629 2630 /// Insert a new forward reference to the given block. 2631 void insertForwardRef(Block *block, SMLoc loc) { 2632 forwardRef.back().try_emplace(block, loc); 2633 } 2634 2635 /// Erase any forward reference to the given block. 2636 bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); } 2637 2638 /// Record that a definition was added at the current scope. 2639 void recordDefinition(StringRef def); 2640 2641 /// Get the value entry for the given SSA name. 2642 SmallVectorImpl<std::pair<Value *, SMLoc>> &getSSAValueEntry(StringRef name); 2643 2644 /// Create a forward reference placeholder value with the given location and 2645 /// result type. 2646 Value *createForwardRefPlaceholder(SMLoc loc, Type type); 2647 2648 /// Return true if this is a forward reference. 2649 bool isForwardRefPlaceholder(Value *value) { 2650 return forwardRefPlaceholders.count(value); 2651 } 2652 2653 /// This struct represents an isolated SSA name scope. This scope may contain 2654 /// other nested non-isolated scopes. These scopes are used for operations 2655 /// that are known to be isolated to allow for reusing names within their 2656 /// regions, even if those names are used above. 2657 struct IsolatedSSANameScope { 2658 /// Record that a definition was added at the current scope. 2659 void recordDefinition(StringRef def) { 2660 definitionsPerScope.back().insert(def); 2661 } 2662 2663 /// Push a nested name scope. 2664 void pushSSANameScope() { definitionsPerScope.push_back({}); } 2665 2666 /// Pop a nested name scope. 2667 void popSSANameScope() { 2668 for (auto &def : definitionsPerScope.pop_back_val()) 2669 values.erase(def.getKey()); 2670 } 2671 2672 /// This keeps track of all of the SSA values we are tracking for each name 2673 /// scope, indexed by their name. This has one entry per result number. 2674 llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values; 2675 2676 /// This keeps track of all of the values defined by a specific name scope. 2677 SmallVector<llvm::StringSet<>, 2> definitionsPerScope; 2678 }; 2679 2680 /// A list of isolated name scopes. 2681 SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes; 2682 2683 /// This keeps track of the block names as well as the location of the first 2684 /// reference for each nested name scope. This is used to diagnose invalid 2685 /// block references and memoize them. 2686 SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName; 2687 SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef; 2688 2689 /// These are all of the placeholders we've made along with the location of 2690 /// their first reference, to allow checking for use of undefined values. 2691 DenseMap<Value *, SMLoc> forwardRefPlaceholders; 2692 2693 /// The builder used when creating parsed operation instances. 2694 OpBuilder opBuilder; 2695 2696 /// The top level module operation. 2697 ModuleOp moduleOp; 2698 }; 2699 } // end anonymous namespace 2700 2701 OperationParser::~OperationParser() { 2702 for (auto &fwd : forwardRefPlaceholders) { 2703 // Drop all uses of undefined forward declared reference and destroy 2704 // defining operation. 2705 fwd.first->dropAllUses(); 2706 fwd.first->getDefiningOp()->destroy(); 2707 } 2708 } 2709 2710 /// After parsing is finished, this function must be called to see if there are 2711 /// any remaining issues. 2712 ParseResult OperationParser::finalize() { 2713 // Check for any forward references that are left. If we find any, error 2714 // out. 2715 if (!forwardRefPlaceholders.empty()) { 2716 SmallVector<std::pair<const char *, Value *>, 4> errors; 2717 // Iteration over the map isn't deterministic, so sort by source location. 2718 for (auto entry : forwardRefPlaceholders) 2719 errors.push_back({entry.second.getPointer(), entry.first}); 2720 llvm::array_pod_sort(errors.begin(), errors.end()); 2721 2722 for (auto entry : errors) { 2723 auto loc = SMLoc::getFromPointer(entry.first); 2724 emitError(loc, "use of undeclared SSA value name"); 2725 } 2726 return failure(); 2727 } 2728 2729 return success(); 2730 } 2731 2732 //===----------------------------------------------------------------------===// 2733 // SSA Value Handling 2734 //===----------------------------------------------------------------------===// 2735 2736 void OperationParser::pushSSANameScope(bool isIsolated) { 2737 blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>()); 2738 forwardRef.push_back(DenseMap<Block *, SMLoc>()); 2739 2740 // Push back a new name definition scope. 2741 if (isIsolated) 2742 isolatedNameScopes.push_back({}); 2743 isolatedNameScopes.back().pushSSANameScope(); 2744 } 2745 2746 ParseResult OperationParser::popSSANameScope() { 2747 auto forwardRefInCurrentScope = forwardRef.pop_back_val(); 2748 2749 // Verify that all referenced blocks were defined. 2750 if (!forwardRefInCurrentScope.empty()) { 2751 SmallVector<std::pair<const char *, Block *>, 4> errors; 2752 // Iteration over the map isn't deterministic, so sort by source location. 2753 for (auto entry : forwardRefInCurrentScope) { 2754 errors.push_back({entry.second.getPointer(), entry.first}); 2755 // Add this block to the top-level region to allow for automatic cleanup. 2756 moduleOp.getOperation()->getRegion(0).push_back(entry.first); 2757 } 2758 llvm::array_pod_sort(errors.begin(), errors.end()); 2759 2760 for (auto entry : errors) { 2761 auto loc = SMLoc::getFromPointer(entry.first); 2762 emitError(loc, "reference to an undefined block"); 2763 } 2764 return failure(); 2765 } 2766 2767 // Pop the next nested namescope. If there is only one internal namescope, 2768 // just pop the isolated scope. 2769 auto ¤tNameScope = isolatedNameScopes.back(); 2770 if (currentNameScope.definitionsPerScope.size() == 1) 2771 isolatedNameScopes.pop_back(); 2772 else 2773 currentNameScope.popSSANameScope(); 2774 2775 blocksByName.pop_back(); 2776 return success(); 2777 } 2778 2779 /// Register a definition of a value with the symbol table. 2780 ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value *value) { 2781 auto &entries = getSSAValueEntry(useInfo.name); 2782 2783 // Make sure there is a slot for this value. 2784 if (entries.size() <= useInfo.number) 2785 entries.resize(useInfo.number + 1); 2786 2787 // If we already have an entry for this, check to see if it was a definition 2788 // or a forward reference. 2789 if (auto *existing = entries[useInfo.number].first) { 2790 if (!isForwardRefPlaceholder(existing)) { 2791 return emitError(useInfo.loc) 2792 .append("redefinition of SSA value '", useInfo.name, "'") 2793 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 2794 .append("previously defined here"); 2795 } 2796 2797 // If it was a forward reference, update everything that used it to use 2798 // the actual definition instead, delete the forward ref, and remove it 2799 // from our set of forward references we track. 2800 existing->replaceAllUsesWith(value); 2801 existing->getDefiningOp()->destroy(); 2802 forwardRefPlaceholders.erase(existing); 2803 } 2804 2805 /// Record this definition for the current scope. 2806 entries[useInfo.number] = {value, useInfo.loc}; 2807 recordDefinition(useInfo.name); 2808 return success(); 2809 } 2810 2811 /// Parse a (possibly empty) list of SSA operands. 2812 /// 2813 /// ssa-use-list ::= ssa-use (`,` ssa-use)* 2814 /// ssa-use-list-opt ::= ssa-use-list? 2815 /// 2816 ParseResult 2817 OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) { 2818 if (getToken().isNot(Token::percent_identifier)) 2819 return success(); 2820 return parseCommaSeparatedList([&]() -> ParseResult { 2821 SSAUseInfo result; 2822 if (parseSSAUse(result)) 2823 return failure(); 2824 results.push_back(result); 2825 return success(); 2826 }); 2827 } 2828 2829 /// Parse a SSA operand for an operation. 2830 /// 2831 /// ssa-use ::= ssa-id 2832 /// 2833 ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) { 2834 result.name = getTokenSpelling(); 2835 result.number = 0; 2836 result.loc = getToken().getLoc(); 2837 if (parseToken(Token::percent_identifier, "expected SSA operand")) 2838 return failure(); 2839 2840 // If we have an attribute ID, it is a result number. 2841 if (getToken().is(Token::hash_identifier)) { 2842 if (auto value = getToken().getHashIdentifierNumber()) 2843 result.number = value.getValue(); 2844 else 2845 return emitError("invalid SSA value result number"); 2846 consumeToken(Token::hash_identifier); 2847 } 2848 2849 return success(); 2850 } 2851 2852 /// Given an unbound reference to an SSA value and its type, return the value 2853 /// it specifies. This returns null on failure. 2854 Value *OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { 2855 auto &entries = getSSAValueEntry(useInfo.name); 2856 2857 // If we have already seen a value of this name, return it. 2858 if (useInfo.number < entries.size() && entries[useInfo.number].first) { 2859 auto *result = entries[useInfo.number].first; 2860 // Check that the type matches the other uses. 2861 if (result->getType() == type) 2862 return result; 2863 2864 emitError(useInfo.loc, "use of value '") 2865 .append(useInfo.name, 2866 "' expects different type than prior uses: ", type, " vs ", 2867 result->getType()) 2868 .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) 2869 .append("prior use here"); 2870 return nullptr; 2871 } 2872 2873 // Make sure we have enough slots for this. 2874 if (entries.size() <= useInfo.number) 2875 entries.resize(useInfo.number + 1); 2876 2877 // If the value has already been defined and this is an overly large result 2878 // number, diagnose that. 2879 if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) 2880 return (emitError(useInfo.loc, "reference to invalid result number"), 2881 nullptr); 2882 2883 // Otherwise, this is a forward reference. Create a placeholder and remember 2884 // that we did so. 2885 auto *result = createForwardRefPlaceholder(useInfo.loc, type); 2886 entries[useInfo.number].first = result; 2887 entries[useInfo.number].second = useInfo.loc; 2888 return result; 2889 } 2890 2891 /// Parse an SSA use with an associated type. 2892 /// 2893 /// ssa-use-and-type ::= ssa-use `:` type 2894 ParseResult OperationParser::parseSSADefOrUseAndType( 2895 const std::function<ParseResult(SSAUseInfo, Type)> &action) { 2896 SSAUseInfo useInfo; 2897 if (parseSSAUse(useInfo) || 2898 parseToken(Token::colon, "expected ':' and type for SSA operand")) 2899 return failure(); 2900 2901 auto type = parseType(); 2902 if (!type) 2903 return failure(); 2904 2905 return action(useInfo, type); 2906 } 2907 2908 /// Parse a (possibly empty) list of SSA operands, followed by a colon, then 2909 /// followed by a type list. 2910 /// 2911 /// ssa-use-and-type-list 2912 /// ::= ssa-use-list ':' type-list-no-parens 2913 /// 2914 ParseResult OperationParser::parseOptionalSSAUseAndTypeList( 2915 SmallVectorImpl<Value *> &results) { 2916 SmallVector<SSAUseInfo, 4> valueIDs; 2917 if (parseOptionalSSAUseList(valueIDs)) 2918 return failure(); 2919 2920 // If there were no operands, then there is no colon or type lists. 2921 if (valueIDs.empty()) 2922 return success(); 2923 2924 SmallVector<Type, 4> types; 2925 if (parseToken(Token::colon, "expected ':' in operand list") || 2926 parseTypeListNoParens(types)) 2927 return failure(); 2928 2929 if (valueIDs.size() != types.size()) 2930 return emitError("expected ") 2931 << valueIDs.size() << " types to match operand list"; 2932 2933 results.reserve(valueIDs.size()); 2934 for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) { 2935 if (auto *value = resolveSSAUse(valueIDs[i], types[i])) 2936 results.push_back(value); 2937 else 2938 return failure(); 2939 } 2940 2941 return success(); 2942 } 2943 2944 /// Record that a definition was added at the current scope. 2945 void OperationParser::recordDefinition(StringRef def) { 2946 isolatedNameScopes.back().recordDefinition(def); 2947 } 2948 2949 /// Get the value entry for the given SSA name. 2950 SmallVectorImpl<std::pair<Value *, SMLoc>> & 2951 OperationParser::getSSAValueEntry(StringRef name) { 2952 return isolatedNameScopes.back().values[name]; 2953 } 2954 2955 /// Create and remember a new placeholder for a forward reference. 2956 Value *OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { 2957 // Forward references are always created as operations, because we just need 2958 // something with a def/use chain. 2959 // 2960 // We create these placeholders as having an empty name, which we know 2961 // cannot be created through normal user input, allowing us to distinguish 2962 // them. 2963 auto name = OperationName("placeholder", getContext()); 2964 auto *op = Operation::create( 2965 getEncodedSourceLocation(loc), name, /*operands=*/{}, type, 2966 /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, 2967 /*resizableOperandList=*/false, getContext()); 2968 forwardRefPlaceholders[op->getResult(0)] = loc; 2969 return op->getResult(0); 2970 } 2971 2972 //===----------------------------------------------------------------------===// 2973 // Operation Parsing 2974 //===----------------------------------------------------------------------===// 2975 2976 /// Parse an operation. 2977 /// 2978 /// operation ::= 2979 /// operation-result? string '(' ssa-use-list? ')' attribute-dict? 2980 /// `:` function-type trailing-location? 2981 /// operation-result ::= ssa-id ((`:` integer-literal) | (`,` ssa-id)*) `=` 2982 /// 2983 ParseResult OperationParser::parseOperation() { 2984 auto loc = getToken().getLoc(); 2985 SmallVector<std::pair<StringRef, SMLoc>, 1> resultIDs; 2986 size_t numExpectedResults; 2987 if (getToken().is(Token::percent_identifier)) { 2988 // Parse the first result id. 2989 resultIDs.emplace_back(getTokenSpelling(), loc); 2990 consumeToken(Token::percent_identifier); 2991 2992 // If the next token is a ':', we parse the expected result count. 2993 if (consumeIf(Token::colon)) { 2994 // Check that the next token is an integer. 2995 if (!getToken().is(Token::integer)) 2996 return emitError("expected integer number of results"); 2997 2998 // Check that number of results is > 0. 2999 auto val = getToken().getUInt64IntegerValue(); 3000 if (!val.hasValue() || val.getValue() < 1) 3001 return emitError("expected named operation to have atleast 1 result"); 3002 consumeToken(Token::integer); 3003 numExpectedResults = *val; 3004 } else { 3005 // Otherwise, this is a comma separated list of result ids. 3006 if (consumeIf(Token::comma)) { 3007 auto parseNextResult = [&]() -> ParseResult { 3008 // Parse the next result id. 3009 if (!getToken().is(Token::percent_identifier)) 3010 return emitError("expected valid ssa identifier"); 3011 3012 resultIDs.emplace_back(getTokenSpelling(), getToken().getLoc()); 3013 consumeToken(Token::percent_identifier); 3014 return success(); 3015 }; 3016 3017 if (parseCommaSeparatedList(parseNextResult)) 3018 return failure(); 3019 } 3020 numExpectedResults = resultIDs.size(); 3021 } 3022 3023 if (parseToken(Token::equal, "expected '=' after SSA name")) 3024 return failure(); 3025 } 3026 3027 Operation *op; 3028 if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) 3029 op = parseCustomOperation(); 3030 else if (getToken().is(Token::string)) 3031 op = parseGenericOperation(); 3032 else 3033 return emitError("expected operation name in quotes"); 3034 3035 // If parsing of the basic operation failed, then this whole thing fails. 3036 if (!op) 3037 return failure(); 3038 3039 // If the operation had a name, register it. 3040 if (!resultIDs.empty()) { 3041 if (op->getNumResults() == 0) 3042 return emitError(loc, "cannot name an operation with no results"); 3043 if (numExpectedResults != op->getNumResults()) 3044 return emitError(loc, "operation defines ") 3045 << op->getNumResults() << " results but was provided " 3046 << numExpectedResults << " to bind"; 3047 3048 // If the number of result names matches the number of operation results, we 3049 // can directly use the provided names. 3050 if (resultIDs.size() == op->getNumResults()) { 3051 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) 3052 if (addDefinition({resultIDs[i].first, 0, resultIDs[i].second}, 3053 op->getResult(i))) 3054 return failure(); 3055 } else { 3056 // Otherwise, we use the same name for all results. 3057 StringRef name = resultIDs.front().first; 3058 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) 3059 if (addDefinition({name, i, loc}, op->getResult(i))) 3060 return failure(); 3061 } 3062 } 3063 3064 // Try to parse the optional trailing location. 3065 if (parseOptionalTrailingLocation(op)) 3066 return failure(); 3067 3068 return success(); 3069 } 3070 3071 /// Parse a single operation successor and its operand list. 3072 /// 3073 /// successor ::= block-id branch-use-list? 3074 /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` 3075 /// 3076 ParseResult 3077 OperationParser::parseSuccessorAndUseList(Block *&dest, 3078 SmallVectorImpl<Value *> &operands) { 3079 // Verify branch is identifier and get the matching block. 3080 if (!getToken().is(Token::caret_identifier)) 3081 return emitError("expected block name"); 3082 dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); 3083 consumeToken(); 3084 3085 // Handle optional arguments. 3086 if (consumeIf(Token::l_paren) && 3087 (parseOptionalSSAUseAndTypeList(operands) || 3088 parseToken(Token::r_paren, "expected ')' to close argument list"))) { 3089 return failure(); 3090 } 3091 3092 return success(); 3093 } 3094 3095 /// Parse a comma-separated list of operation successors in brackets. 3096 /// 3097 /// successor-list ::= `[` successor (`,` successor )* `]` 3098 /// 3099 ParseResult OperationParser::parseSuccessors( 3100 SmallVectorImpl<Block *> &destinations, 3101 SmallVectorImpl<SmallVector<Value *, 4>> &operands) { 3102 if (parseToken(Token::l_square, "expected '['")) 3103 return failure(); 3104 3105 auto parseElt = [this, &destinations, &operands]() { 3106 Block *dest; 3107 SmallVector<Value *, 4> destOperands; 3108 auto res = parseSuccessorAndUseList(dest, destOperands); 3109 destinations.push_back(dest); 3110 operands.push_back(destOperands); 3111 return res; 3112 }; 3113 return parseCommaSeparatedListUntil(Token::r_square, parseElt, 3114 /*allowEmptyList=*/false); 3115 } 3116 3117 namespace { 3118 // RAII-style guard for cleaning up the regions in the operation state before 3119 // deleting them. Within the parser, regions may get deleted if parsing failed, 3120 // and other errors may be present, in praticular undominated uses. This makes 3121 // sure such uses are deleted. 3122 struct CleanupOpStateRegions { 3123 ~CleanupOpStateRegions() { 3124 SmallVector<Region *, 4> regionsToClean; 3125 regionsToClean.reserve(state.regions.size()); 3126 for (auto ®ion : state.regions) 3127 if (region) 3128 for (auto &block : *region) 3129 block.dropAllDefinedValueUses(); 3130 } 3131 OperationState &state; 3132 }; 3133 } // namespace 3134 3135 Operation *OperationParser::parseGenericOperation() { 3136 // Get location information for the operation. 3137 auto srcLocation = getEncodedSourceLocation(getToken().getLoc()); 3138 3139 auto name = getToken().getStringValue(); 3140 if (name.empty()) 3141 return (emitError("empty operation name is invalid"), nullptr); 3142 if (name.find('\0') != StringRef::npos) 3143 return (emitError("null character not allowed in operation name"), nullptr); 3144 3145 consumeToken(Token::string); 3146 3147 OperationState result(srcLocation, name); 3148 3149 // Generic operations have a resizable operation list. 3150 result.setOperandListToResizable(); 3151 3152 // Parse the operand list. 3153 SmallVector<SSAUseInfo, 8> operandInfos; 3154 3155 if (parseToken(Token::l_paren, "expected '(' to start operand list") || 3156 parseOptionalSSAUseList(operandInfos) || 3157 parseToken(Token::r_paren, "expected ')' to end operand list")) { 3158 return nullptr; 3159 } 3160 3161 // Parse the successor list but don't add successors to the result yet to 3162 // avoid messing up with the argument order. 3163 SmallVector<Block *, 2> successors; 3164 SmallVector<SmallVector<Value *, 4>, 2> successorOperands; 3165 if (getToken().is(Token::l_square)) { 3166 // Check if the operation is a known terminator. 3167 const AbstractOperation *abstractOp = result.name.getAbstractOperation(); 3168 if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) 3169 return emitError("successors in non-terminator"), nullptr; 3170 if (parseSuccessors(successors, successorOperands)) 3171 return nullptr; 3172 } 3173 3174 // Parse the region list. 3175 CleanupOpStateRegions guard{result}; 3176 if (consumeIf(Token::l_paren)) { 3177 do { 3178 // Create temporary regions with the top level region as parent. 3179 result.regions.emplace_back(new Region(moduleOp)); 3180 if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) 3181 return nullptr; 3182 } while (consumeIf(Token::comma)); 3183 if (parseToken(Token::r_paren, "expected ')' to end region list")) 3184 return nullptr; 3185 } 3186 3187 if (getToken().is(Token::l_brace)) { 3188 if (parseAttributeDict(result.attributes)) 3189 return nullptr; 3190 } 3191 3192 if (parseToken(Token::colon, "expected ':' followed by operation type")) 3193 return nullptr; 3194 3195 auto typeLoc = getToken().getLoc(); 3196 auto type = parseType(); 3197 if (!type) 3198 return nullptr; 3199 auto fnType = type.dyn_cast<FunctionType>(); 3200 if (!fnType) 3201 return (emitError(typeLoc, "expected function type"), nullptr); 3202 3203 result.addTypes(fnType.getResults()); 3204 3205 // Check that we have the right number of types for the operands. 3206 auto operandTypes = fnType.getInputs(); 3207 if (operandTypes.size() != operandInfos.size()) { 3208 auto plural = "s"[operandInfos.size() == 1]; 3209 return (emitError(typeLoc, "expected ") 3210 << operandInfos.size() << " operand type" << plural 3211 << " but had " << operandTypes.size(), 3212 nullptr); 3213 } 3214 3215 // Resolve all of the operands. 3216 for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) { 3217 result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i])); 3218 if (!result.operands.back()) 3219 return nullptr; 3220 } 3221 3222 // Add the sucessors, and their operands after the proper operands. 3223 for (const auto &succ : llvm::zip(successors, successorOperands)) { 3224 Block *successor = std::get<0>(succ); 3225 const SmallVector<Value *, 4> &operands = std::get<1>(succ); 3226 result.addSuccessor(successor, operands); 3227 } 3228 3229 return opBuilder.createOperation(result); 3230 } 3231 3232 namespace { 3233 class CustomOpAsmParser : public OpAsmParser { 3234 public: 3235 CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, 3236 OperationParser &parser) 3237 : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} 3238 3239 /// Parse an instance of the operation described by 'opDefinition' into the 3240 /// provided operation state. 3241 ParseResult parseOperation(OperationState *opState) { 3242 if (opDefinition->parseAssembly(this, opState)) 3243 return failure(); 3244 return success(); 3245 } 3246 3247 //===--------------------------------------------------------------------===// 3248 // Utilities 3249 //===--------------------------------------------------------------------===// 3250 3251 /// Return if any errors were emitted during parsing. 3252 bool didEmitError() const { return emittedError; } 3253 3254 /// Emit a diagnostic at the specified location and return failure. 3255 InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { 3256 emittedError = true; 3257 return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + 3258 message); 3259 } 3260 3261 llvm::SMLoc getCurrentLocation() override { 3262 return parser.getToken().getLoc(); 3263 } 3264 3265 Builder &getBuilder() const override { return parser.builder; } 3266 3267 llvm::SMLoc getNameLoc() const override { return nameLoc; } 3268 3269 //===--------------------------------------------------------------------===// 3270 // Token Parsing 3271 //===--------------------------------------------------------------------===// 3272 3273 /// Parse a `->` token. 3274 ParseResult parseArrow() override { 3275 return parser.parseToken(Token::arrow, "expected '->'"); 3276 } 3277 3278 /// Parses a `->` if present. 3279 ParseResult parseOptionalArrow() override { 3280 return success(parser.consumeIf(Token::arrow)); 3281 } 3282 3283 /// Parse a `:` token. 3284 ParseResult parseColon() override { 3285 return parser.parseToken(Token::colon, "expected ':'"); 3286 } 3287 3288 /// Parse a `:` token if present. 3289 ParseResult parseOptionalColon() override { 3290 return success(parser.consumeIf(Token::colon)); 3291 } 3292 3293 /// Parse a `,` token. 3294 ParseResult parseComma() override { 3295 return parser.parseToken(Token::comma, "expected ','"); 3296 } 3297 3298 /// Parse a `,` token if present. 3299 ParseResult parseOptionalComma() override { 3300 return success(parser.consumeIf(Token::comma)); 3301 } 3302 3303 /// Parses a `...` if present. 3304 ParseResult parseOptionalEllipsis() override { 3305 return success(parser.consumeIf(Token::ellipsis)); 3306 } 3307 3308 /// Parse a `=` token. 3309 ParseResult parseEqual() override { 3310 return parser.parseToken(Token::equal, "expected '='"); 3311 } 3312 3313 /// Parse a keyword if present. 3314 ParseResult parseOptionalKeyword(const char *keyword) override { 3315 // Check that the current token is a bare identifier or keyword. 3316 if (parser.getToken().isNot(Token::bare_identifier) && 3317 !parser.getToken().isKeyword()) 3318 return failure(); 3319 3320 if (parser.getTokenSpelling() == keyword) { 3321 parser.consumeToken(); 3322 return success(); 3323 } 3324 return failure(); 3325 } 3326 3327 /// Parse a `(` token. 3328 ParseResult parseLParen() override { 3329 return parser.parseToken(Token::l_paren, "expected '('"); 3330 } 3331 3332 /// Parses a '(' if present. 3333 ParseResult parseOptionalLParen() override { 3334 return success(parser.consumeIf(Token::l_paren)); 3335 } 3336 3337 /// Parse a `)` token. 3338 ParseResult parseRParen() override { 3339 return parser.parseToken(Token::r_paren, "expected ')'"); 3340 } 3341 3342 /// Parses a ')' if present. 3343 ParseResult parseOptionalRParen() override { 3344 return success(parser.consumeIf(Token::r_paren)); 3345 } 3346 3347 /// Parse a `[` token. 3348 ParseResult parseLSquare() override { 3349 return parser.parseToken(Token::l_square, "expected '['"); 3350 } 3351 3352 /// Parses a '[' if present. 3353 ParseResult parseOptionalLSquare() override { 3354 return success(parser.consumeIf(Token::l_square)); 3355 } 3356 3357 /// Parse a `]` token. 3358 ParseResult parseRSquare() override { 3359 return parser.parseToken(Token::r_square, "expected ']'"); 3360 } 3361 3362 /// Parses a ']' if present. 3363 ParseResult parseOptionalRSquare() override { 3364 return success(parser.consumeIf(Token::r_square)); 3365 } 3366 3367 //===--------------------------------------------------------------------===// 3368 // Attribute Parsing 3369 //===--------------------------------------------------------------------===// 3370 3371 /// Parse an arbitrary attribute of a given type and return it in result. This 3372 /// also adds the attribute to the specified attribute list with the specified 3373 /// name. 3374 ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, 3375 SmallVectorImpl<NamedAttribute> &attrs) override { 3376 result = parser.parseAttribute(type); 3377 if (!result) 3378 return failure(); 3379 3380 attrs.push_back(parser.builder.getNamedAttr(attrName, result)); 3381 return success(); 3382 } 3383 3384 /// Parse a named dictionary into 'result' if it is present. 3385 ParseResult 3386 parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) override { 3387 if (parser.getToken().isNot(Token::l_brace)) 3388 return success(); 3389 return parser.parseAttributeDict(result); 3390 } 3391 3392 //===--------------------------------------------------------------------===// 3393 // Identifier Parsing 3394 //===--------------------------------------------------------------------===// 3395 3396 /// Parse an @-identifier and store it (without the '@' symbol) in a string 3397 /// attribute named 'attrName'. 3398 ParseResult parseSymbolName(StringAttr &result, StringRef attrName, 3399 SmallVectorImpl<NamedAttribute> &attrs) override { 3400 if (parser.getToken().isNot(Token::at_identifier)) 3401 return failure(); 3402 result = getBuilder().getStringAttr(parser.getTokenSpelling().drop_front()); 3403 attrs.push_back(getBuilder().getNamedAttr(attrName, result)); 3404 parser.consumeToken(); 3405 return success(); 3406 } 3407 3408 //===--------------------------------------------------------------------===// 3409 // Operand Parsing 3410 //===--------------------------------------------------------------------===// 3411 3412 /// Parse a single operand. 3413 ParseResult parseOperand(OperandType &result) override { 3414 OperationParser::SSAUseInfo useInfo; 3415 if (parser.parseSSAUse(useInfo)) 3416 return failure(); 3417 3418 result = {useInfo.loc, useInfo.name, useInfo.number}; 3419 return success(); 3420 } 3421 3422 /// Parse zero or more SSA comma-separated operand references with a specified 3423 /// surrounding delimiter, and an optional required operand count. 3424 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 3425 int requiredOperandCount = -1, 3426 Delimiter delimiter = Delimiter::None) override { 3427 return parseOperandOrRegionArgList(result, /*isOperandList=*/true, 3428 requiredOperandCount, delimiter); 3429 } 3430 3431 /// Parse zero or more SSA comma-separated operand or region arguments with 3432 /// optional surrounding delimiter and required operand count. 3433 ParseResult 3434 parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 3435 bool isOperandList, int requiredOperandCount = -1, 3436 Delimiter delimiter = Delimiter::None) { 3437 auto startLoc = parser.getToken().getLoc(); 3438 3439 // Handle delimiters. 3440 switch (delimiter) { 3441 case Delimiter::None: 3442 // Don't check for the absence of a delimiter if the number of operands 3443 // is unknown (and hence the operand list could be empty). 3444 if (requiredOperandCount == -1) 3445 break; 3446 // Token already matches an identifier and so can't be a delimiter. 3447 if (parser.getToken().is(Token::percent_identifier)) 3448 break; 3449 // Test against known delimiters. 3450 if (parser.getToken().is(Token::l_paren) || 3451 parser.getToken().is(Token::l_square)) 3452 return emitError(startLoc, "unexpected delimiter"); 3453 return emitError(startLoc, "invalid operand"); 3454 case Delimiter::OptionalParen: 3455 if (parser.getToken().isNot(Token::l_paren)) 3456 return success(); 3457 LLVM_FALLTHROUGH; 3458 case Delimiter::Paren: 3459 if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) 3460 return failure(); 3461 break; 3462 case Delimiter::OptionalSquare: 3463 if (parser.getToken().isNot(Token::l_square)) 3464 return success(); 3465 LLVM_FALLTHROUGH; 3466 case Delimiter::Square: 3467 if (parser.parseToken(Token::l_square, "expected '[' in operand list")) 3468 return failure(); 3469 break; 3470 } 3471 3472 // Check for zero operands. 3473 if (parser.getToken().is(Token::percent_identifier)) { 3474 do { 3475 OperandType operandOrArg; 3476 if (isOperandList ? parseOperand(operandOrArg) 3477 : parseRegionArgument(operandOrArg)) 3478 return failure(); 3479 result.push_back(operandOrArg); 3480 } while (parser.consumeIf(Token::comma)); 3481 } 3482 3483 // Handle delimiters. If we reach here, the optional delimiters were 3484 // present, so we need to parse their closing one. 3485 switch (delimiter) { 3486 case Delimiter::None: 3487 break; 3488 case Delimiter::OptionalParen: 3489 case Delimiter::Paren: 3490 if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) 3491 return failure(); 3492 break; 3493 case Delimiter::OptionalSquare: 3494 case Delimiter::Square: 3495 if (parser.parseToken(Token::r_square, "expected ']' in operand list")) 3496 return failure(); 3497 break; 3498 } 3499 3500 if (requiredOperandCount != -1 && 3501 result.size() != static_cast<size_t>(requiredOperandCount)) 3502 return emitError(startLoc, "expected ") 3503 << requiredOperandCount << " operands"; 3504 return success(); 3505 } 3506 3507 /// Parse zero or more trailing SSA comma-separated trailing operand 3508 /// references with a specified surrounding delimiter, and an optional 3509 /// required operand count. A leading comma is expected before the operands. 3510 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 3511 int requiredOperandCount, 3512 Delimiter delimiter) override { 3513 if (parser.getToken().is(Token::comma)) { 3514 parseComma(); 3515 return parseOperandList(result, requiredOperandCount, delimiter); 3516 } 3517 if (requiredOperandCount != -1) 3518 return emitError(parser.getToken().getLoc(), "expected ") 3519 << requiredOperandCount << " operands"; 3520 return success(); 3521 } 3522 3523 /// Resolve an operand to an SSA value, emitting an error on failure. 3524 ParseResult resolveOperand(const OperandType &operand, Type type, 3525 SmallVectorImpl<Value *> &result) override { 3526 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 3527 operand.location}; 3528 if (auto *value = parser.resolveSSAUse(operandInfo, type)) { 3529 result.push_back(value); 3530 return success(); 3531 } 3532 return failure(); 3533 } 3534 3535 /// Parse an AffineMap of SSA ids. 3536 ParseResult 3537 parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, 3538 Attribute &mapAttr, StringRef attrName, 3539 SmallVectorImpl<NamedAttribute> &attrs) override { 3540 SmallVector<OperandType, 2> dimOperands; 3541 SmallVector<OperandType, 1> symOperands; 3542 3543 auto parseElement = [&](bool isSymbol) -> ParseResult { 3544 OperandType operand; 3545 if (parseOperand(operand)) 3546 return failure(); 3547 if (isSymbol) 3548 symOperands.push_back(operand); 3549 else 3550 dimOperands.push_back(operand); 3551 return success(); 3552 }; 3553 3554 AffineMap map; 3555 if (parser.parseAffineMapOfSSAIds(map, parseElement)) 3556 return failure(); 3557 // Add AffineMap attribute. 3558 if (map) { 3559 mapAttr = parser.builder.getAffineMapAttr(map); 3560 attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); 3561 } 3562 3563 // Add dim operands before symbol operands in 'operands'. 3564 operands.assign(dimOperands.begin(), dimOperands.end()); 3565 operands.append(symOperands.begin(), symOperands.end()); 3566 return success(); 3567 } 3568 3569 //===--------------------------------------------------------------------===// 3570 // Region Parsing 3571 //===--------------------------------------------------------------------===// 3572 3573 /// Parse a region that takes `arguments` of `argTypes` types. This 3574 /// effectively defines the SSA values of `arguments` and assignes their type. 3575 ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments, 3576 ArrayRef<Type> argTypes, 3577 bool enableNameShadowing) override { 3578 assert(arguments.size() == argTypes.size() && 3579 "mismatching number of arguments and types"); 3580 3581 SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2> 3582 regionArguments; 3583 for (const auto &pair : llvm::zip(arguments, argTypes)) { 3584 const OperandType &operand = std::get<0>(pair); 3585 Type type = std::get<1>(pair); 3586 OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number, 3587 operand.location}; 3588 regionArguments.emplace_back(operandInfo, type); 3589 } 3590 3591 // Try to parse the region. 3592 assert((!enableNameShadowing || 3593 opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) && 3594 "name shadowing is only allowed on isolated regions"); 3595 if (parser.parseRegion(region, regionArguments, enableNameShadowing)) 3596 return failure(); 3597 return success(); 3598 } 3599 3600 /// Parses a region if present. 3601 ParseResult parseOptionalRegion(Region ®ion, 3602 ArrayRef<OperandType> arguments, 3603 ArrayRef<Type> argTypes, 3604 bool enableNameShadowing) override { 3605 if (parser.getToken().isNot(Token::l_brace)) 3606 return success(); 3607 return parseRegion(region, arguments, argTypes, enableNameShadowing); 3608 } 3609 3610 /// Parse a region argument. The type of the argument will be resolved later 3611 /// by a call to `parseRegion`. 3612 ParseResult parseRegionArgument(OperandType &argument) override { 3613 return parseOperand(argument); 3614 } 3615 3616 /// Parse a region argument if present. 3617 ParseResult parseOptionalRegionArgument(OperandType &argument) override { 3618 if (parser.getToken().isNot(Token::percent_identifier)) 3619 return success(); 3620 return parseRegionArgument(argument); 3621 } 3622 3623 ParseResult 3624 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 3625 int requiredOperandCount = -1, 3626 Delimiter delimiter = Delimiter::None) override { 3627 return parseOperandOrRegionArgList(result, /*isOperandList=*/false, 3628 requiredOperandCount, delimiter); 3629 } 3630 3631 //===--------------------------------------------------------------------===// 3632 // Successor Parsing 3633 //===--------------------------------------------------------------------===// 3634 3635 /// Parse a single operation successor and its operand list. 3636 ParseResult 3637 parseSuccessorAndUseList(Block *&dest, 3638 SmallVectorImpl<Value *> &operands) override { 3639 return parser.parseSuccessorAndUseList(dest, operands); 3640 } 3641 3642 //===--------------------------------------------------------------------===// 3643 // Type Parsing 3644 //===--------------------------------------------------------------------===// 3645 3646 /// Parse a type. 3647 ParseResult parseType(Type &result) override { 3648 return failure(!(result = parser.parseType())); 3649 } 3650 3651 /// Parse an optional arrow followed by a type list. 3652 ParseResult 3653 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override { 3654 if (!parser.consumeIf(Token::arrow)) 3655 return success(); 3656 return parser.parseFunctionResultTypes(result); 3657 } 3658 3659 /// Parse a colon followed by a type. 3660 ParseResult parseColonType(Type &result) override { 3661 return failure(parser.parseToken(Token::colon, "expected ':'") || 3662 !(result = parser.parseType())); 3663 } 3664 3665 /// Parse a colon followed by a type list, which must have at least one type. 3666 ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override { 3667 if (parser.parseToken(Token::colon, "expected ':'")) 3668 return failure(); 3669 return parser.parseTypeListNoParens(result); 3670 } 3671 3672 /// Parse an optional colon followed by a type list, which if present must 3673 /// have at least one type. 3674 ParseResult 3675 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override { 3676 if (!parser.consumeIf(Token::colon)) 3677 return success(); 3678 return parser.parseTypeListNoParens(result); 3679 } 3680 3681 private: 3682 /// The source location of the operation name. 3683 SMLoc nameLoc; 3684 3685 /// The abstract information of the operation. 3686 const AbstractOperation *opDefinition; 3687 3688 /// The main operation parser. 3689 OperationParser &parser; 3690 3691 /// A flag that indicates if any errors were emitted during parsing. 3692 bool emittedError = false; 3693 }; 3694 } // end anonymous namespace. 3695 3696 Operation *OperationParser::parseCustomOperation() { 3697 auto opLoc = getToken().getLoc(); 3698 auto opName = getTokenSpelling(); 3699 3700 auto *opDefinition = AbstractOperation::lookup(opName, getContext()); 3701 if (!opDefinition && !opName.contains('.')) { 3702 // If the operation name has no namespace prefix we treat it as a standard 3703 // operation and prefix it with "std". 3704 // TODO: Would it be better to just build a mapping of the registered 3705 // operations in the standard dialect? 3706 opDefinition = 3707 AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); 3708 } 3709 3710 if (!opDefinition) { 3711 emitError(opLoc) << "custom op '" << opName << "' is unknown"; 3712 return nullptr; 3713 } 3714 3715 consumeToken(); 3716 3717 // If the custom op parser crashes, produce some indication to help 3718 // debugging. 3719 std::string opNameStr = opName.str(); 3720 llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", 3721 opNameStr.c_str()); 3722 3723 // Get location information for the operation. 3724 auto srcLocation = getEncodedSourceLocation(opLoc); 3725 3726 // Have the op implementation take a crack and parsing this. 3727 OperationState opState(srcLocation, opDefinition->name); 3728 CleanupOpStateRegions guard{opState}; 3729 CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); 3730 if (opAsmParser.parseOperation(&opState)) 3731 return nullptr; 3732 3733 // If it emitted an error, we failed. 3734 if (opAsmParser.didEmitError()) 3735 return nullptr; 3736 3737 // Otherwise, we succeeded. Use the state it parsed as our op information. 3738 return opBuilder.createOperation(opState); 3739 } 3740 3741 //===----------------------------------------------------------------------===// 3742 // Region Parsing 3743 //===----------------------------------------------------------------------===// 3744 3745 /// Region. 3746 /// 3747 /// region ::= '{' region-body 3748 /// 3749 ParseResult OperationParser::parseRegion( 3750 Region ®ion, 3751 ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments, 3752 bool isIsolatedNameScope) { 3753 // Parse the '{'. 3754 if (parseToken(Token::l_brace, "expected '{' to begin a region")) 3755 return failure(); 3756 3757 // Check for an empty region. 3758 if (entryArguments.empty() && consumeIf(Token::r_brace)) 3759 return success(); 3760 auto currentPt = opBuilder.saveInsertionPoint(); 3761 3762 // Push a new named value scope. 3763 pushSSANameScope(isIsolatedNameScope); 3764 3765 // Parse the first block directly to allow for it to be unnamed. 3766 Block *block = new Block(); 3767 3768 // Add arguments to the entry block. 3769 if (!entryArguments.empty()) { 3770 for (auto &placeholderArgPair : entryArguments) { 3771 auto &argInfo = placeholderArgPair.first; 3772 // Ensure that the argument was not already defined. 3773 if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { 3774 return emitError(argInfo.loc, "region entry argument '" + argInfo.name + 3775 "' is already in use") 3776 .attachNote(getEncodedSourceLocation(*defLoc)) 3777 << "previously referenced here"; 3778 } 3779 if (addDefinition(placeholderArgPair.first, 3780 block->addArgument(placeholderArgPair.second))) { 3781 delete block; 3782 return failure(); 3783 } 3784 } 3785 3786 // If we had named arguments, then don't allow a block name. 3787 if (getToken().is(Token::caret_identifier)) 3788 return emitError("invalid block name in region with named arguments"); 3789 } 3790 3791 if (parseBlock(block)) { 3792 delete block; 3793 return failure(); 3794 } 3795 3796 // Verify that no other arguments were parsed. 3797 if (!entryArguments.empty() && 3798 block->getNumArguments() > entryArguments.size()) { 3799 delete block; 3800 return emitError("entry block arguments were already defined"); 3801 } 3802 3803 // Parse the rest of the region. 3804 region.push_back(block); 3805 if (parseRegionBody(region)) 3806 return failure(); 3807 3808 // Pop the SSA value scope for this region. 3809 if (popSSANameScope()) 3810 return failure(); 3811 3812 // Reset the original insertion point. 3813 opBuilder.restoreInsertionPoint(currentPt); 3814 return success(); 3815 } 3816 3817 /// Region. 3818 /// 3819 /// region-body ::= block* '}' 3820 /// 3821 ParseResult OperationParser::parseRegionBody(Region ®ion) { 3822 // Parse the list of blocks. 3823 while (!consumeIf(Token::r_brace)) { 3824 Block *newBlock = nullptr; 3825 if (parseBlock(newBlock)) 3826 return failure(); 3827 region.push_back(newBlock); 3828 } 3829 return success(); 3830 } 3831 3832 //===----------------------------------------------------------------------===// 3833 // Block Parsing 3834 //===----------------------------------------------------------------------===// 3835 3836 /// Block declaration. 3837 /// 3838 /// block ::= block-label? operation* 3839 /// block-label ::= block-id block-arg-list? `:` 3840 /// block-id ::= caret-id 3841 /// block-arg-list ::= `(` ssa-id-and-type-list? `)` 3842 /// 3843 ParseResult OperationParser::parseBlock(Block *&block) { 3844 // The first block of a region may already exist, if it does the caret 3845 // identifier is optional. 3846 if (block && getToken().isNot(Token::caret_identifier)) 3847 return parseBlockBody(block); 3848 3849 SMLoc nameLoc = getToken().getLoc(); 3850 auto name = getTokenSpelling(); 3851 if (parseToken(Token::caret_identifier, "expected block name")) 3852 return failure(); 3853 3854 block = defineBlockNamed(name, nameLoc, block); 3855 3856 // Fail if the block was already defined. 3857 if (!block) 3858 return emitError(nameLoc, "redefinition of block '") << name << "'"; 3859 3860 // If an argument list is present, parse it. 3861 if (consumeIf(Token::l_paren)) { 3862 SmallVector<BlockArgument *, 8> bbArgs; 3863 if (parseOptionalBlockArgList(bbArgs, block) || 3864 parseToken(Token::r_paren, "expected ')' to end argument list")) 3865 return failure(); 3866 } 3867 3868 if (parseToken(Token::colon, "expected ':' after block name")) 3869 return failure(); 3870 3871 return parseBlockBody(block); 3872 } 3873 3874 ParseResult OperationParser::parseBlockBody(Block *block) { 3875 // Set the insertion point to the end of the block to parse. 3876 opBuilder.setInsertionPointToEnd(block); 3877 3878 // Parse the list of operations that make up the body of the block. 3879 while (getToken().isNot(Token::caret_identifier, Token::r_brace)) 3880 if (parseOperation()) 3881 return failure(); 3882 3883 return success(); 3884 } 3885 3886 /// Get the block with the specified name, creating it if it doesn't already 3887 /// exist. The location specified is the point of use, which allows 3888 /// us to diagnose references to blocks that are not defined precisely. 3889 Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { 3890 auto &blockAndLoc = getBlockInfoByName(name); 3891 if (!blockAndLoc.first) { 3892 blockAndLoc = {new Block(), loc}; 3893 insertForwardRef(blockAndLoc.first, loc); 3894 } 3895 3896 return blockAndLoc.first; 3897 } 3898 3899 /// Define the block with the specified name. Returns the Block* or nullptr in 3900 /// the case of redefinition. 3901 Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, 3902 Block *existing) { 3903 auto &blockAndLoc = getBlockInfoByName(name); 3904 if (!blockAndLoc.first) { 3905 // If the caller provided a block, use it. Otherwise create a new one. 3906 if (!existing) 3907 existing = new Block(); 3908 blockAndLoc.first = existing; 3909 blockAndLoc.second = loc; 3910 return blockAndLoc.first; 3911 } 3912 3913 // Forward declarations are removed once defined, so if we are defining a 3914 // existing block and it is not a forward declaration, then it is a 3915 // redeclaration. 3916 if (!eraseForwardRef(blockAndLoc.first)) 3917 return nullptr; 3918 return blockAndLoc.first; 3919 } 3920 3921 /// Parse a (possibly empty) list of SSA operands with types as block arguments. 3922 /// 3923 /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* 3924 /// 3925 ParseResult OperationParser::parseOptionalBlockArgList( 3926 SmallVectorImpl<BlockArgument *> &results, Block *owner) { 3927 if (getToken().is(Token::r_brace)) 3928 return success(); 3929 3930 // If the block already has arguments, then we're handling the entry block. 3931 // Parse and register the names for the arguments, but do not add them. 3932 bool definingExistingArgs = owner->getNumArguments() != 0; 3933 unsigned nextArgument = 0; 3934 3935 return parseCommaSeparatedList([&]() -> ParseResult { 3936 return parseSSADefOrUseAndType( 3937 [&](SSAUseInfo useInfo, Type type) -> ParseResult { 3938 // If this block did not have existing arguments, define a new one. 3939 if (!definingExistingArgs) 3940 return addDefinition(useInfo, owner->addArgument(type)); 3941 3942 // Otherwise, ensure that this argument has already been created. 3943 if (nextArgument >= owner->getNumArguments()) 3944 return emitError("too many arguments specified in argument list"); 3945 3946 // Finally, make sure the existing argument has the correct type. 3947 auto *arg = owner->getArgument(nextArgument++); 3948 if (arg->getType() != type) 3949 return emitError("argument and block argument type mismatch"); 3950 return addDefinition(useInfo, arg); 3951 }); 3952 }); 3953 } 3954 3955 //===----------------------------------------------------------------------===// 3956 // Top-level entity parsing. 3957 //===----------------------------------------------------------------------===// 3958 3959 namespace { 3960 /// This parser handles entities that are only valid at the top level of the 3961 /// file. 3962 class ModuleParser : public Parser { 3963 public: 3964 explicit ModuleParser(ParserState &state) : Parser(state) {} 3965 3966 ParseResult parseModule(ModuleOp module); 3967 3968 private: 3969 /// Parse an attribute alias declaration. 3970 ParseResult parseAttributeAliasDef(); 3971 3972 /// Parse an attribute alias declaration. 3973 ParseResult parseTypeAliasDef(); 3974 }; 3975 } // end anonymous namespace 3976 3977 /// Parses an attribute alias declaration. 3978 /// 3979 /// attribute-alias-def ::= '#' alias-name `=` attribute-value 3980 /// 3981 ParseResult ModuleParser::parseAttributeAliasDef() { 3982 assert(getToken().is(Token::hash_identifier)); 3983 StringRef aliasName = getTokenSpelling().drop_front(); 3984 3985 // Check for redefinitions. 3986 if (getState().attributeAliasDefinitions.count(aliasName) > 0) 3987 return emitError("redefinition of attribute alias id '" + aliasName + "'"); 3988 3989 // Make sure this isn't invading the dialect attribute namespace. 3990 if (aliasName.contains('.')) 3991 return emitError("attribute names with a '.' are reserved for " 3992 "dialect-defined names"); 3993 3994 consumeToken(Token::hash_identifier); 3995 3996 // Parse the '='. 3997 if (parseToken(Token::equal, "expected '=' in attribute alias definition")) 3998 return failure(); 3999 4000 // Parse the attribute value. 4001 Attribute attr = parseAttribute(); 4002 if (!attr) 4003 return failure(); 4004 4005 getState().attributeAliasDefinitions[aliasName] = attr; 4006 return success(); 4007 } 4008 4009 /// Parse a type alias declaration. 4010 /// 4011 /// type-alias-def ::= '!' alias-name `=` 'type' type 4012 /// 4013 ParseResult ModuleParser::parseTypeAliasDef() { 4014 assert(getToken().is(Token::exclamation_identifier)); 4015 StringRef aliasName = getTokenSpelling().drop_front(); 4016 4017 // Check for redefinitions. 4018 if (getState().typeAliasDefinitions.count(aliasName) > 0) 4019 return emitError("redefinition of type alias id '" + aliasName + "'"); 4020 4021 // Make sure this isn't invading the dialect type namespace. 4022 if (aliasName.contains('.')) 4023 return emitError("type names with a '.' are reserved for " 4024 "dialect-defined names"); 4025 4026 consumeToken(Token::exclamation_identifier); 4027 4028 // Parse the '=' and 'type'. 4029 if (parseToken(Token::equal, "expected '=' in type alias definition") || 4030 parseToken(Token::kw_type, "expected 'type' in type alias definition")) 4031 return failure(); 4032 4033 // Parse the type. 4034 Type aliasedType = parseType(); 4035 if (!aliasedType) 4036 return failure(); 4037 4038 // Register this alias with the parser state. 4039 getState().typeAliasDefinitions.try_emplace(aliasName, aliasedType); 4040 return success(); 4041 } 4042 4043 /// This is the top-level module parser. 4044 ParseResult ModuleParser::parseModule(ModuleOp module) { 4045 OperationParser opParser(getState(), module); 4046 4047 // Module itself is a name scope. 4048 opParser.pushSSANameScope(/*isIsolated=*/true); 4049 4050 while (1) { 4051 switch (getToken().getKind()) { 4052 default: 4053 // Parse a top-level operation. 4054 if (opParser.parseOperation()) 4055 return failure(); 4056 break; 4057 4058 // If we got to the end of the file, then we're done. 4059 case Token::eof: { 4060 if (opParser.finalize()) 4061 return failure(); 4062 4063 // Handle the case where the top level module was explicitly defined. 4064 auto &bodyBlocks = module.getBodyRegion().getBlocks(); 4065 auto &operations = bodyBlocks.front().getOperations(); 4066 assert(!operations.empty() && "expected a valid module terminator"); 4067 4068 // Check that the first operation is a module, and it is the only 4069 // non-terminator operation. 4070 ModuleOp nested = dyn_cast<ModuleOp>(operations.front()); 4071 if (nested && std::next(operations.begin(), 2) == operations.end()) { 4072 // Merge the data of the nested module operation into 'module'. 4073 module.setLoc(nested.getLoc()); 4074 module.setAttrs(nested.getOperation()->getAttrList()); 4075 bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks()); 4076 4077 // Erase the original module body. 4078 bodyBlocks.pop_front(); 4079 } 4080 4081 return opParser.popSSANameScope(); 4082 } 4083 4084 // If we got an error token, then the lexer already emitted an error, just 4085 // stop. Someday we could introduce error recovery if there was demand 4086 // for it. 4087 case Token::error: 4088 return failure(); 4089 4090 // Parse an attribute alias. 4091 case Token::hash_identifier: 4092 if (parseAttributeAliasDef()) 4093 return failure(); 4094 break; 4095 4096 // Parse a type alias. 4097 case Token::exclamation_identifier: 4098 if (parseTypeAliasDef()) 4099 return failure(); 4100 break; 4101 } 4102 } 4103 } 4104 4105 //===----------------------------------------------------------------------===// 4106 4107 /// This parses the file specified by the indicated SourceMgr and returns an 4108 /// MLIR module if it was valid. If not, it emits diagnostics and returns 4109 /// null. 4110 ModuleOp mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, 4111 MLIRContext *context) { 4112 auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); 4113 4114 // This is the result module we are parsing into. 4115 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 4116 sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 4117 4118 ParserState state(sourceMgr, context); 4119 if (ModuleParser(state).parseModule(*module)) 4120 return nullptr; 4121 4122 // Make sure the parse module has no other structural problems detected by 4123 // the verifier. 4124 if (failed(verify(*module))) 4125 return nullptr; 4126 4127 return module.release(); 4128 } 4129 4130 /// This parses the file specified by the indicated filename and returns an 4131 /// MLIR module if it was valid. If not, the error message is emitted through 4132 /// the error handler registered in the context, and a null pointer is returned. 4133 ModuleOp mlir::parseSourceFile(StringRef filename, MLIRContext *context) { 4134 llvm::SourceMgr sourceMgr; 4135 return parseSourceFile(filename, sourceMgr, context); 4136 } 4137 4138 /// This parses the file specified by the indicated filename using the provided 4139 /// SourceMgr and returns an MLIR module if it was valid. If not, the error 4140 /// message is emitted through the error handler registered in the context, and 4141 /// a null pointer is returned. 4142 ModuleOp mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr, 4143 MLIRContext *context) { 4144 if (sourceMgr.getNumBuffers() != 0) { 4145 // TODO(b/136086478): Extend to support multiple buffers. 4146 emitError(mlir::UnknownLoc::get(context), 4147 "only main buffer parsed at the moment"); 4148 return nullptr; 4149 } 4150 auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename); 4151 if (std::error_code error = file_or_err.getError()) { 4152 emitError(mlir::UnknownLoc::get(context), 4153 "could not open input file " + filename); 4154 return nullptr; 4155 } 4156 4157 // Load the MLIR module. 4158 sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); 4159 return parseSourceFile(sourceMgr, context); 4160 } 4161 4162 /// This parses the program string to a MLIR module if it was valid. If not, 4163 /// it emits diagnostics and returns null. 4164 ModuleOp mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) { 4165 auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); 4166 if (!memBuffer) 4167 return nullptr; 4168 4169 SourceMgr sourceMgr; 4170 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 4171 return parseSourceFile(sourceMgr, context); 4172 } 4173 4174 Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context) { 4175 SourceMgr sourceMgr; 4176 auto memBuffer = 4177 MemoryBuffer::getMemBuffer(typeStr, /*BufferName=*/"<mlir_type_buffer>", 4178 /*RequiresNullTerminator=*/false); 4179 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 4180 SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context); 4181 ParserState state(sourceMgr, context); 4182 Parser parser(state); 4183 auto start = parser.getToken().getLoc(); 4184 auto ty = parser.parseType(); 4185 if (!ty) 4186 return Type(); 4187 4188 auto end = parser.getToken().getLoc(); 4189 auto read = end.getPointer() - start.getPointer(); 4190 // Make sure that the parsing of type consumes the entire string 4191 if (static_cast<size_t>(read) < typeStr.size()) { 4192 parser.emitError("unexpected additional tokens: '") 4193 << typeStr.substr(read) << "' after parsing type: " << ty; 4194 return Type(); 4195 } 4196 return ty; 4197 } 4198