1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities used to lower to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/StringExtras.h" 19 #include "llvm/Support/Debug.h" 20 21 #include <functional> 22 23 #define DEBUG_TYPE "mlir-spirv-conversion" 24 25 using namespace mlir; 26 27 //===----------------------------------------------------------------------===// 28 // Utility functions 29 //===----------------------------------------------------------------------===// 30 31 /// Checks that `candidates` extension requirements are possible to be satisfied 32 /// with the given `targetEnv`. 33 /// 34 /// `candidates` is a vector of vector for extension requirements following 35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 36 /// convention. 37 template <typename LabelT> 38 static LogicalResult checkExtensionRequirements( 39 LabelT label, const spirv::TargetEnv &targetEnv, 40 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { 41 for (const auto &ors : candidates) { 42 if (targetEnv.allows(ors)) 43 continue; 44 45 SmallVector<StringRef, 4> extStrings; 46 for (spirv::Extension ext : ors) 47 extStrings.push_back(spirv::stringifyExtension(ext)); 48 49 LLVM_DEBUG(llvm::dbgs() 50 << label << " illegal: requires at least one extension in [" 51 << llvm::join(extStrings, ", ") 52 << "] but none allowed in target environment\n"); 53 return failure(); 54 } 55 return success(); 56 } 57 58 /// Checks that `candidates`capability requirements are possible to be satisfied 59 /// with the given `isAllowedFn`. 60 /// 61 /// `candidates` is a vector of vector for capability requirements following 62 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 63 /// convention. 64 template <typename LabelT> 65 static LogicalResult checkCapabilityRequirements( 66 LabelT label, const spirv::TargetEnv &targetEnv, 67 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { 68 for (const auto &ors : candidates) { 69 if (targetEnv.allows(ors)) 70 continue; 71 72 SmallVector<StringRef, 4> capStrings; 73 for (spirv::Capability cap : ors) 74 capStrings.push_back(spirv::stringifyCapability(cap)); 75 76 LLVM_DEBUG(llvm::dbgs() 77 << label << " illegal: requires at least one capability in [" 78 << llvm::join(capStrings, ", ") 79 << "] but none allowed in target environment\n"); 80 return failure(); 81 } 82 return success(); 83 } 84 85 //===----------------------------------------------------------------------===// 86 // Type Conversion 87 //===----------------------------------------------------------------------===// 88 89 Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { 90 // Convert to 32-bit integers for now. Might need a way to control this in 91 // future. 92 // TODO: It is probably better to make it 64-bit integers. To 93 // this some support is needed in SPIR-V dialect for Conversion 94 // instructions. The Vulkan spec requires the builtins like 95 // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be 96 // SExtended to 64-bit for index computations. 97 return IntegerType::get(context, 32); 98 } 99 100 /// Mapping between SPIR-V storage classes to memref memory spaces. 101 /// 102 /// Note: memref does not have a defined semantics for each memory space; it 103 /// depends on the context where it is used. There are no particular reasons 104 /// behind the number assignments; we try to follow NVVM conventions and largely 105 /// give common storage classes a smaller number. The hope is use symbolic 106 /// memory space representation eventually after memref supports it. 107 // TODO: swap Generic and StorageBuffer assignment to be more akin 108 // to NVVM. 109 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \ 110 MAP_FN(spirv::StorageClass::Generic, 1) \ 111 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ 112 MAP_FN(spirv::StorageClass::Workgroup, 3) \ 113 MAP_FN(spirv::StorageClass::Uniform, 4) \ 114 MAP_FN(spirv::StorageClass::Private, 5) \ 115 MAP_FN(spirv::StorageClass::Function, 6) \ 116 MAP_FN(spirv::StorageClass::PushConstant, 7) \ 117 MAP_FN(spirv::StorageClass::UniformConstant, 8) \ 118 MAP_FN(spirv::StorageClass::Input, 9) \ 119 MAP_FN(spirv::StorageClass::Output, 10) \ 120 MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ 121 MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ 122 MAP_FN(spirv::StorageClass::Image, 13) \ 123 MAP_FN(spirv::StorageClass::CallableDataNV, 14) \ 124 MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15) \ 125 MAP_FN(spirv::StorageClass::RayPayloadNV, 16) \ 126 MAP_FN(spirv::StorageClass::HitAttributeNV, 17) \ 127 MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18) \ 128 MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19) \ 129 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) 130 131 unsigned 132 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { 133 #define STORAGE_SPACE_MAP_FN(storage, space) \ 134 case storage: \ 135 return space; 136 137 switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) } 138 #undef STORAGE_SPACE_MAP_FN 139 llvm_unreachable("unhandled storage class!"); 140 } 141 142 Optional<spirv::StorageClass> 143 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { 144 #define STORAGE_SPACE_MAP_FN(storage, space) \ 145 case space: \ 146 return storage; 147 148 switch (space) { 149 STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 150 default: 151 return llvm::None; 152 } 153 #undef STORAGE_SPACE_MAP_FN 154 } 155 156 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const { 157 return options; 158 } 159 160 #undef STORAGE_SPACE_MAP_LIST 161 162 // TODO: This is a utility function that should probably be exposed by the 163 // SPIR-V dialect. Keeping it local till the use case arises. 164 static Optional<int64_t> 165 getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) { 166 if (type.isa<spirv::ScalarType>()) { 167 auto bitWidth = type.getIntOrFloatBitWidth(); 168 // According to the SPIR-V spec: 169 // "There is no physical size or bit pattern defined for values with boolean 170 // type. If they are stored (in conjunction with OpVariable), they can only 171 // be used with logical addressing operations, not physical, and only with 172 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, 173 // Private, Function, Input, and Output." 174 if (bitWidth == 1) 175 return llvm::None; 176 return bitWidth / 8; 177 } 178 179 if (auto vecType = type.dyn_cast<VectorType>()) { 180 auto elementSize = getTypeNumBytes(options, vecType.getElementType()); 181 if (!elementSize) 182 return llvm::None; 183 return vecType.getNumElements() * elementSize.getValue(); 184 } 185 186 if (auto memRefType = type.dyn_cast<MemRefType>()) { 187 // TODO: Layout should also be controlled by the ABI attributes. For now 188 // using the layout from MemRef. 189 int64_t offset; 190 SmallVector<int64_t, 4> strides; 191 if (!memRefType.hasStaticShape() || 192 failed(getStridesAndOffset(memRefType, strides, offset))) 193 return llvm::None; 194 195 // To get the size of the memref object in memory, the total size is the 196 // max(stride * dimension-size) computed for all dimensions times the size 197 // of the element. 198 auto elementSize = getTypeNumBytes(options, memRefType.getElementType()); 199 if (!elementSize) 200 return llvm::None; 201 202 if (memRefType.getRank() == 0) 203 return elementSize; 204 205 auto dims = memRefType.getShape(); 206 if (llvm::is_contained(dims, ShapedType::kDynamicSize) || 207 offset == MemRefType::getDynamicStrideOrOffset() || 208 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) 209 return llvm::None; 210 211 int64_t memrefSize = -1; 212 for (auto shape : enumerate(dims)) 213 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); 214 215 return (offset + memrefSize) * elementSize.getValue(); 216 } 217 218 if (auto tensorType = type.dyn_cast<TensorType>()) { 219 if (!tensorType.hasStaticShape()) 220 return llvm::None; 221 222 auto elementSize = getTypeNumBytes(options, tensorType.getElementType()); 223 if (!elementSize) 224 return llvm::None; 225 226 int64_t size = elementSize.getValue(); 227 for (auto shape : tensorType.getShape()) 228 size *= shape; 229 230 return size; 231 } 232 233 // TODO: Add size computation for other types. 234 return llvm::None; 235 } 236 237 /// Converts a scalar `type` to a suitable type under the given `targetEnv`. 238 static Type convertScalarType(const spirv::TargetEnv &targetEnv, 239 const SPIRVTypeConverter::Options &options, 240 spirv::ScalarType type, 241 Optional<spirv::StorageClass> storageClass = {}) { 242 // Get extension and capability requirements for the given type. 243 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 244 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 245 type.getExtensions(extensions, storageClass); 246 type.getCapabilities(capabilities, storageClass); 247 248 // If all requirements are met, then we can accept this type as-is. 249 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 250 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 251 return type; 252 253 // Otherwise we need to adjust the type, which really means adjusting the 254 // bitwidth given this is a scalar type. 255 256 if (!options.emulateNon32BitScalarTypes) 257 return nullptr; 258 259 if (auto floatType = type.dyn_cast<FloatType>()) { 260 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 261 return Builder(targetEnv.getContext()).getF32Type(); 262 } 263 264 auto intType = type.cast<IntegerType>(); 265 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 266 return IntegerType::get(targetEnv.getContext(), /*width=*/32, 267 intType.getSignedness()); 268 } 269 270 /// Converts a vector `type` to a suitable type under the given `targetEnv`. 271 static Type convertVectorType(const spirv::TargetEnv &targetEnv, 272 const SPIRVTypeConverter::Options &options, 273 VectorType type, 274 Optional<spirv::StorageClass> storageClass = {}) { 275 if (type.getRank() == 1 && type.getNumElements() == 1) 276 return type.getElementType(); 277 278 if (!spirv::CompositeType::isValid(type)) { 279 // TODO: Vector types with more than four elements can be translated into 280 // array types. 281 LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n"); 282 return nullptr; 283 } 284 285 // Get extension and capability requirements for the given type. 286 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 287 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 288 type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass); 289 type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass); 290 291 // If all requirements are met, then we can accept this type as-is. 292 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 293 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 294 return type; 295 296 auto elementType = convertScalarType( 297 targetEnv, options, type.getElementType().cast<spirv::ScalarType>(), 298 storageClass); 299 if (elementType) 300 return VectorType::get(type.getShape(), elementType); 301 return nullptr; 302 } 303 304 /// Converts a tensor `type` to a suitable type under the given `targetEnv`. 305 /// 306 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can 307 /// create composite constants with OpConstantComposite to embed relative large 308 /// constant values and use OpCompositeExtract and OpCompositeInsert to 309 /// manipulate, like what we do for vectors. 310 static Type convertTensorType(const spirv::TargetEnv &targetEnv, 311 const SPIRVTypeConverter::Options &options, 312 TensorType type) { 313 // TODO: Handle dynamic shapes. 314 if (!type.hasStaticShape()) { 315 LLVM_DEBUG(llvm::dbgs() 316 << type << " illegal: dynamic shape unimplemented\n"); 317 return nullptr; 318 } 319 320 auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>(); 321 if (!scalarType) { 322 LLVM_DEBUG(llvm::dbgs() 323 << type << " illegal: cannot convert non-scalar element type\n"); 324 return nullptr; 325 } 326 327 Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType); 328 Optional<int64_t> tensorSize = getTypeNumBytes(options, type); 329 if (!scalarSize || !tensorSize) { 330 LLVM_DEBUG(llvm::dbgs() 331 << type << " illegal: cannot deduce element count\n"); 332 return nullptr; 333 } 334 335 auto arrayElemCount = *tensorSize / *scalarSize; 336 auto arrayElemType = convertScalarType(targetEnv, options, scalarType); 337 if (!arrayElemType) 338 return nullptr; 339 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 340 if (!arrayElemSize) { 341 LLVM_DEBUG(llvm::dbgs() 342 << type << " illegal: cannot deduce converted element size\n"); 343 return nullptr; 344 } 345 346 return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); 347 } 348 349 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, 350 const SPIRVTypeConverter::Options &options, 351 MemRefType type) { 352 if (!type.hasStaticShape()) { 353 LLVM_DEBUG(llvm::dbgs() 354 << type << " dynamic shape on i1 is not supported yet\n"); 355 return nullptr; 356 } 357 358 Optional<spirv::StorageClass> storageClass = 359 SPIRVTypeConverter::getStorageClassForMemorySpace( 360 type.getMemorySpaceAsInt()); 361 if (!storageClass) { 362 LLVM_DEBUG(llvm::dbgs() 363 << type << " illegal: cannot convert memory space\n"); 364 return nullptr; 365 } 366 367 unsigned numBoolBits = options.boolNumBits; 368 if (numBoolBits != 8) { 369 LLVM_DEBUG(llvm::dbgs() 370 << "using non-8-bit storage for bool types unimplemented"); 371 return nullptr; 372 } 373 auto elementType = IntegerType::get(type.getContext(), numBoolBits) 374 .dyn_cast<spirv::ScalarType>(); 375 if (!elementType) 376 return nullptr; 377 Type arrayElemType = 378 convertScalarType(targetEnv, options, elementType, storageClass); 379 if (!arrayElemType) 380 return nullptr; 381 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 382 if (!arrayElemSize) { 383 LLVM_DEBUG(llvm::dbgs() 384 << type << " illegal: cannot deduce converted element size\n"); 385 return nullptr; 386 } 387 388 int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8; 389 auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize; 390 auto arrayType = 391 spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); 392 393 // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with 394 // workgroup storage class do not need the struct to be laid out explicitly. 395 auto structType = *storageClass == spirv::StorageClass::Workgroup 396 ? spirv::StructType::get(arrayType) 397 : spirv::StructType::get(arrayType, 0); 398 return spirv::PointerType::get(structType, *storageClass); 399 } 400 401 static Type convertMemrefType(const spirv::TargetEnv &targetEnv, 402 const SPIRVTypeConverter::Options &options, 403 MemRefType type) { 404 if (type.getElementType().isa<IntegerType>() && 405 type.getElementTypeBitWidth() == 1) { 406 return convertBoolMemrefType(targetEnv, options, type); 407 } 408 409 Optional<spirv::StorageClass> storageClass = 410 SPIRVTypeConverter::getStorageClassForMemorySpace( 411 type.getMemorySpaceAsInt()); 412 if (!storageClass) { 413 LLVM_DEBUG(llvm::dbgs() 414 << type << " illegal: cannot convert memory space\n"); 415 return nullptr; 416 } 417 418 Type arrayElemType; 419 Type elementType = type.getElementType(); 420 if (auto vecType = elementType.dyn_cast<VectorType>()) { 421 arrayElemType = 422 convertVectorType(targetEnv, options, vecType, storageClass); 423 } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) { 424 arrayElemType = 425 convertScalarType(targetEnv, options, scalarType, storageClass); 426 } else { 427 LLVM_DEBUG( 428 llvm::dbgs() 429 << type 430 << " unhandled: can only convert scalar or vector element type\n"); 431 return nullptr; 432 } 433 if (!arrayElemType) 434 return nullptr; 435 436 Optional<int64_t> elementSize = getTypeNumBytes(options, elementType); 437 if (!elementSize) { 438 LLVM_DEBUG(llvm::dbgs() 439 << type << " illegal: cannot deduce element size\n"); 440 return nullptr; 441 } 442 443 if (!type.hasStaticShape()) { 444 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, *elementSize); 445 // Wrap in a struct to satisfy Vulkan interface requirements. 446 auto structType = spirv::StructType::get(arrayType, 0); 447 return spirv::PointerType::get(structType, *storageClass); 448 } 449 450 Optional<int64_t> memrefSize = getTypeNumBytes(options, type); 451 if (!memrefSize) { 452 LLVM_DEBUG(llvm::dbgs() 453 << type << " illegal: cannot deduce element count\n"); 454 return nullptr; 455 } 456 457 auto arrayElemCount = *memrefSize / *elementSize; 458 459 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType); 460 if (!arrayElemSize) { 461 LLVM_DEBUG(llvm::dbgs() 462 << type << " illegal: cannot deduce converted element size\n"); 463 return nullptr; 464 } 465 466 auto arrayType = 467 spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); 468 469 // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with 470 // workgroup storage class do not need the struct to be laid out explicitly. 471 auto structType = *storageClass == spirv::StorageClass::Workgroup 472 ? spirv::StructType::get(arrayType) 473 : spirv::StructType::get(arrayType, 0); 474 return spirv::PointerType::get(structType, *storageClass); 475 } 476 477 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, 478 Options options) 479 : targetEnv(targetAttr), options(options) { 480 // Add conversions. The order matters here: later ones will be tried earlier. 481 482 // Allow all SPIR-V dialect specific types. This assumes all builtin types 483 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) 484 // were tried before. 485 // 486 // TODO: this assumes that the SPIR-V types are valid to use in 487 // the given target environment, which should be the case if the whole 488 // pipeline is driven by the same target environment. Still, we probably still 489 // want to validate and convert to be safe. 490 addConversion([](spirv::SPIRVType type) { return type; }); 491 492 addConversion([](IndexType indexType) { 493 return SPIRVTypeConverter::getIndexType(indexType.getContext()); 494 }); 495 496 addConversion([this](IntegerType intType) -> Optional<Type> { 497 if (auto scalarType = intType.dyn_cast<spirv::ScalarType>()) 498 return convertScalarType(this->targetEnv, this->options, scalarType); 499 return Type(); 500 }); 501 502 addConversion([this](FloatType floatType) -> Optional<Type> { 503 if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>()) 504 return convertScalarType(this->targetEnv, this->options, scalarType); 505 return Type(); 506 }); 507 508 addConversion([this](VectorType vectorType) { 509 return convertVectorType(this->targetEnv, this->options, vectorType); 510 }); 511 512 addConversion([this](TensorType tensorType) { 513 return convertTensorType(this->targetEnv, this->options, tensorType); 514 }); 515 516 addConversion([this](MemRefType memRefType) { 517 return convertMemrefType(this->targetEnv, this->options, memRefType); 518 }); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // FuncOp Conversion Patterns 523 //===----------------------------------------------------------------------===// 524 525 namespace { 526 /// A pattern for rewriting function signature to convert arguments of functions 527 /// to be of valid SPIR-V types. 528 class FuncOpConversion final : public OpConversionPattern<FuncOp> { 529 public: 530 using OpConversionPattern<FuncOp>::OpConversionPattern; 531 532 LogicalResult 533 matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 534 ConversionPatternRewriter &rewriter) const override; 535 }; 536 } // namespace 537 538 LogicalResult 539 FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 540 ConversionPatternRewriter &rewriter) const { 541 auto fnType = funcOp.getType(); 542 if (fnType.getNumResults() > 1) 543 return failure(); 544 545 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 546 for (auto argType : enumerate(fnType.getInputs())) { 547 auto convertedType = getTypeConverter()->convertType(argType.value()); 548 if (!convertedType) 549 return failure(); 550 signatureConverter.addInputs(argType.index(), convertedType); 551 } 552 553 Type resultType; 554 if (fnType.getNumResults() == 1) { 555 resultType = getTypeConverter()->convertType(fnType.getResult(0)); 556 if (!resultType) 557 return failure(); 558 } 559 560 // Create the converted spv.func op. 561 auto newFuncOp = rewriter.create<spirv::FuncOp>( 562 funcOp.getLoc(), funcOp.getName(), 563 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 564 resultType ? TypeRange(resultType) 565 : TypeRange())); 566 567 // Copy over all attributes other than the function name and type. 568 for (const auto &namedAttr : funcOp->getAttrs()) { 569 if (namedAttr.first != impl::getTypeAttrName() && 570 namedAttr.first != SymbolTable::getSymbolAttrName()) 571 newFuncOp->setAttr(namedAttr.first, namedAttr.second); 572 } 573 574 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 575 newFuncOp.end()); 576 if (failed(rewriter.convertRegionTypes( 577 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) 578 return failure(); 579 rewriter.eraseOp(funcOp); 580 return success(); 581 } 582 583 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 584 RewritePatternSet &patterns) { 585 patterns.add<FuncOpConversion>(typeConverter, patterns.getContext()); 586 } 587 588 //===----------------------------------------------------------------------===// 589 // Builtin Variables 590 //===----------------------------------------------------------------------===// 591 592 static spirv::GlobalVariableOp getBuiltinVariable(Block &body, 593 spirv::BuiltIn builtin) { 594 // Look through all global variables in the given `body` block and check if 595 // there is a spv.GlobalVariable that has the same `builtin` attribute. 596 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 597 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>( 598 spirv::SPIRVDialect::getAttributeName( 599 spirv::Decoration::BuiltIn))) { 600 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); 601 if (varBuiltIn && varBuiltIn.getValue() == builtin) { 602 return varOp; 603 } 604 } 605 } 606 return nullptr; 607 } 608 609 /// Gets name of global variable for a builtin. 610 static std::string getBuiltinVarName(spirv::BuiltIn builtin) { 611 return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; 612 } 613 614 /// Gets or inserts a global variable for a builtin within `body` block. 615 static spirv::GlobalVariableOp 616 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, 617 OpBuilder &builder) { 618 if (auto varOp = getBuiltinVariable(body, builtin)) 619 return varOp; 620 621 OpBuilder::InsertionGuard guard(builder); 622 builder.setInsertionPointToStart(&body); 623 624 spirv::GlobalVariableOp newVarOp; 625 switch (builtin) { 626 case spirv::BuiltIn::NumWorkgroups: 627 case spirv::BuiltIn::WorkgroupSize: 628 case spirv::BuiltIn::WorkgroupId: 629 case spirv::BuiltIn::LocalInvocationId: 630 case spirv::BuiltIn::GlobalInvocationId: { 631 auto ptrType = spirv::PointerType::get( 632 VectorType::get({3}, builder.getIntegerType(32)), 633 spirv::StorageClass::Input); 634 std::string name = getBuiltinVarName(builtin); 635 newVarOp = 636 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 637 break; 638 } 639 case spirv::BuiltIn::SubgroupId: 640 case spirv::BuiltIn::NumSubgroups: 641 case spirv::BuiltIn::SubgroupSize: { 642 auto ptrType = spirv::PointerType::get(builder.getIntegerType(32), 643 spirv::StorageClass::Input); 644 std::string name = getBuiltinVarName(builtin); 645 newVarOp = 646 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 647 break; 648 } 649 default: 650 emitError(loc, "unimplemented builtin variable generation for ") 651 << stringifyBuiltIn(builtin); 652 } 653 return newVarOp; 654 } 655 656 Value mlir::spirv::getBuiltinVariableValue(Operation *op, 657 spirv::BuiltIn builtin, 658 OpBuilder &builder) { 659 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 660 if (!parent) { 661 op->emitError("expected operation to be within a module-like op"); 662 return nullptr; 663 } 664 665 spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable( 666 *parent->getRegion(0).begin(), op->getLoc(), builtin, builder); 667 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp); 668 return builder.create<spirv::LoadOp>(op->getLoc(), ptr); 669 } 670 671 //===----------------------------------------------------------------------===// 672 // Push constant storage 673 //===----------------------------------------------------------------------===// 674 675 /// Returns the pointer type for the push constant storage containing 676 /// `elementCount` 32-bit integer values. 677 static spirv::PointerType getPushConstantStorageType(unsigned elementCount, 678 Builder &builder) { 679 auto arrayType = spirv::ArrayType::get( 680 SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount, 681 /*stride=*/4); 682 auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); 683 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); 684 } 685 686 /// Returns the push constant varible containing `elementCount` 32-bit integer 687 /// values in `body`. Returns null op if such an op does not exit. 688 static spirv::GlobalVariableOp getPushConstantVariable(Block &body, 689 unsigned elementCount) { 690 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 691 auto ptrType = varOp.type().cast<spirv::PointerType>(); 692 // Note that Vulkan requires "There must be no more than one push constant 693 // block statically used per shader entry point." So we should always reuse 694 // the existing one. 695 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { 696 auto numElements = ptrType.getPointeeType() 697 .cast<spirv::StructType>() 698 .getElementType(0) 699 .cast<spirv::ArrayType>() 700 .getNumElements(); 701 if (numElements == elementCount) 702 return varOp; 703 } 704 } 705 return nullptr; 706 } 707 708 /// Gets or inserts a global variable for push constant storage containing 709 /// `elementCount` 32-bit integer values in `block`. 710 static spirv::GlobalVariableOp 711 getOrInsertPushConstantVariable(Location loc, Block &block, 712 unsigned elementCount, OpBuilder &b) { 713 if (auto varOp = getPushConstantVariable(block, elementCount)) 714 return varOp; 715 716 auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); 717 auto type = getPushConstantStorageType(elementCount, builder); 718 const char *name = "__push_constant_var__"; 719 return builder.create<spirv::GlobalVariableOp>(loc, type, name, 720 /*initializer=*/nullptr); 721 } 722 723 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, 724 unsigned offset, OpBuilder &builder) { 725 Location loc = op->getLoc(); 726 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 727 if (!parent) { 728 op->emitError("expected operation to be within a module-like op"); 729 return nullptr; 730 } 731 732 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( 733 loc, parent->getRegion(0).front(), elementCount, builder); 734 735 auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext()); 736 Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder); 737 Value offsetOp = builder.create<spirv::ConstantOp>( 738 loc, i32Type, builder.getI32IntegerAttr(offset)); 739 auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp); 740 auto acOp = builder.create<spirv::AccessChainOp>( 741 loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp})); 742 return builder.create<spirv::LoadOp>(loc, acOp); 743 } 744 745 //===----------------------------------------------------------------------===// 746 // Index calculation 747 //===----------------------------------------------------------------------===// 748 749 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, 750 int64_t offset, Location loc, 751 OpBuilder &builder) { 752 assert(indices.size() == strides.size() && 753 "must provide indices for all dimensions"); 754 755 auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext()); 756 757 // TODO: Consider moving to use affine.apply and patterns converting 758 // affine.apply to standard ops. This needs converting to SPIR-V passes to be 759 // broken down into progressive small steps so we can have intermediate steps 760 // using other dialects. At the moment SPIR-V is the final sink. 761 762 Value linearizedIndex = builder.create<spirv::ConstantOp>( 763 loc, indexType, IntegerAttr::get(indexType, offset)); 764 for (auto index : llvm::enumerate(indices)) { 765 Value strideVal = builder.create<spirv::ConstantOp>( 766 loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); 767 Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value()); 768 linearizedIndex = 769 builder.create<spirv::IAddOp>(loc, linearizedIndex, update); 770 } 771 return linearizedIndex; 772 } 773 774 spirv::AccessChainOp mlir::spirv::getElementPtr( 775 SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, 776 ValueRange indices, Location loc, OpBuilder &builder) { 777 // Get base and offset of the MemRefType and verify they are static. 778 779 int64_t offset; 780 SmallVector<int64_t, 4> strides; 781 if (failed(getStridesAndOffset(baseType, strides, offset)) || 782 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || 783 offset == MemRefType::getDynamicStrideOrOffset()) { 784 return nullptr; 785 } 786 787 auto indexType = typeConverter.getIndexType(builder.getContext()); 788 789 SmallVector<Value, 2> linearizedIndices; 790 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); 791 792 // Add a '0' at the start to index into the struct. 793 linearizedIndices.push_back(zero); 794 795 if (baseType.getRank() == 0) { 796 linearizedIndices.push_back(zero); 797 } else { 798 linearizedIndices.push_back( 799 linearizeIndex(indices, strides, offset, loc, builder)); 800 } 801 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); 802 } 803 804 //===----------------------------------------------------------------------===// 805 // SPIR-V ConversionTarget 806 //===----------------------------------------------------------------------===// 807 808 std::unique_ptr<SPIRVConversionTarget> 809 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { 810 std::unique_ptr<SPIRVConversionTarget> target( 811 // std::make_unique does not work here because the constructor is private. 812 new SPIRVConversionTarget(targetAttr)); 813 SPIRVConversionTarget *targetPtr = target.get(); 814 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>( 815 // We need to capture the raw pointer here because it is stable: 816 // target will be destroyed once this function is returned. 817 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); 818 return target; 819 } 820 821 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr) 822 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} 823 824 bool SPIRVConversionTarget::isLegalOp(Operation *op) { 825 // Make sure this op is available at the given version. Ops not implementing 826 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all 827 // SPIR-V versions. 828 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) 829 if (minVersion.getMinVersion() > this->targetEnv.getVersion()) { 830 LLVM_DEBUG(llvm::dbgs() 831 << op->getName() << " illegal: requiring min version " 832 << spirv::stringifyVersion(minVersion.getMinVersion()) 833 << "\n"); 834 return false; 835 } 836 if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op)) 837 if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) { 838 LLVM_DEBUG(llvm::dbgs() 839 << op->getName() << " illegal: requiring max version " 840 << spirv::stringifyVersion(maxVersion.getMaxVersion()) 841 << "\n"); 842 return false; 843 } 844 845 // Make sure this op's required extensions are allowed to use. Ops not 846 // implementing QueryExtensionInterface do not require extensions to be 847 // available. 848 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 849 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 850 extensions.getExtensions()))) 851 return false; 852 853 // Make sure this op's required extensions are allowed to use. Ops not 854 // implementing QueryCapabilityInterface do not require capabilities to be 855 // available. 856 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 857 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 858 capabilities.getCapabilities()))) 859 return false; 860 861 SmallVector<Type, 4> valueTypes; 862 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 863 valueTypes.append(op->result_type_begin(), op->result_type_end()); 864 865 // Special treatment for global variables, whose type requirements are 866 // conveyed by type attributes. 867 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 868 valueTypes.push_back(globalVar.type()); 869 870 // Make sure the op's operands/results use types that are allowed by the 871 // target environment. 872 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 873 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 874 for (Type valueType : valueTypes) { 875 typeExtensions.clear(); 876 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); 877 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 878 typeExtensions))) 879 return false; 880 881 typeCapabilities.clear(); 882 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); 883 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 884 typeCapabilities))) 885 return false; 886 } 887 888 return true; 889 } 890