1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "../PassDetail.h" 18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/BuiltinOps.h" 23 24 #include "llvm/ADT/SmallString.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 using namespace mlir; 28 29 static constexpr const char *kCInterfaceVulkanLaunch = 30 "_mlir_ciface_vulkanLaunch"; 31 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 32 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 33 static constexpr const char *kInitVulkan = "initVulkan"; 34 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 35 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 37 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 38 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 39 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 40 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 41 42 namespace { 43 44 /// A pass to convert vulkan launch call op into a sequence of Vulkan 45 /// runtime calls in the following order: 46 /// 47 /// * initVulkan -- initializes vulkan runtime 48 /// * bindMemRef -- binds memref 49 /// * setBinaryShader -- sets the binary shader data 50 /// * setEntryPoint -- sets the entry point name 51 /// * setNumWorkGroups -- sets the number of a local workgroups 52 /// * runOnVulkan -- runs vulkan runtime 53 /// * deinitVulkan -- deinitializes vulkan runtime 54 /// 55 class VulkanLaunchFuncToVulkanCallsPass 56 : public ConvertVulkanLaunchFuncToVulkanCallsBase< 57 VulkanLaunchFuncToVulkanCallsPass> { 58 private: 59 void initializeCachedTypes() { 60 llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext()); 61 llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext()); 62 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext()); 63 llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext()); 64 llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext()); 65 } 66 67 LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { 68 // According to the MLIR doc memref argument is converted into a 69 // pointer-to-struct argument of type: 70 // template <typename Elem, size_t Rank> 71 // struct { 72 // Elem *allocated; 73 // Elem *aligned; 74 // int64_t offset; 75 // int64_t sizes[Rank]; // omitted when rank == 0 76 // int64_t strides[Rank]; // omitted when rank == 0 77 // }; 78 auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType); 79 auto llvmArrayRankElementSizeType = 80 LLVM::LLVMType::getArrayTy(getInt64Type(), rank); 81 82 // Create a type 83 // `!llvm<"{ `element-type`*, `element-type`*, i64, 84 // [`rank` x i64], [`rank` x i64]}">`. 85 return LLVM::LLVMType::getStructTy( 86 &getContext(), 87 {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(), 88 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); 89 } 90 91 LLVM::LLVMType getVoidType() { return llvmVoidType; } 92 LLVM::LLVMType getPointerType() { return llvmPointerType; } 93 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 94 LLVM::LLVMType getInt64Type() { return llvmInt64Type; } 95 96 /// Creates an LLVM global for the given `name`. 97 Value createEntryPointNameConstant(StringRef name, Location loc, 98 OpBuilder &builder); 99 100 /// Declares all needed runtime functions. 101 void declareVulkanFunctions(Location loc); 102 103 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 104 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 105 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 106 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 107 } 108 109 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 110 /// op. 111 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 112 return (callOp.callee() && 113 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 114 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 115 } 116 117 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 118 /// runtime calls. 119 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 120 121 /// Creates call to `bindMemRef` for each memref operand. 122 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 123 Value vulkanRuntime); 124 125 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 126 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 127 128 /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`. 129 LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor, 130 uint32_t &rank, LLVM::LLVMType &type); 131 132 /// Returns a string representation from the given `type`. 133 StringRef stringifyType(LLVM::LLVMType type) { 134 if (type.isa<LLVM::LLVMFloatType>()) 135 return "Float"; 136 if (type.isa<LLVM::LLVMHalfType>()) 137 return "Half"; 138 if (auto intType = type.dyn_cast<LLVM::LLVMIntegerType>()) { 139 if (intType.getBitWidth() == 32) 140 return "Int32"; 141 if (intType.getBitWidth() == 16) 142 return "Int16"; 143 if (intType.getBitWidth() == 8) 144 return "Int8"; 145 } 146 147 llvm_unreachable("unsupported type"); 148 } 149 150 public: 151 void runOnOperation() override; 152 153 private: 154 LLVM::LLVMType llvmFloatType; 155 LLVM::LLVMType llvmVoidType; 156 LLVM::LLVMType llvmPointerType; 157 LLVM::LLVMType llvmInt32Type; 158 LLVM::LLVMType llvmInt64Type; 159 160 // TODO: Use an associative array to support multiple vulkan launch calls. 161 std::pair<StringAttr, StringAttr> spirvAttributes; 162 /// The number of vulkan launch configuration operands, placed at the leading 163 /// positions of the operand list. 164 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3; 165 }; 166 167 } // anonymous namespace 168 169 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { 170 initializeCachedTypes(); 171 172 // Collect SPIR-V attributes such as `spirv_blob` and 173 // `spirv_entry_point_name`. 174 getOperation().walk([this](LLVM::CallOp op) { 175 if (isVulkanLaunchCallOp(op)) 176 collectSPIRVAttributes(op); 177 }); 178 179 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 180 getOperation().walk([this](LLVM::CallOp op) { 181 if (isCInterfaceVulkanLaunchCallOp(op)) 182 translateVulkanLaunchCall(op); 183 }); 184 } 185 186 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 187 LLVM::CallOp vulkanLaunchCallOp) { 188 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 189 // for the given vulkan launch call. 190 auto spirvBlobAttr = 191 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 192 if (!spirvBlobAttr) { 193 vulkanLaunchCallOp.emitError() 194 << "missing " << kSPIRVBlobAttrName << " attribute"; 195 return signalPassFailure(); 196 } 197 198 auto spirvEntryPointNameAttr = 199 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 200 if (!spirvEntryPointNameAttr) { 201 vulkanLaunchCallOp.emitError() 202 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 203 return signalPassFailure(); 204 } 205 206 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 207 } 208 209 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 210 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 211 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 212 kVulkanLaunchNumConfigOperands) 213 return; 214 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 215 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 216 217 // Create LLVM constant for the descriptor set index. 218 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 219 // pass does. 220 Value descriptorSet = builder.create<LLVM::ConstantOp>( 221 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 222 223 for (auto en : 224 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 225 kVulkanLaunchNumConfigOperands))) { 226 // Create LLVM constant for the descriptor binding index. 227 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 228 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 229 230 auto ptrToMemRefDescriptor = en.value(); 231 uint32_t rank = 0; 232 LLVM::LLVMType type; 233 if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) { 234 cInterfaceVulkanLaunchCallOp.emitError() 235 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 236 return signalPassFailure(); 237 } 238 239 auto symbolName = 240 llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); 241 // Special case for fp16 type. Since it is not a supported type in C we use 242 // int16_t and bitcast the descriptor. 243 if (type.isa<LLVM::LLVMHalfType>()) { 244 auto memRefTy = 245 getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext())); 246 ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>( 247 loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor); 248 } 249 // Create call to `bindMemRef`. 250 builder.create<LLVM::CallOp>( 251 loc, TypeRange{getVoidType()}, 252 builder.getSymbolRefAttr( 253 StringRef(symbolName.data(), symbolName.size())), 254 ValueRange{vulkanRuntime, descriptorSet, descriptorBinding, 255 ptrToMemRefDescriptor}); 256 } 257 } 258 259 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( 260 Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) { 261 auto llvmPtrDescriptorTy = 262 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>(); 263 if (!llvmPtrDescriptorTy) 264 return failure(); 265 266 auto llvmDescriptorTy = 267 llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>(); 268 // template <typename Elem, size_t Rank> 269 // struct { 270 // Elem *allocated; 271 // Elem *aligned; 272 // int64_t offset; 273 // int64_t sizes[Rank]; // omitted when rank == 0 274 // int64_t strides[Rank]; // omitted when rank == 0 275 // }; 276 if (!llvmDescriptorTy) 277 return failure(); 278 279 type = llvmDescriptorTy.getBody()[0] 280 .cast<LLVM::LLVMPointerType>() 281 .getElementType(); 282 if (llvmDescriptorTy.getBody().size() == 3) { 283 rank = 0; 284 return success(); 285 } 286 rank = llvmDescriptorTy.getBody()[3] 287 .cast<LLVM::LLVMArrayType>() 288 .getNumElements(); 289 return success(); 290 } 291 292 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 293 ModuleOp module = getOperation(); 294 OpBuilder builder(module.getBody()->getTerminator()); 295 296 if (!module.lookupSymbol(kSetEntryPoint)) { 297 builder.create<LLVM::LLVMFuncOp>( 298 loc, kSetEntryPoint, 299 LLVM::LLVMType::getFunctionTy(getVoidType(), 300 {getPointerType(), getPointerType()}, 301 /*isVarArg=*/false)); 302 } 303 304 if (!module.lookupSymbol(kSetNumWorkGroups)) { 305 builder.create<LLVM::LLVMFuncOp>( 306 loc, kSetNumWorkGroups, 307 LLVM::LLVMType::getFunctionTy( 308 getVoidType(), 309 {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, 310 /*isVarArg=*/false)); 311 } 312 313 if (!module.lookupSymbol(kSetBinaryShader)) { 314 builder.create<LLVM::LLVMFuncOp>( 315 loc, kSetBinaryShader, 316 LLVM::LLVMType::getFunctionTy( 317 getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, 318 /*isVarArg=*/false)); 319 } 320 321 if (!module.lookupSymbol(kRunOnVulkan)) { 322 builder.create<LLVM::LLVMFuncOp>( 323 loc, kRunOnVulkan, 324 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 325 /*isVarArg=*/false)); 326 } 327 328 for (unsigned i = 1; i <= 3; i++) { 329 for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()), 330 LLVM::LLVMType::getInt32Ty(&getContext()), 331 LLVM::LLVMType::getInt16Ty(&getContext()), 332 LLVM::LLVMType::getInt8Ty(&getContext()), 333 LLVM::LLVMType::getHalfTy(&getContext())}) { 334 std::string fnName = "bindMemRef" + std::to_string(i) + "D" + 335 std::string(stringifyType(type)); 336 if (type.isa<LLVM::LLVMHalfType>()) 337 type = LLVM::LLVMType::getInt16Ty(&getContext()); 338 if (!module.lookupSymbol(fnName)) { 339 auto fnType = LLVM::LLVMType::getFunctionTy( 340 getVoidType(), 341 {getPointerType(), getInt32Type(), getInt32Type(), 342 LLVM::LLVMPointerType::get(getMemRefType(i, type))}, 343 /*isVarArg=*/false); 344 builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType); 345 } 346 } 347 } 348 349 if (!module.lookupSymbol(kInitVulkan)) { 350 builder.create<LLVM::LLVMFuncOp>( 351 loc, kInitVulkan, 352 LLVM::LLVMType::getFunctionTy(getPointerType(), {}, 353 /*isVarArg=*/false)); 354 } 355 356 if (!module.lookupSymbol(kDeinitVulkan)) { 357 builder.create<LLVM::LLVMFuncOp>( 358 loc, kDeinitVulkan, 359 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 360 /*isVarArg=*/false)); 361 } 362 } 363 364 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 365 StringRef name, Location loc, OpBuilder &builder) { 366 SmallString<16> shaderName(name.begin(), name.end()); 367 // Append `\0` to follow C style string given that LLVM::createGlobalString() 368 // won't handle this directly for us. 369 shaderName.push_back('\0'); 370 371 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 372 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 373 shaderName, LLVM::Linkage::Internal); 374 } 375 376 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 377 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 378 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 379 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 380 // Create call to `initVulkan`. 381 auto initVulkanCall = builder.create<LLVM::CallOp>( 382 loc, TypeRange{getPointerType()}, builder.getSymbolRefAttr(kInitVulkan), 383 ValueRange{}); 384 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 385 // need to pass that pointer to each Vulkan runtime call. 386 auto vulkanRuntime = initVulkanCall.getResult(0); 387 388 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 389 // that data to runtime call. 390 Value ptrToSPIRVBinary = LLVM::createGlobalString( 391 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 392 LLVM::Linkage::Internal); 393 394 // Create LLVM constant for the size of SPIR-V binary shader. 395 Value binarySize = builder.create<LLVM::ConstantOp>( 396 loc, getInt32Type(), 397 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 398 399 // Create call to `bindMemRef` for each memref operand. 400 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 401 402 // Create call to `setBinaryShader` runtime function with the given pointer to 403 // SPIR-V binary and binary size. 404 builder.create<LLVM::CallOp>( 405 loc, TypeRange{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader), 406 ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 407 // Create LLVM global with entry point name. 408 Value entryPointName = createEntryPointNameConstant( 409 spirvAttributes.second.getValue(), loc, builder); 410 // Create call to `setEntryPoint` runtime function with the given pointer to 411 // entry point name. 412 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()}, 413 builder.getSymbolRefAttr(kSetEntryPoint), 414 ValueRange{vulkanRuntime, entryPointName}); 415 416 // Create number of local workgroup for each dimension. 417 builder.create<LLVM::CallOp>( 418 loc, TypeRange{getVoidType()}, 419 builder.getSymbolRefAttr(kSetNumWorkGroups), 420 ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 421 cInterfaceVulkanLaunchCallOp.getOperand(1), 422 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 423 424 // Create call to `runOnVulkan` runtime function. 425 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()}, 426 builder.getSymbolRefAttr(kRunOnVulkan), 427 ValueRange{vulkanRuntime}); 428 429 // Create call to 'deinitVulkan' runtime function. 430 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()}, 431 builder.getSymbolRefAttr(kDeinitVulkan), 432 ValueRange{vulkanRuntime}); 433 434 // Declare runtime functions. 435 declareVulkanFunctions(loc); 436 437 cInterfaceVulkanLaunchCallOp.erase(); 438 } 439 440 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 441 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 442 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 443 } 444