1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities used to lower to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/StringExtras.h" 19 #include "llvm/Support/Debug.h" 20 21 #include <functional> 22 23 #define DEBUG_TYPE "mlir-spirv-conversion" 24 25 using namespace mlir; 26 27 //===----------------------------------------------------------------------===// 28 // Utility functions 29 //===----------------------------------------------------------------------===// 30 31 /// Checks that `candidates` extension requirements are possible to be satisfied 32 /// with the given `targetEnv`. 33 /// 34 /// `candidates` is a vector of vector for extension requirements following 35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 36 /// convention. 37 template <typename LabelT> 38 static LogicalResult checkExtensionRequirements( 39 LabelT label, const spirv::TargetEnv &targetEnv, 40 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { 41 for (const auto &ors : candidates) { 42 if (targetEnv.allows(ors)) 43 continue; 44 45 LLVM_DEBUG({ 46 SmallVector<StringRef> extStrings; 47 for (spirv::Extension ext : ors) 48 extStrings.push_back(spirv::stringifyExtension(ext)); 49 50 llvm::dbgs() << label << " illegal: requires at least one extension in [" 51 << llvm::join(extStrings, ", ") 52 << "] but none allowed in target environment\n"; 53 }); 54 return failure(); 55 } 56 return success(); 57 } 58 59 /// Checks that `candidates`capability requirements are possible to be satisfied 60 /// with the given `isAllowedFn`. 61 /// 62 /// `candidates` is a vector of vector for capability requirements following 63 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 64 /// convention. 65 template <typename LabelT> 66 static LogicalResult checkCapabilityRequirements( 67 LabelT label, const spirv::TargetEnv &targetEnv, 68 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { 69 for (const auto &ors : candidates) { 70 if (targetEnv.allows(ors)) 71 continue; 72 73 LLVM_DEBUG({ 74 SmallVector<StringRef> capStrings; 75 for (spirv::Capability cap : ors) 76 capStrings.push_back(spirv::stringifyCapability(cap)); 77 78 llvm::dbgs() << label << " illegal: requires at least one capability in [" 79 << llvm::join(capStrings, ", ") 80 << "] but none allowed in target environment\n"; 81 }); 82 return failure(); 83 } 84 return success(); 85 } 86 87 /// Returns true if the given `storageClass` needs explicit layout when used in 88 /// Shader environments. 89 static bool needsExplicitLayout(spirv::StorageClass storageClass) { 90 switch (storageClass) { 91 case spirv::StorageClass::PhysicalStorageBuffer: 92 case spirv::StorageClass::PushConstant: 93 case spirv::StorageClass::StorageBuffer: 94 case spirv::StorageClass::Uniform: 95 return true; 96 default: 97 return false; 98 } 99 } 100 101 /// Wraps the given `elementType` in a struct and gets the pointer to the 102 /// struct. This is used to satisfy Vulkan interface requirements. 103 static spirv::PointerType 104 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) { 105 auto structType = needsExplicitLayout(storageClass) 106 ? spirv::StructType::get(elementType, /*offsetInfo=*/0) 107 : spirv::StructType::get(elementType); 108 return spirv::PointerType::get(structType, storageClass); 109 } 110 111 //===----------------------------------------------------------------------===// 112 // Type Conversion 113 //===----------------------------------------------------------------------===// 114 115 Type SPIRVTypeConverter::getIndexType() const { 116 return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32); 117 } 118 119 /// Mapping between SPIR-V storage classes to memref memory spaces. 120 /// 121 /// Note: memref does not have a defined semantics for each memory space; it 122 /// depends on the context where it is used. There are no particular reasons 123 /// behind the number assignments; we try to follow NVVM conventions and largely 124 /// give common storage classes a smaller number. The hope is use symbolic 125 /// memory space representation eventually after memref supports it. 126 // TODO: swap Generic and StorageBuffer assignment to be more akin 127 // to NVVM. 128 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \ 129 MAP_FN(spirv::StorageClass::Generic, 1) \ 130 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ 131 MAP_FN(spirv::StorageClass::Workgroup, 3) \ 132 MAP_FN(spirv::StorageClass::Uniform, 4) \ 133 MAP_FN(spirv::StorageClass::Private, 5) \ 134 MAP_FN(spirv::StorageClass::Function, 6) \ 135 MAP_FN(spirv::StorageClass::PushConstant, 7) \ 136 MAP_FN(spirv::StorageClass::UniformConstant, 8) \ 137 MAP_FN(spirv::StorageClass::Input, 9) \ 138 MAP_FN(spirv::StorageClass::Output, 10) \ 139 MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ 140 MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ 141 MAP_FN(spirv::StorageClass::Image, 13) \ 142 MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \ 143 MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \ 144 MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \ 145 MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \ 146 MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \ 147 MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \ 148 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \ 149 MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \ 150 MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \ 151 MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23) 152 153 unsigned 154 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { 155 #define STORAGE_SPACE_MAP_FN(storage, space) \ 156 case storage: \ 157 return space; 158 159 switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) } 160 #undef STORAGE_SPACE_MAP_FN 161 llvm_unreachable("unhandled storage class!"); 162 } 163 164 Optional<spirv::StorageClass> 165 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { 166 #define STORAGE_SPACE_MAP_FN(storage, space) \ 167 case space: \ 168 return storage; 169 170 switch (space) { 171 STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 172 default: 173 return llvm::None; 174 } 175 #undef STORAGE_SPACE_MAP_FN 176 } 177 178 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const { 179 return options; 180 } 181 182 MLIRContext *SPIRVTypeConverter::getContext() const { 183 return targetEnv.getAttr().getContext(); 184 } 185 186 #undef STORAGE_SPACE_MAP_LIST 187 188 // TODO: This is a utility function that should probably be exposed by the 189 // SPIR-V dialect. Keeping it local till the use case arises. 190 static Optional<int64_t> 191 getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) { 192 if (type.isa<spirv::ScalarType>()) { 193 auto bitWidth = type.getIntOrFloatBitWidth(); 194 // According to the SPIR-V spec: 195 // "There is no physical size or bit pattern defined for values with boolean 196 // type. If they are stored (in conjunction with OpVariable), they can only 197 // be used with logical addressing operations, not physical, and only with 198 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, 199 // Private, Function, Input, and Output." 200 if (bitWidth == 1) 201 return llvm::None; 202 return bitWidth / 8; 203 } 204 205 if (auto vecType = type.dyn_cast<VectorType>()) { 206 auto elementSize = getTypeNumBytes(options, vecType.getElementType()); 207 if (!elementSize) 208 return llvm::None; 209 return vecType.getNumElements() * elementSize.getValue(); 210 } 211 212 if (auto memRefType = type.dyn_cast<MemRefType>()) { 213 // TODO: Layout should also be controlled by the ABI attributes. For now 214 // using the layout from MemRef. 215 int64_t offset; 216 SmallVector<int64_t, 4> strides; 217 if (!memRefType.hasStaticShape() || 218 failed(getStridesAndOffset(memRefType, strides, offset))) 219 return llvm::None; 220 221 // To get the size of the memref object in memory, the total size is the 222 // max(stride * dimension-size) computed for all dimensions times the size 223 // of the element. 224 auto elementSize = getTypeNumBytes(options, memRefType.getElementType()); 225 if (!elementSize) 226 return llvm::None; 227 228 if (memRefType.getRank() == 0) 229 return elementSize; 230 231 auto dims = memRefType.getShape(); 232 if (llvm::is_contained(dims, ShapedType::kDynamicSize) || 233 offset == MemRefType::getDynamicStrideOrOffset() || 234 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) 235 return llvm::None; 236 237 int64_t memrefSize = -1; 238 for (auto shape : enumerate(dims)) 239 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); 240 241 return (offset + memrefSize) * elementSize.getValue(); 242 } 243 244 if (auto tensorType = type.dyn_cast<TensorType>()) { 245 if (!tensorType.hasStaticShape()) 246 return llvm::None; 247 248 auto elementSize = getTypeNumBytes(options, tensorType.getElementType()); 249 if (!elementSize) 250 return llvm::None; 251 252 int64_t size = elementSize.getValue(); 253 for (auto shape : tensorType.getShape()) 254 size *= shape; 255 256 return size; 257 } 258 259 // TODO: Add size computation for other types. 260 return llvm::None; 261 } 262 263 /// Converts a scalar `type` to a suitable type under the given `targetEnv`. 264 static Type convertScalarType(const spirv::TargetEnv &targetEnv, 265 const SPIRVTypeConverter::Options &options, 266 spirv::ScalarType type, 267 Optional<spirv::StorageClass> storageClass = {}) { 268 // Get extension and capability requirements for the given type. 269 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 270 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 271 type.getExtensions(extensions, storageClass); 272 type.getCapabilities(capabilities, storageClass); 273 274 // If all requirements are met, then we can accept this type as-is. 275 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 276 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 277 return type; 278 279 // Otherwise we need to adjust the type, which really means adjusting the 280 // bitwidth given this is a scalar type. 281 282 if (!options.emulateNon32BitScalarTypes) 283 return nullptr; 284 285 if (auto floatType = type.dyn_cast<FloatType>()) { 286 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 287 return Builder(targetEnv.getContext()).getF32Type(); 288 } 289 290 auto intType = type.cast<IntegerType>(); 291 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 292 return IntegerType::get(targetEnv.getContext(), /*width=*/32, 293 intType.getSignedness()); 294 } 295 296 /// Converts a vector `type` to a suitable type under the given `targetEnv`. 297 static Type convertVectorType(const spirv::TargetEnv &targetEnv, 298 const SPIRVTypeConverter::Options &options, 299 VectorType type, 300 Optional<spirv::StorageClass> storageClass = {}) { 301 if (type.getRank() == 1 && type.getNumElements() == 1) 302 return type.getElementType(); 303 304 if (!spirv::CompositeType::isValid(type)) { 305 // TODO: Vector types with more than four elements can be translated into 306 // array types. 307 LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n"); 308 return nullptr; 309 } 310 311 // Get extension and capability requirements for the given type. 312 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 313 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 314 type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass); 315 type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass); 316 317 // If all requirements are met, then we can accept this type as-is. 318 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 319 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 320 return type; 321 322 auto elementType = convertScalarType( 323 targetEnv, options, type.getElementType().cast<spirv::ScalarType>(), 324 storageClass); 325 if (elementType) 326 return VectorType::get(type.getShape(), elementType); 327 return nullptr; 328 } 329 330 /// Converts a tensor `type` to a suitable type under the given `targetEnv`. 331 /// 332 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can 333 /// create composite constants with OpConstantComposite to embed relative large 334 /// constant values and use OpCompositeExtract and OpCompositeInsert to 335 /// manipulate, like what we do for vectors. 336 static Type convertTensorType(const spirv::TargetEnv &targetEnv, 337 const SPIRVTypeConverter::Options &options, 338 TensorType type) { 339 // TODO: Handle dynamic shapes. 340 if (!type.hasStaticShape()) { 341 LLVM_DEBUG(llvm::dbgs() 342 << type << " illegal: dynamic shape unimplemented\n"); 343 return nullptr; 344 } 345 346 auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>(); 347 if (!scalarType) { 348 LLVM_DEBUG(llvm::dbgs() 349 << type << " illegal: cannot convert non-scalar element type\n"); 350 return nullptr; 351 } 352 353 Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType); 354 Optional<int64_t> tensorSize = getTypeNumBytes(options, type); 355 if (!scalarSize || !tensorSize) { 356 LLVM_DEBUG(llvm::dbgs() 357 << type << " illegal: cannot deduce element count\n"); 358 return nullptr; 359 } 360 361 auto arrayElemCount = *tensorSize / *scalarSize; 362 auto arrayElemType = convertScalarType(targetEnv, options, scalarType); 363 if (!arrayElemType) 364 return nullptr; 365 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 366 if (!arrayElemSize) { 367 LLVM_DEBUG(llvm::dbgs() 368 << type << " illegal: cannot deduce converted element size\n"); 369 return nullptr; 370 } 371 372 return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); 373 } 374 375 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, 376 const SPIRVTypeConverter::Options &options, 377 MemRefType type) { 378 Optional<spirv::StorageClass> storageClass = 379 SPIRVTypeConverter::getStorageClassForMemorySpace( 380 type.getMemorySpaceAsInt()); 381 if (!storageClass) { 382 LLVM_DEBUG(llvm::dbgs() 383 << type << " illegal: cannot convert memory space\n"); 384 return nullptr; 385 } 386 387 unsigned numBoolBits = options.boolNumBits; 388 if (numBoolBits != 8) { 389 LLVM_DEBUG(llvm::dbgs() 390 << "using non-8-bit storage for bool types unimplemented"); 391 return nullptr; 392 } 393 auto elementType = IntegerType::get(type.getContext(), numBoolBits) 394 .dyn_cast<spirv::ScalarType>(); 395 if (!elementType) 396 return nullptr; 397 Type arrayElemType = 398 convertScalarType(targetEnv, options, elementType, storageClass); 399 if (!arrayElemType) 400 return nullptr; 401 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 402 if (!arrayElemSize) { 403 LLVM_DEBUG(llvm::dbgs() 404 << type << " illegal: cannot deduce converted element size\n"); 405 return nullptr; 406 } 407 408 if (!type.hasStaticShape()) { 409 auto arrayType = 410 spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize); 411 return wrapInStructAndGetPointer(arrayType, *storageClass); 412 } 413 414 int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8; 415 auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize; 416 auto arrayType = 417 spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); 418 419 return wrapInStructAndGetPointer(arrayType, *storageClass); 420 } 421 422 static Type convertMemrefType(const spirv::TargetEnv &targetEnv, 423 const SPIRVTypeConverter::Options &options, 424 MemRefType type) { 425 if (type.getElementType().isa<IntegerType>() && 426 type.getElementTypeBitWidth() == 1) { 427 return convertBoolMemrefType(targetEnv, options, type); 428 } 429 430 Optional<spirv::StorageClass> storageClass = 431 SPIRVTypeConverter::getStorageClassForMemorySpace( 432 type.getMemorySpaceAsInt()); 433 if (!storageClass) { 434 LLVM_DEBUG(llvm::dbgs() 435 << type << " illegal: cannot convert memory space\n"); 436 return nullptr; 437 } 438 439 Type arrayElemType; 440 Type elementType = type.getElementType(); 441 if (auto vecType = elementType.dyn_cast<VectorType>()) { 442 arrayElemType = 443 convertVectorType(targetEnv, options, vecType, storageClass); 444 } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) { 445 arrayElemType = 446 convertScalarType(targetEnv, options, scalarType, storageClass); 447 } else { 448 LLVM_DEBUG( 449 llvm::dbgs() 450 << type 451 << " unhandled: can only convert scalar or vector element type\n"); 452 return nullptr; 453 } 454 if (!arrayElemType) 455 return nullptr; 456 457 Optional<int64_t> elementSize = getTypeNumBytes(options, elementType); 458 if (!elementSize) { 459 LLVM_DEBUG(llvm::dbgs() 460 << type << " illegal: cannot deduce element size\n"); 461 return nullptr; 462 } 463 464 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 465 if (!arrayElemSize) { 466 LLVM_DEBUG(llvm::dbgs() 467 << type << " illegal: cannot deduce converted element size\n"); 468 return nullptr; 469 } 470 471 if (!type.hasStaticShape()) { 472 auto arrayType = 473 spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize); 474 return wrapInStructAndGetPointer(arrayType, *storageClass); 475 } 476 477 Optional<int64_t> memrefSize = getTypeNumBytes(options, type); 478 if (!memrefSize) { 479 LLVM_DEBUG(llvm::dbgs() 480 << type << " illegal: cannot deduce element count\n"); 481 return nullptr; 482 } 483 484 auto arrayElemCount = *memrefSize / *elementSize; 485 486 487 auto arrayType = 488 spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); 489 490 return wrapInStructAndGetPointer(arrayType, *storageClass); 491 } 492 493 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, 494 Options options) 495 : targetEnv(targetAttr), options(options) { 496 // Add conversions. The order matters here: later ones will be tried earlier. 497 498 // Allow all SPIR-V dialect specific types. This assumes all builtin types 499 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) 500 // were tried before. 501 // 502 // TODO: this assumes that the SPIR-V types are valid to use in 503 // the given target environment, which should be the case if the whole 504 // pipeline is driven by the same target environment. Still, we probably still 505 // want to validate and convert to be safe. 506 addConversion([](spirv::SPIRVType type) { return type; }); 507 508 addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); 509 510 addConversion([this](IntegerType intType) -> Optional<Type> { 511 if (auto scalarType = intType.dyn_cast<spirv::ScalarType>()) 512 return convertScalarType(this->targetEnv, this->options, scalarType); 513 return Type(); 514 }); 515 516 addConversion([this](FloatType floatType) -> Optional<Type> { 517 if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>()) 518 return convertScalarType(this->targetEnv, this->options, scalarType); 519 return Type(); 520 }); 521 522 addConversion([this](VectorType vectorType) { 523 return convertVectorType(this->targetEnv, this->options, vectorType); 524 }); 525 526 addConversion([this](TensorType tensorType) { 527 return convertTensorType(this->targetEnv, this->options, tensorType); 528 }); 529 530 addConversion([this](MemRefType memRefType) { 531 return convertMemrefType(this->targetEnv, this->options, memRefType); 532 }); 533 } 534 535 //===----------------------------------------------------------------------===// 536 // FuncOp Conversion Patterns 537 //===----------------------------------------------------------------------===// 538 539 namespace { 540 /// A pattern for rewriting function signature to convert arguments of functions 541 /// to be of valid SPIR-V types. 542 class FuncOpConversion final : public OpConversionPattern<FuncOp> { 543 public: 544 using OpConversionPattern<FuncOp>::OpConversionPattern; 545 546 LogicalResult 547 matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, 548 ConversionPatternRewriter &rewriter) const override; 549 }; 550 } // namespace 551 552 LogicalResult 553 FuncOpConversion::matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, 554 ConversionPatternRewriter &rewriter) const { 555 auto fnType = funcOp.getType(); 556 if (fnType.getNumResults() > 1) 557 return failure(); 558 559 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 560 for (auto argType : enumerate(fnType.getInputs())) { 561 auto convertedType = getTypeConverter()->convertType(argType.value()); 562 if (!convertedType) 563 return failure(); 564 signatureConverter.addInputs(argType.index(), convertedType); 565 } 566 567 Type resultType; 568 if (fnType.getNumResults() == 1) { 569 resultType = getTypeConverter()->convertType(fnType.getResult(0)); 570 if (!resultType) 571 return failure(); 572 } 573 574 // Create the converted spv.func op. 575 auto newFuncOp = rewriter.create<spirv::FuncOp>( 576 funcOp.getLoc(), funcOp.getName(), 577 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 578 resultType ? TypeRange(resultType) 579 : TypeRange())); 580 581 // Copy over all attributes other than the function name and type. 582 for (const auto &namedAttr : funcOp->getAttrs()) { 583 if (namedAttr.getName() != function_like_impl::getTypeAttrName() && 584 namedAttr.getName() != SymbolTable::getSymbolAttrName()) 585 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); 586 } 587 588 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 589 newFuncOp.end()); 590 if (failed(rewriter.convertRegionTypes( 591 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) 592 return failure(); 593 rewriter.eraseOp(funcOp); 594 return success(); 595 } 596 597 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 598 RewritePatternSet &patterns) { 599 patterns.add<FuncOpConversion>(typeConverter, patterns.getContext()); 600 } 601 602 //===----------------------------------------------------------------------===// 603 // Builtin Variables 604 //===----------------------------------------------------------------------===// 605 606 static spirv::GlobalVariableOp getBuiltinVariable(Block &body, 607 spirv::BuiltIn builtin) { 608 // Look through all global variables in the given `body` block and check if 609 // there is a spv.GlobalVariable that has the same `builtin` attribute. 610 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 611 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>( 612 spirv::SPIRVDialect::getAttributeName( 613 spirv::Decoration::BuiltIn))) { 614 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); 615 if (varBuiltIn && varBuiltIn.getValue() == builtin) { 616 return varOp; 617 } 618 } 619 } 620 return nullptr; 621 } 622 623 /// Gets name of global variable for a builtin. 624 static std::string getBuiltinVarName(spirv::BuiltIn builtin) { 625 return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; 626 } 627 628 /// Gets or inserts a global variable for a builtin within `body` block. 629 static spirv::GlobalVariableOp 630 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, 631 Type integerType, OpBuilder &builder) { 632 if (auto varOp = getBuiltinVariable(body, builtin)) 633 return varOp; 634 635 OpBuilder::InsertionGuard guard(builder); 636 builder.setInsertionPointToStart(&body); 637 638 spirv::GlobalVariableOp newVarOp; 639 switch (builtin) { 640 case spirv::BuiltIn::NumWorkgroups: 641 case spirv::BuiltIn::WorkgroupSize: 642 case spirv::BuiltIn::WorkgroupId: 643 case spirv::BuiltIn::LocalInvocationId: 644 case spirv::BuiltIn::GlobalInvocationId: { 645 auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), 646 spirv::StorageClass::Input); 647 std::string name = getBuiltinVarName(builtin); 648 newVarOp = 649 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 650 break; 651 } 652 case spirv::BuiltIn::SubgroupId: 653 case spirv::BuiltIn::NumSubgroups: 654 case spirv::BuiltIn::SubgroupSize: { 655 auto ptrType = 656 spirv::PointerType::get(integerType, spirv::StorageClass::Input); 657 std::string name = getBuiltinVarName(builtin); 658 newVarOp = 659 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 660 break; 661 } 662 default: 663 emitError(loc, "unimplemented builtin variable generation for ") 664 << stringifyBuiltIn(builtin); 665 } 666 return newVarOp; 667 } 668 669 Value mlir::spirv::getBuiltinVariableValue(Operation *op, 670 spirv::BuiltIn builtin, 671 Type integerType, 672 OpBuilder &builder) { 673 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 674 if (!parent) { 675 op->emitError("expected operation to be within a module-like op"); 676 return nullptr; 677 } 678 679 spirv::GlobalVariableOp varOp = 680 getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(), 681 builtin, integerType, builder); 682 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp); 683 return builder.create<spirv::LoadOp>(op->getLoc(), ptr); 684 } 685 686 //===----------------------------------------------------------------------===// 687 // Push constant storage 688 //===----------------------------------------------------------------------===// 689 690 /// Returns the pointer type for the push constant storage containing 691 /// `elementCount` 32-bit integer values. 692 static spirv::PointerType getPushConstantStorageType(unsigned elementCount, 693 Builder &builder, 694 Type indexType) { 695 auto arrayType = spirv::ArrayType::get(indexType, elementCount, 696 /*stride=*/4); 697 auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); 698 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); 699 } 700 701 /// Returns the push constant varible containing `elementCount` 32-bit integer 702 /// values in `body`. Returns null op if such an op does not exit. 703 static spirv::GlobalVariableOp getPushConstantVariable(Block &body, 704 unsigned elementCount) { 705 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 706 auto ptrType = varOp.type().dyn_cast<spirv::PointerType>(); 707 if (!ptrType) 708 continue; 709 710 // Note that Vulkan requires "There must be no more than one push constant 711 // block statically used per shader entry point." So we should always reuse 712 // the existing one. 713 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { 714 auto numElements = ptrType.getPointeeType() 715 .cast<spirv::StructType>() 716 .getElementType(0) 717 .cast<spirv::ArrayType>() 718 .getNumElements(); 719 if (numElements == elementCount) 720 return varOp; 721 } 722 } 723 return nullptr; 724 } 725 726 /// Gets or inserts a global variable for push constant storage containing 727 /// `elementCount` 32-bit integer values in `block`. 728 static spirv::GlobalVariableOp 729 getOrInsertPushConstantVariable(Location loc, Block &block, 730 unsigned elementCount, OpBuilder &b, 731 Type indexType) { 732 if (auto varOp = getPushConstantVariable(block, elementCount)) 733 return varOp; 734 735 auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); 736 auto type = getPushConstantStorageType(elementCount, builder, indexType); 737 const char *name = "__push_constant_var__"; 738 return builder.create<spirv::GlobalVariableOp>(loc, type, name, 739 /*initializer=*/nullptr); 740 } 741 742 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, 743 unsigned offset, Type integerType, 744 OpBuilder &builder) { 745 Location loc = op->getLoc(); 746 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 747 if (!parent) { 748 op->emitError("expected operation to be within a module-like op"); 749 return nullptr; 750 } 751 752 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( 753 loc, parent->getRegion(0).front(), elementCount, builder, integerType); 754 755 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder); 756 Value offsetOp = builder.create<spirv::ConstantOp>( 757 loc, integerType, builder.getI32IntegerAttr(offset)); 758 auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp); 759 auto acOp = builder.create<spirv::AccessChainOp>( 760 loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp})); 761 return builder.create<spirv::LoadOp>(loc, acOp); 762 } 763 764 //===----------------------------------------------------------------------===// 765 // Index calculation 766 //===----------------------------------------------------------------------===// 767 768 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, 769 int64_t offset, Type integerType, 770 Location loc, OpBuilder &builder) { 771 assert(indices.size() == strides.size() && 772 "must provide indices for all dimensions"); 773 774 // TODO: Consider moving to use affine.apply and patterns converting 775 // affine.apply to standard ops. This needs converting to SPIR-V passes to be 776 // broken down into progressive small steps so we can have intermediate steps 777 // using other dialects. At the moment SPIR-V is the final sink. 778 779 Value linearizedIndex = builder.create<spirv::ConstantOp>( 780 loc, integerType, IntegerAttr::get(integerType, offset)); 781 for (auto index : llvm::enumerate(indices)) { 782 Value strideVal = builder.create<spirv::ConstantOp>( 783 loc, integerType, 784 IntegerAttr::get(integerType, strides[index.index()])); 785 Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value()); 786 linearizedIndex = 787 builder.create<spirv::IAddOp>(loc, linearizedIndex, update); 788 } 789 return linearizedIndex; 790 } 791 792 spirv::AccessChainOp mlir::spirv::getElementPtr( 793 SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, 794 ValueRange indices, Location loc, OpBuilder &builder) { 795 // Get base and offset of the MemRefType and verify they are static. 796 797 int64_t offset; 798 SmallVector<int64_t, 4> strides; 799 if (failed(getStridesAndOffset(baseType, strides, offset)) || 800 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || 801 offset == MemRefType::getDynamicStrideOrOffset()) { 802 return nullptr; 803 } 804 805 auto indexType = typeConverter.getIndexType(); 806 807 SmallVector<Value, 2> linearizedIndices; 808 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); 809 810 // Add a '0' at the start to index into the struct. 811 linearizedIndices.push_back(zero); 812 813 if (baseType.getRank() == 0) { 814 linearizedIndices.push_back(zero); 815 } else { 816 linearizedIndices.push_back( 817 linearizeIndex(indices, strides, offset, indexType, loc, builder)); 818 } 819 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); 820 } 821 822 //===----------------------------------------------------------------------===// 823 // SPIR-V ConversionTarget 824 //===----------------------------------------------------------------------===// 825 826 std::unique_ptr<SPIRVConversionTarget> 827 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { 828 std::unique_ptr<SPIRVConversionTarget> target( 829 // std::make_unique does not work here because the constructor is private. 830 new SPIRVConversionTarget(targetAttr)); 831 SPIRVConversionTarget *targetPtr = target.get(); 832 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>( 833 // We need to capture the raw pointer here because it is stable: 834 // target will be destroyed once this function is returned. 835 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); 836 return target; 837 } 838 839 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr) 840 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} 841 842 bool SPIRVConversionTarget::isLegalOp(Operation *op) { 843 // Make sure this op is available at the given version. Ops not implementing 844 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all 845 // SPIR-V versions. 846 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) { 847 Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion(); 848 if (minVersion && *minVersion > this->targetEnv.getVersion()) { 849 LLVM_DEBUG(llvm::dbgs() 850 << op->getName() << " illegal: requiring min version " 851 << spirv::stringifyVersion(*minVersion) << "\n"); 852 return false; 853 } 854 } 855 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) { 856 Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion(); 857 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) { 858 LLVM_DEBUG(llvm::dbgs() 859 << op->getName() << " illegal: requiring max version " 860 << spirv::stringifyVersion(*maxVersion) << "\n"); 861 return false; 862 } 863 } 864 865 // Make sure this op's required extensions are allowed to use. Ops not 866 // implementing QueryExtensionInterface do not require extensions to be 867 // available. 868 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 869 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 870 extensions.getExtensions()))) 871 return false; 872 873 // Make sure this op's required extensions are allowed to use. Ops not 874 // implementing QueryCapabilityInterface do not require capabilities to be 875 // available. 876 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 877 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 878 capabilities.getCapabilities()))) 879 return false; 880 881 SmallVector<Type, 4> valueTypes; 882 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 883 valueTypes.append(op->result_type_begin(), op->result_type_end()); 884 885 // Ensure that all types have been converted to SPIRV types. 886 if (llvm::any_of(valueTypes, 887 [](Type t) { return !t.isa<spirv::SPIRVType>(); })) 888 return false; 889 890 // Special treatment for global variables, whose type requirements are 891 // conveyed by type attributes. 892 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 893 valueTypes.push_back(globalVar.type()); 894 895 // Make sure the op's operands/results use types that are allowed by the 896 // target environment. 897 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 898 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 899 for (Type valueType : valueTypes) { 900 typeExtensions.clear(); 901 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); 902 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 903 typeExtensions))) 904 return false; 905 906 typeCapabilities.clear(); 907 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); 908 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 909 typeCapabilities))) 910 return false; 911 } 912 913 return true; 914 } 915