1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 18 #include "mlir/Dialect/GPU/GPUDialect.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/Pass/Pass.h" 25 26 #include "llvm/ADT/SmallString.h" 27 #include "llvm/Support/FormatVariadic.h" 28 29 using namespace mlir; 30 31 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; 32 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; 33 static constexpr const char *kCInterfaceVulkanLaunch = 34 "_mlir_ciface_vulkanLaunch"; 35 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 36 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 37 static constexpr const char *kInitVulkan = "initVulkan"; 38 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 39 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 40 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 41 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 42 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 43 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 44 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 45 46 namespace { 47 48 /// A pass to convert vulkan launch call op into a sequence of Vulkan 49 /// runtime calls in the following order: 50 /// 51 /// * initVulkan -- initializes vulkan runtime 52 /// * bindMemRef -- binds memref 53 /// * setBinaryShader -- sets the binary shader data 54 /// * setEntryPoint -- sets the entry point name 55 /// * setNumWorkGroups -- sets the number of a local workgroups 56 /// * runOnVulkan -- runs vulkan runtime 57 /// * deinitVulkan -- deinitializes vulkan runtime 58 /// 59 class VulkanLaunchFuncToVulkanCallsPass 60 : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> { 61 private: 62 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 63 64 llvm::LLVMContext &getLLVMContext() { 65 return getLLVMDialect()->getLLVMContext(); 66 } 67 68 void initializeCachedTypes() { 69 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 70 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 71 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 72 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 73 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 74 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 75 initializeMemRefTypes(); 76 } 77 78 void initializeMemRefTypes() { 79 // According to the MLIR doc memref argument is converted into a 80 // pointer-to-struct argument of type: 81 // template <typename Elem, size_t Rank> 82 // struct { 83 // Elem *allocated; 84 // Elem *aligned; 85 // int64_t offset; 86 // int64_t sizes[Rank]; // omitted when rank == 0 87 // int64_t strides[Rank]; // omitted when rank == 0 88 // }; 89 auto llvmPtrToFloatType = getFloatType().getPointerTo(); 90 auto llvmArrayOneElementSizeType = 91 LLVM::LLVMType::getArrayTy(getInt64Type(), 1); 92 auto llvmArrayTwoElementSizeType = 93 LLVM::LLVMType::getArrayTy(getInt64Type(), 2); 94 95 // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`. 96 llvmMemRef1DFloat = LLVM::LLVMType::getStructTy( 97 llvmDialect, 98 {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), 99 llvmArrayOneElementSizeType, llvmArrayOneElementSizeType}); 100 101 // Create a type `!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64]}">`. 102 llvmMemRef2DFloat = LLVM::LLVMType::getStructTy( 103 llvmDialect, 104 {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), 105 llvmArrayTwoElementSizeType, llvmArrayTwoElementSizeType}); 106 } 107 108 LLVM::LLVMType getFloatType() { return llvmFloatType; } 109 LLVM::LLVMType getVoidType() { return llvmVoidType; } 110 LLVM::LLVMType getPointerType() { return llvmPointerType; } 111 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 112 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 113 LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } 114 LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } 115 116 /// Creates a LLVM global for the given `name`. 117 Value createEntryPointNameConstant(StringRef name, Location loc, 118 OpBuilder &builder); 119 120 /// Declares all needed runtime functions. 121 void declareVulkanFunctions(Location loc); 122 123 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 124 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 125 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 126 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 127 } 128 129 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 130 /// op. 131 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 132 return (callOp.callee() && 133 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 134 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 135 } 136 137 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 138 /// runtime calls. 139 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 140 141 /// Creates call to `bindMemRef` for each memref operand. 142 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 143 Value vulkanRuntime); 144 145 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 146 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 147 148 /// Deduces a rank from the given 'ptrToMemRefDescriptor`. 149 LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); 150 151 public: 152 void runOnModule() override; 153 154 private: 155 LLVM::LLVMDialect *llvmDialect; 156 LLVM::LLVMType llvmFloatType; 157 LLVM::LLVMType llvmVoidType; 158 LLVM::LLVMType llvmPointerType; 159 LLVM::LLVMType llvmInt32Type; 160 LLVM::LLVMType llvmInt64Type; 161 LLVM::LLVMType llvmMemRef1DFloat; 162 LLVM::LLVMType llvmMemRef2DFloat; 163 164 // TODO: Use an associative array to support multiple vulkan launch calls. 165 std::pair<StringAttr, StringAttr> spirvAttributes; 166 }; 167 168 } // anonymous namespace 169 170 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { 171 initializeCachedTypes(); 172 173 // Collect SPIR-V attributes such as `spirv_blob` and 174 // `spirv_entry_point_name`. 175 getModule().walk([this](LLVM::CallOp op) { 176 if (isVulkanLaunchCallOp(op)) 177 collectSPIRVAttributes(op); 178 }); 179 180 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 181 getModule().walk([this](LLVM::CallOp op) { 182 if (isCInterfaceVulkanLaunchCallOp(op)) 183 translateVulkanLaunchCall(op); 184 }); 185 } 186 187 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 188 LLVM::CallOp vulkanLaunchCallOp) { 189 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 190 // for the given vulkan launch call. 191 auto spirvBlobAttr = 192 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 193 if (!spirvBlobAttr) { 194 vulkanLaunchCallOp.emitError() 195 << "missing " << kSPIRVBlobAttrName << " attribute"; 196 return signalPassFailure(); 197 } 198 199 auto spirvEntryPointNameAttr = 200 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 201 if (!spirvEntryPointNameAttr) { 202 vulkanLaunchCallOp.emitError() 203 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 204 return signalPassFailure(); 205 } 206 207 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 208 } 209 210 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 211 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 212 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 213 gpu::LaunchOp::kNumConfigOperands) 214 return; 215 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 216 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 217 218 // Create LLVM constant for the descriptor set index. 219 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 220 // pass does. 221 Value descriptorSet = builder.create<LLVM::ConstantOp>( 222 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 223 224 for (auto en : 225 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 226 gpu::LaunchOp::kNumConfigOperands))) { 227 // Create LLVM constant for the descriptor binding index. 228 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 229 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 230 231 auto ptrToMemRefDescriptor = en.value(); 232 uint32_t rank = 0; 233 if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) { 234 cInterfaceVulkanLaunchCallOp.emitError() 235 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 236 return signalPassFailure(); 237 } 238 239 auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str(); 240 // Create call to `bindMemRef`. 241 builder.create<LLVM::CallOp>( 242 loc, ArrayRef<Type>{getVoidType()}, 243 builder.getSymbolRefAttr( 244 StringRef(symbolName.data(), symbolName.size())), 245 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 246 ptrToMemRefDescriptor}); 247 } 248 } 249 250 LogicalResult 251 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor, 252 uint32_t &rank) { 253 auto llvmPtrDescriptorTy = 254 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>(); 255 if (!llvmPtrDescriptorTy) 256 return failure(); 257 258 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); 259 // template <typename Elem, size_t Rank> 260 // struct { 261 // Elem *allocated; 262 // Elem *aligned; 263 // int64_t offset; 264 // int64_t sizes[Rank]; // omitted when rank == 0 265 // int64_t strides[Rank]; // omitted when rank == 0 266 // }; 267 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) 268 return failure(); 269 if (llvmDescriptorTy.getStructNumElements() == 3) { 270 rank = 0; 271 return success(); 272 } 273 274 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); 275 return success(); 276 } 277 278 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 279 ModuleOp module = getModule(); 280 OpBuilder builder(module.getBody()->getTerminator()); 281 282 if (!module.lookupSymbol(kSetEntryPoint)) { 283 builder.create<LLVM::LLVMFuncOp>( 284 loc, kSetEntryPoint, 285 LLVM::LLVMType::getFunctionTy(getVoidType(), 286 {getPointerType(), getPointerType()}, 287 /*isVarArg=*/false)); 288 } 289 290 if (!module.lookupSymbol(kSetNumWorkGroups)) { 291 builder.create<LLVM::LLVMFuncOp>( 292 loc, kSetNumWorkGroups, 293 LLVM::LLVMType::getFunctionTy( 294 getVoidType(), 295 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 296 /*isVarArg=*/false)); 297 } 298 299 if (!module.lookupSymbol(kSetBinaryShader)) { 300 builder.create<LLVM::LLVMFuncOp>( 301 loc, kSetBinaryShader, 302 LLVM::LLVMType::getFunctionTy( 303 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 304 /*isVarArg=*/false)); 305 } 306 307 if (!module.lookupSymbol(kRunOnVulkan)) { 308 builder.create<LLVM::LLVMFuncOp>( 309 loc, kRunOnVulkan, 310 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 311 /*isVarArg=*/false)); 312 } 313 314 if (!module.lookupSymbol(kBindMemRef1DFloat)) { 315 builder.create<LLVM::LLVMFuncOp>( 316 loc, kBindMemRef1DFloat, 317 LLVM::LLVMType::getFunctionTy(getVoidType(), 318 {getPointerType(), getInt32Type(), 319 getInt32Type(), 320 getMemRef1DFloat().getPointerTo()}, 321 /*isVarArg=*/false)); 322 } 323 324 if (!module.lookupSymbol(kBindMemRef2DFloat)) { 325 builder.create<LLVM::LLVMFuncOp>( 326 loc, kBindMemRef2DFloat, 327 LLVM::LLVMType::getFunctionTy(getVoidType(), 328 {getPointerType(), getInt32Type(), 329 getInt32Type(), 330 getMemRef2DFloat().getPointerTo()}, 331 /*isVarArg=*/false)); 332 } 333 334 if (!module.lookupSymbol(kInitVulkan)) { 335 builder.create<LLVM::LLVMFuncOp>( 336 loc, kInitVulkan, 337 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 338 /*isVarArg=*/false)); 339 } 340 341 if (!module.lookupSymbol(kDeinitVulkan)) { 342 builder.create<LLVM::LLVMFuncOp>( 343 loc, kDeinitVulkan, 344 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 345 /*isVarArg=*/false)); 346 } 347 } 348 349 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 350 StringRef name, Location loc, OpBuilder &builder) { 351 SmallString<16> shaderName(name.begin(), name.end()); 352 // Append `\0` to follow C style string given that LLVM::createGlobalString() 353 // won't handle this directly for us. 354 shaderName.push_back('\0'); 355 356 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 357 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 358 shaderName, LLVM::Linkage::Internal, 359 getLLVMDialect()); 360 } 361 362 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 363 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 364 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 365 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 366 // Create call to `initVulkan`. 367 auto initVulkanCall = builder.create<LLVM::CallOp>( 368 loc, ArrayRef<Type>{getPointerType()}, 369 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 370 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 371 // need to pass that pointer to each Vulkan runtime call. 372 auto vulkanRuntime = initVulkanCall.getResult(0); 373 374 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 375 // that data to runtime call. 376 Value ptrToSPIRVBinary = LLVM::createGlobalString( 377 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 378 LLVM::Linkage::Internal, getLLVMDialect()); 379 380 // Create LLVM constant for the size of SPIR-V binary shader. 381 Value binarySize = builder.create<LLVM::ConstantOp>( 382 loc, getInt32Type(), 383 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 384 385 // Create call to `bindMemRef` for each memref operand. 386 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 387 388 // Create call to `setBinaryShader` runtime function with the given pointer to 389 // SPIR-V binary and binary size. 390 builder.create<LLVM::CallOp>( 391 loc, ArrayRef<Type>{getVoidType()}, 392 builder.getSymbolRefAttr(kSetBinaryShader), 393 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 394 // Create LLVM global with entry point name. 395 Value entryPointName = createEntryPointNameConstant( 396 spirvAttributes.second.getValue(), loc, builder); 397 // Create call to `setEntryPoint` runtime function with the given pointer to 398 // entry point name. 399 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 400 builder.getSymbolRefAttr(kSetEntryPoint), 401 ArrayRef<Value>{vulkanRuntime, entryPointName}); 402 403 // Create number of local workgroup for each dimension. 404 builder.create<LLVM::CallOp>( 405 loc, ArrayRef<Type>{getVoidType()}, 406 builder.getSymbolRefAttr(kSetNumWorkGroups), 407 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 408 cInterfaceVulkanLaunchCallOp.getOperand(1), 409 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 410 411 // Create call to `runOnVulkan` runtime function. 412 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 413 builder.getSymbolRefAttr(kRunOnVulkan), 414 ArrayRef<Value>{vulkanRuntime}); 415 416 // Create call to 'deinitVulkan' runtime function. 417 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 418 builder.getSymbolRefAttr(kDeinitVulkan), 419 ArrayRef<Value>{vulkanRuntime}); 420 421 // Declare runtime functions. 422 declareVulkanFunctions(loc); 423 424 cInterfaceVulkanLaunchCallOp.erase(); 425 } 426 427 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> 428 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 429 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 430 } 431 432 static PassRegistration<VulkanLaunchFuncToVulkanCallsPass> 433 pass("launch-func-to-vulkan", 434 "Convert vulkanLaunch external call to Vulkan runtime external calls"); 435