1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 18 #include "mlir/Dialect/GPU/GPUDialect.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/Pass/Pass.h" 25 26 #include "llvm/ADT/SmallString.h" 27 #include "llvm/Support/FormatVariadic.h" 28 29 using namespace mlir; 30 31 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; 32 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; 33 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat"; 34 static constexpr const char *kCInterfaceVulkanLaunch = 35 "_mlir_ciface_vulkanLaunch"; 36 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 37 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 38 static constexpr const char *kInitVulkan = "initVulkan"; 39 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 40 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 41 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 42 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 43 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 44 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 45 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 46 47 namespace { 48 49 /// A pass to convert vulkan launch call op into a sequence of Vulkan 50 /// runtime calls in the following order: 51 /// 52 /// * initVulkan -- initializes vulkan runtime 53 /// * bindMemRef -- binds memref 54 /// * setBinaryShader -- sets the binary shader data 55 /// * setEntryPoint -- sets the entry point name 56 /// * setNumWorkGroups -- sets the number of a local workgroups 57 /// * runOnVulkan -- runs vulkan runtime 58 /// * deinitVulkan -- deinitializes vulkan runtime 59 /// 60 class VulkanLaunchFuncToVulkanCallsPass 61 : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> { 62 private: 63 /// Include the generated pass utilities. 64 #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls 65 #include "mlir/Conversion/Passes.h.inc" 66 67 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 68 69 llvm::LLVMContext &getLLVMContext() { 70 return getLLVMDialect()->getLLVMContext(); 71 } 72 73 void initializeCachedTypes() { 74 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 75 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 76 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 77 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 78 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 79 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 80 llvmMemRef1DFloat = getMemRefType(1); 81 llvmMemRef2DFloat = getMemRefType(2); 82 llvmMemRef3DFloat = getMemRefType(3); 83 } 84 85 LLVM::LLVMType getMemRefType(uint32_t rank) { 86 // According to the MLIR doc memref argument is converted into a 87 // pointer-to-struct argument of type: 88 // template <typename Elem, size_t Rank> 89 // struct { 90 // Elem *allocated; 91 // Elem *aligned; 92 // int64_t offset; 93 // int64_t sizes[Rank]; // omitted when rank == 0 94 // int64_t strides[Rank]; // omitted when rank == 0 95 // }; 96 auto llvmPtrToFloatType = getFloatType().getPointerTo(); 97 auto llvmArrayRankElementSizeType = 98 LLVM::LLVMType::getArrayTy(getInt64Type(), rank); 99 100 // Create a type 101 // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`. 102 return LLVM::LLVMType::getStructTy( 103 llvmDialect, 104 {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), 105 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); 106 } 107 108 LLVM::LLVMType getFloatType() { return llvmFloatType; } 109 LLVM::LLVMType getVoidType() { return llvmVoidType; } 110 LLVM::LLVMType getPointerType() { return llvmPointerType; } 111 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 112 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 113 LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } 114 LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } 115 LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; } 116 117 /// Creates a LLVM global for the given `name`. 118 Value createEntryPointNameConstant(StringRef name, Location loc, 119 OpBuilder &builder); 120 121 /// Declares all needed runtime functions. 122 void declareVulkanFunctions(Location loc); 123 124 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 125 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 126 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 127 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 128 } 129 130 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 131 /// op. 132 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 133 return (callOp.callee() && 134 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 135 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 136 } 137 138 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 139 /// runtime calls. 140 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 141 142 /// Creates call to `bindMemRef` for each memref operand. 143 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 144 Value vulkanRuntime); 145 146 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 147 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 148 149 /// Deduces a rank from the given 'ptrToMemRefDescriptor`. 150 LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); 151 152 public: 153 void runOnModule() override; 154 155 private: 156 LLVM::LLVMDialect *llvmDialect; 157 LLVM::LLVMType llvmFloatType; 158 LLVM::LLVMType llvmVoidType; 159 LLVM::LLVMType llvmPointerType; 160 LLVM::LLVMType llvmInt32Type; 161 LLVM::LLVMType llvmInt64Type; 162 LLVM::LLVMType llvmMemRef1DFloat; 163 LLVM::LLVMType llvmMemRef2DFloat; 164 LLVM::LLVMType llvmMemRef3DFloat; 165 166 // TODO: Use an associative array to support multiple vulkan launch calls. 167 std::pair<StringAttr, StringAttr> spirvAttributes; 168 }; 169 170 } // anonymous namespace 171 172 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { 173 initializeCachedTypes(); 174 175 // Collect SPIR-V attributes such as `spirv_blob` and 176 // `spirv_entry_point_name`. 177 getModule().walk([this](LLVM::CallOp op) { 178 if (isVulkanLaunchCallOp(op)) 179 collectSPIRVAttributes(op); 180 }); 181 182 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 183 getModule().walk([this](LLVM::CallOp op) { 184 if (isCInterfaceVulkanLaunchCallOp(op)) 185 translateVulkanLaunchCall(op); 186 }); 187 } 188 189 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 190 LLVM::CallOp vulkanLaunchCallOp) { 191 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 192 // for the given vulkan launch call. 193 auto spirvBlobAttr = 194 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 195 if (!spirvBlobAttr) { 196 vulkanLaunchCallOp.emitError() 197 << "missing " << kSPIRVBlobAttrName << " attribute"; 198 return signalPassFailure(); 199 } 200 201 auto spirvEntryPointNameAttr = 202 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 203 if (!spirvEntryPointNameAttr) { 204 vulkanLaunchCallOp.emitError() 205 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 206 return signalPassFailure(); 207 } 208 209 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 210 } 211 212 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 213 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 214 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 215 gpu::LaunchOp::kNumConfigOperands) 216 return; 217 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 218 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 219 220 // Create LLVM constant for the descriptor set index. 221 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 222 // pass does. 223 Value descriptorSet = builder.create<LLVM::ConstantOp>( 224 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 225 226 for (auto en : 227 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 228 gpu::LaunchOp::kNumConfigOperands))) { 229 // Create LLVM constant for the descriptor binding index. 230 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 231 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 232 233 auto ptrToMemRefDescriptor = en.value(); 234 uint32_t rank = 0; 235 if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) { 236 cInterfaceVulkanLaunchCallOp.emitError() 237 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 238 return signalPassFailure(); 239 } 240 241 auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str(); 242 // Create call to `bindMemRef`. 243 builder.create<LLVM::CallOp>( 244 loc, ArrayRef<Type>{getVoidType()}, 245 builder.getSymbolRefAttr( 246 StringRef(symbolName.data(), symbolName.size())), 247 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 248 ptrToMemRefDescriptor}); 249 } 250 } 251 252 LogicalResult 253 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor, 254 uint32_t &rank) { 255 auto llvmPtrDescriptorTy = 256 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>(); 257 if (!llvmPtrDescriptorTy) 258 return failure(); 259 260 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); 261 // template <typename Elem, size_t Rank> 262 // struct { 263 // Elem *allocated; 264 // Elem *aligned; 265 // int64_t offset; 266 // int64_t sizes[Rank]; // omitted when rank == 0 267 // int64_t strides[Rank]; // omitted when rank == 0 268 // }; 269 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) 270 return failure(); 271 if (llvmDescriptorTy.getStructNumElements() == 3) { 272 rank = 0; 273 return success(); 274 } 275 276 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); 277 return success(); 278 } 279 280 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 281 ModuleOp module = getModule(); 282 OpBuilder builder(module.getBody()->getTerminator()); 283 284 if (!module.lookupSymbol(kSetEntryPoint)) { 285 builder.create<LLVM::LLVMFuncOp>( 286 loc, kSetEntryPoint, 287 LLVM::LLVMType::getFunctionTy(getVoidType(), 288 {getPointerType(), getPointerType()}, 289 /*isVarArg=*/false)); 290 } 291 292 if (!module.lookupSymbol(kSetNumWorkGroups)) { 293 builder.create<LLVM::LLVMFuncOp>( 294 loc, kSetNumWorkGroups, 295 LLVM::LLVMType::getFunctionTy( 296 getVoidType(), 297 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 298 /*isVarArg=*/false)); 299 } 300 301 if (!module.lookupSymbol(kSetBinaryShader)) { 302 builder.create<LLVM::LLVMFuncOp>( 303 loc, kSetBinaryShader, 304 LLVM::LLVMType::getFunctionTy( 305 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 306 /*isVarArg=*/false)); 307 } 308 309 if (!module.lookupSymbol(kRunOnVulkan)) { 310 builder.create<LLVM::LLVMFuncOp>( 311 loc, kRunOnVulkan, 312 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 313 /*isVarArg=*/false)); 314 } 315 316 if (!module.lookupSymbol(kBindMemRef1DFloat)) { 317 builder.create<LLVM::LLVMFuncOp>( 318 loc, kBindMemRef1DFloat, 319 LLVM::LLVMType::getFunctionTy(getVoidType(), 320 {getPointerType(), getInt32Type(), 321 getInt32Type(), 322 getMemRef1DFloat().getPointerTo()}, 323 /*isVarArg=*/false)); 324 } 325 326 if (!module.lookupSymbol(kBindMemRef2DFloat)) { 327 builder.create<LLVM::LLVMFuncOp>( 328 loc, kBindMemRef2DFloat, 329 LLVM::LLVMType::getFunctionTy(getVoidType(), 330 {getPointerType(), getInt32Type(), 331 getInt32Type(), 332 getMemRef2DFloat().getPointerTo()}, 333 /*isVarArg=*/false)); 334 } 335 336 if (!module.lookupSymbol(kBindMemRef3DFloat)) { 337 builder.create<LLVM::LLVMFuncOp>( 338 loc, kBindMemRef3DFloat, 339 LLVM::LLVMType::getFunctionTy(getVoidType(), 340 {getPointerType(), getInt32Type(), 341 getInt32Type(), 342 getMemRef3DFloat().getPointerTo()}, 343 /*isVarArg=*/false)); 344 } 345 346 if (!module.lookupSymbol(kInitVulkan)) { 347 builder.create<LLVM::LLVMFuncOp>( 348 loc, kInitVulkan, 349 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 350 /*isVarArg=*/false)); 351 } 352 353 if (!module.lookupSymbol(kDeinitVulkan)) { 354 builder.create<LLVM::LLVMFuncOp>( 355 loc, kDeinitVulkan, 356 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 357 /*isVarArg=*/false)); 358 } 359 } 360 361 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 362 StringRef name, Location loc, OpBuilder &builder) { 363 SmallString<16> shaderName(name.begin(), name.end()); 364 // Append `\0` to follow C style string given that LLVM::createGlobalString() 365 // won't handle this directly for us. 366 shaderName.push_back('\0'); 367 368 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 369 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 370 shaderName, LLVM::Linkage::Internal, 371 getLLVMDialect()); 372 } 373 374 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 375 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 376 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 377 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 378 // Create call to `initVulkan`. 379 auto initVulkanCall = builder.create<LLVM::CallOp>( 380 loc, ArrayRef<Type>{getPointerType()}, 381 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 382 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 383 // need to pass that pointer to each Vulkan runtime call. 384 auto vulkanRuntime = initVulkanCall.getResult(0); 385 386 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 387 // that data to runtime call. 388 Value ptrToSPIRVBinary = LLVM::createGlobalString( 389 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 390 LLVM::Linkage::Internal, getLLVMDialect()); 391 392 // Create LLVM constant for the size of SPIR-V binary shader. 393 Value binarySize = builder.create<LLVM::ConstantOp>( 394 loc, getInt32Type(), 395 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 396 397 // Create call to `bindMemRef` for each memref operand. 398 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 399 400 // Create call to `setBinaryShader` runtime function with the given pointer to 401 // SPIR-V binary and binary size. 402 builder.create<LLVM::CallOp>( 403 loc, ArrayRef<Type>{getVoidType()}, 404 builder.getSymbolRefAttr(kSetBinaryShader), 405 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 406 // Create LLVM global with entry point name. 407 Value entryPointName = createEntryPointNameConstant( 408 spirvAttributes.second.getValue(), loc, builder); 409 // Create call to `setEntryPoint` runtime function with the given pointer to 410 // entry point name. 411 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 412 builder.getSymbolRefAttr(kSetEntryPoint), 413 ArrayRef<Value>{vulkanRuntime, entryPointName}); 414 415 // Create number of local workgroup for each dimension. 416 builder.create<LLVM::CallOp>( 417 loc, ArrayRef<Type>{getVoidType()}, 418 builder.getSymbolRefAttr(kSetNumWorkGroups), 419 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 420 cInterfaceVulkanLaunchCallOp.getOperand(1), 421 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 422 423 // Create call to `runOnVulkan` runtime function. 424 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 425 builder.getSymbolRefAttr(kRunOnVulkan), 426 ArrayRef<Value>{vulkanRuntime}); 427 428 // Create call to 'deinitVulkan' runtime function. 429 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 430 builder.getSymbolRefAttr(kDeinitVulkan), 431 ArrayRef<Value>{vulkanRuntime}); 432 433 // Declare runtime functions. 434 declareVulkanFunctions(loc); 435 436 cInterfaceVulkanLaunchCallOp.erase(); 437 } 438 439 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> 440 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 441 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 442 } 443