1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 18 #include "mlir/Dialect/GPU/GPUDialect.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/Pass/Pass.h" 25 26 #include "llvm/ADT/SmallString.h" 27 28 using namespace mlir; 29 30 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; 31 static constexpr const char *kCInterfaceVulkanLaunch = 32 "_mlir_ciface_vulkanLaunch"; 33 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 34 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 35 static constexpr const char *kInitVulkan = "initVulkan"; 36 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 37 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 38 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 39 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 40 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 41 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 42 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 43 44 namespace { 45 46 /// A pass to convert vulkan launch call op into a sequence of Vulkan 47 /// runtime calls in the following order: 48 /// 49 /// * initVulkan -- initializes vulkan runtime 50 /// * bindMemRef -- binds memref 51 /// * setBinaryShader -- sets the binary shader data 52 /// * setEntryPoint -- sets the entry point name 53 /// * setNumWorkGroups -- sets the number of a local workgroups 54 /// * runOnVulkan -- runs vulkan runtime 55 /// * deinitVulkan -- deinitializes vulkan runtime 56 /// 57 class VulkanLaunchFuncToVulkanCallsPass 58 : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> { 59 private: 60 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 61 62 llvm::LLVMContext &getLLVMContext() { 63 return getLLVMDialect()->getLLVMContext(); 64 } 65 66 void initializeCachedTypes() { 67 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 68 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 69 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 70 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 71 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 72 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 73 initializeMemRefTypes(); 74 } 75 76 void initializeMemRefTypes() { 77 // According to the MLIR doc memref argument is converted into a 78 // pointer-to-struct argument of type: 79 // template <typename Elem, size_t Rank> 80 // struct { 81 // Elem *allocated; 82 // Elem *aligned; 83 // int64_t offset; 84 // int64_t sizes[Rank]; // omitted when rank == 0 85 // int64_t strides[Rank]; // omitted when rank == 0 86 // }; 87 auto llvmPtrToFloatType = getFloatType().getPointerTo(); 88 auto llvmArrayOneElementSizeType = 89 LLVM::LLVMType::getArrayTy(getInt64Type(), 1); 90 91 // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`. 92 llvmMemRef1DFloat = LLVM::LLVMType::getStructTy( 93 llvmDialect, 94 {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), 95 llvmArrayOneElementSizeType, llvmArrayOneElementSizeType}); 96 } 97 98 LLVM::LLVMType getFloatType() { return llvmFloatType; } 99 LLVM::LLVMType getVoidType() { return llvmVoidType; } 100 LLVM::LLVMType getPointerType() { return llvmPointerType; } 101 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 102 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 103 LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } 104 105 /// Creates a LLVM global for the given `name`. 106 Value createEntryPointNameConstant(StringRef name, Location loc, 107 OpBuilder &builder); 108 109 /// Declares all needed runtime functions. 110 void declareVulkanFunctions(Location loc); 111 112 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 113 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 114 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 115 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 116 } 117 118 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 119 /// op. 120 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 121 return (callOp.callee() && 122 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 123 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 124 } 125 126 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 127 /// runtime calls. 128 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 129 130 /// Creates call to `bindMemRef` for each memref operand. 131 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 132 Value vulkanRuntime); 133 134 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 135 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 136 137 public: 138 void runOnModule() override; 139 140 private: 141 LLVM::LLVMDialect *llvmDialect; 142 LLVM::LLVMType llvmFloatType; 143 LLVM::LLVMType llvmVoidType; 144 LLVM::LLVMType llvmPointerType; 145 LLVM::LLVMType llvmInt32Type; 146 LLVM::LLVMType llvmInt64Type; 147 LLVM::LLVMType llvmMemRef1DFloat; 148 149 // TODO: Use an associative array to support multiple vulkan launch calls. 150 std::pair<StringAttr, StringAttr> spirvAttributes; 151 }; 152 153 } // anonymous namespace 154 155 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { 156 initializeCachedTypes(); 157 158 // Collect SPIR-V attributes such as `spirv_blob` and 159 // `spirv_entry_point_name`. 160 getModule().walk([this](LLVM::CallOp op) { 161 if (isVulkanLaunchCallOp(op)) 162 collectSPIRVAttributes(op); 163 }); 164 165 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 166 getModule().walk([this](LLVM::CallOp op) { 167 if (isCInterfaceVulkanLaunchCallOp(op)) 168 translateVulkanLaunchCall(op); 169 }); 170 } 171 172 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 173 LLVM::CallOp vulkanLaunchCallOp) { 174 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 175 // for the given vulkan launch call. 176 auto spirvBlobAttr = 177 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 178 if (!spirvBlobAttr) { 179 vulkanLaunchCallOp.emitError() 180 << "missing " << kSPIRVBlobAttrName << " attribute"; 181 return signalPassFailure(); 182 } 183 184 auto spirvEntryPointNameAttr = 185 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 186 if (!spirvEntryPointNameAttr) { 187 vulkanLaunchCallOp.emitError() 188 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 189 return signalPassFailure(); 190 } 191 192 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 193 } 194 195 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 196 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 197 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 198 gpu::LaunchOp::kNumConfigOperands) 199 return; 200 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 201 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 202 203 // Create LLVM constant for the descriptor set index. 204 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 205 // pass does. 206 Value descriptorSet = builder.create<LLVM::ConstantOp>( 207 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 208 209 for (auto en : 210 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 211 gpu::LaunchOp::kNumConfigOperands))) { 212 // Create LLVM constant for the descriptor binding index. 213 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 214 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 215 // Create call to `bindMemRef`. 216 builder.create<LLVM::CallOp>( 217 loc, ArrayRef<Type>{getVoidType()}, 218 // TODO: Add support for memref with other ranks. 219 builder.getSymbolRefAttr(kBindMemRef1DFloat), 220 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 221 en.value()}); 222 } 223 } 224 225 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 226 ModuleOp module = getModule(); 227 OpBuilder builder(module.getBody()->getTerminator()); 228 229 if (!module.lookupSymbol(kSetEntryPoint)) { 230 builder.create<LLVM::LLVMFuncOp>( 231 loc, kSetEntryPoint, 232 LLVM::LLVMType::getFunctionTy(getVoidType(), 233 {getPointerType(), getPointerType()}, 234 /*isVarArg=*/false)); 235 } 236 237 if (!module.lookupSymbol(kSetNumWorkGroups)) { 238 builder.create<LLVM::LLVMFuncOp>( 239 loc, kSetNumWorkGroups, 240 LLVM::LLVMType::getFunctionTy( 241 getVoidType(), 242 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 243 /*isVarArg=*/false)); 244 } 245 246 if (!module.lookupSymbol(kSetBinaryShader)) { 247 builder.create<LLVM::LLVMFuncOp>( 248 loc, kSetBinaryShader, 249 LLVM::LLVMType::getFunctionTy( 250 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 251 /*isVarArg=*/false)); 252 } 253 254 if (!module.lookupSymbol(kRunOnVulkan)) { 255 builder.create<LLVM::LLVMFuncOp>( 256 loc, kRunOnVulkan, 257 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 258 /*isVarArg=*/false)); 259 } 260 261 if (!module.lookupSymbol(kBindMemRef1DFloat)) { 262 builder.create<LLVM::LLVMFuncOp>( 263 loc, kBindMemRef1DFloat, 264 LLVM::LLVMType::getFunctionTy(getVoidType(), 265 {getPointerType(), getInt32Type(), 266 getInt32Type(), 267 getMemRef1DFloat().getPointerTo()}, 268 /*isVarArg=*/false)); 269 } 270 271 if (!module.lookupSymbol(kInitVulkan)) { 272 builder.create<LLVM::LLVMFuncOp>( 273 loc, kInitVulkan, 274 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 275 /*isVarArg=*/false)); 276 } 277 278 if (!module.lookupSymbol(kDeinitVulkan)) { 279 builder.create<LLVM::LLVMFuncOp>( 280 loc, kDeinitVulkan, 281 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 282 /*isVarArg=*/false)); 283 } 284 } 285 286 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 287 StringRef name, Location loc, OpBuilder &builder) { 288 SmallString<16> shaderName(name.begin(), name.end()); 289 // Append `\0` to follow C style string given that LLVM::createGlobalString() 290 // won't handle this directly for us. 291 shaderName.push_back('\0'); 292 293 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 294 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 295 shaderName, LLVM::Linkage::Internal, 296 getLLVMDialect()); 297 } 298 299 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 300 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 301 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 302 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 303 // Create call to `initVulkan`. 304 auto initVulkanCall = builder.create<LLVM::CallOp>( 305 loc, ArrayRef<Type>{getPointerType()}, 306 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 307 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 308 // need to pass that pointer to each Vulkan runtime call. 309 auto vulkanRuntime = initVulkanCall.getResult(0); 310 311 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 312 // that data to runtime call. 313 Value ptrToSPIRVBinary = LLVM::createGlobalString( 314 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 315 LLVM::Linkage::Internal, getLLVMDialect()); 316 317 // Create LLVM constant for the size of SPIR-V binary shader. 318 Value binarySize = builder.create<LLVM::ConstantOp>( 319 loc, getInt32Type(), 320 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 321 322 // Create call to `bindMemRef` for each memref operand. 323 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 324 325 // Create call to `setBinaryShader` runtime function with the given pointer to 326 // SPIR-V binary and binary size. 327 builder.create<LLVM::CallOp>( 328 loc, ArrayRef<Type>{getVoidType()}, 329 builder.getSymbolRefAttr(kSetBinaryShader), 330 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 331 // Create LLVM global with entry point name. 332 Value entryPointName = createEntryPointNameConstant( 333 spirvAttributes.second.getValue(), loc, builder); 334 // Create call to `setEntryPoint` runtime function with the given pointer to 335 // entry point name. 336 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 337 builder.getSymbolRefAttr(kSetEntryPoint), 338 ArrayRef<Value>{vulkanRuntime, entryPointName}); 339 340 // Create number of local workgroup for each dimension. 341 builder.create<LLVM::CallOp>( 342 loc, ArrayRef<Type>{getVoidType()}, 343 builder.getSymbolRefAttr(kSetNumWorkGroups), 344 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 345 cInterfaceVulkanLaunchCallOp.getOperand(1), 346 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 347 348 // Create call to `runOnVulkan` runtime function. 349 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 350 builder.getSymbolRefAttr(kRunOnVulkan), 351 ArrayRef<Value>{vulkanRuntime}); 352 353 // Create call to 'deinitVulkan' runtime function. 354 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 355 builder.getSymbolRefAttr(kDeinitVulkan), 356 ArrayRef<Value>{vulkanRuntime}); 357 358 // Declare runtime functions. 359 declareVulkanFunctions(loc); 360 361 cInterfaceVulkanLaunchCallOp.erase(); 362 } 363 364 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> 365 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 366 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 367 } 368 369 static PassRegistration<VulkanLaunchFuncToVulkanCallsPass> 370 pass("launch-func-to-vulkan", 371 "Convert vulkanLaunch external call to Vulkan runtime external calls"); 372