1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities used to lower to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "llvm/ADT/Sequence.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include "llvm/Support/Debug.h" 19 20 #include <functional> 21 22 #define DEBUG_TYPE "mlir-spirv-conversion" 23 24 using namespace mlir; 25 26 //===----------------------------------------------------------------------===// 27 // Utility functions 28 //===----------------------------------------------------------------------===// 29 30 /// Checks that `candidates` extension requirements are possible to be satisfied 31 /// with the given `targetEnv`. 32 /// 33 /// `candidates` is a vector of vector for extension requirements following 34 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 35 /// convention. 36 template <typename LabelT> 37 static LogicalResult checkExtensionRequirements( 38 LabelT label, const spirv::TargetEnv &targetEnv, 39 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { 40 for (const auto &ors : candidates) { 41 if (targetEnv.allows(ors)) 42 continue; 43 44 SmallVector<StringRef, 4> extStrings; 45 for (spirv::Extension ext : ors) 46 extStrings.push_back(spirv::stringifyExtension(ext)); 47 48 LLVM_DEBUG(llvm::dbgs() 49 << label << " illegal: requires at least one extension in [" 50 << llvm::join(extStrings, ", ") 51 << "] but none allowed in target environment\n"); 52 return failure(); 53 } 54 return success(); 55 } 56 57 /// Checks that `candidates`capability requirements are possible to be satisfied 58 /// with the given `isAllowedFn`. 59 /// 60 /// `candidates` is a vector of vector for capability requirements following 61 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 62 /// convention. 63 template <typename LabelT> 64 static LogicalResult checkCapabilityRequirements( 65 LabelT label, const spirv::TargetEnv &targetEnv, 66 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { 67 for (const auto &ors : candidates) { 68 if (targetEnv.allows(ors)) 69 continue; 70 71 SmallVector<StringRef, 4> capStrings; 72 for (spirv::Capability cap : ors) 73 capStrings.push_back(spirv::stringifyCapability(cap)); 74 75 LLVM_DEBUG(llvm::dbgs() 76 << label << " illegal: requires at least one capability in [" 77 << llvm::join(capStrings, ", ") 78 << "] but none allowed in target environment\n"); 79 return failure(); 80 } 81 return success(); 82 } 83 84 //===----------------------------------------------------------------------===// 85 // Type Conversion 86 //===----------------------------------------------------------------------===// 87 88 Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { 89 // Convert to 32-bit integers for now. Might need a way to control this in 90 // future. 91 // TODO: It is probably better to make it 64-bit integers. To 92 // this some support is needed in SPIR-V dialect for Conversion 93 // instructions. The Vulkan spec requires the builtins like 94 // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be 95 // SExtended to 64-bit for index computations. 96 return IntegerType::get(context, 32); 97 } 98 99 /// Mapping between SPIR-V storage classes to memref memory spaces. 100 /// 101 /// Note: memref does not have a defined semantics for each memory space; it 102 /// depends on the context where it is used. There are no particular reasons 103 /// behind the number assignments; we try to follow NVVM conventions and largely 104 /// give common storage classes a smaller number. The hope is use symbolic 105 /// memory space representation eventually after memref supports it. 106 // TODO: swap Generic and StorageBuffer assignment to be more akin 107 // to NVVM. 108 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \ 109 MAP_FN(spirv::StorageClass::Generic, 1) \ 110 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ 111 MAP_FN(spirv::StorageClass::Workgroup, 3) \ 112 MAP_FN(spirv::StorageClass::Uniform, 4) \ 113 MAP_FN(spirv::StorageClass::Private, 5) \ 114 MAP_FN(spirv::StorageClass::Function, 6) \ 115 MAP_FN(spirv::StorageClass::PushConstant, 7) \ 116 MAP_FN(spirv::StorageClass::UniformConstant, 8) \ 117 MAP_FN(spirv::StorageClass::Input, 9) \ 118 MAP_FN(spirv::StorageClass::Output, 10) \ 119 MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ 120 MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ 121 MAP_FN(spirv::StorageClass::Image, 13) \ 122 MAP_FN(spirv::StorageClass::CallableDataNV, 14) \ 123 MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15) \ 124 MAP_FN(spirv::StorageClass::RayPayloadNV, 16) \ 125 MAP_FN(spirv::StorageClass::HitAttributeNV, 17) \ 126 MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18) \ 127 MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19) \ 128 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) 129 130 unsigned 131 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { 132 #define STORAGE_SPACE_MAP_FN(storage, space) \ 133 case storage: \ 134 return space; 135 136 switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) } 137 #undef STORAGE_SPACE_MAP_FN 138 llvm_unreachable("unhandled storage class!"); 139 } 140 141 Optional<spirv::StorageClass> 142 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { 143 #define STORAGE_SPACE_MAP_FN(storage, space) \ 144 case space: \ 145 return storage; 146 147 switch (space) { 148 STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) 149 default: 150 return llvm::None; 151 } 152 #undef STORAGE_SPACE_MAP_FN 153 } 154 155 #undef STORAGE_SPACE_MAP_LIST 156 157 // TODO: This is a utility function that should probably be 158 // exposed by the SPIR-V dialect. Keeping it local till the use case arises. 159 static Optional<int64_t> getTypeNumBytes(Type t) { 160 if (t.isa<spirv::ScalarType>()) { 161 auto bitWidth = t.getIntOrFloatBitWidth(); 162 // According to the SPIR-V spec: 163 // "There is no physical size or bit pattern defined for values with boolean 164 // type. If they are stored (in conjunction with OpVariable), they can only 165 // be used with logical addressing operations, not physical, and only with 166 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, 167 // Private, Function, Input, and Output." 168 if (bitWidth == 1) { 169 return llvm::None; 170 } 171 return bitWidth / 8; 172 } 173 if (auto vecType = t.dyn_cast<VectorType>()) { 174 auto elementSize = getTypeNumBytes(vecType.getElementType()); 175 if (!elementSize) 176 return llvm::None; 177 return vecType.getNumElements() * *elementSize; 178 } 179 if (auto memRefType = t.dyn_cast<MemRefType>()) { 180 // TODO: Layout should also be controlled by the ABI attributes. For now 181 // using the layout from MemRef. 182 int64_t offset; 183 SmallVector<int64_t, 4> strides; 184 if (!memRefType.hasStaticShape() || 185 failed(getStridesAndOffset(memRefType, strides, offset))) { 186 return llvm::None; 187 } 188 // To get the size of the memref object in memory, the total size is the 189 // max(stride * dimension-size) computed for all dimensions times the size 190 // of the element. 191 auto elementSize = getTypeNumBytes(memRefType.getElementType()); 192 if (!elementSize) { 193 return llvm::None; 194 } 195 if (memRefType.getRank() == 0) { 196 return elementSize; 197 } 198 auto dims = memRefType.getShape(); 199 if (llvm::is_contained(dims, ShapedType::kDynamicSize) || 200 offset == MemRefType::getDynamicStrideOrOffset() || 201 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { 202 return llvm::None; 203 } 204 int64_t memrefSize = -1; 205 for (auto shape : enumerate(dims)) { 206 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); 207 } 208 return (offset + memrefSize) * elementSize.getValue(); 209 } else if (auto tensorType = t.dyn_cast<TensorType>()) { 210 if (!tensorType.hasStaticShape()) { 211 return llvm::None; 212 } 213 auto elementSize = getTypeNumBytes(tensorType.getElementType()); 214 if (!elementSize) { 215 return llvm::None; 216 } 217 int64_t size = elementSize.getValue(); 218 for (auto shape : tensorType.getShape()) { 219 size *= shape; 220 } 221 return size; 222 } 223 // TODO: Add size computation for other types. 224 return llvm::None; 225 } 226 227 Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) { 228 return getTypeNumBytes(t); 229 } 230 231 /// Converts a scalar `type` to a suitable type under the given `targetEnv`. 232 static Optional<Type> 233 convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type, 234 Optional<spirv::StorageClass> storageClass = {}) { 235 // Get extension and capability requirements for the given type. 236 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 237 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 238 type.getExtensions(extensions, storageClass); 239 type.getCapabilities(capabilities, storageClass); 240 241 // If all requirements are met, then we can accept this type as-is. 242 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 243 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 244 return type; 245 246 // Otherwise we need to adjust the type, which really means adjusting the 247 // bitwidth given this is a scalar type. 248 // TODO: We are unconditionally converting the bitwidth here, 249 // this might be okay for non-interface types (i.e., types used in 250 // Private/Function storage classes), but not for interface types (i.e., 251 // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes). 252 // This is because the later actually affects the ABI contract with the 253 // runtime. So we may want to expose a control on SPIRVTypeConverter to fail 254 // conversion if we cannot change there. 255 256 if (auto floatType = type.dyn_cast<FloatType>()) { 257 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 258 return Builder(targetEnv.getContext()).getF32Type(); 259 } 260 261 auto intType = type.cast<IntegerType>(); 262 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 263 return IntegerType::get(targetEnv.getContext(), /*width=*/32, 264 intType.getSignedness()); 265 } 266 267 /// Converts a vector `type` to a suitable type under the given `targetEnv`. 268 static Optional<Type> 269 convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type, 270 Optional<spirv::StorageClass> storageClass = {}) { 271 if (!spirv::CompositeType::isValid(type)) { 272 // TODO: One-element vector types can be translated into scalar 273 // types. Vector types with more than four elements can be translated into 274 // array types. 275 LLVM_DEBUG(llvm::dbgs() 276 << type << " illegal: 1- and > 4-element unimplemented\n"); 277 return llvm::None; 278 } 279 280 // Get extension and capability requirements for the given type. 281 SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 282 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 283 type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass); 284 type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass); 285 286 // If all requirements are met, then we can accept this type as-is. 287 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 288 succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 289 return type; 290 291 auto elementType = convertScalarType( 292 targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass); 293 if (elementType) 294 return VectorType::get(type.getShape(), *elementType); 295 return llvm::None; 296 } 297 298 /// Converts a tensor `type` to a suitable type under the given `targetEnv`. 299 /// 300 /// Note that this is mainly for lowering constant tensors.In SPIR-V one can 301 /// create composite constants with OpConstantComposite to embed relative large 302 /// constant values and use OpCompositeExtract and OpCompositeInsert to 303 /// manipulate, like what we do for vectors. 304 static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv, 305 TensorType type) { 306 // TODO: Handle dynamic shapes. 307 if (!type.hasStaticShape()) { 308 LLVM_DEBUG(llvm::dbgs() 309 << type << " illegal: dynamic shape unimplemented\n"); 310 return llvm::None; 311 } 312 313 auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>(); 314 if (!scalarType) { 315 LLVM_DEBUG(llvm::dbgs() 316 << type << " illegal: cannot convert non-scalar element type\n"); 317 return llvm::None; 318 } 319 320 Optional<int64_t> scalarSize = getTypeNumBytes(scalarType); 321 Optional<int64_t> tensorSize = getTypeNumBytes(type); 322 if (!scalarSize || !tensorSize) { 323 LLVM_DEBUG(llvm::dbgs() 324 << type << " illegal: cannot deduce element count\n"); 325 return llvm::None; 326 } 327 328 auto arrayElemCount = *tensorSize / *scalarSize; 329 auto arrayElemType = convertScalarType(targetEnv, scalarType); 330 if (!arrayElemType) 331 return llvm::None; 332 Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType); 333 if (!arrayElemSize) { 334 LLVM_DEBUG(llvm::dbgs() 335 << type << " illegal: cannot deduce converted element size\n"); 336 return llvm::None; 337 } 338 339 return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); 340 } 341 342 static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv, 343 MemRefType type) { 344 Optional<spirv::StorageClass> storageClass = 345 SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace()); 346 if (!storageClass) { 347 LLVM_DEBUG(llvm::dbgs() 348 << type << " illegal: cannot convert memory space\n"); 349 return llvm::None; 350 } 351 352 Optional<Type> arrayElemType; 353 Type elementType = type.getElementType(); 354 if (auto vecType = elementType.dyn_cast<VectorType>()) { 355 arrayElemType = convertVectorType(targetEnv, vecType, storageClass); 356 } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) { 357 arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); 358 } else { 359 LLVM_DEBUG( 360 llvm::dbgs() 361 << type 362 << " unhandled: can only convert scalar or vector element type\n"); 363 return llvm::None; 364 } 365 if (!arrayElemType) 366 return llvm::None; 367 368 Optional<int64_t> elementSize = getTypeNumBytes(elementType); 369 if (!elementSize) { 370 LLVM_DEBUG(llvm::dbgs() 371 << type << " illegal: cannot deduce element size\n"); 372 return llvm::None; 373 } 374 375 if (!type.hasStaticShape()) { 376 auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize); 377 // Wrap in a struct to satisfy Vulkan interface requirements. 378 auto structType = spirv::StructType::get(arrayType, 0); 379 return spirv::PointerType::get(structType, *storageClass); 380 } 381 382 Optional<int64_t> memrefSize = getTypeNumBytes(type); 383 if (!memrefSize) { 384 LLVM_DEBUG(llvm::dbgs() 385 << type << " illegal: cannot deduce element count\n"); 386 return llvm::None; 387 } 388 389 auto arrayElemCount = *memrefSize / *elementSize; 390 391 Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType); 392 if (!arrayElemSize) { 393 LLVM_DEBUG(llvm::dbgs() 394 << type << " illegal: cannot deduce converted element size\n"); 395 return llvm::None; 396 } 397 398 auto arrayType = 399 spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); 400 401 // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with 402 // workgroup storage class do not need the struct to be laid out explicitly. 403 auto structType = *storageClass == spirv::StorageClass::Workgroup 404 ? spirv::StructType::get(arrayType) 405 : spirv::StructType::get(arrayType, 0); 406 return spirv::PointerType::get(structType, *storageClass); 407 } 408 409 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr) 410 : targetEnv(targetAttr) { 411 // Add conversions. The order matters here: later ones will be tried earlier. 412 413 // All other cases failed. Then we cannot convert this type. 414 addConversion([](Type type) { return llvm::None; }); 415 416 // Allow all SPIR-V dialect specific types. This assumes all builtin types 417 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) 418 // were tried before. 419 // 420 // TODO: this assumes that the SPIR-V types are valid to use in 421 // the given target environment, which should be the case if the whole 422 // pipeline is driven by the same target environment. Still, we probably still 423 // want to validate and convert to be safe. 424 addConversion([](spirv::SPIRVType type) { return type; }); 425 426 addConversion([](IndexType indexType) { 427 return SPIRVTypeConverter::getIndexType(indexType.getContext()); 428 }); 429 430 addConversion([this](IntegerType intType) -> Optional<Type> { 431 if (auto scalarType = intType.dyn_cast<spirv::ScalarType>()) 432 return convertScalarType(targetEnv, scalarType); 433 return llvm::None; 434 }); 435 436 addConversion([this](FloatType floatType) -> Optional<Type> { 437 if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>()) 438 return convertScalarType(targetEnv, scalarType); 439 return llvm::None; 440 }); 441 442 addConversion([this](VectorType vectorType) { 443 return convertVectorType(targetEnv, vectorType); 444 }); 445 446 addConversion([this](TensorType tensorType) { 447 return convertTensorType(targetEnv, tensorType); 448 }); 449 450 addConversion([this](MemRefType memRefType) { 451 return convertMemrefType(targetEnv, memRefType); 452 }); 453 } 454 455 //===----------------------------------------------------------------------===// 456 // FuncOp Conversion Patterns 457 //===----------------------------------------------------------------------===// 458 459 namespace { 460 /// A pattern for rewriting function signature to convert arguments of functions 461 /// to be of valid SPIR-V types. 462 class FuncOpConversion final : public SPIRVOpLowering<FuncOp> { 463 public: 464 using SPIRVOpLowering<FuncOp>::SPIRVOpLowering; 465 466 LogicalResult 467 matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 468 ConversionPatternRewriter &rewriter) const override; 469 }; 470 } // namespace 471 472 LogicalResult 473 FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands, 474 ConversionPatternRewriter &rewriter) const { 475 auto fnType = funcOp.getType(); 476 if (fnType.getNumResults() > 1) 477 return failure(); 478 479 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 480 for (auto argType : enumerate(fnType.getInputs())) { 481 auto convertedType = typeConverter.convertType(argType.value()); 482 if (!convertedType) 483 return failure(); 484 signatureConverter.addInputs(argType.index(), convertedType); 485 } 486 487 Type resultType; 488 if (fnType.getNumResults() == 1) 489 resultType = typeConverter.convertType(fnType.getResult(0)); 490 491 // Create the converted spv.func op. 492 auto newFuncOp = rewriter.create<spirv::FuncOp>( 493 funcOp.getLoc(), funcOp.getName(), 494 rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 495 resultType ? TypeRange(resultType) 496 : TypeRange())); 497 498 // Copy over all attributes other than the function name and type. 499 for (const auto &namedAttr : funcOp.getAttrs()) { 500 if (namedAttr.first != impl::getTypeAttrName() && 501 namedAttr.first != SymbolTable::getSymbolAttrName()) 502 newFuncOp->setAttr(namedAttr.first, namedAttr.second); 503 } 504 505 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 506 newFuncOp.end()); 507 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, 508 &signatureConverter))) 509 return failure(); 510 rewriter.eraseOp(funcOp); 511 return success(); 512 } 513 514 void mlir::populateBuiltinFuncToSPIRVPatterns( 515 MLIRContext *context, SPIRVTypeConverter &typeConverter, 516 OwningRewritePatternList &patterns) { 517 patterns.insert<FuncOpConversion>(context, typeConverter); 518 } 519 520 //===----------------------------------------------------------------------===// 521 // Builtin Variables 522 //===----------------------------------------------------------------------===// 523 524 static spirv::GlobalVariableOp getBuiltinVariable(Block &body, 525 spirv::BuiltIn builtin) { 526 // Look through all global variables in the given `body` block and check if 527 // there is a spv.globalVariable that has the same `builtin` attribute. 528 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 529 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>( 530 spirv::SPIRVDialect::getAttributeName( 531 spirv::Decoration::BuiltIn))) { 532 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); 533 if (varBuiltIn && varBuiltIn.getValue() == builtin) { 534 return varOp; 535 } 536 } 537 } 538 return nullptr; 539 } 540 541 /// Gets name of global variable for a builtin. 542 static std::string getBuiltinVarName(spirv::BuiltIn builtin) { 543 return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; 544 } 545 546 /// Gets or inserts a global variable for a builtin within `body` block. 547 static spirv::GlobalVariableOp 548 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, 549 OpBuilder &builder) { 550 if (auto varOp = getBuiltinVariable(body, builtin)) 551 return varOp; 552 553 OpBuilder::InsertionGuard guard(builder); 554 builder.setInsertionPointToStart(&body); 555 556 spirv::GlobalVariableOp newVarOp; 557 switch (builtin) { 558 case spirv::BuiltIn::NumWorkgroups: 559 case spirv::BuiltIn::WorkgroupSize: 560 case spirv::BuiltIn::WorkgroupId: 561 case spirv::BuiltIn::LocalInvocationId: 562 case spirv::BuiltIn::GlobalInvocationId: { 563 auto ptrType = spirv::PointerType::get( 564 VectorType::get({3}, builder.getIntegerType(32)), 565 spirv::StorageClass::Input); 566 std::string name = getBuiltinVarName(builtin); 567 newVarOp = 568 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 569 break; 570 } 571 case spirv::BuiltIn::SubgroupId: 572 case spirv::BuiltIn::NumSubgroups: 573 case spirv::BuiltIn::SubgroupSize: { 574 auto ptrType = spirv::PointerType::get(builder.getIntegerType(32), 575 spirv::StorageClass::Input); 576 std::string name = getBuiltinVarName(builtin); 577 newVarOp = 578 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 579 break; 580 } 581 default: 582 emitError(loc, "unimplemented builtin variable generation for ") 583 << stringifyBuiltIn(builtin); 584 } 585 return newVarOp; 586 } 587 588 Value mlir::spirv::getBuiltinVariableValue(Operation *op, 589 spirv::BuiltIn builtin, 590 OpBuilder &builder) { 591 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 592 if (!parent) { 593 op->emitError("expected operation to be within a module-like op"); 594 return nullptr; 595 } 596 597 spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable( 598 *parent->getRegion(0).begin(), op->getLoc(), builtin, builder); 599 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp); 600 return builder.create<spirv::LoadOp>(op->getLoc(), ptr); 601 } 602 603 //===----------------------------------------------------------------------===// 604 // Index calculation 605 //===----------------------------------------------------------------------===// 606 607 spirv::AccessChainOp mlir::spirv::getElementPtr( 608 SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, 609 ValueRange indices, Location loc, OpBuilder &builder) { 610 // Get base and offset of the MemRefType and verify they are static. 611 612 int64_t offset; 613 SmallVector<int64_t, 4> strides; 614 if (failed(getStridesAndOffset(baseType, strides, offset)) || 615 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || 616 offset == MemRefType::getDynamicStrideOrOffset()) { 617 return nullptr; 618 } 619 620 auto indexType = typeConverter.getIndexType(builder.getContext()); 621 622 SmallVector<Value, 2> linearizedIndices; 623 // Add a '0' at the start to index into the struct. 624 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); 625 linearizedIndices.push_back(zero); 626 627 if (baseType.getRank() == 0) { 628 linearizedIndices.push_back(zero); 629 } else { 630 // TODO: Instead of this logic, use affine.apply and add patterns for 631 // lowering affine.apply to standard ops. These will get lowered to SPIR-V 632 // ops by the DialectConversion framework. 633 Value ptrLoc = builder.create<spirv::ConstantOp>( 634 loc, indexType, IntegerAttr::get(indexType, offset)); 635 assert(indices.size() == strides.size() && 636 "must provide indices for all dimensions"); 637 for (auto index : llvm::enumerate(indices)) { 638 Value strideVal = builder.create<spirv::ConstantOp>( 639 loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); 640 Value update = 641 builder.create<spirv::IMulOp>(loc, strideVal, index.value()); 642 ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update); 643 } 644 linearizedIndices.push_back(ptrLoc); 645 } 646 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); 647 } 648 649 //===----------------------------------------------------------------------===// 650 // Set ABI attributes for lowering entry functions. 651 //===----------------------------------------------------------------------===// 652 653 LogicalResult 654 mlir::spirv::setABIAttrs(spirv::FuncOp funcOp, 655 spirv::EntryPointABIAttr entryPointInfo, 656 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) { 657 // Set the attributes for argument and the function. 658 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); 659 for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) { 660 funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]); 661 } 662 funcOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); 663 return success(); 664 } 665 666 //===----------------------------------------------------------------------===// 667 // SPIR-V ConversionTarget 668 //===----------------------------------------------------------------------===// 669 670 std::unique_ptr<spirv::SPIRVConversionTarget> 671 spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { 672 std::unique_ptr<SPIRVConversionTarget> target( 673 // std::make_unique does not work here because the constructor is private. 674 new SPIRVConversionTarget(targetAttr)); 675 SPIRVConversionTarget *targetPtr = target.get(); 676 target->addDynamicallyLegalDialect<SPIRVDialect>( 677 // We need to capture the raw pointer here because it is stable: 678 // target will be destroyed once this function is returned. 679 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); 680 return target; 681 } 682 683 spirv::SPIRVConversionTarget::SPIRVConversionTarget( 684 spirv::TargetEnvAttr targetAttr) 685 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} 686 687 bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { 688 // Make sure this op is available at the given version. Ops not implementing 689 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all 690 // SPIR-V versions. 691 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) 692 if (minVersion.getMinVersion() > this->targetEnv.getVersion()) { 693 LLVM_DEBUG(llvm::dbgs() 694 << op->getName() << " illegal: requiring min version " 695 << spirv::stringifyVersion(minVersion.getMinVersion()) 696 << "\n"); 697 return false; 698 } 699 if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op)) 700 if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) { 701 LLVM_DEBUG(llvm::dbgs() 702 << op->getName() << " illegal: requiring max version " 703 << spirv::stringifyVersion(maxVersion.getMaxVersion()) 704 << "\n"); 705 return false; 706 } 707 708 // Make sure this op's required extensions are allowed to use. Ops not 709 // implementing QueryExtensionInterface do not require extensions to be 710 // available. 711 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 712 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 713 extensions.getExtensions()))) 714 return false; 715 716 // Make sure this op's required extensions are allowed to use. Ops not 717 // implementing QueryCapabilityInterface do not require capabilities to be 718 // available. 719 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 720 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 721 capabilities.getCapabilities()))) 722 return false; 723 724 SmallVector<Type, 4> valueTypes; 725 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 726 valueTypes.append(op->result_type_begin(), op->result_type_end()); 727 728 // Special treatment for global variables, whose type requirements are 729 // conveyed by type attributes. 730 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 731 valueTypes.push_back(globalVar.type()); 732 733 // Make sure the op's operands/results use types that are allowed by the 734 // target environment. 735 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 736 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 737 for (Type valueType : valueTypes) { 738 typeExtensions.clear(); 739 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); 740 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 741 typeExtensions))) 742 return false; 743 744 typeCapabilities.clear(); 745 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); 746 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 747 typeCapabilities))) 748 return false; 749 } 750 751 return true; 752 } 753