1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities used to lower to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/StringExtras.h" 19 #include "llvm/Support/Debug.h" 20 21 #include <functional> 22 23 #define DEBUG_TYPE "mlir-spirv-conversion" 24 25 using namespace mlir; 26 27 //===----------------------------------------------------------------------===// 28 // Utility functions 29 //===----------------------------------------------------------------------===// 30 31 /// Checks that `candidates` extension requirements are possible to be satisfied 32 /// with the given `targetEnv`. 33 /// 34 /// `candidates` is a vector of vector for extension requirements following 35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 36 /// convention. 37 template <typename LabelT> 38 static LogicalResult checkExtensionRequirements( 39 LabelT label, const spirv::TargetEnv &targetEnv, 40 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { 41 for (const auto &ors : candidates) { 42 if (targetEnv.allows(ors)) 43 continue; 44 45 SmallVector<StringRef, 4> extStrings; 46 for (spirv::Extension ext : ors) 47 extStrings.push_back(spirv::stringifyExtension(ext)); 48 49 LLVM_DEBUG(llvm::dbgs() 50 << label << " illegal: requires at least one extension in [" 51 << llvm::join(extStrings, ", ") 52 << "] but none allowed in target environment\n"); 53 return failure(); 54 } 55 return success(); 56 } 57 58 /// Checks that `candidates`capability requirements are possible to be satisfied 59 /// with the given `isAllowedFn`. 60 /// 61 /// `candidates` is a vector of vector for capability requirements following 62 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 63 /// convention. 64 template <typename LabelT> 65 static LogicalResult checkCapabilityRequirements( 66 LabelT label, const spirv::TargetEnv &targetEnv, 67 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { 68 for (const auto &ors : candidates) { 69 if (targetEnv.allows(ors)) 70 continue; 71 72 SmallVector<StringRef, 4> capStrings; 73 for (spirv::Capability cap : ors) 74 capStrings.push_back(spirv::stringifyCapability(cap)); 75 76 LLVM_DEBUG(llvm::dbgs() 77 << label << " illegal: requires at least one capability in [" 78 << llvm::join(capStrings, ", ") 79 << "] but none allowed in target environment\n"); 80 return failure(); 81 } 82 return success(); 83 } 84 85 //===----------------------------------------------------------------------===// 86 // Type Conversion 87 //===----------------------------------------------------------------------===// 88 89 Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { 90 // Convert to 32-bit integers for now. Might need a way to control this in 91 // future. 92 // TODO: It is probably better to make it 64-bit integers. To 93 // this some support is needed in SPIR-V dialect for Conversion 94 // instructions. The Vulkan spec requires the builtins like 95 // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be 96 // SExtended to 64-bit for index computations. 97 return IntegerType::get(context, 32); 98 } 99 100 /// Mapping between SPIR-V storage classes to memref memory spaces. 101 /// 102 /// Note: memref does not have a defined semantics for each memory space; it 103 /// depends on the context where it is used. There are no particular reasons 104 /// behind the number assignments; we try to follow NVVM conventions and largely 105 /// give common storage classes a smaller number. The hope is use symbolic 106 /// memory space representation eventually after memref supports it. 107 // TODO: swap Generic and StorageBuffer assignment to be more akin 108 // to NVVM. 109 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \ 110 MAP_FN(spirv::StorageClass::Generic, 1) \ 111 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ 112 MAP_FN(spirv::StorageClass::Workgroup, 3) \ 113 MAP_FN(spirv::StorageClass::Uniform, 4) \ 114 MAP_FN(spirv::StorageClass::Private, 5) \ 115 MAP_FN(spirv::StorageClass::Function, 6) \ 116 MAP_FN(spirv::StorageClass::PushConstant, 7) \ 117 MAP_FN(spirv::StorageClass::UniformConstant, 8) \ 118 MAP_FN(spirv::StorageClass::Input, 9) \ 119 MAP_FN(spirv::StorageClass::Output, 10) \ 120 MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ 121 MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ 122 MAP_FN(spirv::StorageClass::Image, 13) \ 123 MAP_FN(spirv::StorageClass::CallableDataNV, 14) \ 124 MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15) \ 125 MAP_FN(spirv::StorageClass::RayPayloadNV, 16) \ 126 MAP_FN(spirv::StorageClass::HitAttributeNV, 17) \ 127 MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18) \ 128 MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19) \ 129 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) 130 131 unsigned 132 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { 133 #define STORAGE_SPACE_MAP_FN(storage, space) \ 134 case storage: \ 135 return space; 136 137 switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) } 138 #undef STORAGE_SPACE_MAP_FN 139 llvm_unreachable("unhandled storage class!"); 140 } 141 142 Optional<spirv::StorageClass> 143 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { 144 #define STORAGE_SPACE_MAP_FN(storage, space) \ 145 case space: \ 146 return storage; 147 148 switch (space) { 149 STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 150 default: 151 return llvm::None; 152 } 153 #undef STORAGE_SPACE_MAP_FN 154 } 155 156 #undef STORAGE_SPACE_MAP_LIST 157 158 // TODO: This is a utility function that should probably be 159 // exposed by the SPIR-V dialect. Keeping it local till the use case arises. 160 static Optional<int64_t> getTypeNumBytes(Type t) { 161 if (t.isa<spirv::ScalarType>()) { 162 auto bitWidth = t.getIntOrFloatBitWidth(); 163 // According to the SPIR-V spec: 164 // "There is no physical size or bit pattern defined for values with boolean 165 // type. If they are stored (in conjunction with OpVariable), they can only 166 // be used with logical addressing operations, not physical, and only with 167 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, 168 // Private, Function, Input, and Output." 169 if (bitWidth == 1) { 170 return llvm::None; 171 } 172 return bitWidth / 8; 173 } 174 175 if (auto vecType = t.dyn_cast<VectorType>()) { 176 auto elementSize = getTypeNumBytes(vecType.getElementType()); 177 if (!elementSize) 178 return llvm::None; 179 return vecType.getNumElements() * *elementSize; 180 } 181 182 if (auto memRefType = t.dyn_cast<MemRefType>()) { 183 // TODO: Layout should also be controlled by the ABI attributes. For now 184 // using the layout from MemRef. 185 int64_t offset; 186 SmallVector<int64_t, 4> strides; 187 if (!memRefType.hasStaticShape() || 188 failed(getStridesAndOffset(memRefType, strides, offset))) { 189 return llvm::None; 190 } 191 // To get the size of the memref object in memory, the total size is the 192 // max(stride * dimension-size) computed for all dimensions times the size 193 // of the element. 194 auto elementSize = getTypeNumBytes(memRefType.getElementType()); 195 if (!elementSize) { 196 return llvm::None; 197 } 198 if (memRefType.getRank() == 0) { 199 return elementSize; 200 } 201 auto dims = memRefType.getShape(); 202 if (llvm::is_contained(dims, ShapedType::kDynamicSize) || 203 offset == MemRefType::getDynamicStrideOrOffset() || 204 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { 205 return llvm::None; 206 } 207 int64_t memrefSize = -1; 208 for (auto shape : enumerate(dims)) { 209 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); 210 } 211 return (offset + memrefSize) * elementSize.getValue(); 212 } 213 214 if (auto tensorType = t.dyn_cast<TensorType>()) { 215 if (!tensorType.hasStaticShape()) { 216 return llvm::None; 217 } 218 auto elementSize = getTypeNumBytes(tensorType.getElementType()); 219 if (!elementSize) { 220 return llvm::None; 221 } 222 int64_t size = elementSize.getValue(); 223 for (auto shape : tensorType.getShape()) { 224 size *= shape; 225 } 226 return size; 227 } 228 229 // TODO: Add size computation for other types. 230 return llvm::None; 231 } 232 233 Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) { 234 return getTypeNumBytes(t); 235 } 236 237 /// Converts a scalar `type` to a suitable type under the given `targetEnv`. 238 static Optional<Type> 239 convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type, 240 Optional<spirv::StorageClass> storageClass = {}) { 241 // Get extension and capability requirements for the given type. 242 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 243 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 244 type.getExtensions(extensions, storageClass); 245 type.getCapabilities(capabilities, storageClass); 246 247 // If all requirements are met, then we can accept this type as-is. 248 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 249 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 250 return type; 251 252 // Otherwise we need to adjust the type, which really means adjusting the 253 // bitwidth given this is a scalar type. 254 // TODO: We are unconditionally converting the bitwidth here, 255 // this might be okay for non-interface types (i.e., types used in 256 // Private/Function storage classes), but not for interface types (i.e., 257 // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes). 258 // This is because the later actually affects the ABI contract with the 259 // runtime. So we may want to expose a control on SPIRVTypeConverter to fail 260 // conversion if we cannot change there. 261 262 if (auto floatType = type.dyn_cast<FloatType>()) { 263 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 264 return Builder(targetEnv.getContext()).getF32Type(); 265 } 266 267 auto intType = type.cast<IntegerType>(); 268 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 269 return IntegerType::get(targetEnv.getContext(), /*width=*/32, 270 intType.getSignedness()); 271 } 272 273 /// Converts a vector `type` to a suitable type under the given `targetEnv`. 274 static Optional<Type> 275 convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type, 276 Optional<spirv::StorageClass> storageClass = {}) { 277 if (type.getRank() == 1 && type.getNumElements() == 1) 278 return type.getElementType(); 279 280 if (!spirv::CompositeType::isValid(type)) { 281 // TODO: Vector types with more than four elements can be translated into 282 // array types. 283 LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n"); 284 return llvm::None; 285 } 286 287 // Get extension and capability requirements for the given type. 288 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 289 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 290 type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass); 291 type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass); 292 293 // If all requirements are met, then we can accept this type as-is. 294 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 295 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 296 return type; 297 298 auto elementType = convertScalarType( 299 targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass); 300 if (elementType) 301 return VectorType::get(type.getShape(), *elementType); 302 return llvm::None; 303 } 304 305 /// Converts a tensor `type` to a suitable type under the given `targetEnv`. 306 /// 307 /// Note that this is mainly for lowering constant tensors.In SPIR-V one can 308 /// create composite constants with OpConstantComposite to embed relative large 309 /// constant values and use OpCompositeExtract and OpCompositeInsert to 310 /// manipulate, like what we do for vectors. 311 static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv, 312 TensorType type) { 313 // TODO: Handle dynamic shapes. 314 if (!type.hasStaticShape()) { 315 LLVM_DEBUG(llvm::dbgs() 316 << type << " illegal: dynamic shape unimplemented\n"); 317 return llvm::None; 318 } 319 320 auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>(); 321 if (!scalarType) { 322 LLVM_DEBUG(llvm::dbgs() 323 << type << " illegal: cannot convert non-scalar element type\n"); 324 return llvm::None; 325 } 326 327 Optional<int64_t> scalarSize = getTypeNumBytes(scalarType); 328 Optional<int64_t> tensorSize = getTypeNumBytes(type); 329 if (!scalarSize || !tensorSize) { 330 LLVM_DEBUG(llvm::dbgs() 331 << type << " illegal: cannot deduce element count\n"); 332 return llvm::None; 333 } 334 335 auto arrayElemCount = *tensorSize / *scalarSize; 336 auto arrayElemType = convertScalarType(targetEnv, scalarType); 337 if (!arrayElemType) 338 return llvm::None; 339 Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType); 340 if (!arrayElemSize) { 341 LLVM_DEBUG(llvm::dbgs() 342 << type << " illegal: cannot deduce converted element size\n"); 343 return llvm::None; 344 } 345 346 return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); 347 } 348 349 static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv, 350 MemRefType type) { 351 Optional<spirv::StorageClass> storageClass = 352 SPIRVTypeConverter::getStorageClassForMemorySpace( 353 type.getMemorySpaceAsInt()); 354 if (!storageClass) { 355 LLVM_DEBUG(llvm::dbgs() 356 << type << " illegal: cannot convert memory space\n"); 357 return llvm::None; 358 } 359 360 Optional<Type> arrayElemType; 361 Type elementType = type.getElementType(); 362 if (auto vecType = elementType.dyn_cast<VectorType>()) { 363 arrayElemType = convertVectorType(targetEnv, vecType, storageClass); 364 } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) { 365 arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); 366 } else { 367 LLVM_DEBUG( 368 llvm::dbgs() 369 << type 370 << " unhandled: can only convert scalar or vector element type\n"); 371 return llvm::None; 372 } 373 if (!arrayElemType) 374 return llvm::None; 375 376 Optional<int64_t> elementSize = getTypeNumBytes(elementType); 377 if (!elementSize) { 378 LLVM_DEBUG(llvm::dbgs() 379 << type << " illegal: cannot deduce element size\n"); 380 return llvm::None; 381 } 382 383 if (!type.hasStaticShape()) { 384 auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize); 385 // Wrap in a struct to satisfy Vulkan interface requirements. 386 auto structType = spirv::StructType::get(arrayType, 0); 387 return spirv::PointerType::get(structType, *storageClass); 388 } 389 390 Optional<int64_t> memrefSize = getTypeNumBytes(type); 391 if (!memrefSize) { 392 LLVM_DEBUG(llvm::dbgs() 393 << type << " illegal: cannot deduce element count\n"); 394 return llvm::None; 395 } 396 397 auto arrayElemCount = *memrefSize / *elementSize; 398 399 Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType); 400 if (!arrayElemSize) { 401 LLVM_DEBUG(llvm::dbgs() 402 << type << " illegal: cannot deduce converted element size\n"); 403 return llvm::None; 404 } 405 406 auto arrayType = 407 spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); 408 409 // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with 410 // workgroup storage class do not need the struct to be laid out explicitly. 411 auto structType = *storageClass == spirv::StorageClass::Workgroup 412 ? spirv::StructType::get(arrayType) 413 : spirv::StructType::get(arrayType, 0); 414 return spirv::PointerType::get(structType, *storageClass); 415 } 416 417 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr) 418 : targetEnv(targetAttr) { 419 // Add conversions. The order matters here: later ones will be tried earlier. 420 421 // All other cases failed. Then we cannot convert this type. 422 addConversion([](Type type) { return llvm::None; }); 423 424 // Allow all SPIR-V dialect specific types. This assumes all builtin types 425 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) 426 // were tried before. 427 // 428 // TODO: this assumes that the SPIR-V types are valid to use in 429 // the given target environment, which should be the case if the whole 430 // pipeline is driven by the same target environment. Still, we probably still 431 // want to validate and convert to be safe. 432 addConversion([](spirv::SPIRVType type) { return type; }); 433 434 addConversion([](IndexType indexType) { 435 return SPIRVTypeConverter::getIndexType(indexType.getContext()); 436 }); 437 438 addConversion([this](IntegerType intType) -> Optional<Type> { 439 if (auto scalarType = intType.dyn_cast<spirv::ScalarType>()) 440 return convertScalarType(targetEnv, scalarType); 441 return llvm::None; 442 }); 443 444 addConversion([this](FloatType floatType) -> Optional<Type> { 445 if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>()) 446 return convertScalarType(targetEnv, scalarType); 447 return llvm::None; 448 }); 449 450 addConversion([this](VectorType vectorType) { 451 return convertVectorType(targetEnv, vectorType); 452 }); 453 454 addConversion([this](TensorType tensorType) { 455 return convertTensorType(targetEnv, tensorType); 456 }); 457 458 addConversion([this](MemRefType memRefType) { 459 return convertMemrefType(targetEnv, memRefType); 460 }); 461 } 462 463 //===----------------------------------------------------------------------===// 464 // FuncOp Conversion Patterns 465 //===----------------------------------------------------------------------===// 466 467 namespace { 468 /// A pattern for rewriting function signature to convert arguments of functions 469 /// to be of valid SPIR-V types. 470 class FuncOpConversion final : public OpConversionPattern<FuncOp> { 471 public: 472 using OpConversionPattern<FuncOp>::OpConversionPattern; 473 474 LogicalResult 475 matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 476 ConversionPatternRewriter &rewriter) const override; 477 }; 478 } // namespace 479 480 LogicalResult 481 FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 482 ConversionPatternRewriter &rewriter) const { 483 auto fnType = funcOp.getType(); 484 if (fnType.getNumResults() > 1) 485 return failure(); 486 487 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 488 for (auto argType : enumerate(fnType.getInputs())) { 489 auto convertedType = getTypeConverter()->convertType(argType.value()); 490 if (!convertedType) 491 return failure(); 492 signatureConverter.addInputs(argType.index(), convertedType); 493 } 494 495 Type resultType; 496 if (fnType.getNumResults() == 1) 497 resultType = getTypeConverter()->convertType(fnType.getResult(0)); 498 499 // Create the converted spv.func op. 500 auto newFuncOp = rewriter.create<spirv::FuncOp>( 501 funcOp.getLoc(), funcOp.getName(), 502 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 503 resultType ? TypeRange(resultType) 504 : TypeRange())); 505 506 // Copy over all attributes other than the function name and type. 507 for (const auto &namedAttr : funcOp->getAttrs()) { 508 if (namedAttr.first != impl::getTypeAttrName() && 509 namedAttr.first != SymbolTable::getSymbolAttrName()) 510 newFuncOp->setAttr(namedAttr.first, namedAttr.second); 511 } 512 513 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 514 newFuncOp.end()); 515 if (failed(rewriter.convertRegionTypes( 516 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) 517 return failure(); 518 rewriter.eraseOp(funcOp); 519 return success(); 520 } 521 522 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 523 RewritePatternSet &patterns) { 524 patterns.add<FuncOpConversion>(typeConverter, patterns.getContext()); 525 } 526 527 //===----------------------------------------------------------------------===// 528 // Builtin Variables 529 //===----------------------------------------------------------------------===// 530 531 static spirv::GlobalVariableOp getBuiltinVariable(Block &body, 532 spirv::BuiltIn builtin) { 533 // Look through all global variables in the given `body` block and check if 534 // there is a spv.GlobalVariable that has the same `builtin` attribute. 535 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 536 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>( 537 spirv::SPIRVDialect::getAttributeName( 538 spirv::Decoration::BuiltIn))) { 539 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); 540 if (varBuiltIn && varBuiltIn.getValue() == builtin) { 541 return varOp; 542 } 543 } 544 } 545 return nullptr; 546 } 547 548 /// Gets name of global variable for a builtin. 549 static std::string getBuiltinVarName(spirv::BuiltIn builtin) { 550 return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; 551 } 552 553 /// Gets or inserts a global variable for a builtin within `body` block. 554 static spirv::GlobalVariableOp 555 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, 556 OpBuilder &builder) { 557 if (auto varOp = getBuiltinVariable(body, builtin)) 558 return varOp; 559 560 OpBuilder::InsertionGuard guard(builder); 561 builder.setInsertionPointToStart(&body); 562 563 spirv::GlobalVariableOp newVarOp; 564 switch (builtin) { 565 case spirv::BuiltIn::NumWorkgroups: 566 case spirv::BuiltIn::WorkgroupSize: 567 case spirv::BuiltIn::WorkgroupId: 568 case spirv::BuiltIn::LocalInvocationId: 569 case spirv::BuiltIn::GlobalInvocationId: { 570 auto ptrType = spirv::PointerType::get( 571 VectorType::get({3}, builder.getIntegerType(32)), 572 spirv::StorageClass::Input); 573 std::string name = getBuiltinVarName(builtin); 574 newVarOp = 575 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 576 break; 577 } 578 case spirv::BuiltIn::SubgroupId: 579 case spirv::BuiltIn::NumSubgroups: 580 case spirv::BuiltIn::SubgroupSize: { 581 auto ptrType = spirv::PointerType::get(builder.getIntegerType(32), 582 spirv::StorageClass::Input); 583 std::string name = getBuiltinVarName(builtin); 584 newVarOp = 585 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 586 break; 587 } 588 default: 589 emitError(loc, "unimplemented builtin variable generation for ") 590 << stringifyBuiltIn(builtin); 591 } 592 return newVarOp; 593 } 594 595 Value mlir::spirv::getBuiltinVariableValue(Operation *op, 596 spirv::BuiltIn builtin, 597 OpBuilder &builder) { 598 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 599 if (!parent) { 600 op->emitError("expected operation to be within a module-like op"); 601 return nullptr; 602 } 603 604 spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable( 605 *parent->getRegion(0).begin(), op->getLoc(), builtin, builder); 606 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp); 607 return builder.create<spirv::LoadOp>(op->getLoc(), ptr); 608 } 609 610 //===----------------------------------------------------------------------===// 611 // Push constant storage 612 //===----------------------------------------------------------------------===// 613 614 /// Returns the pointer type for the push constant storage containing 615 /// `elementCount` 32-bit integer values. 616 static spirv::PointerType getPushConstantStorageType(unsigned elementCount, 617 Builder &builder) { 618 auto arrayType = spirv::ArrayType::get( 619 SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount, 620 /*stride=*/4); 621 auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); 622 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); 623 } 624 625 /// Returns the push constant varible containing `elementCount` 32-bit integer 626 /// values in `body`. Returns null op if such an op does not exit. 627 static spirv::GlobalVariableOp getPushConstantVariable(Block &body, 628 unsigned elementCount) { 629 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 630 auto ptrType = varOp.type().cast<spirv::PointerType>(); 631 // Note that Vulkan requires "There must be no more than one push constant 632 // block statically used per shader entry point." So we should always reuse 633 // the existing one. 634 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { 635 auto numElements = ptrType.getPointeeType() 636 .cast<spirv::StructType>() 637 .getElementType(0) 638 .cast<spirv::ArrayType>() 639 .getNumElements(); 640 if (numElements == elementCount) 641 return varOp; 642 } 643 } 644 return nullptr; 645 } 646 647 /// Gets or inserts a global variable for push constant storage containing 648 /// `elementCount` 32-bit integer values in `block`. 649 static spirv::GlobalVariableOp 650 getOrInsertPushConstantVariable(Location loc, Block &block, 651 unsigned elementCount, OpBuilder &b) { 652 if (auto varOp = getPushConstantVariable(block, elementCount)) 653 return varOp; 654 655 auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); 656 auto type = getPushConstantStorageType(elementCount, builder); 657 const char *name = "__push_constant_var__"; 658 return builder.create<spirv::GlobalVariableOp>(loc, type, name, 659 /*initializer=*/nullptr); 660 } 661 662 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, 663 unsigned offset, OpBuilder &builder) { 664 Location loc = op->getLoc(); 665 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 666 if (!parent) { 667 op->emitError("expected operation to be within a module-like op"); 668 return nullptr; 669 } 670 671 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( 672 loc, parent->getRegion(0).front(), elementCount, builder); 673 674 auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext()); 675 Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder); 676 Value offsetOp = builder.create<spirv::ConstantOp>( 677 loc, i32Type, builder.getI32IntegerAttr(offset)); 678 auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp); 679 auto acOp = builder.create<spirv::AccessChainOp>( 680 loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp})); 681 return builder.create<spirv::LoadOp>(loc, acOp); 682 } 683 684 //===----------------------------------------------------------------------===// 685 // Index calculation 686 //===----------------------------------------------------------------------===// 687 688 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, 689 int64_t offset, Location loc, 690 OpBuilder &builder) { 691 assert(indices.size() == strides.size() && 692 "must provide indices for all dimensions"); 693 694 auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext()); 695 696 // TODO: Consider moving to use affine.apply and patterns converting 697 // affine.apply to standard ops. This needs converting to SPIR-V passes to be 698 // broken down into progressive small steps so we can have intermediate steps 699 // using other dialects. At the moment SPIR-V is the final sink. 700 701 Value linearizedIndex = builder.create<spirv::ConstantOp>( 702 loc, indexType, IntegerAttr::get(indexType, offset)); 703 for (auto index : llvm::enumerate(indices)) { 704 Value strideVal = builder.create<spirv::ConstantOp>( 705 loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); 706 Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value()); 707 linearizedIndex = 708 builder.create<spirv::IAddOp>(loc, linearizedIndex, update); 709 } 710 return linearizedIndex; 711 } 712 713 spirv::AccessChainOp mlir::spirv::getElementPtr( 714 SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, 715 ValueRange indices, Location loc, OpBuilder &builder) { 716 // Get base and offset of the MemRefType and verify they are static. 717 718 int64_t offset; 719 SmallVector<int64_t, 4> strides; 720 if (failed(getStridesAndOffset(baseType, strides, offset)) || 721 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || 722 offset == MemRefType::getDynamicStrideOrOffset()) { 723 return nullptr; 724 } 725 726 auto indexType = typeConverter.getIndexType(builder.getContext()); 727 728 SmallVector<Value, 2> linearizedIndices; 729 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); 730 731 // Add a '0' at the start to index into the struct. 732 linearizedIndices.push_back(zero); 733 734 if (baseType.getRank() == 0) { 735 linearizedIndices.push_back(zero); 736 } else { 737 linearizedIndices.push_back( 738 linearizeIndex(indices, strides, offset, loc, builder)); 739 } 740 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); 741 } 742 743 //===----------------------------------------------------------------------===// 744 // SPIR-V ConversionTarget 745 //===----------------------------------------------------------------------===// 746 747 std::unique_ptr<SPIRVConversionTarget> 748 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { 749 std::unique_ptr<SPIRVConversionTarget> target( 750 // std::make_unique does not work here because the constructor is private. 751 new SPIRVConversionTarget(targetAttr)); 752 SPIRVConversionTarget *targetPtr = target.get(); 753 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>( 754 // We need to capture the raw pointer here because it is stable: 755 // target will be destroyed once this function is returned. 756 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); 757 return target; 758 } 759 760 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr) 761 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} 762 763 bool SPIRVConversionTarget::isLegalOp(Operation *op) { 764 // Make sure this op is available at the given version. Ops not implementing 765 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all 766 // SPIR-V versions. 767 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) 768 if (minVersion.getMinVersion() > this->targetEnv.getVersion()) { 769 LLVM_DEBUG(llvm::dbgs() 770 << op->getName() << " illegal: requiring min version " 771 << spirv::stringifyVersion(minVersion.getMinVersion()) 772 << "\n"); 773 return false; 774 } 775 if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op)) 776 if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) { 777 LLVM_DEBUG(llvm::dbgs() 778 << op->getName() << " illegal: requiring max version " 779 << spirv::stringifyVersion(maxVersion.getMaxVersion()) 780 << "\n"); 781 return false; 782 } 783 784 // Make sure this op's required extensions are allowed to use. Ops not 785 // implementing QueryExtensionInterface do not require extensions to be 786 // available. 787 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 788 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 789 extensions.getExtensions()))) 790 return false; 791 792 // Make sure this op's required extensions are allowed to use. Ops not 793 // implementing QueryCapabilityInterface do not require capabilities to be 794 // available. 795 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 796 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 797 capabilities.getCapabilities()))) 798 return false; 799 800 SmallVector<Type, 4> valueTypes; 801 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 802 valueTypes.append(op->result_type_begin(), op->result_type_end()); 803 804 // Special treatment for global variables, whose type requirements are 805 // conveyed by type attributes. 806 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 807 valueTypes.push_back(globalVar.type()); 808 809 // Make sure the op's operands/results use types that are allowed by the 810 // target environment. 811 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 812 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 813 for (Type valueType : valueTypes) { 814 typeExtensions.clear(); 815 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); 816 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 817 typeExtensions))) 818 return false; 819 820 typeCapabilities.clear(); 821 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); 822 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 823 typeCapabilities))) 824 return false; 825 } 826 827 return true; 828 } 829