1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert vulkan launch call into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "../PassDetail.h" 18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/BuiltinOps.h" 23 24 #include "llvm/ADT/SmallString.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 using namespace mlir; 28 29 static constexpr const char *kCInterfaceVulkanLaunch = 30 "_mlir_ciface_vulkanLaunch"; 31 static constexpr const char *kDeinitVulkan = "deinitVulkan"; 32 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 33 static constexpr const char *kInitVulkan = "initVulkan"; 34 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 35 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 37 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 38 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; 39 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; 40 static constexpr const char *kVulkanLaunch = "vulkanLaunch"; 41 42 namespace { 43 44 /// A pass to convert vulkan launch call op into a sequence of Vulkan 45 /// runtime calls in the following order: 46 /// 47 /// * initVulkan -- initializes vulkan runtime 48 /// * bindMemRef -- binds memref 49 /// * setBinaryShader -- sets the binary shader data 50 /// * setEntryPoint -- sets the entry point name 51 /// * setNumWorkGroups -- sets the number of a local workgroups 52 /// * runOnVulkan -- runs vulkan runtime 53 /// * deinitVulkan -- deinitializes vulkan runtime 54 /// 55 class VulkanLaunchFuncToVulkanCallsPass 56 : public ConvertVulkanLaunchFuncToVulkanCallsBase< 57 VulkanLaunchFuncToVulkanCallsPass> { 58 private: 59 void initializeCachedTypes() { 60 llvmFloatType = LLVM::LLVMFloatType::get(&getContext()); 61 llvmVoidType = LLVM::LLVMVoidType::get(&getContext()); 62 llvmPointerType = LLVM::LLVMPointerType::get( 63 LLVM::LLVMIntegerType::get(&getContext(), 8)); 64 llvmInt32Type = LLVM::LLVMIntegerType::get(&getContext(), 32); 65 llvmInt64Type = LLVM::LLVMIntegerType::get(&getContext(), 64); 66 } 67 68 Type getMemRefType(uint32_t rank, Type elemenType) { 69 // According to the MLIR doc memref argument is converted into a 70 // pointer-to-struct argument of type: 71 // template <typename Elem, size_t Rank> 72 // struct { 73 // Elem *allocated; 74 // Elem *aligned; 75 // int64_t offset; 76 // int64_t sizes[Rank]; // omitted when rank == 0 77 // int64_t strides[Rank]; // omitted when rank == 0 78 // }; 79 auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType); 80 auto llvmArrayRankElementSizeType = 81 LLVM::LLVMArrayType::get(getInt64Type(), rank); 82 83 // Create a type 84 // `!llvm<"{ `element-type`*, `element-type`*, i64, 85 // [`rank` x i64], [`rank` x i64]}">`. 86 return LLVM::LLVMStructType::getLiteral( 87 &getContext(), 88 {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(), 89 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); 90 } 91 92 Type getVoidType() { return llvmVoidType; } 93 Type getPointerType() { return llvmPointerType; } 94 Type getInt32Type() { return llvmInt32Type; } 95 Type getInt64Type() { return llvmInt64Type; } 96 97 /// Creates an LLVM global for the given `name`. 98 Value createEntryPointNameConstant(StringRef name, Location loc, 99 OpBuilder &builder); 100 101 /// Declares all needed runtime functions. 102 void declareVulkanFunctions(Location loc); 103 104 /// Checks whether the given LLVM::CallOp is a vulkan launch call op. 105 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { 106 return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && 107 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 108 } 109 110 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call 111 /// op. 112 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { 113 return (callOp.callee() && 114 callOp.callee().getValue() == kCInterfaceVulkanLaunch && 115 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); 116 } 117 118 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan 119 /// runtime calls. 120 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); 121 122 /// Creates call to `bindMemRef` for each memref operand. 123 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, 124 Value vulkanRuntime); 125 126 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. 127 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); 128 129 /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`. 130 LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor, 131 uint32_t &rank, Type &type); 132 133 /// Returns a string representation from the given `type`. 134 StringRef stringifyType(Type type) { 135 if (type.isa<LLVM::LLVMFloatType>()) 136 return "Float"; 137 if (type.isa<LLVM::LLVMHalfType>()) 138 return "Half"; 139 if (auto intType = type.dyn_cast<LLVM::LLVMIntegerType>()) { 140 if (intType.getBitWidth() == 32) 141 return "Int32"; 142 if (intType.getBitWidth() == 16) 143 return "Int16"; 144 if (intType.getBitWidth() == 8) 145 return "Int8"; 146 } 147 148 llvm_unreachable("unsupported type"); 149 } 150 151 public: 152 void runOnOperation() override; 153 154 private: 155 Type llvmFloatType; 156 Type llvmVoidType; 157 Type llvmPointerType; 158 Type llvmInt32Type; 159 Type llvmInt64Type; 160 161 // TODO: Use an associative array to support multiple vulkan launch calls. 162 std::pair<StringAttr, StringAttr> spirvAttributes; 163 /// The number of vulkan launch configuration operands, placed at the leading 164 /// positions of the operand list. 165 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3; 166 }; 167 168 } // anonymous namespace 169 170 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { 171 initializeCachedTypes(); 172 173 // Collect SPIR-V attributes such as `spirv_blob` and 174 // `spirv_entry_point_name`. 175 getOperation().walk([this](LLVM::CallOp op) { 176 if (isVulkanLaunchCallOp(op)) 177 collectSPIRVAttributes(op); 178 }); 179 180 // Convert vulkan launch call op into a sequence of Vulkan runtime calls. 181 getOperation().walk([this](LLVM::CallOp op) { 182 if (isCInterfaceVulkanLaunchCallOp(op)) 183 translateVulkanLaunchCall(op); 184 }); 185 } 186 187 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( 188 LLVM::CallOp vulkanLaunchCallOp) { 189 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes 190 // for the given vulkan launch call. 191 auto spirvBlobAttr = 192 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName); 193 if (!spirvBlobAttr) { 194 vulkanLaunchCallOp.emitError() 195 << "missing " << kSPIRVBlobAttrName << " attribute"; 196 return signalPassFailure(); 197 } 198 199 auto spirvEntryPointNameAttr = 200 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); 201 if (!spirvEntryPointNameAttr) { 202 vulkanLaunchCallOp.emitError() 203 << "missing " << kSPIRVEntryPointAttrName << " attribute"; 204 return signalPassFailure(); 205 } 206 207 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); 208 } 209 210 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( 211 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { 212 if (cInterfaceVulkanLaunchCallOp.getNumOperands() == 213 kVulkanLaunchNumConfigOperands) 214 return; 215 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 216 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 217 218 // Create LLVM constant for the descriptor set index. 219 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` 220 // pass does. 221 Value descriptorSet = builder.create<LLVM::ConstantOp>( 222 loc, getInt32Type(), builder.getI32IntegerAttr(0)); 223 224 for (auto en : 225 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( 226 kVulkanLaunchNumConfigOperands))) { 227 // Create LLVM constant for the descriptor binding index. 228 Value descriptorBinding = builder.create<LLVM::ConstantOp>( 229 loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); 230 231 auto ptrToMemRefDescriptor = en.value(); 232 uint32_t rank = 0; 233 Type type; 234 if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) { 235 cInterfaceVulkanLaunchCallOp.emitError() 236 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); 237 return signalPassFailure(); 238 } 239 240 auto symbolName = 241 llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); 242 // Special case for fp16 type. Since it is not a supported type in C we use 243 // int16_t and bitcast the descriptor. 244 if (type.isa<LLVM::LLVMHalfType>()) { 245 auto memRefTy = 246 getMemRefType(rank, LLVM::LLVMIntegerType::get(&getContext(), 16)); 247 ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>( 248 loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor); 249 } 250 // Create call to `bindMemRef`. 251 builder.create<LLVM::CallOp>( 252 loc, TypeRange{getVoidType()}, 253 builder.getSymbolRefAttr( 254 StringRef(symbolName.data(), symbolName.size())), 255 ValueRange{vulkanRuntime, descriptorSet, descriptorBinding, 256 ptrToMemRefDescriptor}); 257 } 258 } 259 260 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( 261 Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) { 262 auto llvmPtrDescriptorTy = 263 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>(); 264 if (!llvmPtrDescriptorTy) 265 return failure(); 266 267 auto llvmDescriptorTy = 268 llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>(); 269 // template <typename Elem, size_t Rank> 270 // struct { 271 // Elem *allocated; 272 // Elem *aligned; 273 // int64_t offset; 274 // int64_t sizes[Rank]; // omitted when rank == 0 275 // int64_t strides[Rank]; // omitted when rank == 0 276 // }; 277 if (!llvmDescriptorTy) 278 return failure(); 279 280 type = llvmDescriptorTy.getBody()[0] 281 .cast<LLVM::LLVMPointerType>() 282 .getElementType(); 283 if (llvmDescriptorTy.getBody().size() == 3) { 284 rank = 0; 285 return success(); 286 } 287 rank = llvmDescriptorTy.getBody()[3] 288 .cast<LLVM::LLVMArrayType>() 289 .getNumElements(); 290 return success(); 291 } 292 293 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { 294 ModuleOp module = getOperation(); 295 OpBuilder builder(module.getBody()->getTerminator()); 296 297 if (!module.lookupSymbol(kSetEntryPoint)) { 298 builder.create<LLVM::LLVMFuncOp>( 299 loc, kSetEntryPoint, 300 LLVM::LLVMFunctionType::get(getVoidType(), 301 {getPointerType(), getPointerType()})); 302 } 303 304 if (!module.lookupSymbol(kSetNumWorkGroups)) { 305 builder.create<LLVM::LLVMFuncOp>( 306 loc, kSetNumWorkGroups, 307 LLVM::LLVMFunctionType::get(getVoidType(), 308 {getPointerType(), getInt64Type(), 309 getInt64Type(), getInt64Type()})); 310 } 311 312 if (!module.lookupSymbol(kSetBinaryShader)) { 313 builder.create<LLVM::LLVMFuncOp>( 314 loc, kSetBinaryShader, 315 LLVM::LLVMFunctionType::get( 316 getVoidType(), 317 {getPointerType(), getPointerType(), getInt32Type()})); 318 } 319 320 if (!module.lookupSymbol(kRunOnVulkan)) { 321 builder.create<LLVM::LLVMFuncOp>( 322 loc, kRunOnVulkan, 323 LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()})); 324 } 325 326 for (unsigned i = 1; i <= 3; i++) { 327 SmallVector<Type, 5> types{LLVM::LLVMFloatType::get(&getContext()), 328 LLVM::LLVMIntegerType::get(&getContext(), 32), 329 LLVM::LLVMIntegerType::get(&getContext(), 16), 330 LLVM::LLVMIntegerType::get(&getContext(), 8), 331 LLVM::LLVMHalfType::get(&getContext())}; 332 for (auto type : types) { 333 std::string fnName = "bindMemRef" + std::to_string(i) + "D" + 334 std::string(stringifyType(type)); 335 if (type.isa<LLVM::LLVMHalfType>()) 336 type = LLVM::LLVMIntegerType::get(&getContext(), 16); 337 if (!module.lookupSymbol(fnName)) { 338 auto fnType = LLVM::LLVMFunctionType::get( 339 getVoidType(), 340 {getPointerType(), getInt32Type(), getInt32Type(), 341 LLVM::LLVMPointerType::get(getMemRefType(i, type))}, 342 /*isVarArg=*/false); 343 builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType); 344 } 345 } 346 } 347 348 if (!module.lookupSymbol(kInitVulkan)) { 349 builder.create<LLVM::LLVMFuncOp>( 350 loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {})); 351 } 352 353 if (!module.lookupSymbol(kDeinitVulkan)) { 354 builder.create<LLVM::LLVMFuncOp>( 355 loc, kDeinitVulkan, 356 LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()})); 357 } 358 } 359 360 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( 361 StringRef name, Location loc, OpBuilder &builder) { 362 SmallString<16> shaderName(name.begin(), name.end()); 363 // Append `\0` to follow C style string given that LLVM::createGlobalString() 364 // won't handle this directly for us. 365 shaderName.push_back('\0'); 366 367 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 368 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 369 shaderName, LLVM::Linkage::Internal); 370 } 371 372 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( 373 LLVM::CallOp cInterfaceVulkanLaunchCallOp) { 374 OpBuilder builder(cInterfaceVulkanLaunchCallOp); 375 Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); 376 // Create call to `initVulkan`. 377 auto initVulkanCall = builder.create<LLVM::CallOp>( 378 loc, TypeRange{getPointerType()}, builder.getSymbolRefAttr(kInitVulkan), 379 ValueRange{}); 380 // The result of `initVulkan` function is a pointer to Vulkan runtime, we 381 // need to pass that pointer to each Vulkan runtime call. 382 auto vulkanRuntime = initVulkanCall.getResult(0); 383 384 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 385 // that data to runtime call. 386 Value ptrToSPIRVBinary = LLVM::createGlobalString( 387 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), 388 LLVM::Linkage::Internal); 389 390 // Create LLVM constant for the size of SPIR-V binary shader. 391 Value binarySize = builder.create<LLVM::ConstantOp>( 392 loc, getInt32Type(), 393 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); 394 395 // Create call to `bindMemRef` for each memref operand. 396 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); 397 398 // Create call to `setBinaryShader` runtime function with the given pointer to 399 // SPIR-V binary and binary size. 400 builder.create<LLVM::CallOp>( 401 loc, TypeRange{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader), 402 ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize}); 403 // Create LLVM global with entry point name. 404 Value entryPointName = createEntryPointNameConstant( 405 spirvAttributes.second.getValue(), loc, builder); 406 // Create call to `setEntryPoint` runtime function with the given pointer to 407 // entry point name. 408 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()}, 409 builder.getSymbolRefAttr(kSetEntryPoint), 410 ValueRange{vulkanRuntime, entryPointName}); 411 412 // Create number of local workgroup for each dimension. 413 builder.create<LLVM::CallOp>( 414 loc, TypeRange{getVoidType()}, 415 builder.getSymbolRefAttr(kSetNumWorkGroups), 416 ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), 417 cInterfaceVulkanLaunchCallOp.getOperand(1), 418 cInterfaceVulkanLaunchCallOp.getOperand(2)}); 419 420 // Create call to `runOnVulkan` runtime function. 421 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()}, 422 builder.getSymbolRefAttr(kRunOnVulkan), 423 ValueRange{vulkanRuntime}); 424 425 // Create call to 'deinitVulkan' runtime function. 426 builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()}, 427 builder.getSymbolRefAttr(kDeinitVulkan), 428 ValueRange{vulkanRuntime}); 429 430 // Declare runtime functions. 431 declareVulkanFunctions(loc); 432 433 cInterfaceVulkanLaunchCallOp.erase(); 434 } 435 436 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 437 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { 438 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); 439 } 440