1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 18 #include "mlir/Dialect/GPU/GPUDialect.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/Pass/Pass.h" 25 26 #include "llvm/ADT/SmallString.h" 27 #include "llvm/Support/FormatVariadic.h" 28 29 using namespace mlir; 30 31 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; 32 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; 33 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat"; 34 static constexpr const char *kCInterfaceVulkanLaunch = 35 "_mlir_ciface_vulkanLaunch"; 36 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 37 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 38 static constexpr const char *kInitVulkan = "initVulkan"; 39 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 40 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 41 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 42 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 43 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 44 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 45 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 46 47 namespace { 48 49 /// A pass to convert vulkan launch call op into a sequence of Vulkan 50 /// runtime calls in the following order: 51 /// 52 /// * initVulkan -- initializes vulkan runtime 53 /// * bindMemRef -- binds memref 54 /// * setBinaryShader -- sets the binary shader data 55 /// * setEntryPoint -- sets the entry point name 56 /// * setNumWorkGroups -- sets the number of a local workgroups 57 /// * runOnVulkan -- runs vulkan runtime 58 /// * deinitVulkan -- deinitializes vulkan runtime 59 /// 60 class VulkanLaunchFuncToVulkanCallsPass 61 : public PassWrapper<VulkanLaunchFuncToVulkanCallsPass, 62 OperationPass<ModuleOp>> { 63 private: 64 /// Include the generated pass utilities. 65 #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls 66 #include "mlir/Conversion/Passes.h.inc" 67 68 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 69 70 llvm::LLVMContext &getLLVMContext() { 71 return getLLVMDialect()->getLLVMContext(); 72 } 73 74 void initializeCachedTypes() { 75 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 76 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 77 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 78 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 79 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 80 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 81 llvmMemRef1DFloat = getMemRefType(1); 82 llvmMemRef2DFloat = getMemRefType(2); 83 llvmMemRef3DFloat = getMemRefType(3); 84 } 85 86 LLVM::LLVMType getMemRefType(uint32_t rank) { 87 // According to the MLIR doc memref argument is converted into a 88 // pointer-to-struct argument of type: 89 // template <typename Elem, size_t Rank> 90 // struct { 91 // Elem *allocated; 92 // Elem *aligned; 93 // int64_t offset; 94 // int64_t sizes[Rank]; // omitted when rank == 0 95 // int64_t strides[Rank]; // omitted when rank == 0 96 // }; 97 auto llvmPtrToFloatType = getFloatType().getPointerTo(); 98 auto llvmArrayRankElementSizeType = 99 LLVM::LLVMType::getArrayTy(getInt64Type(), rank); 100 101 // Create a type 102 // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`. 103 return LLVM::LLVMType::getStructTy( 104 llvmDialect, 105 {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), 106 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); 107 } 108 109 LLVM::LLVMType getFloatType() { return llvmFloatType; } 110 LLVM::LLVMType getVoidType() { return llvmVoidType; } 111 LLVM::LLVMType getPointerType() { return llvmPointerType; } 112 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 113 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 114 LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } 115 LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } 116 LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; } 117 118 /// Creates a LLVM global for the given `name`. 119 Value createEntryPointNameConstant(StringRef name, Location loc, 120 OpBuilder &builder); 121 122 /// Declares all needed runtime functions. 123 void declareVulkanFunctions(Location loc); 124 125 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 126 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 127 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 128 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 129 } 130 131 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 132 /// op. 133 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 134 return (callOp.callee() && 135 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 136 callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands); 137 } 138 139 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 140 /// runtime calls. 141 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 142 143 /// Creates call to `bindMemRef` for each memref operand. 144 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 145 Value vulkanRuntime); 146 147 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 148 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 149 150 /// Deduces a rank from the given 'ptrToMemRefDescriptor`. 151 LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); 152 153 public: 154 void runOnOperation() override; 155 156 private: 157 LLVM::LLVMDialect *llvmDialect; 158 LLVM::LLVMType llvmFloatType; 159 LLVM::LLVMType llvmVoidType; 160 LLVM::LLVMType llvmPointerType; 161 LLVM::LLVMType llvmInt32Type; 162 LLVM::LLVMType llvmInt64Type; 163 LLVM::LLVMType llvmMemRef1DFloat; 164 LLVM::LLVMType llvmMemRef2DFloat; 165 LLVM::LLVMType llvmMemRef3DFloat; 166 167 // TODO: Use an associative array to support multiple vulkan launch calls. 168 std::pair<StringAttr, StringAttr> spirvAttributes; 169 }; 170 171 } // anonymous namespace 172 173 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { 174 initializeCachedTypes(); 175 176 // Collect SPIR-V attributes such as `spirv_blob` and 177 // `spirv_entry_point_name`. 178 getOperation().walk([this](LLVM::CallOp op) { 179 if (isVulkanLaunchCallOp(op)) 180 collectSPIRVAttributes(op); 181 }); 182 183 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 184 getOperation().walk([this](LLVM::CallOp op) { 185 if (isCInterfaceVulkanLaunchCallOp(op)) 186 translateVulkanLaunchCall(op); 187 }); 188 } 189 190 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 191 LLVM::CallOp vulkanLaunchCallOp) { 192 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 193 // for the given vulkan launch call. 194 auto spirvBlobAttr = 195 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 196 if (!spirvBlobAttr) { 197 vulkanLaunchCallOp.emitError() 198 << "missing " << kSPIRVBlobAttrName << " attribute"; 199 return signalPassFailure(); 200 } 201 202 auto spirvEntryPointNameAttr = 203 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 204 if (!spirvEntryPointNameAttr) { 205 vulkanLaunchCallOp.emitError() 206 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 207 return signalPassFailure(); 208 } 209 210 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 211 } 212 213 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 214 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 215 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 216 gpu::LaunchOp::kNumConfigOperands) 217 return; 218 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 219 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 220 221 // Create LLVM constant for the descriptor set index. 222 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 223 // pass does. 224 Value descriptorSet = builder.create<LLVM::ConstantOp>( 225 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 226 227 for (auto en : 228 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 229 gpu::LaunchOp::kNumConfigOperands))) { 230 // Create LLVM constant for the descriptor binding index. 231 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 232 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 233 234 auto ptrToMemRefDescriptor = en.value(); 235 uint32_t rank = 0; 236 if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) { 237 cInterfaceVulkanLaunchCallOp.emitError() 238 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 239 return signalPassFailure(); 240 } 241 242 auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str(); 243 // Create call to `bindMemRef`. 244 builder.create<LLVM::CallOp>( 245 loc, ArrayRef<Type>{getVoidType()}, 246 builder.getSymbolRefAttr( 247 StringRef(symbolName.data(), symbolName.size())), 248 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 249 ptrToMemRefDescriptor}); 250 } 251 } 252 253 LogicalResult 254 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor, 255 uint32_t &rank) { 256 auto llvmPtrDescriptorTy = 257 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>(); 258 if (!llvmPtrDescriptorTy) 259 return failure(); 260 261 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); 262 // template <typename Elem, size_t Rank> 263 // struct { 264 // Elem *allocated; 265 // Elem *aligned; 266 // int64_t offset; 267 // int64_t sizes[Rank]; // omitted when rank == 0 268 // int64_t strides[Rank]; // omitted when rank == 0 269 // }; 270 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) 271 return failure(); 272 if (llvmDescriptorTy.getStructNumElements() == 3) { 273 rank = 0; 274 return success(); 275 } 276 277 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); 278 return success(); 279 } 280 281 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 282 ModuleOp module = getOperation(); 283 OpBuilder builder(module.getBody()->getTerminator()); 284 285 if (!module.lookupSymbol(kSetEntryPoint)) { 286 builder.create<LLVM::LLVMFuncOp>( 287 loc, kSetEntryPoint, 288 LLVM::LLVMType::getFunctionTy(getVoidType(), 289 {getPointerType(), getPointerType()}, 290 /*isVarArg=*/false)); 291 } 292 293 if (!module.lookupSymbol(kSetNumWorkGroups)) { 294 builder.create<LLVM::LLVMFuncOp>( 295 loc, kSetNumWorkGroups, 296 LLVM::LLVMType::getFunctionTy( 297 getVoidType(), 298 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 299 /*isVarArg=*/false)); 300 } 301 302 if (!module.lookupSymbol(kSetBinaryShader)) { 303 builder.create<LLVM::LLVMFuncOp>( 304 loc, kSetBinaryShader, 305 LLVM::LLVMType::getFunctionTy( 306 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 307 /*isVarArg=*/false)); 308 } 309 310 if (!module.lookupSymbol(kRunOnVulkan)) { 311 builder.create<LLVM::LLVMFuncOp>( 312 loc, kRunOnVulkan, 313 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 314 /*isVarArg=*/false)); 315 } 316 317 if (!module.lookupSymbol(kBindMemRef1DFloat)) { 318 builder.create<LLVM::LLVMFuncOp>( 319 loc, kBindMemRef1DFloat, 320 LLVM::LLVMType::getFunctionTy(getVoidType(), 321 {getPointerType(), getInt32Type(), 322 getInt32Type(), 323 getMemRef1DFloat().getPointerTo()}, 324 /*isVarArg=*/false)); 325 } 326 327 if (!module.lookupSymbol(kBindMemRef2DFloat)) { 328 builder.create<LLVM::LLVMFuncOp>( 329 loc, kBindMemRef2DFloat, 330 LLVM::LLVMType::getFunctionTy(getVoidType(), 331 {getPointerType(), getInt32Type(), 332 getInt32Type(), 333 getMemRef2DFloat().getPointerTo()}, 334 /*isVarArg=*/false)); 335 } 336 337 if (!module.lookupSymbol(kBindMemRef3DFloat)) { 338 builder.create<LLVM::LLVMFuncOp>( 339 loc, kBindMemRef3DFloat, 340 LLVM::LLVMType::getFunctionTy(getVoidType(), 341 {getPointerType(), getInt32Type(), 342 getInt32Type(), 343 getMemRef3DFloat().getPointerTo()}, 344 /*isVarArg=*/false)); 345 } 346 347 if (!module.lookupSymbol(kInitVulkan)) { 348 builder.create<LLVM::LLVMFuncOp>( 349 loc, kInitVulkan, 350 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 351 /*isVarArg=*/false)); 352 } 353 354 if (!module.lookupSymbol(kDeinitVulkan)) { 355 builder.create<LLVM::LLVMFuncOp>( 356 loc, kDeinitVulkan, 357 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 358 /*isVarArg=*/false)); 359 } 360 } 361 362 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 363 StringRef name, Location loc, OpBuilder &builder) { 364 SmallString<16> shaderName(name.begin(), name.end()); 365 // Append `\0` to follow C style string given that LLVM::createGlobalString() 366 // won't handle this directly for us. 367 shaderName.push_back('\0'); 368 369 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 370 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 371 shaderName, LLVM::Linkage::Internal, 372 getLLVMDialect()); 373 } 374 375 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 376 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 377 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 378 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 379 // Create call to `initVulkan`. 380 auto initVulkanCall = builder.create<LLVM::CallOp>( 381 loc, ArrayRef<Type>{getPointerType()}, 382 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 383 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 384 // need to pass that pointer to each Vulkan runtime call. 385 auto vulkanRuntime = initVulkanCall.getResult(0); 386 387 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 388 // that data to runtime call. 389 Value ptrToSPIRVBinary = LLVM::createGlobalString( 390 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 391 LLVM::Linkage::Internal, getLLVMDialect()); 392 393 // Create LLVM constant for the size of SPIR-V binary shader. 394 Value binarySize = builder.create<LLVM::ConstantOp>( 395 loc, getInt32Type(), 396 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 397 398 // Create call to `bindMemRef` for each memref operand. 399 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 400 401 // Create call to `setBinaryShader` runtime function with the given pointer to 402 // SPIR-V binary and binary size. 403 builder.create<LLVM::CallOp>( 404 loc, ArrayRef<Type>{getVoidType()}, 405 builder.getSymbolRefAttr(kSetBinaryShader), 406 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 407 // Create LLVM global with entry point name. 408 Value entryPointName = createEntryPointNameConstant( 409 spirvAttributes.second.getValue(), loc, builder); 410 // Create call to `setEntryPoint` runtime function with the given pointer to 411 // entry point name. 412 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 413 builder.getSymbolRefAttr(kSetEntryPoint), 414 ArrayRef<Value>{vulkanRuntime, entryPointName}); 415 416 // Create number of local workgroup for each dimension. 417 builder.create<LLVM::CallOp>( 418 loc, ArrayRef<Type>{getVoidType()}, 419 builder.getSymbolRefAttr(kSetNumWorkGroups), 420 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 421 cInterfaceVulkanLaunchCallOp.getOperand(1), 422 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 423 424 // Create call to `runOnVulkan` runtime function. 425 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 426 builder.getSymbolRefAttr(kRunOnVulkan), 427 ArrayRef<Value>{vulkanRuntime}); 428 429 // Create call to 'deinitVulkan' runtime function. 430 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 431 builder.getSymbolRefAttr(kDeinitVulkan), 432 ArrayRef<Value>{vulkanRuntime}); 433 434 // Declare runtime functions. 435 declareVulkanFunctions(loc); 436 437 cInterfaceVulkanLaunchCallOp.erase(); 438 } 439 440 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 441 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 442 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 443 } 444