1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities used to lower to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/StringExtras.h" 19 #include "llvm/Support/Debug.h" 20 21 #include <functional> 22 23 #define DEBUG_TYPE "mlir-spirv-conversion" 24 25 using namespace mlir; 26 27 //===----------------------------------------------------------------------===// 28 // Utility functions 29 //===----------------------------------------------------------------------===// 30 31 /// Checks that `candidates` extension requirements are possible to be satisfied 32 /// with the given `targetEnv`. 33 /// 34 /// `candidates` is a vector of vector for extension requirements following 35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 36 /// convention. 37 template <typename LabelT> 38 static LogicalResult checkExtensionRequirements( 39 LabelT label, const spirv::TargetEnv &targetEnv, 40 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { 41 for (const auto &ors : candidates) { 42 if (targetEnv.allows(ors)) 43 continue; 44 45 LLVM_DEBUG({ 46 SmallVector<StringRef> extStrings; 47 for (spirv::Extension ext : ors) 48 extStrings.push_back(spirv::stringifyExtension(ext)); 49 50 llvm::dbgs() << label << " illegal: requires at least one extension in [" 51 << llvm::join(extStrings, ", ") 52 << "] but none allowed in target environment\n"; 53 }); 54 return failure(); 55 } 56 return success(); 57 } 58 59 /// Checks that `candidates`capability requirements are possible to be satisfied 60 /// with the given `isAllowedFn`. 61 /// 62 /// `candidates` is a vector of vector for capability requirements following 63 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 64 /// convention. 65 template <typename LabelT> 66 static LogicalResult checkCapabilityRequirements( 67 LabelT label, const spirv::TargetEnv &targetEnv, 68 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { 69 for (const auto &ors : candidates) { 70 if (targetEnv.allows(ors)) 71 continue; 72 73 LLVM_DEBUG({ 74 SmallVector<StringRef> capStrings; 75 for (spirv::Capability cap : ors) 76 capStrings.push_back(spirv::stringifyCapability(cap)); 77 78 llvm::dbgs() << label << " illegal: requires at least one capability in [" 79 << llvm::join(capStrings, ", ") 80 << "] but none allowed in target environment\n"; 81 }); 82 return failure(); 83 } 84 return success(); 85 } 86 87 /// Returns true if the given `storageClass` needs explicit layout when used in 88 /// Shader environments. 89 static bool needsExplicitLayout(spirv::StorageClass storageClass) { 90 switch (storageClass) { 91 case spirv::StorageClass::PhysicalStorageBuffer: 92 case spirv::StorageClass::PushConstant: 93 case spirv::StorageClass::StorageBuffer: 94 case spirv::StorageClass::Uniform: 95 return true; 96 default: 97 return false; 98 } 99 } 100 101 /// Wraps the given `elementType` in a struct and gets the pointer to the 102 /// struct. This is used to satisfy Vulkan interface requirements. 103 static spirv::PointerType 104 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) { 105 auto structType = needsExplicitLayout(storageClass) 106 ? spirv::StructType::get(elementType, /*offsetInfo=*/0) 107 : spirv::StructType::get(elementType); 108 return spirv::PointerType::get(structType, storageClass); 109 } 110 111 //===----------------------------------------------------------------------===// 112 // Type Conversion 113 //===----------------------------------------------------------------------===// 114 115 Type SPIRVTypeConverter::getIndexType() const { 116 return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32); 117 } 118 119 /// Mapping between SPIR-V storage classes to memref memory spaces. 120 /// 121 /// Note: memref does not have a defined semantics for each memory space; it 122 /// depends on the context where it is used. There are no particular reasons 123 /// behind the number assignments; we try to follow NVVM conventions and largely 124 /// give common storage classes a smaller number. The hope is use symbolic 125 /// memory space representation eventually after memref supports it. 126 // TODO: swap Generic and StorageBuffer assignment to be more akin 127 // to NVVM. 128 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \ 129 MAP_FN(spirv::StorageClass::Generic, 1) \ 130 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ 131 MAP_FN(spirv::StorageClass::Workgroup, 3) \ 132 MAP_FN(spirv::StorageClass::Uniform, 4) \ 133 MAP_FN(spirv::StorageClass::Private, 5) \ 134 MAP_FN(spirv::StorageClass::Function, 6) \ 135 MAP_FN(spirv::StorageClass::PushConstant, 7) \ 136 MAP_FN(spirv::StorageClass::UniformConstant, 8) \ 137 MAP_FN(spirv::StorageClass::Input, 9) \ 138 MAP_FN(spirv::StorageClass::Output, 10) \ 139 MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ 140 MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ 141 MAP_FN(spirv::StorageClass::Image, 13) \ 142 MAP_FN(spirv::StorageClass::CallableDataNV, 14) \ 143 MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15) \ 144 MAP_FN(spirv::StorageClass::RayPayloadNV, 16) \ 145 MAP_FN(spirv::StorageClass::HitAttributeNV, 17) \ 146 MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18) \ 147 MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19) \ 148 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) 149 150 unsigned 151 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { 152 #define STORAGE_SPACE_MAP_FN(storage, space) \ 153 case storage: \ 154 return space; 155 156 switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) } 157 #undef STORAGE_SPACE_MAP_FN 158 llvm_unreachable("unhandled storage class!"); 159 } 160 161 Optional<spirv::StorageClass> 162 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { 163 #define STORAGE_SPACE_MAP_FN(storage, space) \ 164 case space: \ 165 return storage; 166 167 switch (space) { 168 STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 169 default: 170 return llvm::None; 171 } 172 #undef STORAGE_SPACE_MAP_FN 173 } 174 175 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const { 176 return options; 177 } 178 179 MLIRContext *SPIRVTypeConverter::getContext() const { 180 return targetEnv.getAttr().getContext(); 181 } 182 183 #undef STORAGE_SPACE_MAP_LIST 184 185 // TODO: This is a utility function that should probably be exposed by the 186 // SPIR-V dialect. Keeping it local till the use case arises. 187 static Optional<int64_t> 188 getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) { 189 if (type.isa<spirv::ScalarType>()) { 190 auto bitWidth = type.getIntOrFloatBitWidth(); 191 // According to the SPIR-V spec: 192 // "There is no physical size or bit pattern defined for values with boolean 193 // type. If they are stored (in conjunction with OpVariable), they can only 194 // be used with logical addressing operations, not physical, and only with 195 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, 196 // Private, Function, Input, and Output." 197 if (bitWidth == 1) 198 return llvm::None; 199 return bitWidth / 8; 200 } 201 202 if (auto vecType = type.dyn_cast<VectorType>()) { 203 auto elementSize = getTypeNumBytes(options, vecType.getElementType()); 204 if (!elementSize) 205 return llvm::None; 206 return vecType.getNumElements() * elementSize.getValue(); 207 } 208 209 if (auto memRefType = type.dyn_cast<MemRefType>()) { 210 // TODO: Layout should also be controlled by the ABI attributes. For now 211 // using the layout from MemRef. 212 int64_t offset; 213 SmallVector<int64_t, 4> strides; 214 if (!memRefType.hasStaticShape() || 215 failed(getStridesAndOffset(memRefType, strides, offset))) 216 return llvm::None; 217 218 // To get the size of the memref object in memory, the total size is the 219 // max(stride * dimension-size) computed for all dimensions times the size 220 // of the element. 221 auto elementSize = getTypeNumBytes(options, memRefType.getElementType()); 222 if (!elementSize) 223 return llvm::None; 224 225 if (memRefType.getRank() == 0) 226 return elementSize; 227 228 auto dims = memRefType.getShape(); 229 if (llvm::is_contained(dims, ShapedType::kDynamicSize) || 230 offset == MemRefType::getDynamicStrideOrOffset() || 231 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) 232 return llvm::None; 233 234 int64_t memrefSize = -1; 235 for (auto shape : enumerate(dims)) 236 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); 237 238 return (offset + memrefSize) * elementSize.getValue(); 239 } 240 241 if (auto tensorType = type.dyn_cast<TensorType>()) { 242 if (!tensorType.hasStaticShape()) 243 return llvm::None; 244 245 auto elementSize = getTypeNumBytes(options, tensorType.getElementType()); 246 if (!elementSize) 247 return llvm::None; 248 249 int64_t size = elementSize.getValue(); 250 for (auto shape : tensorType.getShape()) 251 size *= shape; 252 253 return size; 254 } 255 256 // TODO: Add size computation for other types. 257 return llvm::None; 258 } 259 260 /// Converts a scalar `type` to a suitable type under the given `targetEnv`. 261 static Type convertScalarType(const spirv::TargetEnv &targetEnv, 262 const SPIRVTypeConverter::Options &options, 263 spirv::ScalarType type, 264 Optional<spirv::StorageClass> storageClass = {}) { 265 // Get extension and capability requirements for the given type. 266 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 267 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 268 type.getExtensions(extensions, storageClass); 269 type.getCapabilities(capabilities, storageClass); 270 271 // If all requirements are met, then we can accept this type as-is. 272 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 273 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 274 return type; 275 276 // Otherwise we need to adjust the type, which really means adjusting the 277 // bitwidth given this is a scalar type. 278 279 if (!options.emulateNon32BitScalarTypes) 280 return nullptr; 281 282 if (auto floatType = type.dyn_cast<FloatType>()) { 283 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 284 return Builder(targetEnv.getContext()).getF32Type(); 285 } 286 287 auto intType = type.cast<IntegerType>(); 288 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 289 return IntegerType::get(targetEnv.getContext(), /*width=*/32, 290 intType.getSignedness()); 291 } 292 293 /// Converts a vector `type` to a suitable type under the given `targetEnv`. 294 static Type convertVectorType(const spirv::TargetEnv &targetEnv, 295 const SPIRVTypeConverter::Options &options, 296 VectorType type, 297 Optional<spirv::StorageClass> storageClass = {}) { 298 if (type.getRank() == 1 && type.getNumElements() == 1) 299 return type.getElementType(); 300 301 if (!spirv::CompositeType::isValid(type)) { 302 // TODO: Vector types with more than four elements can be translated into 303 // array types. 304 LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n"); 305 return nullptr; 306 } 307 308 // Get extension and capability requirements for the given type. 309 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 310 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 311 type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass); 312 type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass); 313 314 // If all requirements are met, then we can accept this type as-is. 315 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 316 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 317 return type; 318 319 auto elementType = convertScalarType( 320 targetEnv, options, type.getElementType().cast<spirv::ScalarType>(), 321 storageClass); 322 if (elementType) 323 return VectorType::get(type.getShape(), elementType); 324 return nullptr; 325 } 326 327 /// Converts a tensor `type` to a suitable type under the given `targetEnv`. 328 /// 329 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can 330 /// create composite constants with OpConstantComposite to embed relative large 331 /// constant values and use OpCompositeExtract and OpCompositeInsert to 332 /// manipulate, like what we do for vectors. 333 static Type convertTensorType(const spirv::TargetEnv &targetEnv, 334 const SPIRVTypeConverter::Options &options, 335 TensorType type) { 336 // TODO: Handle dynamic shapes. 337 if (!type.hasStaticShape()) { 338 LLVM_DEBUG(llvm::dbgs() 339 << type << " illegal: dynamic shape unimplemented\n"); 340 return nullptr; 341 } 342 343 auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>(); 344 if (!scalarType) { 345 LLVM_DEBUG(llvm::dbgs() 346 << type << " illegal: cannot convert non-scalar element type\n"); 347 return nullptr; 348 } 349 350 Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType); 351 Optional<int64_t> tensorSize = getTypeNumBytes(options, type); 352 if (!scalarSize || !tensorSize) { 353 LLVM_DEBUG(llvm::dbgs() 354 << type << " illegal: cannot deduce element count\n"); 355 return nullptr; 356 } 357 358 auto arrayElemCount = *tensorSize / *scalarSize; 359 auto arrayElemType = convertScalarType(targetEnv, options, scalarType); 360 if (!arrayElemType) 361 return nullptr; 362 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 363 if (!arrayElemSize) { 364 LLVM_DEBUG(llvm::dbgs() 365 << type << " illegal: cannot deduce converted element size\n"); 366 return nullptr; 367 } 368 369 return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); 370 } 371 372 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, 373 const SPIRVTypeConverter::Options &options, 374 MemRefType type) { 375 Optional<spirv::StorageClass> storageClass = 376 SPIRVTypeConverter::getStorageClassForMemorySpace( 377 type.getMemorySpaceAsInt()); 378 if (!storageClass) { 379 LLVM_DEBUG(llvm::dbgs() 380 << type << " illegal: cannot convert memory space\n"); 381 return nullptr; 382 } 383 384 unsigned numBoolBits = options.boolNumBits; 385 if (numBoolBits != 8) { 386 LLVM_DEBUG(llvm::dbgs() 387 << "using non-8-bit storage for bool types unimplemented"); 388 return nullptr; 389 } 390 auto elementType = IntegerType::get(type.getContext(), numBoolBits) 391 .dyn_cast<spirv::ScalarType>(); 392 if (!elementType) 393 return nullptr; 394 Type arrayElemType = 395 convertScalarType(targetEnv, options, elementType, storageClass); 396 if (!arrayElemType) 397 return nullptr; 398 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 399 if (!arrayElemSize) { 400 LLVM_DEBUG(llvm::dbgs() 401 << type << " illegal: cannot deduce converted element size\n"); 402 return nullptr; 403 } 404 405 if (!type.hasStaticShape()) { 406 auto arrayType = 407 spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize); 408 return wrapInStructAndGetPointer(arrayType, *storageClass); 409 } 410 411 int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8; 412 auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize; 413 auto arrayType = 414 spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); 415 416 return wrapInStructAndGetPointer(arrayType, *storageClass); 417 } 418 419 static Type convertMemrefType(const spirv::TargetEnv &targetEnv, 420 const SPIRVTypeConverter::Options &options, 421 MemRefType type) { 422 if (type.getElementType().isa<IntegerType>() && 423 type.getElementTypeBitWidth() == 1) { 424 return convertBoolMemrefType(targetEnv, options, type); 425 } 426 427 Optional<spirv::StorageClass> storageClass = 428 SPIRVTypeConverter::getStorageClassForMemorySpace( 429 type.getMemorySpaceAsInt()); 430 if (!storageClass) { 431 LLVM_DEBUG(llvm::dbgs() 432 << type << " illegal: cannot convert memory space\n"); 433 return nullptr; 434 } 435 436 Type arrayElemType; 437 Type elementType = type.getElementType(); 438 if (auto vecType = elementType.dyn_cast<VectorType>()) { 439 arrayElemType = 440 convertVectorType(targetEnv, options, vecType, storageClass); 441 } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) { 442 arrayElemType = 443 convertScalarType(targetEnv, options, scalarType, storageClass); 444 } else { 445 LLVM_DEBUG( 446 llvm::dbgs() 447 << type 448 << " unhandled: can only convert scalar or vector element type\n"); 449 return nullptr; 450 } 451 if (!arrayElemType) 452 return nullptr; 453 454 Optional<int64_t> elementSize = getTypeNumBytes(options, elementType); 455 if (!elementSize) { 456 LLVM_DEBUG(llvm::dbgs() 457 << type << " illegal: cannot deduce element size\n"); 458 return nullptr; 459 } 460 461 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 462 if (!arrayElemSize) { 463 LLVM_DEBUG(llvm::dbgs() 464 << type << " illegal: cannot deduce converted element size\n"); 465 return nullptr; 466 } 467 468 if (!type.hasStaticShape()) { 469 auto arrayType = 470 spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize); 471 return wrapInStructAndGetPointer(arrayType, *storageClass); 472 } 473 474 Optional<int64_t> memrefSize = getTypeNumBytes(options, type); 475 if (!memrefSize) { 476 LLVM_DEBUG(llvm::dbgs() 477 << type << " illegal: cannot deduce element count\n"); 478 return nullptr; 479 } 480 481 auto arrayElemCount = *memrefSize / *elementSize; 482 483 484 auto arrayType = 485 spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); 486 487 return wrapInStructAndGetPointer(arrayType, *storageClass); 488 } 489 490 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, 491 Options options) 492 : targetEnv(targetAttr), options(options) { 493 // Add conversions. The order matters here: later ones will be tried earlier. 494 495 // Allow all SPIR-V dialect specific types. This assumes all builtin types 496 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) 497 // were tried before. 498 // 499 // TODO: this assumes that the SPIR-V types are valid to use in 500 // the given target environment, which should be the case if the whole 501 // pipeline is driven by the same target environment. Still, we probably still 502 // want to validate and convert to be safe. 503 addConversion([](spirv::SPIRVType type) { return type; }); 504 505 addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); 506 507 addConversion([this](IntegerType intType) -> Optional<Type> { 508 if (auto scalarType = intType.dyn_cast<spirv::ScalarType>()) 509 return convertScalarType(this->targetEnv, this->options, scalarType); 510 return Type(); 511 }); 512 513 addConversion([this](FloatType floatType) -> Optional<Type> { 514 if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>()) 515 return convertScalarType(this->targetEnv, this->options, scalarType); 516 return Type(); 517 }); 518 519 addConversion([this](VectorType vectorType) { 520 return convertVectorType(this->targetEnv, this->options, vectorType); 521 }); 522 523 addConversion([this](TensorType tensorType) { 524 return convertTensorType(this->targetEnv, this->options, tensorType); 525 }); 526 527 addConversion([this](MemRefType memRefType) { 528 return convertMemrefType(this->targetEnv, this->options, memRefType); 529 }); 530 } 531 532 //===----------------------------------------------------------------------===// 533 // FuncOp Conversion Patterns 534 //===----------------------------------------------------------------------===// 535 536 namespace { 537 /// A pattern for rewriting function signature to convert arguments of functions 538 /// to be of valid SPIR-V types. 539 class FuncOpConversion final : public OpConversionPattern<FuncOp> { 540 public: 541 using OpConversionPattern<FuncOp>::OpConversionPattern; 542 543 LogicalResult 544 matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, 545 ConversionPatternRewriter &rewriter) const override; 546 }; 547 } // namespace 548 549 LogicalResult 550 FuncOpConversion::matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, 551 ConversionPatternRewriter &rewriter) const { 552 auto fnType = funcOp.getType(); 553 if (fnType.getNumResults() > 1) 554 return failure(); 555 556 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 557 for (auto argType : enumerate(fnType.getInputs())) { 558 auto convertedType = getTypeConverter()->convertType(argType.value()); 559 if (!convertedType) 560 return failure(); 561 signatureConverter.addInputs(argType.index(), convertedType); 562 } 563 564 Type resultType; 565 if (fnType.getNumResults() == 1) { 566 resultType = getTypeConverter()->convertType(fnType.getResult(0)); 567 if (!resultType) 568 return failure(); 569 } 570 571 // Create the converted spv.func op. 572 auto newFuncOp = rewriter.create<spirv::FuncOp>( 573 funcOp.getLoc(), funcOp.getName(), 574 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 575 resultType ? TypeRange(resultType) 576 : TypeRange())); 577 578 // Copy over all attributes other than the function name and type. 579 for (const auto &namedAttr : funcOp->getAttrs()) { 580 if (namedAttr.first != function_like_impl::getTypeAttrName() && 581 namedAttr.first != SymbolTable::getSymbolAttrName()) 582 newFuncOp->setAttr(namedAttr.first, namedAttr.second); 583 } 584 585 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 586 newFuncOp.end()); 587 if (failed(rewriter.convertRegionTypes( 588 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) 589 return failure(); 590 rewriter.eraseOp(funcOp); 591 return success(); 592 } 593 594 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 595 RewritePatternSet &patterns) { 596 patterns.add<FuncOpConversion>(typeConverter, patterns.getContext()); 597 } 598 599 //===----------------------------------------------------------------------===// 600 // Builtin Variables 601 //===----------------------------------------------------------------------===// 602 603 static spirv::GlobalVariableOp getBuiltinVariable(Block &body, 604 spirv::BuiltIn builtin) { 605 // Look through all global variables in the given `body` block and check if 606 // there is a spv.GlobalVariable that has the same `builtin` attribute. 607 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 608 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>( 609 spirv::SPIRVDialect::getAttributeName( 610 spirv::Decoration::BuiltIn))) { 611 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); 612 if (varBuiltIn && varBuiltIn.getValue() == builtin) { 613 return varOp; 614 } 615 } 616 } 617 return nullptr; 618 } 619 620 /// Gets name of global variable for a builtin. 621 static std::string getBuiltinVarName(spirv::BuiltIn builtin) { 622 return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; 623 } 624 625 /// Gets or inserts a global variable for a builtin within `body` block. 626 static spirv::GlobalVariableOp 627 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, 628 Type integerType, OpBuilder &builder) { 629 if (auto varOp = getBuiltinVariable(body, builtin)) 630 return varOp; 631 632 OpBuilder::InsertionGuard guard(builder); 633 builder.setInsertionPointToStart(&body); 634 635 spirv::GlobalVariableOp newVarOp; 636 switch (builtin) { 637 case spirv::BuiltIn::NumWorkgroups: 638 case spirv::BuiltIn::WorkgroupSize: 639 case spirv::BuiltIn::WorkgroupId: 640 case spirv::BuiltIn::LocalInvocationId: 641 case spirv::BuiltIn::GlobalInvocationId: { 642 auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), 643 spirv::StorageClass::Input); 644 std::string name = getBuiltinVarName(builtin); 645 newVarOp = 646 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 647 break; 648 } 649 case spirv::BuiltIn::SubgroupId: 650 case spirv::BuiltIn::NumSubgroups: 651 case spirv::BuiltIn::SubgroupSize: { 652 auto ptrType = 653 spirv::PointerType::get(integerType, spirv::StorageClass::Input); 654 std::string name = getBuiltinVarName(builtin); 655 newVarOp = 656 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 657 break; 658 } 659 default: 660 emitError(loc, "unimplemented builtin variable generation for ") 661 << stringifyBuiltIn(builtin); 662 } 663 return newVarOp; 664 } 665 666 Value mlir::spirv::getBuiltinVariableValue(Operation *op, 667 spirv::BuiltIn builtin, 668 Type integerType, 669 OpBuilder &builder) { 670 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 671 if (!parent) { 672 op->emitError("expected operation to be within a module-like op"); 673 return nullptr; 674 } 675 676 spirv::GlobalVariableOp varOp = 677 getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(), 678 builtin, integerType, builder); 679 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp); 680 return builder.create<spirv::LoadOp>(op->getLoc(), ptr); 681 } 682 683 //===----------------------------------------------------------------------===// 684 // Push constant storage 685 //===----------------------------------------------------------------------===// 686 687 /// Returns the pointer type for the push constant storage containing 688 /// `elementCount` 32-bit integer values. 689 static spirv::PointerType getPushConstantStorageType(unsigned elementCount, 690 Builder &builder, 691 Type indexType) { 692 auto arrayType = spirv::ArrayType::get(indexType, elementCount, 693 /*stride=*/4); 694 auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); 695 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); 696 } 697 698 /// Returns the push constant varible containing `elementCount` 32-bit integer 699 /// values in `body`. Returns null op if such an op does not exit. 700 static spirv::GlobalVariableOp getPushConstantVariable(Block &body, 701 unsigned elementCount) { 702 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 703 auto ptrType = varOp.type().dyn_cast<spirv::PointerType>(); 704 if (!ptrType) 705 continue; 706 707 // Note that Vulkan requires "There must be no more than one push constant 708 // block statically used per shader entry point." So we should always reuse 709 // the existing one. 710 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { 711 auto numElements = ptrType.getPointeeType() 712 .cast<spirv::StructType>() 713 .getElementType(0) 714 .cast<spirv::ArrayType>() 715 .getNumElements(); 716 if (numElements == elementCount) 717 return varOp; 718 } 719 } 720 return nullptr; 721 } 722 723 /// Gets or inserts a global variable for push constant storage containing 724 /// `elementCount` 32-bit integer values in `block`. 725 static spirv::GlobalVariableOp 726 getOrInsertPushConstantVariable(Location loc, Block &block, 727 unsigned elementCount, OpBuilder &b, 728 Type indexType) { 729 if (auto varOp = getPushConstantVariable(block, elementCount)) 730 return varOp; 731 732 auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); 733 auto type = getPushConstantStorageType(elementCount, builder, indexType); 734 const char *name = "__push_constant_var__"; 735 return builder.create<spirv::GlobalVariableOp>(loc, type, name, 736 /*initializer=*/nullptr); 737 } 738 739 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, 740 unsigned offset, Type integerType, 741 OpBuilder &builder) { 742 Location loc = op->getLoc(); 743 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 744 if (!parent) { 745 op->emitError("expected operation to be within a module-like op"); 746 return nullptr; 747 } 748 749 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( 750 loc, parent->getRegion(0).front(), elementCount, builder, integerType); 751 752 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder); 753 Value offsetOp = builder.create<spirv::ConstantOp>( 754 loc, integerType, builder.getI32IntegerAttr(offset)); 755 auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp); 756 auto acOp = builder.create<spirv::AccessChainOp>( 757 loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp})); 758 return builder.create<spirv::LoadOp>(loc, acOp); 759 } 760 761 //===----------------------------------------------------------------------===// 762 // Index calculation 763 //===----------------------------------------------------------------------===// 764 765 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, 766 int64_t offset, Type integerType, 767 Location loc, OpBuilder &builder) { 768 assert(indices.size() == strides.size() && 769 "must provide indices for all dimensions"); 770 771 // TODO: Consider moving to use affine.apply and patterns converting 772 // affine.apply to standard ops. This needs converting to SPIR-V passes to be 773 // broken down into progressive small steps so we can have intermediate steps 774 // using other dialects. At the moment SPIR-V is the final sink. 775 776 Value linearizedIndex = builder.create<spirv::ConstantOp>( 777 loc, integerType, IntegerAttr::get(integerType, offset)); 778 for (auto index : llvm::enumerate(indices)) { 779 Value strideVal = builder.create<spirv::ConstantOp>( 780 loc, integerType, 781 IntegerAttr::get(integerType, strides[index.index()])); 782 Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value()); 783 linearizedIndex = 784 builder.create<spirv::IAddOp>(loc, linearizedIndex, update); 785 } 786 return linearizedIndex; 787 } 788 789 spirv::AccessChainOp mlir::spirv::getElementPtr( 790 SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, 791 ValueRange indices, Location loc, OpBuilder &builder) { 792 // Get base and offset of the MemRefType and verify they are static. 793 794 int64_t offset; 795 SmallVector<int64_t, 4> strides; 796 if (failed(getStridesAndOffset(baseType, strides, offset)) || 797 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || 798 offset == MemRefType::getDynamicStrideOrOffset()) { 799 return nullptr; 800 } 801 802 auto indexType = typeConverter.getIndexType(); 803 804 SmallVector<Value, 2> linearizedIndices; 805 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); 806 807 // Add a '0' at the start to index into the struct. 808 linearizedIndices.push_back(zero); 809 810 if (baseType.getRank() == 0) { 811 linearizedIndices.push_back(zero); 812 } else { 813 linearizedIndices.push_back( 814 linearizeIndex(indices, strides, offset, indexType, loc, builder)); 815 } 816 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); 817 } 818 819 //===----------------------------------------------------------------------===// 820 // SPIR-V ConversionTarget 821 //===----------------------------------------------------------------------===// 822 823 std::unique_ptr<SPIRVConversionTarget> 824 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { 825 std::unique_ptr<SPIRVConversionTarget> target( 826 // std::make_unique does not work here because the constructor is private. 827 new SPIRVConversionTarget(targetAttr)); 828 SPIRVConversionTarget *targetPtr = target.get(); 829 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>( 830 // We need to capture the raw pointer here because it is stable: 831 // target will be destroyed once this function is returned. 832 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); 833 return target; 834 } 835 836 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr) 837 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} 838 839 bool SPIRVConversionTarget::isLegalOp(Operation *op) { 840 // Make sure this op is available at the given version. Ops not implementing 841 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all 842 // SPIR-V versions. 843 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) 844 if (minVersion.getMinVersion() > this->targetEnv.getVersion()) { 845 LLVM_DEBUG(llvm::dbgs() 846 << op->getName() << " illegal: requiring min version " 847 << spirv::stringifyVersion(minVersion.getMinVersion()) 848 << "\n"); 849 return false; 850 } 851 if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op)) 852 if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) { 853 LLVM_DEBUG(llvm::dbgs() 854 << op->getName() << " illegal: requiring max version " 855 << spirv::stringifyVersion(maxVersion.getMaxVersion()) 856 << "\n"); 857 return false; 858 } 859 860 // Make sure this op's required extensions are allowed to use. Ops not 861 // implementing QueryExtensionInterface do not require extensions to be 862 // available. 863 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 864 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 865 extensions.getExtensions()))) 866 return false; 867 868 // Make sure this op's required extensions are allowed to use. Ops not 869 // implementing QueryCapabilityInterface do not require capabilities to be 870 // available. 871 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 872 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 873 capabilities.getCapabilities()))) 874 return false; 875 876 SmallVector<Type, 4> valueTypes; 877 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 878 valueTypes.append(op->result_type_begin(), op->result_type_end()); 879 880 // Special treatment for global variables, whose type requirements are 881 // conveyed by type attributes. 882 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 883 valueTypes.push_back(globalVar.type()); 884 885 // Make sure the op's operands/results use types that are allowed by the 886 // target environment. 887 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 888 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 889 for (Type valueType : valueTypes) { 890 typeExtensions.clear(); 891 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); 892 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 893 typeExtensions))) 894 return false; 895 896 typeCapabilities.clear(); 897 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); 898 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 899 typeCapabilities))) 900 return false; 901 } 902 903 return true; 904 } 905