1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the SPIR-V dialect in MLIR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 14 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 17 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/DialectImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/Parser.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/DenseMap.h" 25 #include "llvm/ADT/Sequence.h" 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/StringMap.h" 29 #include "llvm/ADT/StringSwitch.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/Support/raw_ostream.h" 32 33 using namespace mlir; 34 using namespace mlir::spirv; 35 36 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc" 37 38 //===----------------------------------------------------------------------===// 39 // InlinerInterface 40 //===----------------------------------------------------------------------===// 41 42 /// Returns true if the given region contains spv.Return or spv.ReturnValue ops. 43 static inline bool containsReturn(Region ®ion) { 44 return llvm::any_of(region, [](Block &block) { 45 Operation *terminator = block.getTerminator(); 46 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator); 47 }); 48 } 49 50 namespace { 51 /// This class defines the interface for inlining within the SPIR-V dialect. 52 struct SPIRVInlinerInterface : public DialectInlinerInterface { 53 using DialectInlinerInterface::DialectInlinerInterface; 54 55 /// All call operations within SPIRV can be inlined. 56 bool isLegalToInline(Operation *call, Operation *callable, 57 bool wouldBeCloned) const final { 58 return true; 59 } 60 61 /// Returns true if the given region 'src' can be inlined into the region 62 /// 'dest' that is attached to an operation registered to the current dialect. 63 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 64 BlockAndValueMapping &) const final { 65 // Return true here when inlining into spv.func, spv.mlir.selection, and 66 // spv.mlir.loop operations. 67 auto *op = dest->getParentOp(); 68 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op); 69 } 70 71 /// Returns true if the given operation 'op', that is registered to this 72 /// dialect, can be inlined into the region 'dest' that is attached to an 73 /// operation registered to the current dialect. 74 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 75 BlockAndValueMapping &) const final { 76 // TODO: Enable inlining structured control flows with return. 77 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) && 78 containsReturn(op->getRegion(0))) 79 return false; 80 // TODO: we need to filter OpKill here to avoid inlining it to 81 // a loop continue construct: 82 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86 83 // However OpKill is fragment shader specific and we don't support it yet. 84 return true; 85 } 86 87 /// Handle the given inlined terminator by replacing it with a new operation 88 /// as necessary. 89 void handleTerminator(Operation *op, Block *newDest) const final { 90 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) { 91 OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest); 92 op->erase(); 93 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) { 94 llvm_unreachable("unimplemented spv.ReturnValue in inliner"); 95 } 96 } 97 98 /// Handle the given inlined terminator by replacing it with a new operation 99 /// as necessary. 100 void handleTerminator(Operation *op, 101 ArrayRef<Value> valuesToRepl) const final { 102 // Only spv.ReturnValue needs to be handled here. 103 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op); 104 if (!retValOp) 105 return; 106 107 // Replace the values directly with the return operands. 108 assert(valuesToRepl.size() == 1 && 109 "spv.ReturnValue expected to only handle one result"); 110 valuesToRepl.front().replaceAllUsesWith(retValOp.value()); 111 } 112 }; 113 } // namespace 114 115 //===----------------------------------------------------------------------===// 116 // SPIR-V Dialect 117 //===----------------------------------------------------------------------===// 118 119 void SPIRVDialect::initialize() { 120 registerAttributes(); 121 registerTypes(); 122 123 // Add SPIR-V ops. 124 addOperations< 125 #define GET_OP_LIST 126 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" 127 >(); 128 129 addInterfaces<SPIRVInlinerInterface>(); 130 131 // Allow unknown operations because SPIR-V is extensible. 132 allowUnknownOperations(); 133 } 134 135 std::string SPIRVDialect::getAttributeName(Decoration decoration) { 136 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)); 137 } 138 139 //===----------------------------------------------------------------------===// 140 // Type Parsing 141 //===----------------------------------------------------------------------===// 142 143 // Forward declarations. 144 template <typename ValTy> 145 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, 146 DialectAsmParser &parser); 147 template <> 148 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, 149 DialectAsmParser &parser); 150 151 template <> 152 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, 153 DialectAsmParser &parser); 154 155 static Type parseAndVerifyType(SPIRVDialect const &dialect, 156 DialectAsmParser &parser) { 157 Type type; 158 llvm::SMLoc typeLoc = parser.getCurrentLocation(); 159 if (parser.parseType(type)) 160 return Type(); 161 162 // Allow SPIR-V dialect types 163 if (&type.getDialect() == &dialect) 164 return type; 165 166 // Check other allowed types 167 if (auto t = type.dyn_cast<FloatType>()) { 168 if (type.isBF16()) { 169 parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); 170 return Type(); 171 } 172 } else if (auto t = type.dyn_cast<IntegerType>()) { 173 if (!ScalarType::isValid(t)) { 174 parser.emitError(typeLoc, 175 "only 1/8/16/32/64-bit integer type allowed but found ") 176 << type; 177 return Type(); 178 } 179 } else if (auto t = type.dyn_cast<VectorType>()) { 180 if (t.getRank() != 1) { 181 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; 182 return Type(); 183 } 184 if (t.getNumElements() > 4) { 185 parser.emitError( 186 typeLoc, "vector length has to be less than or equal to 4 but found ") 187 << t.getNumElements(); 188 return Type(); 189 } 190 } else { 191 parser.emitError(typeLoc, "cannot use ") 192 << type << " to compose SPIR-V types"; 193 return Type(); 194 } 195 196 return type; 197 } 198 199 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, 200 DialectAsmParser &parser) { 201 Type type; 202 llvm::SMLoc typeLoc = parser.getCurrentLocation(); 203 if (parser.parseType(type)) 204 return Type(); 205 206 if (auto t = type.dyn_cast<VectorType>()) { 207 if (t.getRank() != 1) { 208 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; 209 return Type(); 210 } 211 if (t.getNumElements() > 4 || t.getNumElements() < 2) { 212 parser.emitError(typeLoc, 213 "matrix columns size has to be less than or equal " 214 "to 4 and greater than or equal 2, but found ") 215 << t.getNumElements(); 216 return Type(); 217 } 218 219 if (!t.getElementType().isa<FloatType>()) { 220 parser.emitError(typeLoc, "matrix columns' elements must be of " 221 "Float type, got ") 222 << t.getElementType(); 223 return Type(); 224 } 225 } else { 226 parser.emitError(typeLoc, "matrix must be composed using vector " 227 "type, got ") 228 << type; 229 return Type(); 230 } 231 232 return type; 233 } 234 235 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, 236 DialectAsmParser &parser) { 237 Type type; 238 llvm::SMLoc typeLoc = parser.getCurrentLocation(); 239 if (parser.parseType(type)) 240 return Type(); 241 242 if (!type.isa<ImageType>()) { 243 parser.emitError(typeLoc, 244 "sampled image must be composed using image type, got ") 245 << type; 246 return Type(); 247 } 248 249 return type; 250 } 251 252 /// Parses an optional `, stride = N` assembly segment. If no parsing failure 253 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if 254 /// missing. 255 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, 256 DialectAsmParser &parser, 257 unsigned &stride) { 258 if (failed(parser.parseOptionalComma())) { 259 stride = 0; 260 return success(); 261 } 262 263 if (parser.parseKeyword("stride") || parser.parseEqual()) 264 return failure(); 265 266 llvm::SMLoc strideLoc = parser.getCurrentLocation(); 267 Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser); 268 if (!optStride) 269 return failure(); 270 271 if (!(stride = optStride.getValue())) { 272 parser.emitError(strideLoc, "ArrayStride must be greater than zero"); 273 return failure(); 274 } 275 return success(); 276 } 277 278 // element-type ::= integer-type 279 // | floating-point-type 280 // | vector-type 281 // | spirv-type 282 // 283 // array-type ::= `!spv.array` `<` integer-literal `x` element-type 284 // (`,` `stride` `=` integer-literal)? `>` 285 static Type parseArrayType(SPIRVDialect const &dialect, 286 DialectAsmParser &parser) { 287 if (parser.parseLess()) 288 return Type(); 289 290 SmallVector<int64_t, 1> countDims; 291 llvm::SMLoc countLoc = parser.getCurrentLocation(); 292 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false)) 293 return Type(); 294 if (countDims.size() != 1) { 295 parser.emitError(countLoc, 296 "expected single integer for array element count"); 297 return Type(); 298 } 299 300 // According to the SPIR-V spec: 301 // "Length is the number of elements in the array. It must be at least 1." 302 int64_t count = countDims[0]; 303 if (count == 0) { 304 parser.emitError(countLoc, "expected array length greater than 0"); 305 return Type(); 306 } 307 308 Type elementType = parseAndVerifyType(dialect, parser); 309 if (!elementType) 310 return Type(); 311 312 unsigned stride = 0; 313 if (failed(parseOptionalArrayStride(dialect, parser, stride))) 314 return Type(); 315 316 if (parser.parseGreater()) 317 return Type(); 318 return ArrayType::get(elementType, count, stride); 319 } 320 321 // cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ',' 322 // rows ',' columns>` 323 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, 324 DialectAsmParser &parser) { 325 if (parser.parseLess()) 326 return Type(); 327 328 SmallVector<int64_t, 2> dims; 329 llvm::SMLoc countLoc = parser.getCurrentLocation(); 330 if (parser.parseDimensionList(dims, /*allowDynamic=*/false)) 331 return Type(); 332 333 if (dims.size() != 2) { 334 parser.emitError(countLoc, "expected rows and columns size"); 335 return Type(); 336 } 337 338 auto elementTy = parseAndVerifyType(dialect, parser); 339 if (!elementTy) 340 return Type(); 341 342 Scope scope; 343 if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>")) 344 return Type(); 345 346 if (parser.parseGreater()) 347 return Type(); 348 return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]); 349 } 350 351 // TODO: Reorder methods to be utilities first and parse*Type 352 // methods in alphabetical order 353 // 354 // storage-class ::= `UniformConstant` 355 // | `Uniform` 356 // | `Workgroup` 357 // | <and other storage classes...> 358 // 359 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>` 360 static Type parsePointerType(SPIRVDialect const &dialect, 361 DialectAsmParser &parser) { 362 if (parser.parseLess()) 363 return Type(); 364 365 auto pointeeType = parseAndVerifyType(dialect, parser); 366 if (!pointeeType) 367 return Type(); 368 369 StringRef storageClassSpec; 370 llvm::SMLoc storageClassLoc = parser.getCurrentLocation(); 371 if (parser.parseComma() || parser.parseKeyword(&storageClassSpec)) 372 return Type(); 373 374 auto storageClass = symbolizeStorageClass(storageClassSpec); 375 if (!storageClass) { 376 parser.emitError(storageClassLoc, "unknown storage class: ") 377 << storageClassSpec; 378 return Type(); 379 } 380 if (parser.parseGreater()) 381 return Type(); 382 return PointerType::get(pointeeType, *storageClass); 383 } 384 385 // runtime-array-type ::= `!spv.rtarray` `<` element-type 386 // (`,` `stride` `=` integer-literal)? `>` 387 static Type parseRuntimeArrayType(SPIRVDialect const &dialect, 388 DialectAsmParser &parser) { 389 if (parser.parseLess()) 390 return Type(); 391 392 Type elementType = parseAndVerifyType(dialect, parser); 393 if (!elementType) 394 return Type(); 395 396 unsigned stride = 0; 397 if (failed(parseOptionalArrayStride(dialect, parser, stride))) 398 return Type(); 399 400 if (parser.parseGreater()) 401 return Type(); 402 return RuntimeArrayType::get(elementType, stride); 403 } 404 405 // matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>` 406 static Type parseMatrixType(SPIRVDialect const &dialect, 407 DialectAsmParser &parser) { 408 if (parser.parseLess()) 409 return Type(); 410 411 SmallVector<int64_t, 1> countDims; 412 llvm::SMLoc countLoc = parser.getCurrentLocation(); 413 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false)) 414 return Type(); 415 if (countDims.size() != 1) { 416 parser.emitError(countLoc, "expected single unsigned " 417 "integer for number of columns"); 418 return Type(); 419 } 420 421 int64_t columnCount = countDims[0]; 422 // According to the specification, Matrices can have 2, 3, or 4 columns 423 if (columnCount < 2 || columnCount > 4) { 424 parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 " 425 "columns"); 426 return Type(); 427 } 428 429 Type columnType = parseAndVerifyMatrixType(dialect, parser); 430 if (!columnType) 431 return Type(); 432 433 if (parser.parseGreater()) 434 return Type(); 435 436 return MatrixType::get(columnType, columnCount); 437 } 438 439 // Specialize this function to parse each of the parameters that define an 440 // ImageType. By default it assumes this is an enum type. 441 template <typename ValTy> 442 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, 443 DialectAsmParser &parser) { 444 StringRef enumSpec; 445 llvm::SMLoc enumLoc = parser.getCurrentLocation(); 446 if (parser.parseKeyword(&enumSpec)) { 447 return llvm::None; 448 } 449 450 auto val = spirv::symbolizeEnum<ValTy>(enumSpec); 451 if (!val) 452 parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'"; 453 return val; 454 } 455 456 template <> 457 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, 458 DialectAsmParser &parser) { 459 // TODO: Further verify that the element type can be sampled 460 auto ty = parseAndVerifyType(dialect, parser); 461 if (!ty) 462 return llvm::None; 463 return ty; 464 } 465 466 template <typename IntTy> 467 static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect, 468 DialectAsmParser &parser) { 469 IntTy offsetVal = std::numeric_limits<IntTy>::max(); 470 if (parser.parseInteger(offsetVal)) 471 return llvm::None; 472 return offsetVal; 473 } 474 475 template <> 476 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, 477 DialectAsmParser &parser) { 478 return parseAndVerifyInteger<unsigned>(dialect, parser); 479 } 480 481 namespace { 482 // Functor object to parse a comma separated list of specs. The function 483 // parseAndVerify does the actual parsing and verification of individual 484 // elements. This is a functor since parsing the last element of the list 485 // (termination condition) needs partial specialization. 486 template <typename ParseType, typename... Args> struct ParseCommaSeparatedList { 487 Optional<std::tuple<ParseType, Args...>> 488 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { 489 auto parseVal = parseAndVerify<ParseType>(dialect, parser); 490 if (!parseVal) 491 return llvm::None; 492 493 auto numArgs = std::tuple_size<std::tuple<Args...>>::value; 494 if (numArgs != 0 && failed(parser.parseComma())) 495 return llvm::None; 496 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser); 497 if (!remainingValues) 498 return llvm::None; 499 return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()), 500 remainingValues.getValue()); 501 } 502 }; 503 504 // Partial specialization of the function to parse a comma separated list of 505 // specs to parse the last element of the list. 506 template <typename ParseType> struct ParseCommaSeparatedList<ParseType> { 507 Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect, 508 DialectAsmParser &parser) const { 509 if (auto value = parseAndVerify<ParseType>(dialect, parser)) 510 return std::tuple<ParseType>(value.getValue()); 511 return llvm::None; 512 } 513 }; 514 } // namespace 515 516 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...> 517 // 518 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown` 519 // 520 // arrayed-info ::= `NonArrayed` | `Arrayed` 521 // 522 // sampling-info ::= `SingleSampled` | `MultiSampled` 523 // 524 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler` 525 // 526 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...> 527 // 528 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,` 529 // arrayed-info `,` sampling-info `,` 530 // sampler-use-info `,` format `>` 531 static Type parseImageType(SPIRVDialect const &dialect, 532 DialectAsmParser &parser) { 533 if (parser.parseLess()) 534 return Type(); 535 536 auto value = 537 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 538 ImageSamplingInfo, ImageSamplerUseInfo, 539 ImageFormat>{}(dialect, parser); 540 if (!value) 541 return Type(); 542 543 if (parser.parseGreater()) 544 return Type(); 545 return ImageType::get(value.getValue()); 546 } 547 548 // sampledImage-type :: = `!spv.sampledImage<` image-type `>` 549 static Type parseSampledImageType(SPIRVDialect const &dialect, 550 DialectAsmParser &parser) { 551 if (parser.parseLess()) 552 return Type(); 553 554 Type parsedType = parseAndVerifySampledImageType(dialect, parser); 555 if (!parsedType) 556 return Type(); 557 558 if (parser.parseGreater()) 559 return Type(); 560 return SampledImageType::get(parsedType); 561 } 562 563 // Parse decorations associated with a member. 564 static ParseResult parseStructMemberDecorations( 565 SPIRVDialect const &dialect, DialectAsmParser &parser, 566 ArrayRef<Type> memberTypes, 567 SmallVectorImpl<StructType::OffsetInfo> &offsetInfo, 568 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) { 569 570 // Check if the first element is offset. 571 llvm::SMLoc offsetLoc = parser.getCurrentLocation(); 572 StructType::OffsetInfo offset = 0; 573 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset); 574 if (offsetParseResult.hasValue()) { 575 if (failed(*offsetParseResult)) 576 return failure(); 577 578 if (offsetInfo.size() != memberTypes.size() - 1) { 579 return parser.emitError(offsetLoc, 580 "offset specification must be given for " 581 "all members"); 582 } 583 offsetInfo.push_back(offset); 584 } 585 586 // Check for no spirv::Decorations. 587 if (succeeded(parser.parseOptionalRSquare())) 588 return success(); 589 590 // If there was an offset, make sure to parse the comma. 591 if (offsetParseResult.hasValue() && parser.parseComma()) 592 return failure(); 593 594 // Check for spirv::Decorations. 595 do { 596 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser); 597 if (!memberDecoration) 598 return failure(); 599 600 // Parse member decoration value if it exists. 601 if (succeeded(parser.parseOptionalEqual())) { 602 auto memberDecorationValue = 603 parseAndVerifyInteger<uint32_t>(dialect, parser); 604 605 if (!memberDecorationValue) 606 return failure(); 607 608 memberDecorationInfo.emplace_back( 609 static_cast<uint32_t>(memberTypes.size() - 1), 1, 610 memberDecoration.getValue(), memberDecorationValue.getValue()); 611 } else { 612 memberDecorationInfo.emplace_back( 613 static_cast<uint32_t>(memberTypes.size() - 1), 0, 614 memberDecoration.getValue(), 0); 615 } 616 617 } while (succeeded(parser.parseOptionalComma())); 618 619 return parser.parseRSquare(); 620 } 621 622 // struct-member-decoration ::= integer-literal? spirv-decoration* 623 // struct-type ::= 624 // `!spv.struct<` (id `,`)? 625 // `(` 626 // (spirv-type (`[` struct-member-decoration `]`)?)* 627 // `)>` 628 static Type parseStructType(SPIRVDialect const &dialect, 629 DialectAsmParser &parser) { 630 // TODO: This function is quite lengthy. Break it down into smaller chunks. 631 632 // To properly resolve recursive references while parsing recursive struct 633 // types, we need to maintain a list of enclosing struct type names. This set 634 // maintains the names of struct types in which the type we are about to parse 635 // is nested. 636 // 637 // Note: This has to be thread_local to enable multiple threads to safely 638 // parse concurrently. 639 thread_local SetVector<StringRef> structContext; 640 641 static auto removeIdentifierAndFail = [](SetVector<StringRef> &structContext, 642 StringRef identifier) { 643 if (!identifier.empty()) 644 structContext.remove(identifier); 645 646 return Type(); 647 }; 648 649 if (parser.parseLess()) 650 return Type(); 651 652 StringRef identifier; 653 654 // Check if this is an identified struct type. 655 if (succeeded(parser.parseOptionalKeyword(&identifier))) { 656 // Check if this is a possible recursive reference. 657 if (succeeded(parser.parseOptionalGreater())) { 658 if (structContext.count(identifier) == 0) { 659 parser.emitError( 660 parser.getNameLoc(), 661 "recursive struct reference not nested in struct definition"); 662 663 return Type(); 664 } 665 666 return StructType::getIdentified(dialect.getContext(), identifier); 667 } 668 669 if (failed(parser.parseComma())) 670 return Type(); 671 672 if (structContext.count(identifier) != 0) { 673 parser.emitError(parser.getNameLoc(), 674 "identifier already used for an enclosing struct"); 675 676 return removeIdentifierAndFail(structContext, identifier); 677 } 678 679 structContext.insert(identifier); 680 } 681 682 if (failed(parser.parseLParen())) 683 return removeIdentifierAndFail(structContext, identifier); 684 685 if (succeeded(parser.parseOptionalRParen()) && 686 succeeded(parser.parseOptionalGreater())) { 687 if (!identifier.empty()) 688 structContext.remove(identifier); 689 690 return StructType::getEmpty(dialect.getContext(), identifier); 691 } 692 693 StructType idStructTy; 694 695 if (!identifier.empty()) 696 idStructTy = StructType::getIdentified(dialect.getContext(), identifier); 697 698 SmallVector<Type, 4> memberTypes; 699 SmallVector<StructType::OffsetInfo, 4> offsetInfo; 700 SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo; 701 702 do { 703 Type memberType; 704 if (parser.parseType(memberType)) 705 return removeIdentifierAndFail(structContext, identifier); 706 memberTypes.push_back(memberType); 707 708 if (succeeded(parser.parseOptionalLSquare())) 709 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo, 710 memberDecorationInfo)) 711 return removeIdentifierAndFail(structContext, identifier); 712 } while (succeeded(parser.parseOptionalComma())); 713 714 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) { 715 parser.emitError(parser.getNameLoc(), 716 "offset specification must be given for all members"); 717 return removeIdentifierAndFail(structContext, identifier); 718 } 719 720 if (failed(parser.parseRParen()) || failed(parser.parseGreater())) 721 return removeIdentifierAndFail(structContext, identifier); 722 723 if (!identifier.empty()) { 724 if (failed(idStructTy.trySetBody(memberTypes, offsetInfo, 725 memberDecorationInfo))) 726 return Type(); 727 728 structContext.remove(identifier); 729 return idStructTy; 730 } 731 732 return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); 733 } 734 735 // spirv-type ::= array-type 736 // | element-type 737 // | image-type 738 // | pointer-type 739 // | runtime-array-type 740 // | sampled-image-type 741 // | struct-type 742 Type SPIRVDialect::parseType(DialectAsmParser &parser) const { 743 StringRef keyword; 744 if (parser.parseKeyword(&keyword)) 745 return Type(); 746 747 if (keyword == "array") 748 return parseArrayType(*this, parser); 749 if (keyword == "coopmatrix") 750 return parseCooperativeMatrixType(*this, parser); 751 if (keyword == "image") 752 return parseImageType(*this, parser); 753 if (keyword == "ptr") 754 return parsePointerType(*this, parser); 755 if (keyword == "rtarray") 756 return parseRuntimeArrayType(*this, parser); 757 if (keyword == "sampled_image") 758 return parseSampledImageType(*this, parser); 759 if (keyword == "struct") 760 return parseStructType(*this, parser); 761 if (keyword == "matrix") 762 return parseMatrixType(*this, parser); 763 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword; 764 return Type(); 765 } 766 767 //===----------------------------------------------------------------------===// 768 // Type Printing 769 //===----------------------------------------------------------------------===// 770 771 static void print(ArrayType type, DialectAsmPrinter &os) { 772 os << "array<" << type.getNumElements() << " x " << type.getElementType(); 773 if (unsigned stride = type.getArrayStride()) 774 os << ", stride=" << stride; 775 os << ">"; 776 } 777 778 static void print(RuntimeArrayType type, DialectAsmPrinter &os) { 779 os << "rtarray<" << type.getElementType(); 780 if (unsigned stride = type.getArrayStride()) 781 os << ", stride=" << stride; 782 os << ">"; 783 } 784 785 static void print(PointerType type, DialectAsmPrinter &os) { 786 os << "ptr<" << type.getPointeeType() << ", " 787 << stringifyStorageClass(type.getStorageClass()) << ">"; 788 } 789 790 static void print(ImageType type, DialectAsmPrinter &os) { 791 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim()) 792 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", " 793 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", " 794 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", " 795 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", " 796 << stringifyImageFormat(type.getImageFormat()) << ">"; 797 } 798 799 static void print(SampledImageType type, DialectAsmPrinter &os) { 800 os << "sampled_image<" << type.getImageType() << ">"; 801 } 802 803 static void print(StructType type, DialectAsmPrinter &os) { 804 thread_local SetVector<StringRef> structContext; 805 806 os << "struct<"; 807 808 if (type.isIdentified()) { 809 os << type.getIdentifier(); 810 811 if (structContext.count(type.getIdentifier())) { 812 os << ">"; 813 return; 814 } 815 816 os << ", "; 817 structContext.insert(type.getIdentifier()); 818 } 819 820 os << "("; 821 822 auto printMember = [&](unsigned i) { 823 os << type.getElementType(i); 824 SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations; 825 type.getMemberDecorations(i, decorations); 826 if (type.hasOffset() || !decorations.empty()) { 827 os << " ["; 828 if (type.hasOffset()) { 829 os << type.getMemberOffset(i); 830 if (!decorations.empty()) 831 os << ", "; 832 } 833 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) { 834 os << stringifyDecoration(decoration.decoration); 835 if (decoration.hasValue) { 836 os << "=" << decoration.decorationValue; 837 } 838 }; 839 llvm::interleaveComma(decorations, os, eachFn); 840 os << "]"; 841 } 842 }; 843 llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os, 844 printMember); 845 os << ")>"; 846 847 if (type.isIdentified()) 848 structContext.remove(type.getIdentifier()); 849 } 850 851 static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { 852 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; 853 os << type.getElementType() << ", " << stringifyScope(type.getScope()); 854 os << ">"; 855 } 856 857 static void print(MatrixType type, DialectAsmPrinter &os) { 858 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType(); 859 os << ">"; 860 } 861 862 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { 863 TypeSwitch<Type>(type) 864 .Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType, 865 ImageType, SampledImageType, StructType, MatrixType>( 866 [&](auto type) { print(type, os); }) 867 .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); 868 } 869 870 //===----------------------------------------------------------------------===// 871 // Attribute Parsing 872 //===----------------------------------------------------------------------===// 873 874 /// Parses a comma-separated list of keywords, invokes `processKeyword` on each 875 /// of the parsed keyword, and returns failure if any error occurs. 876 static ParseResult parseKeywordList( 877 DialectAsmParser &parser, 878 function_ref<LogicalResult(llvm::SMLoc, StringRef)> processKeyword) { 879 if (parser.parseLSquare()) 880 return failure(); 881 882 // Special case for empty list. 883 if (succeeded(parser.parseOptionalRSquare())) 884 return success(); 885 886 // Keep parsing the keyword and an optional comma following it. If the comma 887 // is successfully parsed, then we have more keywords to parse. 888 do { 889 auto loc = parser.getCurrentLocation(); 890 StringRef keyword; 891 if (parser.parseKeyword(&keyword) || failed(processKeyword(loc, keyword))) 892 return failure(); 893 } while (succeeded(parser.parseOptionalComma())); 894 895 if (parser.parseRSquare()) 896 return failure(); 897 898 return success(); 899 } 900 901 /// Parses a spirv::InterfaceVarABIAttr. 902 static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) { 903 if (parser.parseLess()) 904 return {}; 905 906 Builder &builder = parser.getBuilder(); 907 908 if (parser.parseLParen()) 909 return {}; 910 911 IntegerAttr descriptorSetAttr; 912 { 913 auto loc = parser.getCurrentLocation(); 914 uint32_t descriptorSet = 0; 915 auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet); 916 917 if (!descriptorSetParseResult.hasValue() || 918 failed(*descriptorSetParseResult)) { 919 parser.emitError(loc, "missing descriptor set"); 920 return {}; 921 } 922 descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet); 923 } 924 925 if (parser.parseComma()) 926 return {}; 927 928 IntegerAttr bindingAttr; 929 { 930 auto loc = parser.getCurrentLocation(); 931 uint32_t binding = 0; 932 auto bindingParseResult = parser.parseOptionalInteger(binding); 933 934 if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) { 935 parser.emitError(loc, "missing binding"); 936 return {}; 937 } 938 bindingAttr = builder.getI32IntegerAttr(binding); 939 } 940 941 if (parser.parseRParen()) 942 return {}; 943 944 IntegerAttr storageClassAttr; 945 { 946 if (succeeded(parser.parseOptionalComma())) { 947 auto loc = parser.getCurrentLocation(); 948 StringRef storageClass; 949 if (parser.parseKeyword(&storageClass)) 950 return {}; 951 952 if (auto storageClassSymbol = 953 spirv::symbolizeStorageClass(storageClass)) { 954 storageClassAttr = builder.getI32IntegerAttr( 955 static_cast<uint32_t>(*storageClassSymbol)); 956 } else { 957 parser.emitError(loc, "unknown storage class: ") << storageClass; 958 return {}; 959 } 960 } 961 } 962 963 if (parser.parseGreater()) 964 return {}; 965 966 return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr, 967 storageClassAttr); 968 } 969 970 static Attribute parseVerCapExtAttr(DialectAsmParser &parser) { 971 if (parser.parseLess()) 972 return {}; 973 974 Builder &builder = parser.getBuilder(); 975 976 IntegerAttr versionAttr; 977 { 978 auto loc = parser.getCurrentLocation(); 979 StringRef version; 980 if (parser.parseKeyword(&version) || parser.parseComma()) 981 return {}; 982 983 if (auto versionSymbol = spirv::symbolizeVersion(version)) { 984 versionAttr = 985 builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol)); 986 } else { 987 parser.emitError(loc, "unknown version: ") << version; 988 return {}; 989 } 990 } 991 992 ArrayAttr capabilitiesAttr; 993 { 994 SmallVector<Attribute, 4> capabilities; 995 llvm::SMLoc errorloc; 996 StringRef errorKeyword; 997 998 auto processCapability = [&](llvm::SMLoc loc, StringRef capability) { 999 if (auto capSymbol = spirv::symbolizeCapability(capability)) { 1000 capabilities.push_back( 1001 builder.getI32IntegerAttr(static_cast<uint32_t>(*capSymbol))); 1002 return success(); 1003 } 1004 return errorloc = loc, errorKeyword = capability, failure(); 1005 }; 1006 if (parseKeywordList(parser, processCapability) || parser.parseComma()) { 1007 if (!errorKeyword.empty()) 1008 parser.emitError(errorloc, "unknown capability: ") << errorKeyword; 1009 return {}; 1010 } 1011 1012 capabilitiesAttr = builder.getArrayAttr(capabilities); 1013 } 1014 1015 ArrayAttr extensionsAttr; 1016 { 1017 SmallVector<Attribute, 1> extensions; 1018 llvm::SMLoc errorloc; 1019 StringRef errorKeyword; 1020 1021 auto processExtension = [&](llvm::SMLoc loc, StringRef extension) { 1022 if (spirv::symbolizeExtension(extension)) { 1023 extensions.push_back(builder.getStringAttr(extension)); 1024 return success(); 1025 } 1026 return errorloc = loc, errorKeyword = extension, failure(); 1027 }; 1028 if (parseKeywordList(parser, processExtension)) { 1029 if (!errorKeyword.empty()) 1030 parser.emitError(errorloc, "unknown extension: ") << errorKeyword; 1031 return {}; 1032 } 1033 1034 extensionsAttr = builder.getArrayAttr(extensions); 1035 } 1036 1037 if (parser.parseGreater()) 1038 return {}; 1039 1040 return spirv::VerCapExtAttr::get(versionAttr, capabilitiesAttr, 1041 extensionsAttr); 1042 } 1043 1044 /// Parses a spirv::TargetEnvAttr. 1045 static Attribute parseTargetEnvAttr(DialectAsmParser &parser) { 1046 if (parser.parseLess()) 1047 return {}; 1048 1049 spirv::VerCapExtAttr tripleAttr; 1050 if (parser.parseAttribute(tripleAttr) || parser.parseComma()) 1051 return {}; 1052 1053 // Parse [vendor[:device-type[:device-id]]] 1054 Vendor vendorID = Vendor::Unknown; 1055 DeviceType deviceType = DeviceType::Unknown; 1056 uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID; 1057 { 1058 auto loc = parser.getCurrentLocation(); 1059 StringRef vendorStr; 1060 if (succeeded(parser.parseOptionalKeyword(&vendorStr))) { 1061 if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) { 1062 vendorID = *vendorSymbol; 1063 } else { 1064 parser.emitError(loc, "unknown vendor: ") << vendorStr; 1065 } 1066 1067 if (succeeded(parser.parseOptionalColon())) { 1068 loc = parser.getCurrentLocation(); 1069 StringRef deviceTypeStr; 1070 if (parser.parseKeyword(&deviceTypeStr)) 1071 return {}; 1072 if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) { 1073 deviceType = *deviceTypeSymbol; 1074 } else { 1075 parser.emitError(loc, "unknown device type: ") << deviceTypeStr; 1076 } 1077 1078 if (succeeded(parser.parseOptionalColon())) { 1079 loc = parser.getCurrentLocation(); 1080 if (parser.parseInteger(deviceID)) 1081 return {}; 1082 } 1083 } 1084 if (parser.parseComma()) 1085 return {}; 1086 } 1087 } 1088 1089 DictionaryAttr limitsAttr; 1090 { 1091 auto loc = parser.getCurrentLocation(); 1092 if (parser.parseAttribute(limitsAttr)) 1093 return {}; 1094 1095 if (!limitsAttr.isa<spirv::ResourceLimitsAttr>()) { 1096 parser.emitError( 1097 loc, 1098 "limits must be a dictionary attribute containing two 32-bit integer " 1099 "attributes 'max_compute_workgroup_invocations' and " 1100 "'max_compute_workgroup_size'"); 1101 return {}; 1102 } 1103 } 1104 1105 if (parser.parseGreater()) 1106 return {}; 1107 1108 return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID, 1109 limitsAttr); 1110 } 1111 1112 Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser, 1113 Type type) const { 1114 // SPIR-V attributes are dictionaries so they do not have type. 1115 if (type) { 1116 parser.emitError(parser.getNameLoc(), "unexpected type"); 1117 return {}; 1118 } 1119 1120 // Parse the kind keyword first. 1121 StringRef attrKind; 1122 if (parser.parseKeyword(&attrKind)) 1123 return {}; 1124 1125 if (attrKind == spirv::TargetEnvAttr::getKindName()) 1126 return parseTargetEnvAttr(parser); 1127 if (attrKind == spirv::VerCapExtAttr::getKindName()) 1128 return parseVerCapExtAttr(parser); 1129 if (attrKind == spirv::InterfaceVarABIAttr::getKindName()) 1130 return parseInterfaceVarABIAttr(parser); 1131 1132 parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ") 1133 << attrKind; 1134 return {}; 1135 } 1136 1137 //===----------------------------------------------------------------------===// 1138 // Attribute Printing 1139 //===----------------------------------------------------------------------===// 1140 1141 static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) { 1142 auto &os = printer.getStream(); 1143 printer << spirv::VerCapExtAttr::getKindName() << "<" 1144 << spirv::stringifyVersion(triple.getVersion()) << ", ["; 1145 llvm::interleaveComma( 1146 triple.getCapabilities(), os, 1147 [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); }); 1148 printer << "], ["; 1149 llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) { 1150 os << attr.cast<StringAttr>().getValue(); 1151 }); 1152 printer << "]>"; 1153 } 1154 1155 static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) { 1156 printer << spirv::TargetEnvAttr::getKindName() << "<#spv."; 1157 print(targetEnv.getTripleAttr(), printer); 1158 spirv::Vendor vendorID = targetEnv.getVendorID(); 1159 spirv::DeviceType deviceType = targetEnv.getDeviceType(); 1160 uint32_t deviceID = targetEnv.getDeviceID(); 1161 if (vendorID != spirv::Vendor::Unknown) { 1162 printer << ", " << spirv::stringifyVendor(vendorID); 1163 if (deviceType != spirv::DeviceType::Unknown) { 1164 printer << ":" << spirv::stringifyDeviceType(deviceType); 1165 if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID) 1166 printer << ":" << deviceID; 1167 } 1168 } 1169 printer << ", " << targetEnv.getResourceLimits() << ">"; 1170 } 1171 1172 static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr, 1173 DialectAsmPrinter &printer) { 1174 printer << spirv::InterfaceVarABIAttr::getKindName() << "<(" 1175 << interfaceVarABIAttr.getDescriptorSet() << ", " 1176 << interfaceVarABIAttr.getBinding() << ")"; 1177 auto storageClass = interfaceVarABIAttr.getStorageClass(); 1178 if (storageClass) 1179 printer << ", " << spirv::stringifyStorageClass(*storageClass); 1180 printer << ">"; 1181 } 1182 1183 void SPIRVDialect::printAttribute(Attribute attr, 1184 DialectAsmPrinter &printer) const { 1185 if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>()) 1186 print(targetEnv, printer); 1187 else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>()) 1188 print(vceAttr, printer); 1189 else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>()) 1190 print(interfaceVarABIAttr, printer); 1191 else 1192 llvm_unreachable("unhandled SPIR-V attribute kind"); 1193 } 1194 1195 //===----------------------------------------------------------------------===// 1196 // Constant 1197 //===----------------------------------------------------------------------===// 1198 1199 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, 1200 Attribute value, Type type, 1201 Location loc) { 1202 if (!spirv::ConstantOp::isBuildableWith(type)) 1203 return nullptr; 1204 1205 return builder.create<spirv::ConstantOp>(loc, type, value); 1206 } 1207 1208 //===----------------------------------------------------------------------===// 1209 // Shader Interface ABI 1210 //===----------------------------------------------------------------------===// 1211 1212 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, 1213 NamedAttribute attribute) { 1214 StringRef symbol = attribute.getName().strref(); 1215 Attribute attr = attribute.getValue(); 1216 1217 // TODO: figure out a way to generate the description from the 1218 // StructAttr definition. 1219 if (symbol == spirv::getEntryPointABIAttrName()) { 1220 if (!attr.isa<spirv::EntryPointABIAttr>()) 1221 return op->emitError("'") 1222 << symbol 1223 << "' attribute must be a dictionary attribute containing one " 1224 "32-bit integer elements attribute: 'local_size'"; 1225 } else if (symbol == spirv::getTargetEnvAttrName()) { 1226 if (!attr.isa<spirv::TargetEnvAttr>()) 1227 return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr"; 1228 } else { 1229 return op->emitError("found unsupported '") 1230 << symbol << "' attribute on operation"; 1231 } 1232 1233 return success(); 1234 } 1235 1236 /// Verifies the given SPIR-V `attribute` attached to a value of the given 1237 /// `valueType` is valid. 1238 static LogicalResult verifyRegionAttribute(Location loc, Type valueType, 1239 NamedAttribute attribute) { 1240 StringRef symbol = attribute.getName().strref(); 1241 Attribute attr = attribute.getValue(); 1242 1243 if (symbol != spirv::getInterfaceVarABIAttrName()) 1244 return emitError(loc, "found unsupported '") 1245 << symbol << "' attribute on region argument"; 1246 1247 auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>(); 1248 if (!varABIAttr) 1249 return emitError(loc, "'") 1250 << symbol << "' must be a spirv::InterfaceVarABIAttr"; 1251 1252 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat()) 1253 return emitError(loc, "'") << symbol 1254 << "' attribute cannot specify storage class " 1255 "when attaching to a non-scalar value"; 1256 1257 return success(); 1258 } 1259 1260 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, 1261 unsigned regionIndex, 1262 unsigned argIndex, 1263 NamedAttribute attribute) { 1264 return verifyRegionAttribute( 1265 op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(), 1266 attribute); 1267 } 1268 1269 LogicalResult SPIRVDialect::verifyRegionResultAttribute( 1270 Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/, 1271 NamedAttribute attribute) { 1272 return op->emitError("cannot attach SPIR-V attributes to region result"); 1273 } 1274