1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 18 #include "mlir/Dialect/GPU/GPUDialect.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/Pass/Pass.h" 25 26 #include "llvm/ADT/SmallString.h" 27 #include "llvm/Support/FormatVariadic.h" 28 29 using namespace mlir; 30 31 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; 32 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; 33 static constexpr const char *kCInterfaceVulkanLaunch = 34 "_mlir_ciface_vulkanLaunch"; 35 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 36 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 37 static constexpr const char *kInitVulkan = "initVulkan"; 38 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 39 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 40 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 41 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 42 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 43 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 44 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 45 46 namespace { 47 48 /// A pass to convert vulkan launch call op into a sequence of Vulkan 49 /// runtime calls in the following order: 50 /// 51 /// * initVulkan -- initializes vulkan runtime 52 /// * bindMemRef -- binds memref 53 /// * setBinaryShader -- sets the binary shader data 54 /// * setEntryPoint -- sets the entry point name 55 /// * setNumWorkGroups -- sets the number of a local workgroups 56 /// * runOnVulkan -- runs vulkan runtime 57 /// * deinitVulkan -- deinitializes vulkan runtime 58 /// 59 class VulkanLaunchFuncToVulkanCallsPass 60 : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> { 61 private: 62 /// Include the generated pass utilities. 63 #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls 64 #include "mlir/Conversion/Passes.h.inc" 65 66 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 67 68 llvm::LLVMContext &getLLVMContext() { 69 return getLLVMDialect()->getLLVMContext(); 70 } 71 72 void initializeCachedTypes() { 73 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 74 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 75 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 76 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 77 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 78 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 79 initializeMemRefTypes(); 80 } 81 82 void initializeMemRefTypes() { 83 // According to the MLIR doc memref argument is converted into a 84 // pointer-to-struct argument of type: 85 // template <typename Elem, size_t Rank> 86 // struct { 87 // Elem *allocated; 88 // Elem *aligned; 89 // int64_t offset; 90 // int64_t sizes[Rank]; // omitted when rank == 0 91 // int64_t strides[Rank]; // omitted when rank == 0 92 // }; 93 auto llvmPtrToFloatType = getFloatType().getPointerTo(); 94 auto llvmArrayOneElementSizeType = 95 LLVM::LLVMType::getArrayTy(getInt64Type(), 1); 96 auto llvmArrayTwoElementSizeType = 97 LLVM::LLVMType::getArrayTy(getInt64Type(), 2); 98 99 // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`. 100 llvmMemRef1DFloat = LLVM::LLVMType::getStructTy( 101 llvmDialect, 102 {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), 103 llvmArrayOneElementSizeType, llvmArrayOneElementSizeType}); 104 105 // Create a type `!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64]}">`. 106 llvmMemRef2DFloat = LLVM::LLVMType::getStructTy( 107 llvmDialect, 108 {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), 109 llvmArrayTwoElementSizeType, llvmArrayTwoElementSizeType}); 110 } 111 112 LLVM::LLVMType getFloatType() { return llvmFloatType; } 113 LLVM::LLVMType getVoidType() { return llvmVoidType; } 114 LLVM::LLVMType getPointerType() { return llvmPointerType; } 115 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 116 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 117 LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } 118 LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } 119 120 /// Creates a LLVM global for the given `name`. 121 Value createEntryPointNameConstant(StringRef name, Location loc, 122 OpBuilder &builder); 123 124 /// Declares all needed runtime functions. 125 void declareVulkanFunctions(Location loc); 126 127 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 128 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 129 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 130 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 131 } 132 133 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 134 /// op. 135 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 136 return (callOp.callee() && 137 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 138 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 139 } 140 141 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 142 /// runtime calls. 143 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 144 145 /// Creates call to `bindMemRef` for each memref operand. 146 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 147 Value vulkanRuntime); 148 149 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 150 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 151 152 /// Deduces a rank from the given 'ptrToMemRefDescriptor`. 153 LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); 154 155 public: 156 void runOnModule() override; 157 158 private: 159 LLVM::LLVMDialect *llvmDialect; 160 LLVM::LLVMType llvmFloatType; 161 LLVM::LLVMType llvmVoidType; 162 LLVM::LLVMType llvmPointerType; 163 LLVM::LLVMType llvmInt32Type; 164 LLVM::LLVMType llvmInt64Type; 165 LLVM::LLVMType llvmMemRef1DFloat; 166 LLVM::LLVMType llvmMemRef2DFloat; 167 168 // TODO: Use an associative array to support multiple vulkan launch calls. 169 std::pair<StringAttr, StringAttr> spirvAttributes; 170 }; 171 172 } // anonymous namespace 173 174 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { 175 initializeCachedTypes(); 176 177 // Collect SPIR-V attributes such as `spirv_blob` and 178 // `spirv_entry_point_name`. 179 getModule().walk([this](LLVM::CallOp op) { 180 if (isVulkanLaunchCallOp(op)) 181 collectSPIRVAttributes(op); 182 }); 183 184 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 185 getModule().walk([this](LLVM::CallOp op) { 186 if (isCInterfaceVulkanLaunchCallOp(op)) 187 translateVulkanLaunchCall(op); 188 }); 189 } 190 191 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 192 LLVM::CallOp vulkanLaunchCallOp) { 193 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 194 // for the given vulkan launch call. 195 auto spirvBlobAttr = 196 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 197 if (!spirvBlobAttr) { 198 vulkanLaunchCallOp.emitError() 199 << "missing " << kSPIRVBlobAttrName << " attribute"; 200 return signalPassFailure(); 201 } 202 203 auto spirvEntryPointNameAttr = 204 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 205 if (!spirvEntryPointNameAttr) { 206 vulkanLaunchCallOp.emitError() 207 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 208 return signalPassFailure(); 209 } 210 211 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 212 } 213 214 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 215 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 216 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 217 gpu::LaunchOp::kNumConfigOperands) 218 return; 219 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 220 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 221 222 // Create LLVM constant for the descriptor set index. 223 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 224 // pass does. 225 Value descriptorSet = builder.create<LLVM::ConstantOp>( 226 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 227 228 for (auto en : 229 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 230 gpu::LaunchOp::kNumConfigOperands))) { 231 // Create LLVM constant for the descriptor binding index. 232 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 233 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 234 235 auto ptrToMemRefDescriptor = en.value(); 236 uint32_t rank = 0; 237 if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) { 238 cInterfaceVulkanLaunchCallOp.emitError() 239 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 240 return signalPassFailure(); 241 } 242 243 auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str(); 244 // Create call to `bindMemRef`. 245 builder.create<LLVM::CallOp>( 246 loc, ArrayRef<Type>{getVoidType()}, 247 builder.getSymbolRefAttr( 248 StringRef(symbolName.data(), symbolName.size())), 249 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 250 ptrToMemRefDescriptor}); 251 } 252 } 253 254 LogicalResult 255 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor, 256 uint32_t &rank) { 257 auto llvmPtrDescriptorTy = 258 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>(); 259 if (!llvmPtrDescriptorTy) 260 return failure(); 261 262 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); 263 // template <typename Elem, size_t Rank> 264 // struct { 265 // Elem *allocated; 266 // Elem *aligned; 267 // int64_t offset; 268 // int64_t sizes[Rank]; // omitted when rank == 0 269 // int64_t strides[Rank]; // omitted when rank == 0 270 // }; 271 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) 272 return failure(); 273 if (llvmDescriptorTy.getStructNumElements() == 3) { 274 rank = 0; 275 return success(); 276 } 277 278 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); 279 return success(); 280 } 281 282 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 283 ModuleOp module = getModule(); 284 OpBuilder builder(module.getBody()->getTerminator()); 285 286 if (!module.lookupSymbol(kSetEntryPoint)) { 287 builder.create<LLVM::LLVMFuncOp>( 288 loc, kSetEntryPoint, 289 LLVM::LLVMType::getFunctionTy(getVoidType(), 290 {getPointerType(), getPointerType()}, 291 /*isVarArg=*/false)); 292 } 293 294 if (!module.lookupSymbol(kSetNumWorkGroups)) { 295 builder.create<LLVM::LLVMFuncOp>( 296 loc, kSetNumWorkGroups, 297 LLVM::LLVMType::getFunctionTy( 298 getVoidType(), 299 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 300 /*isVarArg=*/false)); 301 } 302 303 if (!module.lookupSymbol(kSetBinaryShader)) { 304 builder.create<LLVM::LLVMFuncOp>( 305 loc, kSetBinaryShader, 306 LLVM::LLVMType::getFunctionTy( 307 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 308 /*isVarArg=*/false)); 309 } 310 311 if (!module.lookupSymbol(kRunOnVulkan)) { 312 builder.create<LLVM::LLVMFuncOp>( 313 loc, kRunOnVulkan, 314 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 315 /*isVarArg=*/false)); 316 } 317 318 if (!module.lookupSymbol(kBindMemRef1DFloat)) { 319 builder.create<LLVM::LLVMFuncOp>( 320 loc, kBindMemRef1DFloat, 321 LLVM::LLVMType::getFunctionTy(getVoidType(), 322 {getPointerType(), getInt32Type(), 323 getInt32Type(), 324 getMemRef1DFloat().getPointerTo()}, 325 /*isVarArg=*/false)); 326 } 327 328 if (!module.lookupSymbol(kBindMemRef2DFloat)) { 329 builder.create<LLVM::LLVMFuncOp>( 330 loc, kBindMemRef2DFloat, 331 LLVM::LLVMType::getFunctionTy(getVoidType(), 332 {getPointerType(), getInt32Type(), 333 getInt32Type(), 334 getMemRef2DFloat().getPointerTo()}, 335 /*isVarArg=*/false)); 336 } 337 338 if (!module.lookupSymbol(kInitVulkan)) { 339 builder.create<LLVM::LLVMFuncOp>( 340 loc, kInitVulkan, 341 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 342 /*isVarArg=*/false)); 343 } 344 345 if (!module.lookupSymbol(kDeinitVulkan)) { 346 builder.create<LLVM::LLVMFuncOp>( 347 loc, kDeinitVulkan, 348 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 349 /*isVarArg=*/false)); 350 } 351 } 352 353 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 354 StringRef name, Location loc, OpBuilder &builder) { 355 SmallString<16> shaderName(name.begin(), name.end()); 356 // Append `\0` to follow C style string given that LLVM::createGlobalString() 357 // won't handle this directly for us. 358 shaderName.push_back('\0'); 359 360 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 361 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 362 shaderName, LLVM::Linkage::Internal, 363 getLLVMDialect()); 364 } 365 366 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 367 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 368 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 369 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 370 // Create call to `initVulkan`. 371 auto initVulkanCall = builder.create<LLVM::CallOp>( 372 loc, ArrayRef<Type>{getPointerType()}, 373 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 374 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 375 // need to pass that pointer to each Vulkan runtime call. 376 auto vulkanRuntime = initVulkanCall.getResult(0); 377 378 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 379 // that data to runtime call. 380 Value ptrToSPIRVBinary = LLVM::createGlobalString( 381 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 382 LLVM::Linkage::Internal, getLLVMDialect()); 383 384 // Create LLVM constant for the size of SPIR-V binary shader. 385 Value binarySize = builder.create<LLVM::ConstantOp>( 386 loc, getInt32Type(), 387 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 388 389 // Create call to `bindMemRef` for each memref operand. 390 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 391 392 // Create call to `setBinaryShader` runtime function with the given pointer to 393 // SPIR-V binary and binary size. 394 builder.create<LLVM::CallOp>( 395 loc, ArrayRef<Type>{getVoidType()}, 396 builder.getSymbolRefAttr(kSetBinaryShader), 397 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 398 // Create LLVM global with entry point name. 399 Value entryPointName = createEntryPointNameConstant( 400 spirvAttributes.second.getValue(), loc, builder); 401 // Create call to `setEntryPoint` runtime function with the given pointer to 402 // entry point name. 403 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 404 builder.getSymbolRefAttr(kSetEntryPoint), 405 ArrayRef<Value>{vulkanRuntime, entryPointName}); 406 407 // Create number of local workgroup for each dimension. 408 builder.create<LLVM::CallOp>( 409 loc, ArrayRef<Type>{getVoidType()}, 410 builder.getSymbolRefAttr(kSetNumWorkGroups), 411 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 412 cInterfaceVulkanLaunchCallOp.getOperand(1), 413 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 414 415 // Create call to `runOnVulkan` runtime function. 416 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 417 builder.getSymbolRefAttr(kRunOnVulkan), 418 ArrayRef<Value>{vulkanRuntime}); 419 420 // Create call to 'deinitVulkan' runtime function. 421 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 422 builder.getSymbolRefAttr(kDeinitVulkan), 423 ArrayRef<Value>{vulkanRuntime}); 424 425 // Declare runtime functions. 426 declareVulkanFunctions(loc); 427 428 cInterfaceVulkanLaunchCallOp.erase(); 429 } 430 431 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> 432 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 433 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 434 } 435