1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "../PassDetail.h" 18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 19 #include "mlir/Dialect/GPU/GPUDialect.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/IR/Attributes.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/Function.h" 24 #include "mlir/IR/Module.h" 25 26 #include "llvm/ADT/SmallString.h" 27 #include "llvm/Support/FormatVariadic.h" 28 29 using namespace mlir; 30 31 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; 32 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; 33 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat"; 34 static constexpr const char *kCInterfaceVulkanLaunch = 35 "_mlir_ciface_vulkanLaunch"; 36 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 37 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 38 static constexpr const char *kInitVulkan = "initVulkan"; 39 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 40 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 41 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 42 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 43 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 44 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 45 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 46 47 namespace { 48 49 /// A pass to convert vulkan launch call op into a sequence of Vulkan 50 /// runtime calls in the following order: 51 /// 52 /// * initVulkan -- initializes vulkan runtime 53 /// * bindMemRef -- binds memref 54 /// * setBinaryShader -- sets the binary shader data 55 /// * setEntryPoint -- sets the entry point name 56 /// * setNumWorkGroups -- sets the number of a local workgroups 57 /// * runOnVulkan -- runs vulkan runtime 58 /// * deinitVulkan -- deinitializes vulkan runtime 59 /// 60 class VulkanLaunchFuncToVulkanCallsPass 61 : public ConvertVulkanLaunchFuncToVulkanCallsBase< 62 VulkanLaunchFuncToVulkanCallsPass> { 63 private: 64 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 65 66 llvm::LLVMContext &getLLVMContext() { 67 return getLLVMDialect()->getLLVMContext(); 68 } 69 70 void initializeCachedTypes() { 71 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 72 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 73 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 74 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 75 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 76 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 77 llvmMemRef1DFloat = getMemRefType(1); 78 llvmMemRef2DFloat = getMemRefType(2); 79 llvmMemRef3DFloat = getMemRefType(3); 80 } 81 82 LLVM::LLVMType getMemRefType(uint32_t rank) { 83 // According to the MLIR doc memref argument is converted into a 84 // pointer-to-struct argument of type: 85 // template <typename Elem, size_t Rank> 86 // struct { 87 // Elem *allocated; 88 // Elem *aligned; 89 // int64_t offset; 90 // int64_t sizes[Rank]; // omitted when rank == 0 91 // int64_t strides[Rank]; // omitted when rank == 0 92 // }; 93 auto llvmPtrToFloatType = getFloatType().getPointerTo(); 94 auto llvmArrayRankElementSizeType = 95 LLVM::LLVMType::getArrayTy(getInt64Type(), rank); 96 97 // Create a type 98 // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`. 99 return LLVM::LLVMType::getStructTy( 100 llvmDialect, 101 {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), 102 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); 103 } 104 105 LLVM::LLVMType getFloatType() { return llvmFloatType; } 106 LLVM::LLVMType getVoidType() { return llvmVoidType; } 107 LLVM::LLVMType getPointerType() { return llvmPointerType; } 108 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 109 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 110 LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } 111 LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } 112 LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; } 113 114 /// Creates a LLVM global for the given `name`. 115 Value createEntryPointNameConstant(StringRef name, Location loc, 116 OpBuilder &builder); 117 118 /// Declares all needed runtime functions. 119 void declareVulkanFunctions(Location loc); 120 121 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 122 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 123 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 124 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 125 } 126 127 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 128 /// op. 129 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 130 return (callOp.callee() && 131 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 132 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 133 } 134 135 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 136 /// runtime calls. 137 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 138 139 /// Creates call to `bindMemRef` for each memref operand. 140 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 141 Value vulkanRuntime); 142 143 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 144 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 145 146 /// Deduces a rank from the given 'ptrToMemRefDescriptor`. 147 LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); 148 149 public: 150 void runOnOperation() override; 151 152 private: 153 LLVM::LLVMDialect *llvmDialect; 154 LLVM::LLVMType llvmFloatType; 155 LLVM::LLVMType llvmVoidType; 156 LLVM::LLVMType llvmPointerType; 157 LLVM::LLVMType llvmInt32Type; 158 LLVM::LLVMType llvmInt64Type; 159 LLVM::LLVMType llvmMemRef1DFloat; 160 LLVM::LLVMType llvmMemRef2DFloat; 161 LLVM::LLVMType llvmMemRef3DFloat; 162 163 // TODO: Use an associative array to support multiple vulkan launch calls. 164 std::pair<StringAttr, StringAttr> spirvAttributes; 165 }; 166 167 } // anonymous namespace 168 169 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { 170 initializeCachedTypes(); 171 172 // Collect SPIR-V attributes such as `spirv_blob` and 173 // `spirv_entry_point_name`. 174 getOperation().walk([this](LLVM::CallOp op) { 175 if (isVulkanLaunchCallOp(op)) 176 collectSPIRVAttributes(op); 177 }); 178 179 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 180 getOperation().walk([this](LLVM::CallOp op) { 181 if (isCInterfaceVulkanLaunchCallOp(op)) 182 translateVulkanLaunchCall(op); 183 }); 184 } 185 186 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 187 LLVM::CallOp vulkanLaunchCallOp) { 188 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 189 // for the given vulkan launch call. 190 auto spirvBlobAttr = 191 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 192 if (!spirvBlobAttr) { 193 vulkanLaunchCallOp.emitError() 194 << "missing " << kSPIRVBlobAttrName << " attribute"; 195 return signalPassFailure(); 196 } 197 198 auto spirvEntryPointNameAttr = 199 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 200 if (!spirvEntryPointNameAttr) { 201 vulkanLaunchCallOp.emitError() 202 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 203 return signalPassFailure(); 204 } 205 206 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 207 } 208 209 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 210 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 211 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 212 gpu::LaunchOp::kNumConfigOperands) 213 return; 214 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 215 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 216 217 // Create LLVM constant for the descriptor set index. 218 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 219 // pass does. 220 Value descriptorSet = builder.create<LLVM::ConstantOp>( 221 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 222 223 for (auto en : 224 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 225 gpu::LaunchOp::kNumConfigOperands))) { 226 // Create LLVM constant for the descriptor binding index. 227 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 228 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 229 230 auto ptrToMemRefDescriptor = en.value(); 231 uint32_t rank = 0; 232 if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) { 233 cInterfaceVulkanLaunchCallOp.emitError() 234 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 235 return signalPassFailure(); 236 } 237 238 auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str(); 239 // Create call to `bindMemRef`. 240 builder.create<LLVM::CallOp>( 241 loc, ArrayRef<Type>{getVoidType()}, 242 builder.getSymbolRefAttr( 243 StringRef(symbolName.data(), symbolName.size())), 244 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 245 ptrToMemRefDescriptor}); 246 } 247 } 248 249 LogicalResult 250 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor, 251 uint32_t &rank) { 252 auto llvmPtrDescriptorTy = 253 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>(); 254 if (!llvmPtrDescriptorTy) 255 return failure(); 256 257 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); 258 // template <typename Elem, size_t Rank> 259 // struct { 260 // Elem *allocated; 261 // Elem *aligned; 262 // int64_t offset; 263 // int64_t sizes[Rank]; // omitted when rank == 0 264 // int64_t strides[Rank]; // omitted when rank == 0 265 // }; 266 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) 267 return failure(); 268 if (llvmDescriptorTy.getStructNumElements() == 3) { 269 rank = 0; 270 return success(); 271 } 272 273 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); 274 return success(); 275 } 276 277 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 278 ModuleOp module = getOperation(); 279 OpBuilder builder(module.getBody()->getTerminator()); 280 281 if (!module.lookupSymbol(kSetEntryPoint)) { 282 builder.create<LLVM::LLVMFuncOp>( 283 loc, kSetEntryPoint, 284 LLVM::LLVMType::getFunctionTy(getVoidType(), 285 {getPointerType(), getPointerType()}, 286 /*isVarArg=*/false)); 287 } 288 289 if (!module.lookupSymbol(kSetNumWorkGroups)) { 290 builder.create<LLVM::LLVMFuncOp>( 291 loc, kSetNumWorkGroups, 292 LLVM::LLVMType::getFunctionTy( 293 getVoidType(), 294 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 295 /*isVarArg=*/false)); 296 } 297 298 if (!module.lookupSymbol(kSetBinaryShader)) { 299 builder.create<LLVM::LLVMFuncOp>( 300 loc, kSetBinaryShader, 301 LLVM::LLVMType::getFunctionTy( 302 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 303 /*isVarArg=*/false)); 304 } 305 306 if (!module.lookupSymbol(kRunOnVulkan)) { 307 builder.create<LLVM::LLVMFuncOp>( 308 loc, kRunOnVulkan, 309 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 310 /*isVarArg=*/false)); 311 } 312 313 if (!module.lookupSymbol(kBindMemRef1DFloat)) { 314 builder.create<LLVM::LLVMFuncOp>( 315 loc, kBindMemRef1DFloat, 316 LLVM::LLVMType::getFunctionTy(getVoidType(), 317 {getPointerType(), getInt32Type(), 318 getInt32Type(), 319 getMemRef1DFloat().getPointerTo()}, 320 /*isVarArg=*/false)); 321 } 322 323 if (!module.lookupSymbol(kBindMemRef2DFloat)) { 324 builder.create<LLVM::LLVMFuncOp>( 325 loc, kBindMemRef2DFloat, 326 LLVM::LLVMType::getFunctionTy(getVoidType(), 327 {getPointerType(), getInt32Type(), 328 getInt32Type(), 329 getMemRef2DFloat().getPointerTo()}, 330 /*isVarArg=*/false)); 331 } 332 333 if (!module.lookupSymbol(kBindMemRef3DFloat)) { 334 builder.create<LLVM::LLVMFuncOp>( 335 loc, kBindMemRef3DFloat, 336 LLVM::LLVMType::getFunctionTy(getVoidType(), 337 {getPointerType(), getInt32Type(), 338 getInt32Type(), 339 getMemRef3DFloat().getPointerTo()}, 340 /*isVarArg=*/false)); 341 } 342 343 if (!module.lookupSymbol(kInitVulkan)) { 344 builder.create<LLVM::LLVMFuncOp>( 345 loc, kInitVulkan, 346 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 347 /*isVarArg=*/false)); 348 } 349 350 if (!module.lookupSymbol(kDeinitVulkan)) { 351 builder.create<LLVM::LLVMFuncOp>( 352 loc, kDeinitVulkan, 353 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 354 /*isVarArg=*/false)); 355 } 356 } 357 358 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 359 StringRef name, Location loc, OpBuilder &builder) { 360 SmallString<16> shaderName(name.begin(), name.end()); 361 // Append `\0` to follow C style string given that LLVM::createGlobalString() 362 // won't handle this directly for us. 363 shaderName.push_back('\0'); 364 365 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 366 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 367 shaderName, LLVM::Linkage::Internal, 368 getLLVMDialect()); 369 } 370 371 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 372 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 373 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 374 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 375 // Create call to `initVulkan`. 376 auto initVulkanCall = builder.create<LLVM::CallOp>( 377 loc, ArrayRef<Type>{getPointerType()}, 378 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 379 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 380 // need to pass that pointer to each Vulkan runtime call. 381 auto vulkanRuntime = initVulkanCall.getResult(0); 382 383 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 384 // that data to runtime call. 385 Value ptrToSPIRVBinary = LLVM::createGlobalString( 386 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 387 LLVM::Linkage::Internal, getLLVMDialect()); 388 389 // Create LLVM constant for the size of SPIR-V binary shader. 390 Value binarySize = builder.create<LLVM::ConstantOp>( 391 loc, getInt32Type(), 392 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 393 394 // Create call to `bindMemRef` for each memref operand. 395 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 396 397 // Create call to `setBinaryShader` runtime function with the given pointer to 398 // SPIR-V binary and binary size. 399 builder.create<LLVM::CallOp>( 400 loc, ArrayRef<Type>{getVoidType()}, 401 builder.getSymbolRefAttr(kSetBinaryShader), 402 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 403 // Create LLVM global with entry point name. 404 Value entryPointName = createEntryPointNameConstant( 405 spirvAttributes.second.getValue(), loc, builder); 406 // Create call to `setEntryPoint` runtime function with the given pointer to 407 // entry point name. 408 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 409 builder.getSymbolRefAttr(kSetEntryPoint), 410 ArrayRef<Value>{vulkanRuntime, entryPointName}); 411 412 // Create number of local workgroup for each dimension. 413 builder.create<LLVM::CallOp>( 414 loc, ArrayRef<Type>{getVoidType()}, 415 builder.getSymbolRefAttr(kSetNumWorkGroups), 416 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 417 cInterfaceVulkanLaunchCallOp.getOperand(1), 418 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 419 420 // Create call to `runOnVulkan` runtime function. 421 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 422 builder.getSymbolRefAttr(kRunOnVulkan), 423 ArrayRef<Value>{vulkanRuntime}); 424 425 // Create call to 'deinitVulkan' runtime function. 426 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 427 builder.getSymbolRefAttr(kDeinitVulkan), 428 ArrayRef<Value>{vulkanRuntime}); 429 430 // Declare runtime functions. 431 declareVulkanFunctions(loc); 432 433 cInterfaceVulkanLaunchCallOp.erase(); 434 } 435 436 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 437 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 438 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 439 } 440