1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities used to lower to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/StringExtras.h" 20 #include "llvm/Support/Debug.h" 21 22 #include <functional> 23 24 #define DEBUG_TYPE "mlir-spirv-conversion" 25 26 using namespace mlir; 27 28 //===----------------------------------------------------------------------===// 29 // Utility functions 30 //===----------------------------------------------------------------------===// 31 32 /// Checks that `candidates` extension requirements are possible to be satisfied 33 /// with the given `targetEnv`. 34 /// 35 /// `candidates` is a vector of vector for extension requirements following 36 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 37 /// convention. 38 template <typename LabelT> 39 static LogicalResult checkExtensionRequirements( 40 LabelT label, const spirv::TargetEnv &targetEnv, 41 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { 42 for (const auto &ors : candidates) { 43 if (targetEnv.allows(ors)) 44 continue; 45 46 LLVM_DEBUG({ 47 SmallVector<StringRef> extStrings; 48 for (spirv::Extension ext : ors) 49 extStrings.push_back(spirv::stringifyExtension(ext)); 50 51 llvm::dbgs() << label << " illegal: requires at least one extension in [" 52 << llvm::join(extStrings, ", ") 53 << "] but none allowed in target environment\n"; 54 }); 55 return failure(); 56 } 57 return success(); 58 } 59 60 /// Checks that `candidates`capability requirements are possible to be satisfied 61 /// with the given `isAllowedFn`. 62 /// 63 /// `candidates` is a vector of vector for capability requirements following 64 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 65 /// convention. 66 template <typename LabelT> 67 static LogicalResult checkCapabilityRequirements( 68 LabelT label, const spirv::TargetEnv &targetEnv, 69 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { 70 for (const auto &ors : candidates) { 71 if (targetEnv.allows(ors)) 72 continue; 73 74 LLVM_DEBUG({ 75 SmallVector<StringRef> capStrings; 76 for (spirv::Capability cap : ors) 77 capStrings.push_back(spirv::stringifyCapability(cap)); 78 79 llvm::dbgs() << label << " illegal: requires at least one capability in [" 80 << llvm::join(capStrings, ", ") 81 << "] but none allowed in target environment\n"; 82 }); 83 return failure(); 84 } 85 return success(); 86 } 87 88 /// Returns true if the given `storageClass` needs explicit layout when used in 89 /// Shader environments. 90 static bool needsExplicitLayout(spirv::StorageClass storageClass) { 91 switch (storageClass) { 92 case spirv::StorageClass::PhysicalStorageBuffer: 93 case spirv::StorageClass::PushConstant: 94 case spirv::StorageClass::StorageBuffer: 95 case spirv::StorageClass::Uniform: 96 return true; 97 default: 98 return false; 99 } 100 } 101 102 /// Wraps the given `elementType` in a struct and gets the pointer to the 103 /// struct. This is used to satisfy Vulkan interface requirements. 104 static spirv::PointerType 105 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) { 106 auto structType = needsExplicitLayout(storageClass) 107 ? spirv::StructType::get(elementType, /*offsetInfo=*/0) 108 : spirv::StructType::get(elementType); 109 return spirv::PointerType::get(structType, storageClass); 110 } 111 112 //===----------------------------------------------------------------------===// 113 // Type Conversion 114 //===----------------------------------------------------------------------===// 115 116 Type SPIRVTypeConverter::getIndexType() const { 117 return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32); 118 } 119 120 /// Mapping between SPIR-V storage classes to memref memory spaces. 121 /// 122 /// Note: memref does not have a defined semantics for each memory space; it 123 /// depends on the context where it is used. There are no particular reasons 124 /// behind the number assignments; we try to follow NVVM conventions and largely 125 /// give common storage classes a smaller number. The hope is use symbolic 126 /// memory space representation eventually after memref supports it. 127 // TODO: swap Generic and StorageBuffer assignment to be more akin 128 // to NVVM. 129 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \ 130 MAP_FN(spirv::StorageClass::Generic, 1) \ 131 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ 132 MAP_FN(spirv::StorageClass::Workgroup, 3) \ 133 MAP_FN(spirv::StorageClass::Uniform, 4) \ 134 MAP_FN(spirv::StorageClass::Private, 5) \ 135 MAP_FN(spirv::StorageClass::Function, 6) \ 136 MAP_FN(spirv::StorageClass::PushConstant, 7) \ 137 MAP_FN(spirv::StorageClass::UniformConstant, 8) \ 138 MAP_FN(spirv::StorageClass::Input, 9) \ 139 MAP_FN(spirv::StorageClass::Output, 10) \ 140 MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ 141 MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ 142 MAP_FN(spirv::StorageClass::Image, 13) \ 143 MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \ 144 MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \ 145 MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \ 146 MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \ 147 MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \ 148 MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \ 149 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \ 150 MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \ 151 MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \ 152 MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23) 153 154 unsigned 155 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { 156 #define STORAGE_SPACE_MAP_FN(storage, space) \ 157 case storage: \ 158 return space; 159 160 switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) } 161 #undef STORAGE_SPACE_MAP_FN 162 llvm_unreachable("unhandled storage class!"); 163 } 164 165 Optional<spirv::StorageClass> 166 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { 167 #define STORAGE_SPACE_MAP_FN(storage, space) \ 168 case space: \ 169 return storage; 170 171 switch (space) { 172 STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 173 default: 174 return llvm::None; 175 } 176 #undef STORAGE_SPACE_MAP_FN 177 } 178 179 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const { 180 return options; 181 } 182 183 MLIRContext *SPIRVTypeConverter::getContext() const { 184 return targetEnv.getAttr().getContext(); 185 } 186 187 #undef STORAGE_SPACE_MAP_LIST 188 189 // TODO: This is a utility function that should probably be exposed by the 190 // SPIR-V dialect. Keeping it local till the use case arises. 191 static Optional<int64_t> 192 getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) { 193 if (type.isa<spirv::ScalarType>()) { 194 auto bitWidth = type.getIntOrFloatBitWidth(); 195 // According to the SPIR-V spec: 196 // "There is no physical size or bit pattern defined for values with boolean 197 // type. If they are stored (in conjunction with OpVariable), they can only 198 // be used with logical addressing operations, not physical, and only with 199 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, 200 // Private, Function, Input, and Output." 201 if (bitWidth == 1) 202 return llvm::None; 203 return bitWidth / 8; 204 } 205 206 if (auto vecType = type.dyn_cast<VectorType>()) { 207 auto elementSize = getTypeNumBytes(options, vecType.getElementType()); 208 if (!elementSize) 209 return llvm::None; 210 return vecType.getNumElements() * elementSize.getValue(); 211 } 212 213 if (auto memRefType = type.dyn_cast<MemRefType>()) { 214 // TODO: Layout should also be controlled by the ABI attributes. For now 215 // using the layout from MemRef. 216 int64_t offset; 217 SmallVector<int64_t, 4> strides; 218 if (!memRefType.hasStaticShape() || 219 failed(getStridesAndOffset(memRefType, strides, offset))) 220 return llvm::None; 221 222 // To get the size of the memref object in memory, the total size is the 223 // max(stride * dimension-size) computed for all dimensions times the size 224 // of the element. 225 auto elementSize = getTypeNumBytes(options, memRefType.getElementType()); 226 if (!elementSize) 227 return llvm::None; 228 229 if (memRefType.getRank() == 0) 230 return elementSize; 231 232 auto dims = memRefType.getShape(); 233 if (llvm::is_contained(dims, ShapedType::kDynamicSize) || 234 offset == MemRefType::getDynamicStrideOrOffset() || 235 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) 236 return llvm::None; 237 238 int64_t memrefSize = -1; 239 for (const auto &shape : enumerate(dims)) 240 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); 241 242 return (offset + memrefSize) * elementSize.getValue(); 243 } 244 245 if (auto tensorType = type.dyn_cast<TensorType>()) { 246 if (!tensorType.hasStaticShape()) 247 return llvm::None; 248 249 auto elementSize = getTypeNumBytes(options, tensorType.getElementType()); 250 if (!elementSize) 251 return llvm::None; 252 253 int64_t size = elementSize.getValue(); 254 for (auto shape : tensorType.getShape()) 255 size *= shape; 256 257 return size; 258 } 259 260 // TODO: Add size computation for other types. 261 return llvm::None; 262 } 263 264 /// Converts a scalar `type` to a suitable type under the given `targetEnv`. 265 static Type convertScalarType(const spirv::TargetEnv &targetEnv, 266 const SPIRVTypeConverter::Options &options, 267 spirv::ScalarType type, 268 Optional<spirv::StorageClass> storageClass = {}) { 269 // Get extension and capability requirements for the given type. 270 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 271 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 272 type.getExtensions(extensions, storageClass); 273 type.getCapabilities(capabilities, storageClass); 274 275 // If all requirements are met, then we can accept this type as-is. 276 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 277 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 278 return type; 279 280 // Otherwise we need to adjust the type, which really means adjusting the 281 // bitwidth given this is a scalar type. 282 283 if (!options.emulateNon32BitScalarTypes) 284 return nullptr; 285 286 if (auto floatType = type.dyn_cast<FloatType>()) { 287 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 288 return Builder(targetEnv.getContext()).getF32Type(); 289 } 290 291 auto intType = type.cast<IntegerType>(); 292 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 293 return IntegerType::get(targetEnv.getContext(), /*width=*/32, 294 intType.getSignedness()); 295 } 296 297 /// Converts a vector `type` to a suitable type under the given `targetEnv`. 298 static Type convertVectorType(const spirv::TargetEnv &targetEnv, 299 const SPIRVTypeConverter::Options &options, 300 VectorType type, 301 Optional<spirv::StorageClass> storageClass = {}) { 302 if (type.getRank() == 1 && type.getNumElements() == 1) 303 return type.getElementType(); 304 305 if (!spirv::CompositeType::isValid(type)) { 306 // TODO: Vector types with more than four elements can be translated into 307 // array types. 308 LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n"); 309 return nullptr; 310 } 311 312 // Get extension and capability requirements for the given type. 313 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 314 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 315 type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass); 316 type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass); 317 318 // If all requirements are met, then we can accept this type as-is. 319 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 320 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 321 return type; 322 323 auto elementType = convertScalarType( 324 targetEnv, options, type.getElementType().cast<spirv::ScalarType>(), 325 storageClass); 326 if (elementType) 327 return VectorType::get(type.getShape(), elementType); 328 return nullptr; 329 } 330 331 /// Converts a tensor `type` to a suitable type under the given `targetEnv`. 332 /// 333 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can 334 /// create composite constants with OpConstantComposite to embed relative large 335 /// constant values and use OpCompositeExtract and OpCompositeInsert to 336 /// manipulate, like what we do for vectors. 337 static Type convertTensorType(const spirv::TargetEnv &targetEnv, 338 const SPIRVTypeConverter::Options &options, 339 TensorType type) { 340 // TODO: Handle dynamic shapes. 341 if (!type.hasStaticShape()) { 342 LLVM_DEBUG(llvm::dbgs() 343 << type << " illegal: dynamic shape unimplemented\n"); 344 return nullptr; 345 } 346 347 auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>(); 348 if (!scalarType) { 349 LLVM_DEBUG(llvm::dbgs() 350 << type << " illegal: cannot convert non-scalar element type\n"); 351 return nullptr; 352 } 353 354 Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType); 355 Optional<int64_t> tensorSize = getTypeNumBytes(options, type); 356 if (!scalarSize || !tensorSize) { 357 LLVM_DEBUG(llvm::dbgs() 358 << type << " illegal: cannot deduce element count\n"); 359 return nullptr; 360 } 361 362 auto arrayElemCount = *tensorSize / *scalarSize; 363 auto arrayElemType = convertScalarType(targetEnv, options, scalarType); 364 if (!arrayElemType) 365 return nullptr; 366 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 367 if (!arrayElemSize) { 368 LLVM_DEBUG(llvm::dbgs() 369 << type << " illegal: cannot deduce converted element size\n"); 370 return nullptr; 371 } 372 373 return spirv::ArrayType::get(arrayElemType, arrayElemCount); 374 } 375 376 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, 377 const SPIRVTypeConverter::Options &options, 378 MemRefType type) { 379 Optional<spirv::StorageClass> storageClass = 380 SPIRVTypeConverter::getStorageClassForMemorySpace( 381 type.getMemorySpaceAsInt()); 382 if (!storageClass) { 383 LLVM_DEBUG(llvm::dbgs() 384 << type << " illegal: cannot convert memory space\n"); 385 return nullptr; 386 } 387 388 unsigned numBoolBits = options.boolNumBits; 389 if (numBoolBits != 8) { 390 LLVM_DEBUG(llvm::dbgs() 391 << "using non-8-bit storage for bool types unimplemented"); 392 return nullptr; 393 } 394 auto elementType = IntegerType::get(type.getContext(), numBoolBits) 395 .dyn_cast<spirv::ScalarType>(); 396 if (!elementType) 397 return nullptr; 398 Type arrayElemType = 399 convertScalarType(targetEnv, options, elementType, storageClass); 400 if (!arrayElemType) 401 return nullptr; 402 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 403 if (!arrayElemSize) { 404 LLVM_DEBUG(llvm::dbgs() 405 << type << " illegal: cannot deduce converted element size\n"); 406 return nullptr; 407 } 408 409 if (!type.hasStaticShape()) { 410 int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0; 411 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); 412 return wrapInStructAndGetPointer(arrayType, *storageClass); 413 } 414 415 int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8; 416 auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize; 417 int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0; 418 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); 419 420 return wrapInStructAndGetPointer(arrayType, *storageClass); 421 } 422 423 static Type convertMemrefType(const spirv::TargetEnv &targetEnv, 424 const SPIRVTypeConverter::Options &options, 425 MemRefType type) { 426 if (type.getElementType().isa<IntegerType>() && 427 type.getElementTypeBitWidth() == 1) { 428 return convertBoolMemrefType(targetEnv, options, type); 429 } 430 431 Optional<spirv::StorageClass> storageClass = 432 SPIRVTypeConverter::getStorageClassForMemorySpace( 433 type.getMemorySpaceAsInt()); 434 if (!storageClass) { 435 LLVM_DEBUG(llvm::dbgs() 436 << type << " illegal: cannot convert memory space\n"); 437 return nullptr; 438 } 439 440 Type arrayElemType; 441 Type elementType = type.getElementType(); 442 if (auto vecType = elementType.dyn_cast<VectorType>()) { 443 arrayElemType = 444 convertVectorType(targetEnv, options, vecType, storageClass); 445 } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) { 446 arrayElemType = 447 convertScalarType(targetEnv, options, scalarType, storageClass); 448 } else { 449 LLVM_DEBUG( 450 llvm::dbgs() 451 << type 452 << " unhandled: can only convert scalar or vector element type\n"); 453 return nullptr; 454 } 455 if (!arrayElemType) 456 return nullptr; 457 458 Optional<int64_t> elementSize = getTypeNumBytes(options, elementType); 459 if (!elementSize) { 460 LLVM_DEBUG(llvm::dbgs() 461 << type << " illegal: cannot deduce element size\n"); 462 return nullptr; 463 } 464 465 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 466 if (!arrayElemSize) { 467 LLVM_DEBUG(llvm::dbgs() 468 << type << " illegal: cannot deduce converted element size\n"); 469 return nullptr; 470 } 471 472 if (!type.hasStaticShape()) { 473 int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0; 474 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); 475 return wrapInStructAndGetPointer(arrayType, *storageClass); 476 } 477 478 Optional<int64_t> memrefSize = getTypeNumBytes(options, type); 479 if (!memrefSize) { 480 LLVM_DEBUG(llvm::dbgs() 481 << type << " illegal: cannot deduce element count\n"); 482 return nullptr; 483 } 484 485 auto arrayElemCount = *memrefSize / *elementSize; 486 int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0; 487 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); 488 489 return wrapInStructAndGetPointer(arrayType, *storageClass); 490 } 491 492 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, 493 Options options) 494 : targetEnv(targetAttr), options(options) { 495 // Add conversions. The order matters here: later ones will be tried earlier. 496 497 // Allow all SPIR-V dialect specific types. This assumes all builtin types 498 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) 499 // were tried before. 500 // 501 // TODO: this assumes that the SPIR-V types are valid to use in 502 // the given target environment, which should be the case if the whole 503 // pipeline is driven by the same target environment. Still, we probably still 504 // want to validate and convert to be safe. 505 addConversion([](spirv::SPIRVType type) { return type; }); 506 507 addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); 508 509 addConversion([this](IntegerType intType) -> Optional<Type> { 510 if (auto scalarType = intType.dyn_cast<spirv::ScalarType>()) 511 return convertScalarType(this->targetEnv, this->options, scalarType); 512 return Type(); 513 }); 514 515 addConversion([this](FloatType floatType) -> Optional<Type> { 516 if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>()) 517 return convertScalarType(this->targetEnv, this->options, scalarType); 518 return Type(); 519 }); 520 521 addConversion([this](VectorType vectorType) { 522 return convertVectorType(this->targetEnv, this->options, vectorType); 523 }); 524 525 addConversion([this](TensorType tensorType) { 526 return convertTensorType(this->targetEnv, this->options, tensorType); 527 }); 528 529 addConversion([this](MemRefType memRefType) { 530 return convertMemrefType(this->targetEnv, this->options, memRefType); 531 }); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // func::FuncOp Conversion Patterns 536 //===----------------------------------------------------------------------===// 537 538 namespace { 539 /// A pattern for rewriting function signature to convert arguments of functions 540 /// to be of valid SPIR-V types. 541 class FuncOpConversion final : public OpConversionPattern<func::FuncOp> { 542 public: 543 using OpConversionPattern<func::FuncOp>::OpConversionPattern; 544 545 LogicalResult 546 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, 547 ConversionPatternRewriter &rewriter) const override; 548 }; 549 } // namespace 550 551 LogicalResult 552 FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, 553 ConversionPatternRewriter &rewriter) const { 554 auto fnType = funcOp.getFunctionType(); 555 if (fnType.getNumResults() > 1) 556 return failure(); 557 558 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 559 for (const auto &argType : enumerate(fnType.getInputs())) { 560 auto convertedType = getTypeConverter()->convertType(argType.value()); 561 if (!convertedType) 562 return failure(); 563 signatureConverter.addInputs(argType.index(), convertedType); 564 } 565 566 Type resultType; 567 if (fnType.getNumResults() == 1) { 568 resultType = getTypeConverter()->convertType(fnType.getResult(0)); 569 if (!resultType) 570 return failure(); 571 } 572 573 // Create the converted spv.func op. 574 auto newFuncOp = rewriter.create<spirv::FuncOp>( 575 funcOp.getLoc(), funcOp.getName(), 576 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 577 resultType ? TypeRange(resultType) 578 : TypeRange())); 579 580 // Copy over all attributes other than the function name and type. 581 for (const auto &namedAttr : funcOp->getAttrs()) { 582 if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() && 583 namedAttr.getName() != SymbolTable::getSymbolAttrName()) 584 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); 585 } 586 587 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 588 newFuncOp.end()); 589 if (failed(rewriter.convertRegionTypes( 590 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) 591 return failure(); 592 rewriter.eraseOp(funcOp); 593 return success(); 594 } 595 596 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 597 RewritePatternSet &patterns) { 598 patterns.add<FuncOpConversion>(typeConverter, patterns.getContext()); 599 } 600 601 //===----------------------------------------------------------------------===// 602 // Builtin Variables 603 //===----------------------------------------------------------------------===// 604 605 static spirv::GlobalVariableOp getBuiltinVariable(Block &body, 606 spirv::BuiltIn builtin) { 607 // Look through all global variables in the given `body` block and check if 608 // there is a spv.GlobalVariable that has the same `builtin` attribute. 609 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 610 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>( 611 spirv::SPIRVDialect::getAttributeName( 612 spirv::Decoration::BuiltIn))) { 613 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); 614 if (varBuiltIn && varBuiltIn.getValue() == builtin) { 615 return varOp; 616 } 617 } 618 } 619 return nullptr; 620 } 621 622 /// Gets name of global variable for a builtin. 623 static std::string getBuiltinVarName(spirv::BuiltIn builtin) { 624 return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; 625 } 626 627 /// Gets or inserts a global variable for a builtin within `body` block. 628 static spirv::GlobalVariableOp 629 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, 630 Type integerType, OpBuilder &builder) { 631 if (auto varOp = getBuiltinVariable(body, builtin)) 632 return varOp; 633 634 OpBuilder::InsertionGuard guard(builder); 635 builder.setInsertionPointToStart(&body); 636 637 spirv::GlobalVariableOp newVarOp; 638 switch (builtin) { 639 case spirv::BuiltIn::NumWorkgroups: 640 case spirv::BuiltIn::WorkgroupSize: 641 case spirv::BuiltIn::WorkgroupId: 642 case spirv::BuiltIn::LocalInvocationId: 643 case spirv::BuiltIn::GlobalInvocationId: { 644 auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), 645 spirv::StorageClass::Input); 646 std::string name = getBuiltinVarName(builtin); 647 newVarOp = 648 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 649 break; 650 } 651 case spirv::BuiltIn::SubgroupId: 652 case spirv::BuiltIn::NumSubgroups: 653 case spirv::BuiltIn::SubgroupSize: { 654 auto ptrType = 655 spirv::PointerType::get(integerType, spirv::StorageClass::Input); 656 std::string name = getBuiltinVarName(builtin); 657 newVarOp = 658 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 659 break; 660 } 661 default: 662 emitError(loc, "unimplemented builtin variable generation for ") 663 << stringifyBuiltIn(builtin); 664 } 665 return newVarOp; 666 } 667 668 Value mlir::spirv::getBuiltinVariableValue(Operation *op, 669 spirv::BuiltIn builtin, 670 Type integerType, 671 OpBuilder &builder) { 672 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 673 if (!parent) { 674 op->emitError("expected operation to be within a module-like op"); 675 return nullptr; 676 } 677 678 spirv::GlobalVariableOp varOp = 679 getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(), 680 builtin, integerType, builder); 681 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp); 682 return builder.create<spirv::LoadOp>(op->getLoc(), ptr); 683 } 684 685 //===----------------------------------------------------------------------===// 686 // Push constant storage 687 //===----------------------------------------------------------------------===// 688 689 /// Returns the pointer type for the push constant storage containing 690 /// `elementCount` 32-bit integer values. 691 static spirv::PointerType getPushConstantStorageType(unsigned elementCount, 692 Builder &builder, 693 Type indexType) { 694 auto arrayType = spirv::ArrayType::get(indexType, elementCount, 695 /*stride=*/4); 696 auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); 697 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); 698 } 699 700 /// Returns the push constant varible containing `elementCount` 32-bit integer 701 /// values in `body`. Returns null op if such an op does not exit. 702 static spirv::GlobalVariableOp getPushConstantVariable(Block &body, 703 unsigned elementCount) { 704 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 705 auto ptrType = varOp.type().dyn_cast<spirv::PointerType>(); 706 if (!ptrType) 707 continue; 708 709 // Note that Vulkan requires "There must be no more than one push constant 710 // block statically used per shader entry point." So we should always reuse 711 // the existing one. 712 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { 713 auto numElements = ptrType.getPointeeType() 714 .cast<spirv::StructType>() 715 .getElementType(0) 716 .cast<spirv::ArrayType>() 717 .getNumElements(); 718 if (numElements == elementCount) 719 return varOp; 720 } 721 } 722 return nullptr; 723 } 724 725 /// Gets or inserts a global variable for push constant storage containing 726 /// `elementCount` 32-bit integer values in `block`. 727 static spirv::GlobalVariableOp 728 getOrInsertPushConstantVariable(Location loc, Block &block, 729 unsigned elementCount, OpBuilder &b, 730 Type indexType) { 731 if (auto varOp = getPushConstantVariable(block, elementCount)) 732 return varOp; 733 734 auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); 735 auto type = getPushConstantStorageType(elementCount, builder, indexType); 736 const char *name = "__push_constant_var__"; 737 return builder.create<spirv::GlobalVariableOp>(loc, type, name, 738 /*initializer=*/nullptr); 739 } 740 741 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, 742 unsigned offset, Type integerType, 743 OpBuilder &builder) { 744 Location loc = op->getLoc(); 745 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 746 if (!parent) { 747 op->emitError("expected operation to be within a module-like op"); 748 return nullptr; 749 } 750 751 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( 752 loc, parent->getRegion(0).front(), elementCount, builder, integerType); 753 754 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder); 755 Value offsetOp = builder.create<spirv::ConstantOp>( 756 loc, integerType, builder.getI32IntegerAttr(offset)); 757 auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp); 758 auto acOp = builder.create<spirv::AccessChainOp>( 759 loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp})); 760 return builder.create<spirv::LoadOp>(loc, acOp); 761 } 762 763 //===----------------------------------------------------------------------===// 764 // Index calculation 765 //===----------------------------------------------------------------------===// 766 767 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, 768 int64_t offset, Type integerType, 769 Location loc, OpBuilder &builder) { 770 assert(indices.size() == strides.size() && 771 "must provide indices for all dimensions"); 772 773 // TODO: Consider moving to use affine.apply and patterns converting 774 // affine.apply to standard ops. This needs converting to SPIR-V passes to be 775 // broken down into progressive small steps so we can have intermediate steps 776 // using other dialects. At the moment SPIR-V is the final sink. 777 778 Value linearizedIndex = builder.create<spirv::ConstantOp>( 779 loc, integerType, IntegerAttr::get(integerType, offset)); 780 for (const auto &index : llvm::enumerate(indices)) { 781 Value strideVal = builder.create<spirv::ConstantOp>( 782 loc, integerType, 783 IntegerAttr::get(integerType, strides[index.index()])); 784 Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value()); 785 linearizedIndex = 786 builder.create<spirv::IAddOp>(loc, linearizedIndex, update); 787 } 788 return linearizedIndex; 789 } 790 791 spirv::AccessChainOp mlir::spirv::getElementPtr( 792 SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, 793 ValueRange indices, Location loc, OpBuilder &builder) { 794 // Get base and offset of the MemRefType and verify they are static. 795 796 int64_t offset; 797 SmallVector<int64_t, 4> strides; 798 if (failed(getStridesAndOffset(baseType, strides, offset)) || 799 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || 800 offset == MemRefType::getDynamicStrideOrOffset()) { 801 return nullptr; 802 } 803 804 auto indexType = typeConverter.getIndexType(); 805 806 SmallVector<Value, 2> linearizedIndices; 807 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); 808 809 // Add a '0' at the start to index into the struct. 810 linearizedIndices.push_back(zero); 811 812 if (baseType.getRank() == 0) { 813 linearizedIndices.push_back(zero); 814 } else { 815 linearizedIndices.push_back( 816 linearizeIndex(indices, strides, offset, indexType, loc, builder)); 817 } 818 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); 819 } 820 821 //===----------------------------------------------------------------------===// 822 // SPIR-V ConversionTarget 823 //===----------------------------------------------------------------------===// 824 825 std::unique_ptr<SPIRVConversionTarget> 826 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { 827 std::unique_ptr<SPIRVConversionTarget> target( 828 // std::make_unique does not work here because the constructor is private. 829 new SPIRVConversionTarget(targetAttr)); 830 SPIRVConversionTarget *targetPtr = target.get(); 831 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>( 832 // We need to capture the raw pointer here because it is stable: 833 // target will be destroyed once this function is returned. 834 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); 835 return target; 836 } 837 838 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr) 839 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} 840 841 bool SPIRVConversionTarget::isLegalOp(Operation *op) { 842 // Make sure this op is available at the given version. Ops not implementing 843 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all 844 // SPIR-V versions. 845 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) { 846 Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion(); 847 if (minVersion && *minVersion > this->targetEnv.getVersion()) { 848 LLVM_DEBUG(llvm::dbgs() 849 << op->getName() << " illegal: requiring min version " 850 << spirv::stringifyVersion(*minVersion) << "\n"); 851 return false; 852 } 853 } 854 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) { 855 Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion(); 856 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) { 857 LLVM_DEBUG(llvm::dbgs() 858 << op->getName() << " illegal: requiring max version " 859 << spirv::stringifyVersion(*maxVersion) << "\n"); 860 return false; 861 } 862 } 863 864 // Make sure this op's required extensions are allowed to use. Ops not 865 // implementing QueryExtensionInterface do not require extensions to be 866 // available. 867 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 868 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 869 extensions.getExtensions()))) 870 return false; 871 872 // Make sure this op's required extensions are allowed to use. Ops not 873 // implementing QueryCapabilityInterface do not require capabilities to be 874 // available. 875 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 876 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 877 capabilities.getCapabilities()))) 878 return false; 879 880 SmallVector<Type, 4> valueTypes; 881 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 882 valueTypes.append(op->result_type_begin(), op->result_type_end()); 883 884 // Ensure that all types have been converted to SPIRV types. 885 if (llvm::any_of(valueTypes, 886 [](Type t) { return !t.isa<spirv::SPIRVType>(); })) 887 return false; 888 889 // Special treatment for global variables, whose type requirements are 890 // conveyed by type attributes. 891 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 892 valueTypes.push_back(globalVar.type()); 893 894 // Make sure the op's operands/results use types that are allowed by the 895 // target environment. 896 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 897 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 898 for (Type valueType : valueTypes) { 899 typeExtensions.clear(); 900 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); 901 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 902 typeExtensions))) 903 return false; 904 905 typeCapabilities.clear(); 906 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); 907 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 908 typeCapabilities))) 909 return false; 910 } 911 912 return true; 913 } 914