1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/IR/Attributes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Function.h" 22 #include "mlir/IR/Module.h" 23 #include "mlir/Pass/Pass.h" 24 25 #include "llvm/ADT/SmallString.h" 26 27 using namespace mlir; 28 29 static constexpr const char *kBindResource = "bindResource"; 30 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 31 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 32 static constexpr const char *kInitVulkan = "initVulkan"; 33 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 34 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 35 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 36 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 37 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 38 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 39 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 40 41 namespace { 42 43 /// A pass to convert vulkan launch func into a sequence of Vulkan 44 /// runtime calls in the following order: 45 /// 46 /// * initVulkan -- initializes vulkan runtime 47 /// * bindResource -- binds resource 48 /// * setBinaryShader -- sets the binary shader data 49 /// * setEntryPoint -- sets the entry point name 50 /// * setNumWorkGroups -- sets the number of a local workgroups 51 /// * runOnVulkan -- runs vulkan runtime 52 /// * deinitVulkan -- deinitializes vulkan runtime 53 /// 54 class VulkanLaunchFuncToVulkanCallsPass 55 : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> { 56 private: 57 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 58 59 llvm::LLVMContext &getLLVMContext() { 60 return getLLVMDialect()->getLLVMContext(); 61 } 62 63 void initializeCachedTypes() { 64 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 65 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 66 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 67 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 68 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 69 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 70 } 71 72 LLVM::LLVMType getFloatType() { return llvmFloatType; } 73 LLVM::LLVMType getVoidType() { return llvmVoidType; } 74 LLVM::LLVMType getPointerType() { return llvmPointerType; } 75 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 76 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 77 78 /// Creates a LLVM global for the given `name`. 79 Value createEntryPointNameConstant(StringRef name, Location loc, 80 OpBuilder &builder); 81 82 /// Declares all needed runtime functions. 83 void declareVulkanFunctions(Location loc); 84 85 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 86 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 87 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 88 callOp.getNumOperands() >= 6); 89 } 90 91 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 92 /// runtime calls. 93 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 94 95 /// Creates call to `bindResource` for each resource operand. 96 void createBindResourceCalls(LLVM::CallOp vulkanLaunchCallOp, 97 Value vulkanRuntiem); 98 99 public: 100 void runOnModule() override; 101 102 private: 103 LLVM::LLVMDialect *llvmDialect; 104 LLVM::LLVMType llvmFloatType; 105 LLVM::LLVMType llvmVoidType; 106 LLVM::LLVMType llvmPointerType; 107 LLVM::LLVMType llvmInt32Type; 108 LLVM::LLVMType llvmInt64Type; 109 }; 110 111 /// Represents operand adaptor for vulkan launch call operation, to simplify an 112 /// access to the lowered memref. 113 // TODO: We should use 'emit-c-wrappers' option to lower memref type: 114 // https://mlir.llvm.org/docs/ConversionToLLVMDialect/#c-compatible-wrapper-emission. 115 struct VulkanLaunchOpOperandAdaptor { 116 VulkanLaunchOpOperandAdaptor(ArrayRef<Value> values) { operands = values; } 117 VulkanLaunchOpOperandAdaptor(const VulkanLaunchOpOperandAdaptor &) = delete; 118 VulkanLaunchOpOperandAdaptor 119 operator=(const VulkanLaunchOpOperandAdaptor &) = delete; 120 121 /// Returns a tuple with a pointer to the memory and the size for the index-th 122 /// resource. 123 std::tuple<Value, Value> getResourceDescriptor1D(uint32_t index) { 124 assert(index < getResourceCount1D()); 125 // 1D memref calling convention according to "ConversionToLLVMDialect.md": 126 // 0. Allocated pointer. 127 // 1. Aligned pointer. 128 // 2. Offset. 129 // 3. Size in dim 0. 130 // 4. Stride in dim 0. 131 auto offset = numConfigOps + index * loweredMemRefNumOps1D; 132 return std::make_tuple(operands[offset], operands[offset + 3]); 133 } 134 135 /// Returns the number of resources assuming all operands lowered from 136 /// 1D memref. 137 uint32_t getResourceCount1D() { 138 return (operands.size() - numConfigOps) / loweredMemRefNumOps1D; 139 } 140 141 private: 142 /// The number of operands of lowered 1D memref. 143 static constexpr const uint32_t loweredMemRefNumOps1D = 5; 144 /// The number of the first config operands. 145 static constexpr const uint32_t numConfigOps = 6; 146 ArrayRef<Value> operands; 147 }; 148 149 } // anonymous namespace 150 151 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { 152 initializeCachedTypes(); 153 getModule().walk([this](LLVM::CallOp op) { 154 if (isVulkanLaunchCallOp(op)) 155 translateVulkanLaunchCall(op); 156 }); 157 } 158 159 void VulkanLaunchFuncToVulkanCallsPass::createBindResourceCalls( 160 LLVM::CallOp vulkanLaunchCallOp, Value vulkanRuntime) { 161 if (vulkanLaunchCallOp.getNumOperands() == 6) 162 return; 163 OpBuilder builder(vulkanLaunchCallOp); 164 Location loc = vulkanLaunchCallOp.getLoc(); 165 166 // Create LLVM constant for the descriptor set index. 167 // Bind all resources to the `0` descriptor set, the same way as `GPUToSPIRV` 168 // pass does. 169 Value descriptorSet = builder.create<LLVM::ConstantOp>( 170 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 171 172 auto operands = SmallVector<Value, 32>{vulkanLaunchCallOp.getOperands()}; 173 VulkanLaunchOpOperandAdaptor vkLaunchOperandAdaptor(operands); 174 175 for (auto resourceIdx : 176 llvm::seq<uint32_t>(0, vkLaunchOperandAdaptor.getResourceCount1D())) { 177 // Create LLVM constant for the descriptor binding index. 178 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 179 loc, getInt32Type(), builder.getI32IntegerAttr(resourceIdx)); 180 // Get a pointer to the memory and size of that memory. 181 auto resourceDescriptor = 182 vkLaunchOperandAdaptor.getResourceDescriptor1D(resourceIdx); 183 // Create call to `bindResource`. 184 builder.create<LLVM::CallOp>( 185 loc, ArrayRef<Type>{getVoidType()}, 186 builder.getSymbolRefAttr(kBindResource), 187 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 188 // Pointer to the memory. 189 std::get<0>(resourceDescriptor), 190 // Size of the memory. 191 std::get<1>(resourceDescriptor)}); 192 } 193 } 194 195 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 196 ModuleOp module = getModule(); 197 OpBuilder builder(module.getBody()->getTerminator()); 198 199 if (!module.lookupSymbol(kSetEntryPoint)) { 200 builder.create<LLVM::LLVMFuncOp>( 201 loc, kSetEntryPoint, 202 LLVM::LLVMType::getFunctionTy(getVoidType(), 203 {getPointerType(), getPointerType()}, 204 /*isVarArg=*/false)); 205 } 206 207 if (!module.lookupSymbol(kSetNumWorkGroups)) { 208 builder.create<LLVM::LLVMFuncOp>( 209 loc, kSetNumWorkGroups, 210 LLVM::LLVMType::getFunctionTy( 211 getVoidType(), 212 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 213 /*isVarArg=*/false)); 214 } 215 216 if (!module.lookupSymbol(kSetBinaryShader)) { 217 builder.create<LLVM::LLVMFuncOp>( 218 loc, kSetBinaryShader, 219 LLVM::LLVMType::getFunctionTy( 220 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 221 /*isVarArg=*/false)); 222 } 223 224 if (!module.lookupSymbol(kRunOnVulkan)) { 225 builder.create<LLVM::LLVMFuncOp>( 226 loc, kRunOnVulkan, 227 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 228 /*isVarArg=*/false)); 229 } 230 231 if (!module.lookupSymbol(kBindResource)) { 232 builder.create<LLVM::LLVMFuncOp>( 233 loc, kBindResource, 234 LLVM::LLVMType::getFunctionTy( 235 getVoidType(), 236 {getPointerType(), getInt32Type(), getInt32Type(), 237 getFloatType().getPointerTo(), getInt64Type()}, 238 /*isVarArg=*/false)); 239 } 240 241 if (!module.lookupSymbol(kInitVulkan)) { 242 builder.create<LLVM::LLVMFuncOp>( 243 loc, kInitVulkan, 244 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 245 /*isVarArg=*/false)); 246 } 247 248 if (!module.lookupSymbol(kDeinitVulkan)) { 249 builder.create<LLVM::LLVMFuncOp>( 250 loc, kDeinitVulkan, 251 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 252 /*isVarArg=*/false)); 253 } 254 } 255 256 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 257 StringRef name, Location loc, OpBuilder &builder) { 258 SmallString<16> shaderName(name.begin(), name.end()); 259 // Append `\0` to follow C style string given that LLVM::createGlobalString() 260 // won't handle this directly for us. 261 shaderName.push_back('\0'); 262 263 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 264 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 265 shaderName, LLVM::Linkage::Internal, 266 getLLVMDialect()); 267 } 268 269 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 270 LLVM::CallOp vulkanLaunchCallOp) { 271 OpBuilder builder(vulkanLaunchCallOp); 272 Location loc = vulkanLaunchCallOp.getLoc(); 273 274 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 275 // for the given vulkan launch call. 276 auto spirvBlobAttr = 277 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 278 if (!spirvBlobAttr) { 279 vulkanLaunchCallOp.emitError() 280 << "missing " << kSPIRVBlobAttrName << " attribute"; 281 return signalPassFailure(); 282 } 283 284 auto entryPointNameAttr = 285 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 286 if (!entryPointNameAttr) { 287 vulkanLaunchCallOp.emitError() 288 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 289 return signalPassFailure(); 290 } 291 292 // Create call to `initVulkan`. 293 auto initVulkanCall = builder.create<LLVM::CallOp>( 294 loc, ArrayRef<Type>{getPointerType()}, 295 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 296 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 297 // need to pass that pointer to each Vulkan runtime call. 298 auto vulkanRuntime = initVulkanCall.getResult(0); 299 300 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 301 // that data to runtime call. 302 Value ptrToSPIRVBinary = LLVM::createGlobalString( 303 loc, builder, kSPIRVBinary, spirvBlobAttr.getValue(), 304 LLVM::Linkage::Internal, getLLVMDialect()); 305 306 // Create LLVM constant for the size of SPIR-V binary shader. 307 Value binarySize = builder.create<LLVM::ConstantOp>( 308 loc, getInt32Type(), 309 builder.getI32IntegerAttr(spirvBlobAttr.getValue().size())); 310 311 // Create call to `bindResource` for each resource operand. 312 createBindResourceCalls(vulkanLaunchCallOp, vulkanRuntime); 313 314 // Create call to `setBinaryShader` runtime function with the given pointer to 315 // SPIR-V binary and binary size. 316 builder.create<LLVM::CallOp>( 317 loc, ArrayRef<Type>{getVoidType()}, 318 builder.getSymbolRefAttr(kSetBinaryShader), 319 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 320 // Create LLVM global with entry point name. 321 Value entryPointName = 322 createEntryPointNameConstant(entryPointNameAttr.getValue(), loc, builder); 323 // Create call to `setEntryPoint` runtime function with the given pointer to 324 // entry point name. 325 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 326 builder.getSymbolRefAttr(kSetEntryPoint), 327 ArrayRef<Value>{vulkanRuntime, entryPointName}); 328 329 // Create number of local workgroup for each dimension. 330 builder.create<LLVM::CallOp>( 331 loc, ArrayRef<Type>{getVoidType()}, 332 builder.getSymbolRefAttr(kSetNumWorkGroups), 333 ArrayRef<Value>{vulkanRuntime, vulkanLaunchCallOp.getOperand(0), 334 vulkanLaunchCallOp.getOperand(1), 335 vulkanLaunchCallOp.getOperand(2)}); 336 337 // Create call to `runOnVulkan` runtime function. 338 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 339 builder.getSymbolRefAttr(kRunOnVulkan), 340 ArrayRef<Value>{vulkanRuntime}); 341 342 // Create call to 'deinitVulkan' runtime function. 343 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 344 builder.getSymbolRefAttr(kDeinitVulkan), 345 ArrayRef<Value>{vulkanRuntime}); 346 347 // Declare runtime functions. 348 declareVulkanFunctions(loc); 349 350 vulkanLaunchCallOp.erase(); 351 } 352 353 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> 354 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 355 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 356 } 357 358 static PassRegistration<VulkanLaunchFuncToVulkanCallsPass> 359 pass("launch-func-to-vulkan", 360 "Convert vulkanLaunch external call to Vulkan runtime external calls"); 361