1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "../PassDetail.h" 18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/Support/FormatVariadic.h" 27 28 using namespace mlir; 29 30 static constexpr const char *kCInterfaceVulkanLaunch = 31 "_mlir_ciface_vulkanLaunch"; 32 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 33 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 34 static constexpr const char *kInitVulkan = "initVulkan"; 35 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 36 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 37 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 38 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 39 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 40 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 41 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 42 43 namespace { 44 45 /// A pass to convert vulkan launch call op into a sequence of Vulkan 46 /// runtime calls in the following order: 47 /// 48 /// * initVulkan -- initializes vulkan runtime 49 /// * bindMemRef -- binds memref 50 /// * setBinaryShader -- sets the binary shader data 51 /// * setEntryPoint -- sets the entry point name 52 /// * setNumWorkGroups -- sets the number of a local workgroups 53 /// * runOnVulkan -- runs vulkan runtime 54 /// * deinitVulkan -- deinitializes vulkan runtime 55 /// 56 class VulkanLaunchFuncToVulkanCallsPass 57 : public ConvertVulkanLaunchFuncToVulkanCallsBase< 58 VulkanLaunchFuncToVulkanCallsPass> { 59 private: 60 void initializeCachedTypes() { 61 llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext()); 62 llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext()); 63 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext()); 64 llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext()); 65 llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext()); 66 } 67 68 LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { 69 // According to the MLIR doc memref argument is converted into a 70 // pointer-to-struct argument of type: 71 // template <typename Elem, size_t Rank> 72 // struct { 73 // Elem *allocated; 74 // Elem *aligned; 75 // int64_t offset; 76 // int64_t sizes[Rank]; // omitted when rank == 0 77 // int64_t strides[Rank]; // omitted when rank == 0 78 // }; 79 auto llvmPtrToElementType = elemenType.getPointerTo(); 80 auto llvmArrayRankElementSizeType = 81 LLVM::LLVMType::getArrayTy(getInt64Type(), rank); 82 83 // Create a type 84 // `!llvm<"{ `element-type`*, `element-type`*, i64, 85 // [`rank` x i64], [`rank` x i64]}">`. 86 return LLVM::LLVMType::getStructTy( 87 &getContext(), 88 {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(), 89 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); 90 } 91 92 LLVM::LLVMType getVoidType() { return llvmVoidType; } 93 LLVM::LLVMType getPointerType() { return llvmPointerType; } 94 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 95 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 96 97 /// Creates an LLVM global for the given `name`. 98 Value createEntryPointNameConstant(StringRef name, Location loc, 99 OpBuilder &builder); 100 101 /// Declares all needed runtime functions. 102 void declareVulkanFunctions(Location loc); 103 104 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 105 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 106 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 107 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 108 } 109 110 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 111 /// op. 112 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 113 return (callOp.callee() && 114 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 115 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 116 } 117 118 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 119 /// runtime calls. 120 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 121 122 /// Creates call to `bindMemRef` for each memref operand. 123 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 124 Value vulkanRuntime); 125 126 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 127 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 128 129 /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`. 130 LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor, 131 uint32_t &rank, LLVM::LLVMType &type); 132 133 /// Returns a string representation from the given `type`. 134 StringRef stringifyType(LLVM::LLVMType type) { 135 if (type.isFloatTy()) 136 return "Float"; 137 if (type.isHalfTy()) 138 return "Half"; 139 if (type.isIntegerTy(32)) 140 return "Int32"; 141 if (type.isIntegerTy(16)) 142 return "Int16"; 143 if (type.isIntegerTy(8)) 144 return "Int8"; 145 146 llvm_unreachable("unsupported type"); 147 } 148 149 public: 150 void runOnOperation() override; 151 152 private: 153 LLVM::LLVMType llvmFloatType; 154 LLVM::LLVMType llvmVoidType; 155 LLVM::LLVMType llvmPointerType; 156 LLVM::LLVMType llvmInt32Type; 157 LLVM::LLVMType llvmInt64Type; 158 159 // TODO: Use an associative array to support multiple vulkan launch calls. 160 std::pair<StringAttr, StringAttr> spirvAttributes; 161 /// The number of vulkan launch configuration operands, placed at the leading 162 /// positions of the operand list. 163 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3; 164 }; 165 166 } // anonymous namespace 167 168 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { 169 initializeCachedTypes(); 170 171 // Collect SPIR-V attributes such as `spirv_blob` and 172 // `spirv_entry_point_name`. 173 getOperation().walk([this](LLVM::CallOp op) { 174 if (isVulkanLaunchCallOp(op)) 175 collectSPIRVAttributes(op); 176 }); 177 178 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 179 getOperation().walk([this](LLVM::CallOp op) { 180 if (isCInterfaceVulkanLaunchCallOp(op)) 181 translateVulkanLaunchCall(op); 182 }); 183 } 184 185 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 186 LLVM::CallOp vulkanLaunchCallOp) { 187 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 188 // for the given vulkan launch call. 189 auto spirvBlobAttr = 190 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 191 if (!spirvBlobAttr) { 192 vulkanLaunchCallOp.emitError() 193 << "missing " << kSPIRVBlobAttrName << " attribute"; 194 return signalPassFailure(); 195 } 196 197 auto spirvEntryPointNameAttr = 198 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 199 if (!spirvEntryPointNameAttr) { 200 vulkanLaunchCallOp.emitError() 201 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 202 return signalPassFailure(); 203 } 204 205 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 206 } 207 208 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 209 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 210 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 211 kVulkanLaunchNumConfigOperands) 212 return; 213 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 214 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 215 216 // Create LLVM constant for the descriptor set index. 217 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 218 // pass does. 219 Value descriptorSet = builder.create<LLVM::ConstantOp>( 220 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 221 222 for (auto en : 223 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 224 kVulkanLaunchNumConfigOperands))) { 225 // Create LLVM constant for the descriptor binding index. 226 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 227 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 228 229 auto ptrToMemRefDescriptor = en.value(); 230 uint32_t rank = 0; 231 LLVM::LLVMType type; 232 if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) { 233 cInterfaceVulkanLaunchCallOp.emitError() 234 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 235 return signalPassFailure(); 236 } 237 238 auto symbolName = 239 llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); 240 // Special case for fp16 type. Since it is not a supported type in C we use 241 // int16_t and bitcast the descriptor. 242 if (type.isHalfTy()) { 243 auto memRefTy = 244 getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext())); 245 ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>( 246 loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor); 247 } 248 // Create call to `bindMemRef`. 249 builder.create<LLVM::CallOp>( 250 loc, ArrayRef<Type>{getVoidType()}, 251 builder.getSymbolRefAttr( 252 StringRef(symbolName.data(), symbolName.size())), 253 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 254 ptrToMemRefDescriptor}); 255 } 256 } 257 258 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( 259 Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) { 260 auto llvmPtrDescriptorTy = 261 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>(); 262 if (!llvmPtrDescriptorTy) 263 return failure(); 264 265 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); 266 // template <typename Elem, size_t Rank> 267 // struct { 268 // Elem *allocated; 269 // Elem *aligned; 270 // int64_t offset; 271 // int64_t sizes[Rank]; // omitted when rank == 0 272 // int64_t strides[Rank]; // omitted when rank == 0 273 // }; 274 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) 275 return failure(); 276 277 type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy(); 278 if (llvmDescriptorTy.getStructNumElements() == 3) { 279 rank = 0; 280 return success(); 281 } 282 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); 283 return success(); 284 } 285 286 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 287 ModuleOp module = getOperation(); 288 OpBuilder builder(module.getBody()->getTerminator()); 289 290 if (!module.lookupSymbol(kSetEntryPoint)) { 291 builder.create<LLVM::LLVMFuncOp>( 292 loc, kSetEntryPoint, 293 LLVM::LLVMType::getFunctionTy(getVoidType(), 294 {getPointerType(), getPointerType()}, 295 /*isVarArg=*/false)); 296 } 297 298 if (!module.lookupSymbol(kSetNumWorkGroups)) { 299 builder.create<LLVM::LLVMFuncOp>( 300 loc, kSetNumWorkGroups, 301 LLVM::LLVMType::getFunctionTy( 302 getVoidType(), 303 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 304 /*isVarArg=*/false)); 305 } 306 307 if (!module.lookupSymbol(kSetBinaryShader)) { 308 builder.create<LLVM::LLVMFuncOp>( 309 loc, kSetBinaryShader, 310 LLVM::LLVMType::getFunctionTy( 311 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 312 /*isVarArg=*/false)); 313 } 314 315 if (!module.lookupSymbol(kRunOnVulkan)) { 316 builder.create<LLVM::LLVMFuncOp>( 317 loc, kRunOnVulkan, 318 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 319 /*isVarArg=*/false)); 320 } 321 322 for (unsigned i = 1; i <= 3; i++) { 323 for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()), 324 LLVM::LLVMType::getInt32Ty(&getContext()), 325 LLVM::LLVMType::getInt16Ty(&getContext()), 326 LLVM::LLVMType::getInt8Ty(&getContext()), 327 LLVM::LLVMType::getHalfTy(&getContext())}) { 328 std::string fnName = "bindMemRef" + std::to_string(i) + "D" + 329 std::string(stringifyType(type)); 330 if (type.isHalfTy()) 331 type = LLVM::LLVMType::getInt16Ty(&getContext()); 332 if (!module.lookupSymbol(fnName)) { 333 auto fnType = LLVM::LLVMType::getFunctionTy( 334 getVoidType(), 335 {getPointerType(), getInt32Type(), getInt32Type(), 336 getMemRefType(i, type).getPointerTo()}, 337 /*isVarArg=*/false); 338 builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType); 339 } 340 } 341 } 342 343 if (!module.lookupSymbol(kInitVulkan)) { 344 builder.create<LLVM::LLVMFuncOp>( 345 loc, kInitVulkan, 346 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 347 /*isVarArg=*/false)); 348 } 349 350 if (!module.lookupSymbol(kDeinitVulkan)) { 351 builder.create<LLVM::LLVMFuncOp>( 352 loc, kDeinitVulkan, 353 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 354 /*isVarArg=*/false)); 355 } 356 } 357 358 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 359 StringRef name, Location loc, OpBuilder &builder) { 360 SmallString<16> shaderName(name.begin(), name.end()); 361 // Append `\0` to follow C style string given that LLVM::createGlobalString() 362 // won't handle this directly for us. 363 shaderName.push_back('\0'); 364 365 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 366 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 367 shaderName, LLVM::Linkage::Internal); 368 } 369 370 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 371 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 372 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 373 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 374 // Create call to `initVulkan`. 375 auto initVulkanCall = builder.create<LLVM::CallOp>( 376 loc, ArrayRef<Type>{getPointerType()}, 377 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 378 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 379 // need to pass that pointer to each Vulkan runtime call. 380 auto vulkanRuntime = initVulkanCall.getResult(0); 381 382 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 383 // that data to runtime call. 384 Value ptrToSPIRVBinary = LLVM::createGlobalString( 385 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 386 LLVM::Linkage::Internal); 387 388 // Create LLVM constant for the size of SPIR-V binary shader. 389 Value binarySize = builder.create<LLVM::ConstantOp>( 390 loc, getInt32Type(), 391 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 392 393 // Create call to `bindMemRef` for each memref operand. 394 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 395 396 // Create call to `setBinaryShader` runtime function with the given pointer to 397 // SPIR-V binary and binary size. 398 builder.create<LLVM::CallOp>( 399 loc, ArrayRef<Type>{getVoidType()}, 400 builder.getSymbolRefAttr(kSetBinaryShader), 401 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 402 // Create LLVM global with entry point name. 403 Value entryPointName = createEntryPointNameConstant( 404 spirvAttributes.second.getValue(), loc, builder); 405 // Create call to `setEntryPoint` runtime function with the given pointer to 406 // entry point name. 407 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 408 builder.getSymbolRefAttr(kSetEntryPoint), 409 ArrayRef<Value>{vulkanRuntime, entryPointName}); 410 411 // Create number of local workgroup for each dimension. 412 builder.create<LLVM::CallOp>( 413 loc, ArrayRef<Type>{getVoidType()}, 414 builder.getSymbolRefAttr(kSetNumWorkGroups), 415 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 416 cInterfaceVulkanLaunchCallOp.getOperand(1), 417 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 418 419 // Create call to `runOnVulkan` runtime function. 420 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 421 builder.getSymbolRefAttr(kRunOnVulkan), 422 ArrayRef<Value>{vulkanRuntime}); 423 424 // Create call to 'deinitVulkan' runtime function. 425 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 426 builder.getSymbolRefAttr(kDeinitVulkan), 427 ArrayRef<Value>{vulkanRuntime}); 428 429 // Declare runtime functions. 430 declareVulkanFunctions(loc); 431 432 cInterfaceVulkanLaunchCallOp.erase(); 433 } 434 435 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 436 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 437 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 438 } 439