1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "../PassDetail.h" 18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/Support/FormatVariadic.h" 27 28 using namespace mlir; 29 30 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; 31 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; 32 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat"; 33 static constexpr const char *kBindMemRef1DInt = "bindMemRef1DInt"; 34 static constexpr const char *kBindMemRef2DInt = "bindMemRef2DInt"; 35 static constexpr const char *kBindMemRef3DInt = "bindMemRef3DInt"; 36 static constexpr const char *kCInterfaceVulkanLaunch = 37 "_mlir_ciface_vulkanLaunch"; 38 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 39 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 40 static constexpr const char *kInitVulkan = "initVulkan"; 41 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 42 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 43 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 44 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 45 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 46 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 47 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 48 49 namespace { 50 51 /// A pass to convert vulkan launch call op into a sequence of Vulkan 52 /// runtime calls in the following order: 53 /// 54 /// * initVulkan -- initializes vulkan runtime 55 /// * bindMemRef -- binds memref 56 /// * setBinaryShader -- sets the binary shader data 57 /// * setEntryPoint -- sets the entry point name 58 /// * setNumWorkGroups -- sets the number of a local workgroups 59 /// * runOnVulkan -- runs vulkan runtime 60 /// * deinitVulkan -- deinitializes vulkan runtime 61 /// 62 class VulkanLaunchFuncToVulkanCallsPass 63 : public ConvertVulkanLaunchFuncToVulkanCallsBase< 64 VulkanLaunchFuncToVulkanCallsPass> { 65 private: 66 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 67 68 llvm::LLVMContext &getLLVMContext() { 69 return getLLVMDialect()->getLLVMContext(); 70 } 71 72 void initializeCachedTypes() { 73 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 74 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 75 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 76 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 77 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 78 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 79 llvmMemRef1DFloat = getMemRefType(1, llvmFloatType); 80 llvmMemRef2DFloat = getMemRefType(2, llvmFloatType); 81 llvmMemRef3DFloat = getMemRefType(3, llvmFloatType); 82 llvmMemRef1DInt = getMemRefType(1, llvmInt32Type); 83 llvmMemRef2DInt = getMemRefType(2, llvmInt32Type); 84 llvmMemRef3DInt = getMemRefType(3, llvmInt32Type); 85 } 86 87 LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { 88 // According to the MLIR doc memref argument is converted into a 89 // pointer-to-struct argument of type: 90 // template <typename Elem, size_t Rank> 91 // struct { 92 // Elem *allocated; 93 // Elem *aligned; 94 // int64_t offset; 95 // int64_t sizes[Rank]; // omitted when rank == 0 96 // int64_t strides[Rank]; // omitted when rank == 0 97 // }; 98 auto llvmPtrToElementType = elemenType.getPointerTo(); 99 auto llvmArrayRankElementSizeType = 100 LLVM::LLVMType::getArrayTy(getInt64Type(), rank); 101 102 // Create a type 103 // `!llvm<"{ `element-type`*, `element-type`*, i64, 104 // [`rank` x i64], [`rank` x i64]}">`. 105 return LLVM::LLVMType::getStructTy( 106 llvmDialect, 107 {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(), 108 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); 109 } 110 111 LLVM::LLVMType getFloatType() { return llvmFloatType; } 112 LLVM::LLVMType getVoidType() { return llvmVoidType; } 113 LLVM::LLVMType getPointerType() { return llvmPointerType; } 114 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 115 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 116 LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } 117 LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } 118 LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; } 119 LLVM::LLVMType getMemRef1DInt() { return llvmMemRef1DInt; } 120 LLVM::LLVMType getMemRef2DInt() { return llvmMemRef2DInt; } 121 LLVM::LLVMType getMemRef3DInt() { return llvmMemRef3DInt; } 122 123 /// Creates a LLVM global for the given `name`. 124 Value createEntryPointNameConstant(StringRef name, Location loc, 125 OpBuilder &builder); 126 127 /// Declares all needed runtime functions. 128 void declareVulkanFunctions(Location loc); 129 130 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 131 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 132 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 133 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 134 } 135 136 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 137 /// op. 138 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 139 return (callOp.callee() && 140 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 141 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 142 } 143 144 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 145 /// runtime calls. 146 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 147 148 /// Creates call to `bindMemRef` for each memref operand. 149 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 150 Value vulkanRuntime); 151 152 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 153 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 154 155 /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`. 156 LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor, 157 uint32_t &rank, LLVM::LLVMType &type); 158 159 /// Returns a string representation from the given `type`. 160 StringRef stringifyType(LLVM::LLVMType type) { 161 if (type.isFloatTy()) 162 return "Float"; 163 if (type.isIntegerTy()) 164 return "Int"; 165 166 llvm_unreachable("unsupported type"); 167 } 168 169 public: 170 void runOnOperation() override; 171 172 private: 173 LLVM::LLVMDialect *llvmDialect; 174 LLVM::LLVMType llvmFloatType; 175 LLVM::LLVMType llvmVoidType; 176 LLVM::LLVMType llvmPointerType; 177 LLVM::LLVMType llvmInt32Type; 178 LLVM::LLVMType llvmInt64Type; 179 LLVM::LLVMType llvmMemRef1DFloat; 180 LLVM::LLVMType llvmMemRef2DFloat; 181 LLVM::LLVMType llvmMemRef3DFloat; 182 LLVM::LLVMType llvmMemRef1DInt; 183 LLVM::LLVMType llvmMemRef2DInt; 184 LLVM::LLVMType llvmMemRef3DInt; 185 186 // TODO: Use an associative array to support multiple vulkan launch calls. 187 std::pair<StringAttr, StringAttr> spirvAttributes; 188 /// The number of vulkan launch configuration operands, placed at the leading 189 /// positions of the operand list. 190 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3; 191 }; 192 193 } // anonymous namespace 194 195 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { 196 initializeCachedTypes(); 197 198 // Collect SPIR-V attributes such as `spirv_blob` and 199 // `spirv_entry_point_name`. 200 getOperation().walk([this](LLVM::CallOp op) { 201 if (isVulkanLaunchCallOp(op)) 202 collectSPIRVAttributes(op); 203 }); 204 205 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 206 getOperation().walk([this](LLVM::CallOp op) { 207 if (isCInterfaceVulkanLaunchCallOp(op)) 208 translateVulkanLaunchCall(op); 209 }); 210 } 211 212 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 213 LLVM::CallOp vulkanLaunchCallOp) { 214 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 215 // for the given vulkan launch call. 216 auto spirvBlobAttr = 217 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 218 if (!spirvBlobAttr) { 219 vulkanLaunchCallOp.emitError() 220 << "missing " << kSPIRVBlobAttrName << " attribute"; 221 return signalPassFailure(); 222 } 223 224 auto spirvEntryPointNameAttr = 225 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 226 if (!spirvEntryPointNameAttr) { 227 vulkanLaunchCallOp.emitError() 228 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 229 return signalPassFailure(); 230 } 231 232 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 233 } 234 235 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 236 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 237 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 238 kVulkanLaunchNumConfigOperands) 239 return; 240 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 241 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 242 243 // Create LLVM constant for the descriptor set index. 244 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 245 // pass does. 246 Value descriptorSet = builder.create<LLVM::ConstantOp>( 247 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 248 249 for (auto en : 250 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 251 kVulkanLaunchNumConfigOperands))) { 252 // Create LLVM constant for the descriptor binding index. 253 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 254 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 255 256 auto ptrToMemRefDescriptor = en.value(); 257 uint32_t rank = 0; 258 LLVM::LLVMType type; 259 if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) { 260 cInterfaceVulkanLaunchCallOp.emitError() 261 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 262 return signalPassFailure(); 263 } 264 265 auto symbolName = 266 llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); 267 // Create call to `bindMemRef`. 268 builder.create<LLVM::CallOp>( 269 loc, ArrayRef<Type>{getVoidType()}, 270 builder.getSymbolRefAttr( 271 StringRef(symbolName.data(), symbolName.size())), 272 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 273 ptrToMemRefDescriptor}); 274 } 275 } 276 277 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( 278 Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) { 279 auto llvmPtrDescriptorTy = 280 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>(); 281 if (!llvmPtrDescriptorTy) 282 return failure(); 283 284 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); 285 // template <typename Elem, size_t Rank> 286 // struct { 287 // Elem *allocated; 288 // Elem *aligned; 289 // int64_t offset; 290 // int64_t sizes[Rank]; // omitted when rank == 0 291 // int64_t strides[Rank]; // omitted when rank == 0 292 // }; 293 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) 294 return failure(); 295 296 type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy(); 297 if (llvmDescriptorTy.getStructNumElements() == 3) { 298 rank = 0; 299 return success(); 300 } 301 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); 302 return success(); 303 } 304 305 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 306 ModuleOp module = getOperation(); 307 OpBuilder builder(module.getBody()->getTerminator()); 308 309 if (!module.lookupSymbol(kSetEntryPoint)) { 310 builder.create<LLVM::LLVMFuncOp>( 311 loc, kSetEntryPoint, 312 LLVM::LLVMType::getFunctionTy(getVoidType(), 313 {getPointerType(), getPointerType()}, 314 /*isVarArg=*/false)); 315 } 316 317 if (!module.lookupSymbol(kSetNumWorkGroups)) { 318 builder.create<LLVM::LLVMFuncOp>( 319 loc, kSetNumWorkGroups, 320 LLVM::LLVMType::getFunctionTy( 321 getVoidType(), 322 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 323 /*isVarArg=*/false)); 324 } 325 326 if (!module.lookupSymbol(kSetBinaryShader)) { 327 builder.create<LLVM::LLVMFuncOp>( 328 loc, kSetBinaryShader, 329 LLVM::LLVMType::getFunctionTy( 330 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 331 /*isVarArg=*/false)); 332 } 333 334 if (!module.lookupSymbol(kRunOnVulkan)) { 335 builder.create<LLVM::LLVMFuncOp>( 336 loc, kRunOnVulkan, 337 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 338 /*isVarArg=*/false)); 339 } 340 341 #define CREATE_VULKAN_BIND_FUNC(MemRefType) \ 342 if (!module.lookupSymbol(kBind##MemRefType)) { \ 343 builder.create<LLVM::LLVMFuncOp>( \ 344 loc, kBind##MemRefType, \ 345 LLVM::LLVMType::getFunctionTy(getVoidType(), \ 346 {getPointerType(), getInt32Type(), \ 347 getInt32Type(), \ 348 get##MemRefType().getPointerTo()}, \ 349 /*isVarArg=*/false)); \ 350 } 351 352 CREATE_VULKAN_BIND_FUNC(MemRef1DFloat); 353 CREATE_VULKAN_BIND_FUNC(MemRef2DFloat); 354 CREATE_VULKAN_BIND_FUNC(MemRef3DFloat); 355 CREATE_VULKAN_BIND_FUNC(MemRef1DInt); 356 CREATE_VULKAN_BIND_FUNC(MemRef2DInt); 357 CREATE_VULKAN_BIND_FUNC(MemRef3DInt); 358 359 if (!module.lookupSymbol(kInitVulkan)) { 360 builder.create<LLVM::LLVMFuncOp>( 361 loc, kInitVulkan, 362 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 363 /*isVarArg=*/false)); 364 } 365 366 if (!module.lookupSymbol(kDeinitVulkan)) { 367 builder.create<LLVM::LLVMFuncOp>( 368 loc, kDeinitVulkan, 369 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 370 /*isVarArg=*/false)); 371 } 372 } 373 374 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 375 StringRef name, Location loc, OpBuilder &builder) { 376 SmallString<16> shaderName(name.begin(), name.end()); 377 // Append `\0` to follow C style string given that LLVM::createGlobalString() 378 // won't handle this directly for us. 379 shaderName.push_back('\0'); 380 381 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 382 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 383 shaderName, LLVM::Linkage::Internal, 384 getLLVMDialect()); 385 } 386 387 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 388 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 389 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 390 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 391 // Create call to `initVulkan`. 392 auto initVulkanCall = builder.create<LLVM::CallOp>( 393 loc, ArrayRef<Type>{getPointerType()}, 394 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 395 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 396 // need to pass that pointer to each Vulkan runtime call. 397 auto vulkanRuntime = initVulkanCall.getResult(0); 398 399 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 400 // that data to runtime call. 401 Value ptrToSPIRVBinary = LLVM::createGlobalString( 402 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 403 LLVM::Linkage::Internal, getLLVMDialect()); 404 405 // Create LLVM constant for the size of SPIR-V binary shader. 406 Value binarySize = builder.create<LLVM::ConstantOp>( 407 loc, getInt32Type(), 408 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 409 410 // Create call to `bindMemRef` for each memref operand. 411 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 412 413 // Create call to `setBinaryShader` runtime function with the given pointer to 414 // SPIR-V binary and binary size. 415 builder.create<LLVM::CallOp>( 416 loc, ArrayRef<Type>{getVoidType()}, 417 builder.getSymbolRefAttr(kSetBinaryShader), 418 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 419 // Create LLVM global with entry point name. 420 Value entryPointName = createEntryPointNameConstant( 421 spirvAttributes.second.getValue(), loc, builder); 422 // Create call to `setEntryPoint` runtime function with the given pointer to 423 // entry point name. 424 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 425 builder.getSymbolRefAttr(kSetEntryPoint), 426 ArrayRef<Value>{vulkanRuntime, entryPointName}); 427 428 // Create number of local workgroup for each dimension. 429 builder.create<LLVM::CallOp>( 430 loc, ArrayRef<Type>{getVoidType()}, 431 builder.getSymbolRefAttr(kSetNumWorkGroups), 432 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 433 cInterfaceVulkanLaunchCallOp.getOperand(1), 434 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 435 436 // Create call to `runOnVulkan` runtime function. 437 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 438 builder.getSymbolRefAttr(kRunOnVulkan), 439 ArrayRef<Value>{vulkanRuntime}); 440 441 // Create call to 'deinitVulkan' runtime function. 442 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 443 builder.getSymbolRefAttr(kDeinitVulkan), 444 ArrayRef<Value>{vulkanRuntime}); 445 446 // Declare runtime functions. 447 declareVulkanFunctions(loc); 448 449 cInterfaceVulkanLaunchCallOp.erase(); 450 } 451 452 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 453 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 454 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 455 } 456