1 //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR Types. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Parser.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/OpDefinition.h" 17 #include "mlir/IR/TensorEncoding.h" 18 19 using namespace mlir; 20 using namespace mlir::detail; 21 22 /// Optionally parse a type. 23 OptionalParseResult Parser::parseOptionalType(Type &type) { 24 // There are many different starting tokens for a type, check them here. 25 switch (getToken().getKind()) { 26 case Token::l_paren: 27 case Token::kw_memref: 28 case Token::kw_tensor: 29 case Token::kw_complex: 30 case Token::kw_tuple: 31 case Token::kw_vector: 32 case Token::inttype: 33 case Token::kw_bf16: 34 case Token::kw_f16: 35 case Token::kw_f32: 36 case Token::kw_f64: 37 case Token::kw_f80: 38 case Token::kw_f128: 39 case Token::kw_index: 40 case Token::kw_none: 41 case Token::exclamation_identifier: 42 return failure(!(type = parseType())); 43 44 default: 45 return llvm::None; 46 } 47 } 48 49 /// Parse an arbitrary type. 50 /// 51 /// type ::= function-type 52 /// | non-function-type 53 /// 54 Type Parser::parseType() { 55 if (getToken().is(Token::l_paren)) 56 return parseFunctionType(); 57 return parseNonFunctionType(); 58 } 59 60 /// Parse a function result type. 61 /// 62 /// function-result-type ::= type-list-parens 63 /// | non-function-type 64 /// 65 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 66 if (getToken().is(Token::l_paren)) 67 return parseTypeListParens(elements); 68 69 Type t = parseNonFunctionType(); 70 if (!t) 71 return failure(); 72 elements.push_back(t); 73 return success(); 74 } 75 76 /// Parse a list of types without an enclosing parenthesis. The list must have 77 /// at least one member. 78 /// 79 /// type-list-no-parens ::= type (`,` type)* 80 /// 81 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 82 auto parseElt = [&]() -> ParseResult { 83 auto elt = parseType(); 84 elements.push_back(elt); 85 return elt ? success() : failure(); 86 }; 87 88 return parseCommaSeparatedList(parseElt); 89 } 90 91 /// Parse a parenthesized list of types. 92 /// 93 /// type-list-parens ::= `(` `)` 94 /// | `(` type-list-no-parens `)` 95 /// 96 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 97 if (parseToken(Token::l_paren, "expected '('")) 98 return failure(); 99 100 // Handle empty lists. 101 if (getToken().is(Token::r_paren)) 102 return consumeToken(), success(); 103 104 if (parseTypeListNoParens(elements) || 105 parseToken(Token::r_paren, "expected ')'")) 106 return failure(); 107 return success(); 108 } 109 110 /// Parse a complex type. 111 /// 112 /// complex-type ::= `complex` `<` type `>` 113 /// 114 Type Parser::parseComplexType() { 115 consumeToken(Token::kw_complex); 116 117 // Parse the '<'. 118 if (parseToken(Token::less, "expected '<' in complex type")) 119 return nullptr; 120 121 SMLoc elementTypeLoc = getToken().getLoc(); 122 auto elementType = parseType(); 123 if (!elementType || 124 parseToken(Token::greater, "expected '>' in complex type")) 125 return nullptr; 126 if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) 127 return emitError(elementTypeLoc, "invalid element type for complex"), 128 nullptr; 129 130 return ComplexType::get(elementType); 131 } 132 133 /// Parse a function type. 134 /// 135 /// function-type ::= type-list-parens `->` function-result-type 136 /// 137 Type Parser::parseFunctionType() { 138 assert(getToken().is(Token::l_paren)); 139 140 SmallVector<Type, 4> arguments, results; 141 if (parseTypeListParens(arguments) || 142 parseToken(Token::arrow, "expected '->' in function type") || 143 parseFunctionResultTypes(results)) 144 return nullptr; 145 146 return builder.getFunctionType(arguments, results); 147 } 148 149 /// Parse the offset and strides from a strided layout specification. 150 /// 151 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 152 /// 153 ParseResult Parser::parseStridedLayout(int64_t &offset, 154 SmallVectorImpl<int64_t> &strides) { 155 // Parse offset. 156 consumeToken(Token::kw_offset); 157 if (parseToken(Token::colon, "expected colon after `offset` keyword")) 158 return failure(); 159 160 auto maybeOffset = getToken().getUnsignedIntegerValue(); 161 bool question = getToken().is(Token::question); 162 if (!maybeOffset && !question) 163 return emitWrongTokenError("invalid offset"); 164 offset = maybeOffset ? static_cast<int64_t>(*maybeOffset) 165 : MemRefType::getDynamicStrideOrOffset(); 166 consumeToken(); 167 168 // Parse stride list. 169 if (parseToken(Token::comma, "expected comma after offset value") || 170 parseToken(Token::kw_strides, 171 "expected `strides` keyword after offset specification") || 172 parseToken(Token::colon, "expected colon after `strides` keyword") || 173 parseStrideList(strides)) 174 return failure(); 175 return success(); 176 } 177 178 /// Parse a memref type. 179 /// 180 /// memref-type ::= ranked-memref-type | unranked-memref-type 181 /// 182 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type 183 /// (`,` layout-specification)? (`,` memory-space)? `>` 184 /// 185 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` 186 /// 187 /// stride-list ::= `[` (dimension (`,` dimension)*)? `]` 188 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 189 /// layout-specification ::= semi-affine-map | strided-layout | attribute 190 /// memory-space ::= integer-literal | attribute 191 /// 192 Type Parser::parseMemRefType() { 193 SMLoc loc = getToken().getLoc(); 194 consumeToken(Token::kw_memref); 195 196 if (parseToken(Token::less, "expected '<' in memref type")) 197 return nullptr; 198 199 bool isUnranked; 200 SmallVector<int64_t, 4> dimensions; 201 202 if (consumeIf(Token::star)) { 203 // This is an unranked memref type. 204 isUnranked = true; 205 if (parseXInDimensionList()) 206 return nullptr; 207 208 } else { 209 isUnranked = false; 210 if (parseDimensionListRanked(dimensions)) 211 return nullptr; 212 } 213 214 // Parse the element type. 215 auto typeLoc = getToken().getLoc(); 216 auto elementType = parseType(); 217 if (!elementType) 218 return nullptr; 219 220 // Check that memref is formed from allowed types. 221 if (!BaseMemRefType::isValidElementType(elementType)) 222 return emitError(typeLoc, "invalid memref element type"), nullptr; 223 224 MemRefLayoutAttrInterface layout; 225 Attribute memorySpace; 226 227 auto parseElt = [&]() -> ParseResult { 228 // Check for AffineMap as offset/strides. 229 if (getToken().is(Token::kw_offset)) { 230 int64_t offset; 231 SmallVector<int64_t, 4> strides; 232 if (failed(parseStridedLayout(offset, strides))) 233 return failure(); 234 // Construct strided affine map. 235 AffineMap map = makeStridedLinearLayoutMap(strides, offset, getContext()); 236 layout = AffineMapAttr::get(map); 237 } else { 238 // Either it is MemRefLayoutAttrInterface or memory space attribute. 239 Attribute attr = parseAttribute(); 240 if (!attr) 241 return failure(); 242 243 if (attr.isa<MemRefLayoutAttrInterface>()) { 244 layout = attr.cast<MemRefLayoutAttrInterface>(); 245 } else if (memorySpace) { 246 return emitError("multiple memory spaces specified in memref type"); 247 } else { 248 memorySpace = attr; 249 return success(); 250 } 251 } 252 253 if (isUnranked) 254 return emitError("cannot have affine map for unranked memref type"); 255 if (memorySpace) 256 return emitError("expected memory space to be last in memref type"); 257 258 return success(); 259 }; 260 261 // Parse a list of mappings and address space if present. 262 if (!consumeIf(Token::greater)) { 263 // Parse comma separated list of affine maps, followed by memory space. 264 if (parseToken(Token::comma, "expected ',' or '>' in memref type") || 265 parseCommaSeparatedListUntil(Token::greater, parseElt, 266 /*allowEmptyList=*/false)) { 267 return nullptr; 268 } 269 } 270 271 if (isUnranked) 272 return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace); 273 274 return getChecked<MemRefType>(loc, dimensions, elementType, layout, 275 memorySpace); 276 } 277 278 /// Parse any type except the function type. 279 /// 280 /// non-function-type ::= integer-type 281 /// | index-type 282 /// | float-type 283 /// | extended-type 284 /// | vector-type 285 /// | tensor-type 286 /// | memref-type 287 /// | complex-type 288 /// | tuple-type 289 /// | none-type 290 /// 291 /// index-type ::= `index` 292 /// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128` 293 /// none-type ::= `none` 294 /// 295 Type Parser::parseNonFunctionType() { 296 switch (getToken().getKind()) { 297 default: 298 return (emitWrongTokenError("expected non-function type"), nullptr); 299 case Token::kw_memref: 300 return parseMemRefType(); 301 case Token::kw_tensor: 302 return parseTensorType(); 303 case Token::kw_complex: 304 return parseComplexType(); 305 case Token::kw_tuple: 306 return parseTupleType(); 307 case Token::kw_vector: 308 return parseVectorType(); 309 // integer-type 310 case Token::inttype: { 311 auto width = getToken().getIntTypeBitwidth(); 312 if (!width.has_value()) 313 return (emitError("invalid integer width"), nullptr); 314 if (width.value() > IntegerType::kMaxWidth) { 315 emitError(getToken().getLoc(), "integer bitwidth is limited to ") 316 << IntegerType::kMaxWidth << " bits"; 317 return nullptr; 318 } 319 320 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; 321 if (Optional<bool> signedness = getToken().getIntTypeSignedness()) 322 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; 323 324 consumeToken(Token::inttype); 325 return IntegerType::get(getContext(), *width, signSemantics); 326 } 327 328 // float-type 329 case Token::kw_bf16: 330 consumeToken(Token::kw_bf16); 331 return builder.getBF16Type(); 332 case Token::kw_f16: 333 consumeToken(Token::kw_f16); 334 return builder.getF16Type(); 335 case Token::kw_f32: 336 consumeToken(Token::kw_f32); 337 return builder.getF32Type(); 338 case Token::kw_f64: 339 consumeToken(Token::kw_f64); 340 return builder.getF64Type(); 341 case Token::kw_f80: 342 consumeToken(Token::kw_f80); 343 return builder.getF80Type(); 344 case Token::kw_f128: 345 consumeToken(Token::kw_f128); 346 return builder.getF128Type(); 347 348 // index-type 349 case Token::kw_index: 350 consumeToken(Token::kw_index); 351 return builder.getIndexType(); 352 353 // none-type 354 case Token::kw_none: 355 consumeToken(Token::kw_none); 356 return builder.getNoneType(); 357 358 // extended type 359 case Token::exclamation_identifier: 360 return parseExtendedType(); 361 362 // Handle completion of a dialect type. 363 case Token::code_complete: 364 if (getToken().isCodeCompletionFor(Token::exclamation_identifier)) 365 return parseExtendedType(); 366 return codeCompleteType(); 367 } 368 } 369 370 /// Parse a tensor type. 371 /// 372 /// tensor-type ::= `tensor` `<` dimension-list type `>` 373 /// dimension-list ::= dimension-list-ranked | `*x` 374 /// 375 Type Parser::parseTensorType() { 376 consumeToken(Token::kw_tensor); 377 378 if (parseToken(Token::less, "expected '<' in tensor type")) 379 return nullptr; 380 381 bool isUnranked; 382 SmallVector<int64_t, 4> dimensions; 383 384 if (consumeIf(Token::star)) { 385 // This is an unranked tensor type. 386 isUnranked = true; 387 388 if (parseXInDimensionList()) 389 return nullptr; 390 391 } else { 392 isUnranked = false; 393 if (parseDimensionListRanked(dimensions)) 394 return nullptr; 395 } 396 397 // Parse the element type. 398 auto elementTypeLoc = getToken().getLoc(); 399 auto elementType = parseType(); 400 401 // Parse an optional encoding attribute. 402 Attribute encoding; 403 if (consumeIf(Token::comma)) { 404 encoding = parseAttribute(); 405 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) { 406 if (failed(v.verifyEncoding(dimensions, elementType, 407 [&] { return emitError(); }))) 408 return nullptr; 409 } 410 } 411 412 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 413 return nullptr; 414 if (!TensorType::isValidElementType(elementType)) 415 return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; 416 417 if (isUnranked) { 418 if (encoding) 419 return emitError("cannot apply encoding to unranked tensor"), nullptr; 420 return UnrankedTensorType::get(elementType); 421 } 422 return RankedTensorType::get(dimensions, elementType, encoding); 423 } 424 425 /// Parse a tuple type. 426 /// 427 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 428 /// 429 Type Parser::parseTupleType() { 430 consumeToken(Token::kw_tuple); 431 432 // Parse the '<'. 433 if (parseToken(Token::less, "expected '<' in tuple type")) 434 return nullptr; 435 436 // Check for an empty tuple by directly parsing '>'. 437 if (consumeIf(Token::greater)) 438 return TupleType::get(getContext()); 439 440 // Parse the element types and the '>'. 441 SmallVector<Type, 4> types; 442 if (parseTypeListNoParens(types) || 443 parseToken(Token::greater, "expected '>' in tuple type")) 444 return nullptr; 445 446 return TupleType::get(getContext(), types); 447 } 448 449 /// Parse a vector type. 450 /// 451 /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>` 452 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? 453 /// static-dim-list ::= decimal-literal (`x` decimal-literal)* 454 /// 455 VectorType Parser::parseVectorType() { 456 consumeToken(Token::kw_vector); 457 458 if (parseToken(Token::less, "expected '<' in vector type")) 459 return nullptr; 460 461 SmallVector<int64_t, 4> dimensions; 462 unsigned numScalableDims; 463 if (parseVectorDimensionList(dimensions, numScalableDims)) 464 return nullptr; 465 if (any_of(dimensions, [](int64_t i) { return i <= 0; })) 466 return emitError(getToken().getLoc(), 467 "vector types must have positive constant sizes"), 468 nullptr; 469 470 // Parse the element type. 471 auto typeLoc = getToken().getLoc(); 472 auto elementType = parseType(); 473 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 474 return nullptr; 475 476 if (!VectorType::isValidElementType(elementType)) 477 return emitError(typeLoc, "vector elements must be int/index/float type"), 478 nullptr; 479 480 return VectorType::get(dimensions, elementType, numScalableDims); 481 } 482 483 /// Parse a dimension list in a vector type. This populates the dimension list, 484 /// and returns the number of scalable dimensions in `numScalableDims`. 485 /// 486 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? 487 /// static-dim-list ::= decimal-literal (`x` decimal-literal)* 488 /// 489 ParseResult 490 Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions, 491 unsigned &numScalableDims) { 492 numScalableDims = 0; 493 // If there is a set of fixed-length dimensions, consume it 494 while (getToken().is(Token::integer)) { 495 int64_t value; 496 if (parseIntegerInDimensionList(value)) 497 return failure(); 498 dimensions.push_back(value); 499 // Make sure we have an 'x' or something like 'xbf32'. 500 if (parseXInDimensionList()) 501 return failure(); 502 } 503 // If there is a set of scalable dimensions, consume it 504 if (consumeIf(Token::l_square)) { 505 while (getToken().is(Token::integer)) { 506 int64_t value; 507 if (parseIntegerInDimensionList(value)) 508 return failure(); 509 dimensions.push_back(value); 510 numScalableDims++; 511 // Check if we have reached the end of the scalable dimension list 512 if (consumeIf(Token::r_square)) { 513 // Make sure we have something like 'xbf32'. 514 return parseXInDimensionList(); 515 } 516 // Make sure we have an 'x' 517 if (parseXInDimensionList()) 518 return failure(); 519 } 520 // If we make it here, we've finished parsing the dimension list 521 // without finding ']' closing the set of scalable dimensions 522 return emitWrongTokenError( 523 "missing ']' closing set of scalable dimensions"); 524 } 525 526 return success(); 527 } 528 529 /// Parse a dimension list of a tensor or memref type. This populates the 530 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and 531 /// errors out on `?` otherwise. Parsing the trailing `x` is configurable. 532 /// 533 /// dimension-list ::= eps | dimension (`x` dimension)* 534 /// dimension-list-with-trailing-x ::= (dimension `x`)* 535 /// dimension ::= `?` | decimal-literal 536 /// 537 /// When `allowDynamic` is not set, this is used to parse: 538 /// 539 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* 540 /// static-dimension-list-with-trailing-x ::= (dimension `x`)* 541 ParseResult 542 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 543 bool allowDynamic, bool withTrailingX) { 544 auto parseDim = [&]() -> LogicalResult { 545 auto loc = getToken().getLoc(); 546 if (consumeIf(Token::question)) { 547 if (!allowDynamic) 548 return emitError(loc, "expected static shape"); 549 dimensions.push_back(-1); 550 } else { 551 int64_t value; 552 if (failed(parseIntegerInDimensionList(value))) 553 return failure(); 554 dimensions.push_back(value); 555 } 556 return success(); 557 }; 558 559 if (withTrailingX) { 560 while (getToken().isAny(Token::integer, Token::question)) { 561 if (failed(parseDim()) || failed(parseXInDimensionList())) 562 return failure(); 563 } 564 return success(); 565 } 566 567 if (getToken().isAny(Token::integer, Token::question)) { 568 if (failed(parseDim())) 569 return failure(); 570 while (getToken().is(Token::bare_identifier) && 571 getTokenSpelling()[0] == 'x') { 572 if (failed(parseXInDimensionList()) || failed(parseDim())) 573 return failure(); 574 } 575 } 576 return success(); 577 } 578 579 ParseResult Parser::parseIntegerInDimensionList(int64_t &value) { 580 // Hexadecimal integer literals (starting with `0x`) are not allowed in 581 // aggregate type declarations. Therefore, `0xf32` should be processed as 582 // a sequence of separate elements `0`, `x`, `f32`. 583 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 584 // We can get here only if the token is an integer literal. Hexadecimal 585 // integer literals can only start with `0x` (`1x` wouldn't lex as a 586 // literal, just `1` would, at which point we don't get into this 587 // branch). 588 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 589 value = 0; 590 state.lex.resetPointer(getTokenSpelling().data() + 1); 591 consumeToken(); 592 } else { 593 // Make sure this integer value is in bound and valid. 594 Optional<uint64_t> dimension = getToken().getUInt64IntegerValue(); 595 if (!dimension || 596 *dimension > (uint64_t)std::numeric_limits<int64_t>::max()) 597 return emitError("invalid dimension"); 598 value = (int64_t)*dimension; 599 consumeToken(Token::integer); 600 } 601 return success(); 602 } 603 604 /// Parse an 'x' token in a dimension list, handling the case where the x is 605 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 606 /// token. 607 ParseResult Parser::parseXInDimensionList() { 608 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 609 return emitWrongTokenError("expected 'x' in dimension list"); 610 611 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 612 if (getTokenSpelling().size() != 1) 613 state.lex.resetPointer(getTokenSpelling().data() + 1); 614 615 // Consume the 'x'. 616 consumeToken(Token::bare_identifier); 617 618 return success(); 619 } 620 621 // Parse a comma-separated list of dimensions, possibly empty: 622 // stride-list ::= `[` (dimension (`,` dimension)*)? `]` 623 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { 624 return parseCommaSeparatedList( 625 Delimiter::Square, 626 [&]() -> ParseResult { 627 if (consumeIf(Token::question)) { 628 dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); 629 } else { 630 // This must be an integer value. 631 int64_t val; 632 if (getToken().getSpelling().getAsInteger(10, val)) 633 return emitError("invalid integer value: ") 634 << getToken().getSpelling(); 635 // Make sure it is not the one value for `?`. 636 if (ShapedType::isDynamic(val)) 637 return emitError("invalid integer value: ") 638 << getToken().getSpelling() 639 << ", use `?` to specify a dynamic dimension"; 640 641 if (val == 0) 642 return emitError("invalid memref stride"); 643 644 dimensions.push_back(val); 645 consumeToken(Token::integer); 646 } 647 return success(); 648 }, 649 " in stride list"); 650 } 651