1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "../PassDetail.h" 18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/Support/FormatVariadic.h" 27 28 using namespace mlir; 29 30 static constexpr const char *kCInterfaceVulkanLaunch = 31 "_mlir_ciface_vulkanLaunch"; 32 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 33 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 34 static constexpr const char *kInitVulkan = "initVulkan"; 35 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 36 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 37 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 38 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 39 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 40 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 41 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 42 43 namespace { 44 45 /// A pass to convert vulkan launch call op into a sequence of Vulkan 46 /// runtime calls in the following order: 47 /// 48 /// * initVulkan -- initializes vulkan runtime 49 /// * bindMemRef -- binds memref 50 /// * setBinaryShader -- sets the binary shader data 51 /// * setEntryPoint -- sets the entry point name 52 /// * setNumWorkGroups -- sets the number of a local workgroups 53 /// * runOnVulkan -- runs vulkan runtime 54 /// * deinitVulkan -- deinitializes vulkan runtime 55 /// 56 class VulkanLaunchFuncToVulkanCallsPass 57 : public ConvertVulkanLaunchFuncToVulkanCallsBase< 58 VulkanLaunchFuncToVulkanCallsPass> { 59 private: 60 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 61 62 void initializeCachedTypes() { 63 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 64 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 65 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 66 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 67 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 68 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 69 } 70 71 LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { 72 // According to the MLIR doc memref argument is converted into a 73 // pointer-to-struct argument of type: 74 // template <typename Elem, size_t Rank> 75 // struct { 76 // Elem *allocated; 77 // Elem *aligned; 78 // int64_t offset; 79 // int64_t sizes[Rank]; // omitted when rank == 0 80 // int64_t strides[Rank]; // omitted when rank == 0 81 // }; 82 auto llvmPtrToElementType = elemenType.getPointerTo(); 83 auto llvmArrayRankElementSizeType = 84 LLVM::LLVMType::getArrayTy(getInt64Type(), rank); 85 86 // Create a type 87 // `!llvm<"{ `element-type`*, `element-type`*, i64, 88 // [`rank` x i64], [`rank` x i64]}">`. 89 return LLVM::LLVMType::getStructTy( 90 llvmDialect, 91 {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(), 92 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); 93 } 94 95 LLVM::LLVMType getVoidType() { return llvmVoidType; } 96 LLVM::LLVMType getPointerType() { return llvmPointerType; } 97 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 98 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 99 100 /// Creates an LLVM global for the given `name`. 101 Value createEntryPointNameConstant(StringRef name, Location loc, 102 OpBuilder &builder); 103 104 /// Declares all needed runtime functions. 105 void declareVulkanFunctions(Location loc); 106 107 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 108 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 109 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 110 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 111 } 112 113 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 114 /// op. 115 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 116 return (callOp.callee() && 117 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 118 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 119 } 120 121 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 122 /// runtime calls. 123 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 124 125 /// Creates call to `bindMemRef` for each memref operand. 126 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 127 Value vulkanRuntime); 128 129 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 130 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 131 132 /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`. 133 LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor, 134 uint32_t &rank, LLVM::LLVMType &type); 135 136 /// Returns a string representation from the given `type`. 137 StringRef stringifyType(LLVM::LLVMType type) { 138 if (type.isFloatTy()) 139 return "Float"; 140 if (type.isHalfTy()) 141 return "Half"; 142 if (type.isIntegerTy(32)) 143 return "Int32"; 144 if (type.isIntegerTy(16)) 145 return "Int16"; 146 if (type.isIntegerTy(8)) 147 return "Int8"; 148 149 llvm_unreachable("unsupported type"); 150 } 151 152 public: 153 void runOnOperation() override; 154 155 private: 156 LLVM::LLVMDialect *llvmDialect; 157 LLVM::LLVMType llvmFloatType; 158 LLVM::LLVMType llvmVoidType; 159 LLVM::LLVMType llvmPointerType; 160 LLVM::LLVMType llvmInt32Type; 161 LLVM::LLVMType llvmInt64Type; 162 163 // TODO: Use an associative array to support multiple vulkan launch calls. 164 std::pair<StringAttr, StringAttr> spirvAttributes; 165 /// The number of vulkan launch configuration operands, placed at the leading 166 /// positions of the operand list. 167 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3; 168 }; 169 170 } // anonymous namespace 171 172 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { 173 initializeCachedTypes(); 174 175 // Collect SPIR-V attributes such as `spirv_blob` and 176 // `spirv_entry_point_name`. 177 getOperation().walk([this](LLVM::CallOp op) { 178 if (isVulkanLaunchCallOp(op)) 179 collectSPIRVAttributes(op); 180 }); 181 182 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 183 getOperation().walk([this](LLVM::CallOp op) { 184 if (isCInterfaceVulkanLaunchCallOp(op)) 185 translateVulkanLaunchCall(op); 186 }); 187 } 188 189 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 190 LLVM::CallOp vulkanLaunchCallOp) { 191 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 192 // for the given vulkan launch call. 193 auto spirvBlobAttr = 194 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 195 if (!spirvBlobAttr) { 196 vulkanLaunchCallOp.emitError() 197 << "missing " << kSPIRVBlobAttrName << " attribute"; 198 return signalPassFailure(); 199 } 200 201 auto spirvEntryPointNameAttr = 202 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 203 if (!spirvEntryPointNameAttr) { 204 vulkanLaunchCallOp.emitError() 205 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 206 return signalPassFailure(); 207 } 208 209 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 210 } 211 212 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 213 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 214 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 215 kVulkanLaunchNumConfigOperands) 216 return; 217 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 218 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 219 220 // Create LLVM constant for the descriptor set index. 221 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 222 // pass does. 223 Value descriptorSet = builder.create<LLVM::ConstantOp>( 224 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 225 226 for (auto en : 227 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 228 kVulkanLaunchNumConfigOperands))) { 229 // Create LLVM constant for the descriptor binding index. 230 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 231 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 232 233 auto ptrToMemRefDescriptor = en.value(); 234 uint32_t rank = 0; 235 LLVM::LLVMType type; 236 if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) { 237 cInterfaceVulkanLaunchCallOp.emitError() 238 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 239 return signalPassFailure(); 240 } 241 242 auto symbolName = 243 llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); 244 // Special case for fp16 type. Since it is not a supported type in C we use 245 // int16_t and bitcast the descriptor. 246 if (type.isHalfTy()) { 247 auto memRefTy = 248 getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect)); 249 ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>( 250 loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor); 251 } 252 // Create call to `bindMemRef`. 253 builder.create<LLVM::CallOp>( 254 loc, ArrayRef<Type>{getVoidType()}, 255 builder.getSymbolRefAttr( 256 StringRef(symbolName.data(), symbolName.size())), 257 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 258 ptrToMemRefDescriptor}); 259 } 260 } 261 262 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( 263 Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) { 264 auto llvmPtrDescriptorTy = 265 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>(); 266 if (!llvmPtrDescriptorTy) 267 return failure(); 268 269 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); 270 // template <typename Elem, size_t Rank> 271 // struct { 272 // Elem *allocated; 273 // Elem *aligned; 274 // int64_t offset; 275 // int64_t sizes[Rank]; // omitted when rank == 0 276 // int64_t strides[Rank]; // omitted when rank == 0 277 // }; 278 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) 279 return failure(); 280 281 type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy(); 282 if (llvmDescriptorTy.getStructNumElements() == 3) { 283 rank = 0; 284 return success(); 285 } 286 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); 287 return success(); 288 } 289 290 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 291 ModuleOp module = getOperation(); 292 OpBuilder builder(module.getBody()->getTerminator()); 293 294 if (!module.lookupSymbol(kSetEntryPoint)) { 295 builder.create<LLVM::LLVMFuncOp>( 296 loc, kSetEntryPoint, 297 LLVM::LLVMType::getFunctionTy(getVoidType(), 298 {getPointerType(), getPointerType()}, 299 /*isVarArg=*/false)); 300 } 301 302 if (!module.lookupSymbol(kSetNumWorkGroups)) { 303 builder.create<LLVM::LLVMFuncOp>( 304 loc, kSetNumWorkGroups, 305 LLVM::LLVMType::getFunctionTy( 306 getVoidType(), 307 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 308 /*isVarArg=*/false)); 309 } 310 311 if (!module.lookupSymbol(kSetBinaryShader)) { 312 builder.create<LLVM::LLVMFuncOp>( 313 loc, kSetBinaryShader, 314 LLVM::LLVMType::getFunctionTy( 315 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 316 /*isVarArg=*/false)); 317 } 318 319 if (!module.lookupSymbol(kRunOnVulkan)) { 320 builder.create<LLVM::LLVMFuncOp>( 321 loc, kRunOnVulkan, 322 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 323 /*isVarArg=*/false)); 324 } 325 326 for (unsigned i = 1; i <= 3; i++) { 327 for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect), 328 LLVM::LLVMType::getInt32Ty(llvmDialect), 329 LLVM::LLVMType::getInt16Ty(llvmDialect), 330 LLVM::LLVMType::getInt8Ty(llvmDialect), 331 LLVM::LLVMType::getHalfTy(llvmDialect)}) { 332 std::string fnName = "bindMemRef" + std::to_string(i) + "D" + 333 std::string(stringifyType(type)); 334 if (type.isHalfTy()) 335 type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect)); 336 if (!module.lookupSymbol(fnName)) { 337 auto fnType = LLVM::LLVMType::getFunctionTy( 338 getVoidType(), 339 {getPointerType(), getInt32Type(), getInt32Type(), 340 getMemRefType(i, type).getPointerTo()}, 341 /*isVarArg=*/false); 342 builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType); 343 } 344 } 345 } 346 347 if (!module.lookupSymbol(kInitVulkan)) { 348 builder.create<LLVM::LLVMFuncOp>( 349 loc, kInitVulkan, 350 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 351 /*isVarArg=*/false)); 352 } 353 354 if (!module.lookupSymbol(kDeinitVulkan)) { 355 builder.create<LLVM::LLVMFuncOp>( 356 loc, kDeinitVulkan, 357 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 358 /*isVarArg=*/false)); 359 } 360 } 361 362 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 363 StringRef name, Location loc, OpBuilder &builder) { 364 SmallString<16> shaderName(name.begin(), name.end()); 365 // Append `\0` to follow C style string given that LLVM::createGlobalString() 366 // won't handle this directly for us. 367 shaderName.push_back('\0'); 368 369 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 370 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 371 shaderName, LLVM::Linkage::Internal, 372 getLLVMDialect()); 373 } 374 375 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 376 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 377 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 378 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 379 // Create call to `initVulkan`. 380 auto initVulkanCall = builder.create<LLVM::CallOp>( 381 loc, ArrayRef<Type>{getPointerType()}, 382 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 383 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 384 // need to pass that pointer to each Vulkan runtime call. 385 auto vulkanRuntime = initVulkanCall.getResult(0); 386 387 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 388 // that data to runtime call. 389 Value ptrToSPIRVBinary = LLVM::createGlobalString( 390 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 391 LLVM::Linkage::Internal, getLLVMDialect()); 392 393 // Create LLVM constant for the size of SPIR-V binary shader. 394 Value binarySize = builder.create<LLVM::ConstantOp>( 395 loc, getInt32Type(), 396 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 397 398 // Create call to `bindMemRef` for each memref operand. 399 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 400 401 // Create call to `setBinaryShader` runtime function with the given pointer to 402 // SPIR-V binary and binary size. 403 builder.create<LLVM::CallOp>( 404 loc, ArrayRef<Type>{getVoidType()}, 405 builder.getSymbolRefAttr(kSetBinaryShader), 406 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 407 // Create LLVM global with entry point name. 408 Value entryPointName = createEntryPointNameConstant( 409 spirvAttributes.second.getValue(), loc, builder); 410 // Create call to `setEntryPoint` runtime function with the given pointer to 411 // entry point name. 412 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 413 builder.getSymbolRefAttr(kSetEntryPoint), 414 ArrayRef<Value>{vulkanRuntime, entryPointName}); 415 416 // Create number of local workgroup for each dimension. 417 builder.create<LLVM::CallOp>( 418 loc, ArrayRef<Type>{getVoidType()}, 419 builder.getSymbolRefAttr(kSetNumWorkGroups), 420 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 421 cInterfaceVulkanLaunchCallOp.getOperand(1), 422 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 423 424 // Create call to `runOnVulkan` runtime function. 425 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 426 builder.getSymbolRefAttr(kRunOnVulkan), 427 ArrayRef<Value>{vulkanRuntime}); 428 429 // Create call to 'deinitVulkan' runtime function. 430 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 431 builder.getSymbolRefAttr(kDeinitVulkan), 432 ArrayRef<Value>{vulkanRuntime}); 433 434 // Declare runtime functions. 435 declareVulkanFunctions(loc); 436 437 cInterfaceVulkanLaunchCallOp.erase(); 438 } 439 440 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 441 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 442 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 443 } 444