1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "../PassDetail.h" 18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/Support/FormatVariadic.h" 27 28 using namespace mlir; 29 30 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; 31 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; 32 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat"; 33 static constexpr const char *kCInterfaceVulkanLaunch = 34 "_mlir_ciface_vulkanLaunch"; 35 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 36 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 37 static constexpr const char *kInitVulkan = "initVulkan"; 38 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 39 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 40 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 41 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 42 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 43 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 44 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 45 46 namespace { 47 48 /// A pass to convert vulkan launch call op into a sequence of Vulkan 49 /// runtime calls in the following order: 50 /// 51 /// * initVulkan -- initializes vulkan runtime 52 /// * bindMemRef -- binds memref 53 /// * setBinaryShader -- sets the binary shader data 54 /// * setEntryPoint -- sets the entry point name 55 /// * setNumWorkGroups -- sets the number of a local workgroups 56 /// * runOnVulkan -- runs vulkan runtime 57 /// * deinitVulkan -- deinitializes vulkan runtime 58 /// 59 class VulkanLaunchFuncToVulkanCallsPass 60 : public ConvertVulkanLaunchFuncToVulkanCallsBase< 61 VulkanLaunchFuncToVulkanCallsPass> { 62 private: 63 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 64 65 llvm::LLVMContext &getLLVMContext() { 66 return getLLVMDialect()->getLLVMContext(); 67 } 68 69 void initializeCachedTypes() { 70 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 71 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 72 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 73 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 74 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 75 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 76 llvmMemRef1DFloat = getMemRefType(1); 77 llvmMemRef2DFloat = getMemRefType(2); 78 llvmMemRef3DFloat = getMemRefType(3); 79 } 80 81 LLVM::LLVMType getMemRefType(uint32_t rank) { 82 // According to the MLIR doc memref argument is converted into a 83 // pointer-to-struct argument of type: 84 // template <typename Elem, size_t Rank> 85 // struct { 86 // Elem *allocated; 87 // Elem *aligned; 88 // int64_t offset; 89 // int64_t sizes[Rank]; // omitted when rank == 0 90 // int64_t strides[Rank]; // omitted when rank == 0 91 // }; 92 auto llvmPtrToFloatType = getFloatType().getPointerTo(); 93 auto llvmArrayRankElementSizeType = 94 LLVM::LLVMType::getArrayTy(getInt64Type(), rank); 95 96 // Create a type 97 // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`. 98 return LLVM::LLVMType::getStructTy( 99 llvmDialect, 100 {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), 101 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); 102 } 103 104 LLVM::LLVMType getFloatType() { return llvmFloatType; } 105 LLVM::LLVMType getVoidType() { return llvmVoidType; } 106 LLVM::LLVMType getPointerType() { return llvmPointerType; } 107 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 108 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 109 LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } 110 LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } 111 LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; } 112 113 /// Creates a LLVM global for the given `name`. 114 Value createEntryPointNameConstant(StringRef name, Location loc, 115 OpBuilder &builder); 116 117 /// Declares all needed runtime functions. 118 void declareVulkanFunctions(Location loc); 119 120 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 121 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 122 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 123 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 124 } 125 126 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 127 /// op. 128 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 129 return (callOp.callee() && 130 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 131 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 132 } 133 134 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 135 /// runtime calls. 136 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 137 138 /// Creates call to `bindMemRef` for each memref operand. 139 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 140 Value vulkanRuntime); 141 142 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 143 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 144 145 /// Deduces a rank from the given 'ptrToMemRefDescriptor`. 146 LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); 147 148 public: 149 void runOnOperation() override; 150 151 private: 152 LLVM::LLVMDialect *llvmDialect; 153 LLVM::LLVMType llvmFloatType; 154 LLVM::LLVMType llvmVoidType; 155 LLVM::LLVMType llvmPointerType; 156 LLVM::LLVMType llvmInt32Type; 157 LLVM::LLVMType llvmInt64Type; 158 LLVM::LLVMType llvmMemRef1DFloat; 159 LLVM::LLVMType llvmMemRef2DFloat; 160 LLVM::LLVMType llvmMemRef3DFloat; 161 162 // TODO: Use an associative array to support multiple vulkan launch calls. 163 std::pair<StringAttr, StringAttr> spirvAttributes; 164 /// The number of vulkan launch configuration operands, placed at the leading 165 /// positions of the operand list. 166 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3; 167 }; 168 169 } // anonymous namespace 170 171 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { 172 initializeCachedTypes(); 173 174 // Collect SPIR-V attributes such as `spirv_blob` and 175 // `spirv_entry_point_name`. 176 getOperation().walk([this](LLVM::CallOp op) { 177 if (isVulkanLaunchCallOp(op)) 178 collectSPIRVAttributes(op); 179 }); 180 181 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 182 getOperation().walk([this](LLVM::CallOp op) { 183 if (isCInterfaceVulkanLaunchCallOp(op)) 184 translateVulkanLaunchCall(op); 185 }); 186 } 187 188 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 189 LLVM::CallOp vulkanLaunchCallOp) { 190 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 191 // for the given vulkan launch call. 192 auto spirvBlobAttr = 193 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 194 if (!spirvBlobAttr) { 195 vulkanLaunchCallOp.emitError() 196 << "missing " << kSPIRVBlobAttrName << " attribute"; 197 return signalPassFailure(); 198 } 199 200 auto spirvEntryPointNameAttr = 201 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 202 if (!spirvEntryPointNameAttr) { 203 vulkanLaunchCallOp.emitError() 204 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 205 return signalPassFailure(); 206 } 207 208 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 209 } 210 211 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 212 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 213 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 214 kVulkanLaunchNumConfigOperands) 215 return; 216 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 217 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 218 219 // Create LLVM constant for the descriptor set index. 220 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 221 // pass does. 222 Value descriptorSet = builder.create<LLVM::ConstantOp>( 223 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 224 225 for (auto en : 226 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 227 kVulkanLaunchNumConfigOperands))) { 228 // Create LLVM constant for the descriptor binding index. 229 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 230 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 231 232 auto ptrToMemRefDescriptor = en.value(); 233 uint32_t rank = 0; 234 if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) { 235 cInterfaceVulkanLaunchCallOp.emitError() 236 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 237 return signalPassFailure(); 238 } 239 240 auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str(); 241 // Create call to `bindMemRef`. 242 builder.create<LLVM::CallOp>( 243 loc, ArrayRef<Type>{getVoidType()}, 244 builder.getSymbolRefAttr( 245 StringRef(symbolName.data(), symbolName.size())), 246 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 247 ptrToMemRefDescriptor}); 248 } 249 } 250 251 LogicalResult 252 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor, 253 uint32_t &rank) { 254 auto llvmPtrDescriptorTy = 255 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>(); 256 if (!llvmPtrDescriptorTy) 257 return failure(); 258 259 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); 260 // template <typename Elem, size_t Rank> 261 // struct { 262 // Elem *allocated; 263 // Elem *aligned; 264 // int64_t offset; 265 // int64_t sizes[Rank]; // omitted when rank == 0 266 // int64_t strides[Rank]; // omitted when rank == 0 267 // }; 268 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) 269 return failure(); 270 if (llvmDescriptorTy.getStructNumElements() == 3) { 271 rank = 0; 272 return success(); 273 } 274 275 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); 276 return success(); 277 } 278 279 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 280 ModuleOp module = getOperation(); 281 OpBuilder builder(module.getBody()->getTerminator()); 282 283 if (!module.lookupSymbol(kSetEntryPoint)) { 284 builder.create<LLVM::LLVMFuncOp>( 285 loc, kSetEntryPoint, 286 LLVM::LLVMType::getFunctionTy(getVoidType(), 287 {getPointerType(), getPointerType()}, 288 /*isVarArg=*/false)); 289 } 290 291 if (!module.lookupSymbol(kSetNumWorkGroups)) { 292 builder.create<LLVM::LLVMFuncOp>( 293 loc, kSetNumWorkGroups, 294 LLVM::LLVMType::getFunctionTy( 295 getVoidType(), 296 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 297 /*isVarArg=*/false)); 298 } 299 300 if (!module.lookupSymbol(kSetBinaryShader)) { 301 builder.create<LLVM::LLVMFuncOp>( 302 loc, kSetBinaryShader, 303 LLVM::LLVMType::getFunctionTy( 304 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 305 /*isVarArg=*/false)); 306 } 307 308 if (!module.lookupSymbol(kRunOnVulkan)) { 309 builder.create<LLVM::LLVMFuncOp>( 310 loc, kRunOnVulkan, 311 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 312 /*isVarArg=*/false)); 313 } 314 315 if (!module.lookupSymbol(kBindMemRef1DFloat)) { 316 builder.create<LLVM::LLVMFuncOp>( 317 loc, kBindMemRef1DFloat, 318 LLVM::LLVMType::getFunctionTy(getVoidType(), 319 {getPointerType(), getInt32Type(), 320 getInt32Type(), 321 getMemRef1DFloat().getPointerTo()}, 322 /*isVarArg=*/false)); 323 } 324 325 if (!module.lookupSymbol(kBindMemRef2DFloat)) { 326 builder.create<LLVM::LLVMFuncOp>( 327 loc, kBindMemRef2DFloat, 328 LLVM::LLVMType::getFunctionTy(getVoidType(), 329 {getPointerType(), getInt32Type(), 330 getInt32Type(), 331 getMemRef2DFloat().getPointerTo()}, 332 /*isVarArg=*/false)); 333 } 334 335 if (!module.lookupSymbol(kBindMemRef3DFloat)) { 336 builder.create<LLVM::LLVMFuncOp>( 337 loc, kBindMemRef3DFloat, 338 LLVM::LLVMType::getFunctionTy(getVoidType(), 339 {getPointerType(), getInt32Type(), 340 getInt32Type(), 341 getMemRef3DFloat().getPointerTo()}, 342 /*isVarArg=*/false)); 343 } 344 345 if (!module.lookupSymbol(kInitVulkan)) { 346 builder.create<LLVM::LLVMFuncOp>( 347 loc, kInitVulkan, 348 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 349 /*isVarArg=*/false)); 350 } 351 352 if (!module.lookupSymbol(kDeinitVulkan)) { 353 builder.create<LLVM::LLVMFuncOp>( 354 loc, kDeinitVulkan, 355 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 356 /*isVarArg=*/false)); 357 } 358 } 359 360 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 361 StringRef name, Location loc, OpBuilder &builder) { 362 SmallString<16> shaderName(name.begin(), name.end()); 363 // Append `\0` to follow C style string given that LLVM::createGlobalString() 364 // won't handle this directly for us. 365 shaderName.push_back('\0'); 366 367 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 368 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 369 shaderName, LLVM::Linkage::Internal, 370 getLLVMDialect()); 371 } 372 373 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 374 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 375 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 376 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 377 // Create call to `initVulkan`. 378 auto initVulkanCall = builder.create<LLVM::CallOp>( 379 loc, ArrayRef<Type>{getPointerType()}, 380 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 381 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 382 // need to pass that pointer to each Vulkan runtime call. 383 auto vulkanRuntime = initVulkanCall.getResult(0); 384 385 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 386 // that data to runtime call. 387 Value ptrToSPIRVBinary = LLVM::createGlobalString( 388 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 389 LLVM::Linkage::Internal, getLLVMDialect()); 390 391 // Create LLVM constant for the size of SPIR-V binary shader. 392 Value binarySize = builder.create<LLVM::ConstantOp>( 393 loc, getInt32Type(), 394 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 395 396 // Create call to `bindMemRef` for each memref operand. 397 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 398 399 // Create call to `setBinaryShader` runtime function with the given pointer to 400 // SPIR-V binary and binary size. 401 builder.create<LLVM::CallOp>( 402 loc, ArrayRef<Type>{getVoidType()}, 403 builder.getSymbolRefAttr(kSetBinaryShader), 404 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 405 // Create LLVM global with entry point name. 406 Value entryPointName = createEntryPointNameConstant( 407 spirvAttributes.second.getValue(), loc, builder); 408 // Create call to `setEntryPoint` runtime function with the given pointer to 409 // entry point name. 410 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 411 builder.getSymbolRefAttr(kSetEntryPoint), 412 ArrayRef<Value>{vulkanRuntime, entryPointName}); 413 414 // Create number of local workgroup for each dimension. 415 builder.create<LLVM::CallOp>( 416 loc, ArrayRef<Type>{getVoidType()}, 417 builder.getSymbolRefAttr(kSetNumWorkGroups), 418 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 419 cInterfaceVulkanLaunchCallOp.getOperand(1), 420 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 421 422 // Create call to `runOnVulkan` runtime function. 423 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 424 builder.getSymbolRefAttr(kRunOnVulkan), 425 ArrayRef<Value>{vulkanRuntime}); 426 427 // Create call to 'deinitVulkan' runtime function. 428 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 429 builder.getSymbolRefAttr(kDeinitVulkan), 430 ArrayRef<Value>{vulkanRuntime}); 431 432 // Declare runtime functions. 433 declareVulkanFunctions(loc); 434 435 cInterfaceVulkanLaunchCallOp.erase(); 436 } 437 438 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 439 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 440 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 441 } 442