1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "../PassDetail.h" 18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/Support/FormatVariadic.h" 27 28 using namespace mlir; 29 30 static constexpr const char *kCInterfaceVulkanLaunch = 31 "_mlir_ciface_vulkanLaunch"; 32 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 33 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 34 static constexpr const char *kInitVulkan = "initVulkan"; 35 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 36 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 37 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 38 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 39 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 40 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 41 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 42 43 namespace { 44 45 /// A pass to convert vulkan launch call op into a sequence of Vulkan 46 /// runtime calls in the following order: 47 /// 48 /// * initVulkan -- initializes vulkan runtime 49 /// * bindMemRef -- binds memref 50 /// * setBinaryShader -- sets the binary shader data 51 /// * setEntryPoint -- sets the entry point name 52 /// * setNumWorkGroups -- sets the number of a local workgroups 53 /// * runOnVulkan -- runs vulkan runtime 54 /// * deinitVulkan -- deinitializes vulkan runtime 55 /// 56 class VulkanLaunchFuncToVulkanCallsPass 57 : public ConvertVulkanLaunchFuncToVulkanCallsBase< 58 VulkanLaunchFuncToVulkanCallsPass> { 59 private: 60 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 61 62 llvm::LLVMContext &getLLVMContext() { 63 return getLLVMDialect()->getLLVMContext(); 64 } 65 66 void initializeCachedTypes() { 67 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 68 llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); 69 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 70 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 71 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 72 llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); 73 } 74 75 LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { 76 // According to the MLIR doc memref argument is converted into a 77 // pointer-to-struct argument of type: 78 // template <typename Elem, size_t Rank> 79 // struct { 80 // Elem *allocated; 81 // Elem *aligned; 82 // int64_t offset; 83 // int64_t sizes[Rank]; // omitted when rank == 0 84 // int64_t strides[Rank]; // omitted when rank == 0 85 // }; 86 auto llvmPtrToElementType = elemenType.getPointerTo(); 87 auto llvmArrayRankElementSizeType = 88 LLVM::LLVMType::getArrayTy(getInt64Type(), rank); 89 90 // Create a type 91 // `!llvm<"{ `element-type`*, `element-type`*, i64, 92 // [`rank` x i64], [`rank` x i64]}">`. 93 return LLVM::LLVMType::getStructTy( 94 llvmDialect, 95 {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(), 96 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); 97 } 98 99 LLVM::LLVMType getVoidType() { return llvmVoidType; } 100 LLVM::LLVMType getPointerType() { return llvmPointerType; } 101 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 102 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 103 104 /// Creates a LLVM global for the given `name`. 105 Value createEntryPointNameConstant(StringRef name, Location loc, 106 OpBuilder &builder); 107 108 /// Declares all needed runtime functions. 109 void declareVulkanFunctions(Location loc); 110 111 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 112 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 113 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 114 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 115 } 116 117 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 118 /// op. 119 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 120 return (callOp.callee() && 121 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 122 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 123 } 124 125 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 126 /// runtime calls. 127 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 128 129 /// Creates call to `bindMemRef` for each memref operand. 130 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 131 Value vulkanRuntime); 132 133 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 134 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 135 136 /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`. 137 LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor, 138 uint32_t &rank, LLVM::LLVMType &type); 139 140 /// Returns a string representation from the given `type`. 141 StringRef stringifyType(LLVM::LLVMType type) { 142 if (type.isFloatTy()) 143 return "Float"; 144 if (type.isHalfTy()) 145 return "Half"; 146 if (type.isIntegerTy(32)) 147 return "Int32"; 148 if (type.isIntegerTy(16)) 149 return "Int16"; 150 if (type.isIntegerTy(8)) 151 return "Int8"; 152 153 llvm_unreachable("unsupported type"); 154 } 155 156 public: 157 void runOnOperation() override; 158 159 private: 160 LLVM::LLVMDialect *llvmDialect; 161 LLVM::LLVMType llvmFloatType; 162 LLVM::LLVMType llvmVoidType; 163 LLVM::LLVMType llvmPointerType; 164 LLVM::LLVMType llvmInt32Type; 165 LLVM::LLVMType llvmInt64Type; 166 167 // TODO: Use an associative array to support multiple vulkan launch calls. 168 std::pair<StringAttr, StringAttr> spirvAttributes; 169 /// The number of vulkan launch configuration operands, placed at the leading 170 /// positions of the operand list. 171 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3; 172 }; 173 174 } // anonymous namespace 175 176 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { 177 initializeCachedTypes(); 178 179 // Collect SPIR-V attributes such as `spirv_blob` and 180 // `spirv_entry_point_name`. 181 getOperation().walk([this](LLVM::CallOp op) { 182 if (isVulkanLaunchCallOp(op)) 183 collectSPIRVAttributes(op); 184 }); 185 186 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 187 getOperation().walk([this](LLVM::CallOp op) { 188 if (isCInterfaceVulkanLaunchCallOp(op)) 189 translateVulkanLaunchCall(op); 190 }); 191 } 192 193 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 194 LLVM::CallOp vulkanLaunchCallOp) { 195 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 196 // for the given vulkan launch call. 197 auto spirvBlobAttr = 198 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 199 if (!spirvBlobAttr) { 200 vulkanLaunchCallOp.emitError() 201 << "missing " << kSPIRVBlobAttrName << " attribute"; 202 return signalPassFailure(); 203 } 204 205 auto spirvEntryPointNameAttr = 206 vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 207 if (!spirvEntryPointNameAttr) { 208 vulkanLaunchCallOp.emitError() 209 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 210 return signalPassFailure(); 211 } 212 213 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 214 } 215 216 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 217 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 218 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 219 kVulkanLaunchNumConfigOperands) 220 return; 221 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 222 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 223 224 // Create LLVM constant for the descriptor set index. 225 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 226 // pass does. 227 Value descriptorSet = builder.create<LLVM::ConstantOp>( 228 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 229 230 for (auto en : 231 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 232 kVulkanLaunchNumConfigOperands))) { 233 // Create LLVM constant for the descriptor binding index. 234 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 235 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 236 237 auto ptrToMemRefDescriptor = en.value(); 238 uint32_t rank = 0; 239 LLVM::LLVMType type; 240 if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) { 241 cInterfaceVulkanLaunchCallOp.emitError() 242 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 243 return signalPassFailure(); 244 } 245 246 auto symbolName = 247 llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); 248 // Special case for fp16 type. Since it is not a supported type in C we use 249 // int16_t and bitcast the descriptor. 250 if (type.isHalfTy()) { 251 auto memRefTy = 252 getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect)); 253 ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>( 254 loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor); 255 } 256 // Create call to `bindMemRef`. 257 builder.create<LLVM::CallOp>( 258 loc, ArrayRef<Type>{getVoidType()}, 259 builder.getSymbolRefAttr( 260 StringRef(symbolName.data(), symbolName.size())), 261 ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, 262 ptrToMemRefDescriptor}); 263 } 264 } 265 266 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( 267 Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) { 268 auto llvmPtrDescriptorTy = 269 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>(); 270 if (!llvmPtrDescriptorTy) 271 return failure(); 272 273 auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); 274 // template <typename Elem, size_t Rank> 275 // struct { 276 // Elem *allocated; 277 // Elem *aligned; 278 // int64_t offset; 279 // int64_t sizes[Rank]; // omitted when rank == 0 280 // int64_t strides[Rank]; // omitted when rank == 0 281 // }; 282 if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) 283 return failure(); 284 285 type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy(); 286 if (llvmDescriptorTy.getStructNumElements() == 3) { 287 rank = 0; 288 return success(); 289 } 290 rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); 291 return success(); 292 } 293 294 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 295 ModuleOp module = getOperation(); 296 OpBuilder builder(module.getBody()->getTerminator()); 297 298 if (!module.lookupSymbol(kSetEntryPoint)) { 299 builder.create<LLVM::LLVMFuncOp>( 300 loc, kSetEntryPoint, 301 LLVM::LLVMType::getFunctionTy(getVoidType(), 302 {getPointerType(), getPointerType()}, 303 /*isVarArg=*/false)); 304 } 305 306 if (!module.lookupSymbol(kSetNumWorkGroups)) { 307 builder.create<LLVM::LLVMFuncOp>( 308 loc, kSetNumWorkGroups, 309 LLVM::LLVMType::getFunctionTy( 310 getVoidType(), 311 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 312 /*isVarArg=*/false)); 313 } 314 315 if (!module.lookupSymbol(kSetBinaryShader)) { 316 builder.create<LLVM::LLVMFuncOp>( 317 loc, kSetBinaryShader, 318 LLVM::LLVMType::getFunctionTy( 319 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 320 /*isVarArg=*/false)); 321 } 322 323 if (!module.lookupSymbol(kRunOnVulkan)) { 324 builder.create<LLVM::LLVMFuncOp>( 325 loc, kRunOnVulkan, 326 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 327 /*isVarArg=*/false)); 328 } 329 330 for (unsigned i = 1; i <= 3; i++) { 331 for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect), 332 LLVM::LLVMType::getInt32Ty(llvmDialect), 333 LLVM::LLVMType::getInt16Ty(llvmDialect), 334 LLVM::LLVMType::getInt8Ty(llvmDialect), 335 LLVM::LLVMType::getHalfTy(llvmDialect)}) { 336 std::string fnName = "bindMemRef" + std::to_string(i) + "D" + 337 std::string(stringifyType(type)); 338 if (type.isHalfTy()) 339 type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect)); 340 if (!module.lookupSymbol(fnName)) { 341 auto fnType = LLVM::LLVMType::getFunctionTy( 342 getVoidType(), 343 {getPointerType(), getInt32Type(), getInt32Type(), 344 getMemRefType(i, type).getPointerTo()}, 345 /*isVarArg=*/false); 346 builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType); 347 } 348 } 349 } 350 351 if (!module.lookupSymbol(kInitVulkan)) { 352 builder.create<LLVM::LLVMFuncOp>( 353 loc, kInitVulkan, 354 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 355 /*isVarArg=*/false)); 356 } 357 358 if (!module.lookupSymbol(kDeinitVulkan)) { 359 builder.create<LLVM::LLVMFuncOp>( 360 loc, kDeinitVulkan, 361 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 362 /*isVarArg=*/false)); 363 } 364 } 365 366 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 367 StringRef name, Location loc, OpBuilder &builder) { 368 SmallString<16> shaderName(name.begin(), name.end()); 369 // Append `\0` to follow C style string given that LLVM::createGlobalString() 370 // won't handle this directly for us. 371 shaderName.push_back('\0'); 372 373 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 374 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 375 shaderName, LLVM::Linkage::Internal, 376 getLLVMDialect()); 377 } 378 379 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 380 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 381 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 382 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 383 // Create call to `initVulkan`. 384 auto initVulkanCall = builder.create<LLVM::CallOp>( 385 loc, ArrayRef<Type>{getPointerType()}, 386 builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); 387 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 388 // need to pass that pointer to each Vulkan runtime call. 389 auto vulkanRuntime = initVulkanCall.getResult(0); 390 391 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 392 // that data to runtime call. 393 Value ptrToSPIRVBinary = LLVM::createGlobalString( 394 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 395 LLVM::Linkage::Internal, getLLVMDialect()); 396 397 // Create LLVM constant for the size of SPIR-V binary shader. 398 Value binarySize = builder.create<LLVM::ConstantOp>( 399 loc, getInt32Type(), 400 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 401 402 // Create call to `bindMemRef` for each memref operand. 403 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 404 405 // Create call to `setBinaryShader` runtime function with the given pointer to 406 // SPIR-V binary and binary size. 407 builder.create<LLVM::CallOp>( 408 loc, ArrayRef<Type>{getVoidType()}, 409 builder.getSymbolRefAttr(kSetBinaryShader), 410 ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 411 // Create LLVM global with entry point name. 412 Value entryPointName = createEntryPointNameConstant( 413 spirvAttributes.second.getValue(), loc, builder); 414 // Create call to `setEntryPoint` runtime function with the given pointer to 415 // entry point name. 416 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 417 builder.getSymbolRefAttr(kSetEntryPoint), 418 ArrayRef<Value>{vulkanRuntime, entryPointName}); 419 420 // Create number of local workgroup for each dimension. 421 builder.create<LLVM::CallOp>( 422 loc, ArrayRef<Type>{getVoidType()}, 423 builder.getSymbolRefAttr(kSetNumWorkGroups), 424 ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 425 cInterfaceVulkanLaunchCallOp.getOperand(1), 426 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 427 428 // Create call to `runOnVulkan` runtime function. 429 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 430 builder.getSymbolRefAttr(kRunOnVulkan), 431 ArrayRef<Value>{vulkanRuntime}); 432 433 // Create call to 'deinitVulkan' runtime function. 434 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 435 builder.getSymbolRefAttr(kDeinitVulkan), 436 ArrayRef<Value>{vulkanRuntime}); 437 438 // Declare runtime functions. 439 declareVulkanFunctions(loc); 440 441 cInterfaceVulkanLaunchCallOp.erase(); 442 } 443 444 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 445 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 446 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 447 } 448