1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities used to lower to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/StringExtras.h" 19 #include "llvm/Support/Debug.h" 20 21 #include <functional> 22 23 #define DEBUG_TYPE "mlir-spirv-conversion" 24 25 using namespace mlir; 26 27 //===----------------------------------------------------------------------===// 28 // Utility functions 29 //===----------------------------------------------------------------------===// 30 31 /// Checks that `candidates` extension requirements are possible to be satisfied 32 /// with the given `targetEnv`. 33 /// 34 /// `candidates` is a vector of vector for extension requirements following 35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 36 /// convention. 37 template <typename LabelT> 38 static LogicalResult checkExtensionRequirements( 39 LabelT label, const spirv::TargetEnv &targetEnv, 40 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { 41 for (const auto &ors : candidates) { 42 if (targetEnv.allows(ors)) 43 continue; 44 45 SmallVector<StringRef, 4> extStrings; 46 for (spirv::Extension ext : ors) 47 extStrings.push_back(spirv::stringifyExtension(ext)); 48 49 LLVM_DEBUG(llvm::dbgs() 50 << label << " illegal: requires at least one extension in [" 51 << llvm::join(extStrings, ", ") 52 << "] but none allowed in target environment\n"); 53 return failure(); 54 } 55 return success(); 56 } 57 58 /// Checks that `candidates`capability requirements are possible to be satisfied 59 /// with the given `isAllowedFn`. 60 /// 61 /// `candidates` is a vector of vector for capability requirements following 62 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 63 /// convention. 64 template <typename LabelT> 65 static LogicalResult checkCapabilityRequirements( 66 LabelT label, const spirv::TargetEnv &targetEnv, 67 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { 68 for (const auto &ors : candidates) { 69 if (targetEnv.allows(ors)) 70 continue; 71 72 SmallVector<StringRef, 4> capStrings; 73 for (spirv::Capability cap : ors) 74 capStrings.push_back(spirv::stringifyCapability(cap)); 75 76 LLVM_DEBUG(llvm::dbgs() 77 << label << " illegal: requires at least one capability in [" 78 << llvm::join(capStrings, ", ") 79 << "] but none allowed in target environment\n"); 80 return failure(); 81 } 82 return success(); 83 } 84 85 //===----------------------------------------------------------------------===// 86 // Type Conversion 87 //===----------------------------------------------------------------------===// 88 89 Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { 90 // Convert to 32-bit integers for now. Might need a way to control this in 91 // future. 92 // TODO: It is probably better to make it 64-bit integers. To 93 // this some support is needed in SPIR-V dialect for Conversion 94 // instructions. The Vulkan spec requires the builtins like 95 // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be 96 // SExtended to 64-bit for index computations. 97 return IntegerType::get(context, 32); 98 } 99 100 /// Mapping between SPIR-V storage classes to memref memory spaces. 101 /// 102 /// Note: memref does not have a defined semantics for each memory space; it 103 /// depends on the context where it is used. There are no particular reasons 104 /// behind the number assignments; we try to follow NVVM conventions and largely 105 /// give common storage classes a smaller number. The hope is use symbolic 106 /// memory space representation eventually after memref supports it. 107 // TODO: swap Generic and StorageBuffer assignment to be more akin 108 // to NVVM. 109 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \ 110 MAP_FN(spirv::StorageClass::Generic, 1) \ 111 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ 112 MAP_FN(spirv::StorageClass::Workgroup, 3) \ 113 MAP_FN(spirv::StorageClass::Uniform, 4) \ 114 MAP_FN(spirv::StorageClass::Private, 5) \ 115 MAP_FN(spirv::StorageClass::Function, 6) \ 116 MAP_FN(spirv::StorageClass::PushConstant, 7) \ 117 MAP_FN(spirv::StorageClass::UniformConstant, 8) \ 118 MAP_FN(spirv::StorageClass::Input, 9) \ 119 MAP_FN(spirv::StorageClass::Output, 10) \ 120 MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ 121 MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ 122 MAP_FN(spirv::StorageClass::Image, 13) \ 123 MAP_FN(spirv::StorageClass::CallableDataNV, 14) \ 124 MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15) \ 125 MAP_FN(spirv::StorageClass::RayPayloadNV, 16) \ 126 MAP_FN(spirv::StorageClass::HitAttributeNV, 17) \ 127 MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18) \ 128 MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19) \ 129 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) 130 131 unsigned 132 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { 133 #define STORAGE_SPACE_MAP_FN(storage, space) \ 134 case storage: \ 135 return space; 136 137 switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) } 138 #undef STORAGE_SPACE_MAP_FN 139 llvm_unreachable("unhandled storage class!"); 140 } 141 142 Optional<spirv::StorageClass> 143 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { 144 #define STORAGE_SPACE_MAP_FN(storage, space) \ 145 case space: \ 146 return storage; 147 148 switch (space) { 149 STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 150 default: 151 return llvm::None; 152 } 153 #undef STORAGE_SPACE_MAP_FN 154 } 155 156 #undef STORAGE_SPACE_MAP_LIST 157 158 // TODO: This is a utility function that should probably be 159 // exposed by the SPIR-V dialect. Keeping it local till the use case arises. 160 static Optional<int64_t> getTypeNumBytes(Type t) { 161 if (t.isa<spirv::ScalarType>()) { 162 auto bitWidth = t.getIntOrFloatBitWidth(); 163 // According to the SPIR-V spec: 164 // "There is no physical size or bit pattern defined for values with boolean 165 // type. If they are stored (in conjunction with OpVariable), they can only 166 // be used with logical addressing operations, not physical, and only with 167 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, 168 // Private, Function, Input, and Output." 169 if (bitWidth == 1) { 170 return llvm::None; 171 } 172 return bitWidth / 8; 173 } 174 if (auto vecType = t.dyn_cast<VectorType>()) { 175 auto elementSize = getTypeNumBytes(vecType.getElementType()); 176 if (!elementSize) 177 return llvm::None; 178 return vecType.getNumElements() * *elementSize; 179 } 180 if (auto memRefType = t.dyn_cast<MemRefType>()) { 181 // TODO: Layout should also be controlled by the ABI attributes. For now 182 // using the layout from MemRef. 183 int64_t offset; 184 SmallVector<int64_t, 4> strides; 185 if (!memRefType.hasStaticShape() || 186 failed(getStridesAndOffset(memRefType, strides, offset))) { 187 return llvm::None; 188 } 189 // To get the size of the memref object in memory, the total size is the 190 // max(stride * dimension-size) computed for all dimensions times the size 191 // of the element. 192 auto elementSize = getTypeNumBytes(memRefType.getElementType()); 193 if (!elementSize) { 194 return llvm::None; 195 } 196 if (memRefType.getRank() == 0) { 197 return elementSize; 198 } 199 auto dims = memRefType.getShape(); 200 if (llvm::is_contained(dims, ShapedType::kDynamicSize) || 201 offset == MemRefType::getDynamicStrideOrOffset() || 202 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { 203 return llvm::None; 204 } 205 int64_t memrefSize = -1; 206 for (auto shape : enumerate(dims)) { 207 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); 208 } 209 return (offset + memrefSize) * elementSize.getValue(); 210 } else if (auto tensorType = t.dyn_cast<TensorType>()) { 211 if (!tensorType.hasStaticShape()) { 212 return llvm::None; 213 } 214 auto elementSize = getTypeNumBytes(tensorType.getElementType()); 215 if (!elementSize) { 216 return llvm::None; 217 } 218 int64_t size = elementSize.getValue(); 219 for (auto shape : tensorType.getShape()) { 220 size *= shape; 221 } 222 return size; 223 } 224 // TODO: Add size computation for other types. 225 return llvm::None; 226 } 227 228 Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) { 229 return getTypeNumBytes(t); 230 } 231 232 /// Converts a scalar `type` to a suitable type under the given `targetEnv`. 233 static Optional<Type> 234 convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type, 235 Optional<spirv::StorageClass> storageClass = {}) { 236 // Get extension and capability requirements for the given type. 237 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 238 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 239 type.getExtensions(extensions, storageClass); 240 type.getCapabilities(capabilities, storageClass); 241 242 // If all requirements are met, then we can accept this type as-is. 243 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 244 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 245 return type; 246 247 // Otherwise we need to adjust the type, which really means adjusting the 248 // bitwidth given this is a scalar type. 249 // TODO: We are unconditionally converting the bitwidth here, 250 // this might be okay for non-interface types (i.e., types used in 251 // Private/Function storage classes), but not for interface types (i.e., 252 // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes). 253 // This is because the later actually affects the ABI contract with the 254 // runtime. So we may want to expose a control on SPIRVTypeConverter to fail 255 // conversion if we cannot change there. 256 257 if (auto floatType = type.dyn_cast<FloatType>()) { 258 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 259 return Builder(targetEnv.getContext()).getF32Type(); 260 } 261 262 auto intType = type.cast<IntegerType>(); 263 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 264 return IntegerType::get(targetEnv.getContext(), /*width=*/32, 265 intType.getSignedness()); 266 } 267 268 /// Converts a vector `type` to a suitable type under the given `targetEnv`. 269 static Optional<Type> 270 convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type, 271 Optional<spirv::StorageClass> storageClass = {}) { 272 if (!spirv::CompositeType::isValid(type)) { 273 // TODO: One-element vector types can be translated into scalar 274 // types. Vector types with more than four elements can be translated into 275 // array types. 276 LLVM_DEBUG(llvm::dbgs() 277 << type << " illegal: 1- and > 4-element unimplemented\n"); 278 return llvm::None; 279 } 280 281 // Get extension and capability requirements for the given type. 282 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 283 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 284 type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass); 285 type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass); 286 287 // If all requirements are met, then we can accept this type as-is. 288 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 289 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 290 return type; 291 292 auto elementType = convertScalarType( 293 targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass); 294 if (elementType) 295 return VectorType::get(type.getShape(), *elementType); 296 return llvm::None; 297 } 298 299 /// Converts a tensor `type` to a suitable type under the given `targetEnv`. 300 /// 301 /// Note that this is mainly for lowering constant tensors.In SPIR-V one can 302 /// create composite constants with OpConstantComposite to embed relative large 303 /// constant values and use OpCompositeExtract and OpCompositeInsert to 304 /// manipulate, like what we do for vectors. 305 static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv, 306 TensorType type) { 307 // TODO: Handle dynamic shapes. 308 if (!type.hasStaticShape()) { 309 LLVM_DEBUG(llvm::dbgs() 310 << type << " illegal: dynamic shape unimplemented\n"); 311 return llvm::None; 312 } 313 314 auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>(); 315 if (!scalarType) { 316 LLVM_DEBUG(llvm::dbgs() 317 << type << " illegal: cannot convert non-scalar element type\n"); 318 return llvm::None; 319 } 320 321 Optional<int64_t> scalarSize = getTypeNumBytes(scalarType); 322 Optional<int64_t> tensorSize = getTypeNumBytes(type); 323 if (!scalarSize || !tensorSize) { 324 LLVM_DEBUG(llvm::dbgs() 325 << type << " illegal: cannot deduce element count\n"); 326 return llvm::None; 327 } 328 329 auto arrayElemCount = *tensorSize / *scalarSize; 330 auto arrayElemType = convertScalarType(targetEnv, scalarType); 331 if (!arrayElemType) 332 return llvm::None; 333 Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType); 334 if (!arrayElemSize) { 335 LLVM_DEBUG(llvm::dbgs() 336 << type << " illegal: cannot deduce converted element size\n"); 337 return llvm::None; 338 } 339 340 return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); 341 } 342 343 static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv, 344 MemRefType type) { 345 Optional<spirv::StorageClass> storageClass = 346 SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace()); 347 if (!storageClass) { 348 LLVM_DEBUG(llvm::dbgs() 349 << type << " illegal: cannot convert memory space\n"); 350 return llvm::None; 351 } 352 353 Optional<Type> arrayElemType; 354 Type elementType = type.getElementType(); 355 if (auto vecType = elementType.dyn_cast<VectorType>()) { 356 arrayElemType = convertVectorType(targetEnv, vecType, storageClass); 357 } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) { 358 arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); 359 } else { 360 LLVM_DEBUG( 361 llvm::dbgs() 362 << type 363 << " unhandled: can only convert scalar or vector element type\n"); 364 return llvm::None; 365 } 366 if (!arrayElemType) 367 return llvm::None; 368 369 Optional<int64_t> elementSize = getTypeNumBytes(elementType); 370 if (!elementSize) { 371 LLVM_DEBUG(llvm::dbgs() 372 << type << " illegal: cannot deduce element size\n"); 373 return llvm::None; 374 } 375 376 if (!type.hasStaticShape()) { 377 auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize); 378 // Wrap in a struct to satisfy Vulkan interface requirements. 379 auto structType = spirv::StructType::get(arrayType, 0); 380 return spirv::PointerType::get(structType, *storageClass); 381 } 382 383 Optional<int64_t> memrefSize = getTypeNumBytes(type); 384 if (!memrefSize) { 385 LLVM_DEBUG(llvm::dbgs() 386 << type << " illegal: cannot deduce element count\n"); 387 return llvm::None; 388 } 389 390 auto arrayElemCount = *memrefSize / *elementSize; 391 392 Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType); 393 if (!arrayElemSize) { 394 LLVM_DEBUG(llvm::dbgs() 395 << type << " illegal: cannot deduce converted element size\n"); 396 return llvm::None; 397 } 398 399 auto arrayType = 400 spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); 401 402 // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with 403 // workgroup storage class do not need the struct to be laid out explicitly. 404 auto structType = *storageClass == spirv::StorageClass::Workgroup 405 ? spirv::StructType::get(arrayType) 406 : spirv::StructType::get(arrayType, 0); 407 return spirv::PointerType::get(structType, *storageClass); 408 } 409 410 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr) 411 : targetEnv(targetAttr) { 412 // Add conversions. The order matters here: later ones will be tried earlier. 413 414 // All other cases failed. Then we cannot convert this type. 415 addConversion([](Type type) { return llvm::None; }); 416 417 // Allow all SPIR-V dialect specific types. This assumes all builtin types 418 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) 419 // were tried before. 420 // 421 // TODO: this assumes that the SPIR-V types are valid to use in 422 // the given target environment, which should be the case if the whole 423 // pipeline is driven by the same target environment. Still, we probably still 424 // want to validate and convert to be safe. 425 addConversion([](spirv::SPIRVType type) { return type; }); 426 427 addConversion([](IndexType indexType) { 428 return SPIRVTypeConverter::getIndexType(indexType.getContext()); 429 }); 430 431 addConversion([this](IntegerType intType) -> Optional<Type> { 432 if (auto scalarType = intType.dyn_cast<spirv::ScalarType>()) 433 return convertScalarType(targetEnv, scalarType); 434 return llvm::None; 435 }); 436 437 addConversion([this](FloatType floatType) -> Optional<Type> { 438 if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>()) 439 return convertScalarType(targetEnv, scalarType); 440 return llvm::None; 441 }); 442 443 addConversion([this](VectorType vectorType) { 444 return convertVectorType(targetEnv, vectorType); 445 }); 446 447 addConversion([this](TensorType tensorType) { 448 return convertTensorType(targetEnv, tensorType); 449 }); 450 451 addConversion([this](MemRefType memRefType) { 452 return convertMemrefType(targetEnv, memRefType); 453 }); 454 } 455 456 //===----------------------------------------------------------------------===// 457 // FuncOp Conversion Patterns 458 //===----------------------------------------------------------------------===// 459 460 namespace { 461 /// A pattern for rewriting function signature to convert arguments of functions 462 /// to be of valid SPIR-V types. 463 class FuncOpConversion final : public OpConversionPattern<FuncOp> { 464 public: 465 using OpConversionPattern<FuncOp>::OpConversionPattern; 466 467 LogicalResult 468 matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 469 ConversionPatternRewriter &rewriter) const override; 470 }; 471 } // namespace 472 473 LogicalResult 474 FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 475 ConversionPatternRewriter &rewriter) const { 476 auto fnType = funcOp.getType(); 477 if (fnType.getNumResults() > 1) 478 return failure(); 479 480 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 481 for (auto argType : enumerate(fnType.getInputs())) { 482 auto convertedType = getTypeConverter()->convertType(argType.value()); 483 if (!convertedType) 484 return failure(); 485 signatureConverter.addInputs(argType.index(), convertedType); 486 } 487 488 Type resultType; 489 if (fnType.getNumResults() == 1) 490 resultType = getTypeConverter()->convertType(fnType.getResult(0)); 491 492 // Create the converted spv.func op. 493 auto newFuncOp = rewriter.create<spirv::FuncOp>( 494 funcOp.getLoc(), funcOp.getName(), 495 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 496 resultType ? TypeRange(resultType) 497 : TypeRange())); 498 499 // Copy over all attributes other than the function name and type. 500 for (const auto &namedAttr : funcOp.getAttrs()) { 501 if (namedAttr.first != impl::getTypeAttrName() && 502 namedAttr.first != SymbolTable::getSymbolAttrName()) 503 newFuncOp->setAttr(namedAttr.first, namedAttr.second); 504 } 505 506 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 507 newFuncOp.end()); 508 if (failed(rewriter.convertRegionTypes( 509 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) 510 return failure(); 511 rewriter.eraseOp(funcOp); 512 return success(); 513 } 514 515 void mlir::populateBuiltinFuncToSPIRVPatterns( 516 MLIRContext *context, SPIRVTypeConverter &typeConverter, 517 OwningRewritePatternList &patterns) { 518 patterns.insert<FuncOpConversion>(typeConverter, context); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // Builtin Variables 523 //===----------------------------------------------------------------------===// 524 525 static spirv::GlobalVariableOp getBuiltinVariable(Block &body, 526 spirv::BuiltIn builtin) { 527 // Look through all global variables in the given `body` block and check if 528 // there is a spv.globalVariable that has the same `builtin` attribute. 529 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 530 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>( 531 spirv::SPIRVDialect::getAttributeName( 532 spirv::Decoration::BuiltIn))) { 533 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); 534 if (varBuiltIn && varBuiltIn.getValue() == builtin) { 535 return varOp; 536 } 537 } 538 } 539 return nullptr; 540 } 541 542 /// Gets name of global variable for a builtin. 543 static std::string getBuiltinVarName(spirv::BuiltIn builtin) { 544 return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; 545 } 546 547 /// Gets or inserts a global variable for a builtin within `body` block. 548 static spirv::GlobalVariableOp 549 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, 550 OpBuilder &builder) { 551 if (auto varOp = getBuiltinVariable(body, builtin)) 552 return varOp; 553 554 OpBuilder::InsertionGuard guard(builder); 555 builder.setInsertionPointToStart(&body); 556 557 spirv::GlobalVariableOp newVarOp; 558 switch (builtin) { 559 case spirv::BuiltIn::NumWorkgroups: 560 case spirv::BuiltIn::WorkgroupSize: 561 case spirv::BuiltIn::WorkgroupId: 562 case spirv::BuiltIn::LocalInvocationId: 563 case spirv::BuiltIn::GlobalInvocationId: { 564 auto ptrType = spirv::PointerType::get( 565 VectorType::get({3}, builder.getIntegerType(32)), 566 spirv::StorageClass::Input); 567 std::string name = getBuiltinVarName(builtin); 568 newVarOp = 569 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 570 break; 571 } 572 case spirv::BuiltIn::SubgroupId: 573 case spirv::BuiltIn::NumSubgroups: 574 case spirv::BuiltIn::SubgroupSize: { 575 auto ptrType = spirv::PointerType::get(builder.getIntegerType(32), 576 spirv::StorageClass::Input); 577 std::string name = getBuiltinVarName(builtin); 578 newVarOp = 579 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 580 break; 581 } 582 default: 583 emitError(loc, "unimplemented builtin variable generation for ") 584 << stringifyBuiltIn(builtin); 585 } 586 return newVarOp; 587 } 588 589 Value mlir::spirv::getBuiltinVariableValue(Operation *op, 590 spirv::BuiltIn builtin, 591 OpBuilder &builder) { 592 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 593 if (!parent) { 594 op->emitError("expected operation to be within a module-like op"); 595 return nullptr; 596 } 597 598 spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable( 599 *parent->getRegion(0).begin(), op->getLoc(), builtin, builder); 600 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp); 601 return builder.create<spirv::LoadOp>(op->getLoc(), ptr); 602 } 603 604 //===----------------------------------------------------------------------===// 605 // Index calculation 606 //===----------------------------------------------------------------------===// 607 608 spirv::AccessChainOp mlir::spirv::getElementPtr( 609 SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, 610 ValueRange indices, Location loc, OpBuilder &builder) { 611 // Get base and offset of the MemRefType and verify they are static. 612 613 int64_t offset; 614 SmallVector<int64_t, 4> strides; 615 if (failed(getStridesAndOffset(baseType, strides, offset)) || 616 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || 617 offset == MemRefType::getDynamicStrideOrOffset()) { 618 return nullptr; 619 } 620 621 auto indexType = typeConverter.getIndexType(builder.getContext()); 622 623 SmallVector<Value, 2> linearizedIndices; 624 // Add a '0' at the start to index into the struct. 625 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); 626 linearizedIndices.push_back(zero); 627 628 if (baseType.getRank() == 0) { 629 linearizedIndices.push_back(zero); 630 } else { 631 // TODO: Instead of this logic, use affine.apply and add patterns for 632 // lowering affine.apply to standard ops. These will get lowered to SPIR-V 633 // ops by the DialectConversion framework. 634 Value ptrLoc = builder.create<spirv::ConstantOp>( 635 loc, indexType, IntegerAttr::get(indexType, offset)); 636 assert(indices.size() == strides.size() && 637 "must provide indices for all dimensions"); 638 for (auto index : llvm::enumerate(indices)) { 639 Value strideVal = builder.create<spirv::ConstantOp>( 640 loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); 641 Value update = 642 builder.create<spirv::IMulOp>(loc, strideVal, index.value()); 643 ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update); 644 } 645 linearizedIndices.push_back(ptrLoc); 646 } 647 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); 648 } 649 650 //===----------------------------------------------------------------------===// 651 // Set ABI attributes for lowering entry functions. 652 //===----------------------------------------------------------------------===// 653 654 LogicalResult 655 mlir::spirv::setABIAttrs(spirv::FuncOp funcOp, 656 spirv::EntryPointABIAttr entryPointInfo, 657 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) { 658 // Set the attributes for argument and the function. 659 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); 660 for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) { 661 funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]); 662 } 663 funcOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); 664 return success(); 665 } 666 667 //===----------------------------------------------------------------------===// 668 // SPIR-V ConversionTarget 669 //===----------------------------------------------------------------------===// 670 671 std::unique_ptr<spirv::SPIRVConversionTarget> 672 spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { 673 std::unique_ptr<SPIRVConversionTarget> target( 674 // std::make_unique does not work here because the constructor is private. 675 new SPIRVConversionTarget(targetAttr)); 676 SPIRVConversionTarget *targetPtr = target.get(); 677 target->addDynamicallyLegalDialect<SPIRVDialect>( 678 // We need to capture the raw pointer here because it is stable: 679 // target will be destroyed once this function is returned. 680 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); 681 return target; 682 } 683 684 spirv::SPIRVConversionTarget::SPIRVConversionTarget( 685 spirv::TargetEnvAttr targetAttr) 686 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} 687 688 bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { 689 // Make sure this op is available at the given version. Ops not implementing 690 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all 691 // SPIR-V versions. 692 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) 693 if (minVersion.getMinVersion() > this->targetEnv.getVersion()) { 694 LLVM_DEBUG(llvm::dbgs() 695 << op->getName() << " illegal: requiring min version " 696 << spirv::stringifyVersion(minVersion.getMinVersion()) 697 << "\n"); 698 return false; 699 } 700 if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op)) 701 if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) { 702 LLVM_DEBUG(llvm::dbgs() 703 << op->getName() << " illegal: requiring max version " 704 << spirv::stringifyVersion(maxVersion.getMaxVersion()) 705 << "\n"); 706 return false; 707 } 708 709 // Make sure this op's required extensions are allowed to use. Ops not 710 // implementing QueryExtensionInterface do not require extensions to be 711 // available. 712 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 713 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 714 extensions.getExtensions()))) 715 return false; 716 717 // Make sure this op's required extensions are allowed to use. Ops not 718 // implementing QueryCapabilityInterface do not require capabilities to be 719 // available. 720 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 721 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 722 capabilities.getCapabilities()))) 723 return false; 724 725 SmallVector<Type, 4> valueTypes; 726 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 727 valueTypes.append(op->result_type_begin(), op->result_type_end()); 728 729 // Special treatment for global variables, whose type requirements are 730 // conveyed by type attributes. 731 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 732 valueTypes.push_back(globalVar.type()); 733 734 // Make sure the op's operands/results use types that are allowed by the 735 // target environment. 736 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 737 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 738 for (Type valueType : valueTypes) { 739 typeExtensions.clear(); 740 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); 741 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 742 typeExtensions))) 743 return false; 744 745 typeCapabilities.clear(); 746 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); 747 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 748 typeCapabilities))) 749 return false; 750 } 751 752 return true; 753 } 754