//===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert vulkan launch call into a sequence of // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we // don't expose separate external functions in IR for each of them, instead we // expose a few external functions to wrapper libraries which manages Vulkan // runtime. // //===----------------------------------------------------------------------===// #include "../PassDetail.h" #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat"; static constexpr const char *kBindMemRef1DInt = "bindMemRef1DInt"; static constexpr const char *kBindMemRef2DInt = "bindMemRef2DInt"; static constexpr const char *kBindMemRef3DInt = "bindMemRef3DInt"; static constexpr const char *kCInterfaceVulkanLaunch = "_mlir_ciface_vulkanLaunch"; static constexpr const char *kDeinitVulkan = "deinitVulkan"; static constexpr const char *kRunOnVulkan = "runOnVulkan"; static constexpr const char *kInitVulkan = "initVulkan"; static constexpr const char *kSetBinaryShader = "setBinaryShader"; static constexpr const char *kSetEntryPoint = "setEntryPoint"; static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; static constexpr const char *kVulkanLaunch = "vulkanLaunch"; namespace { /// A pass to convert vulkan launch call op into a sequence of Vulkan /// runtime calls in the following order: /// /// * initVulkan -- initializes vulkan runtime /// * bindMemRef -- binds memref /// * setBinaryShader -- sets the binary shader data /// * setEntryPoint -- sets the entry point name /// * setNumWorkGroups -- sets the number of a local workgroups /// * runOnVulkan -- runs vulkan runtime /// * deinitVulkan -- deinitializes vulkan runtime /// class VulkanLaunchFuncToVulkanCallsPass : public ConvertVulkanLaunchFuncToVulkanCallsBase< VulkanLaunchFuncToVulkanCallsPass> { private: LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } llvm::LLVMContext &getLLVMContext() { return getLLVMDialect()->getLLVMContext(); } void initializeCachedTypes() { llvmDialect = getContext().getRegisteredDialect(); llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); llvmMemRef1DFloat = getMemRefType(1, llvmFloatType); llvmMemRef2DFloat = getMemRefType(2, llvmFloatType); llvmMemRef3DFloat = getMemRefType(3, llvmFloatType); llvmMemRef1DInt = getMemRefType(1, llvmInt32Type); llvmMemRef2DInt = getMemRefType(2, llvmInt32Type); llvmMemRef3DInt = getMemRefType(3, llvmInt32Type); } LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { // According to the MLIR doc memref argument is converted into a // pointer-to-struct argument of type: // template // struct { // Elem *allocated; // Elem *aligned; // int64_t offset; // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; auto llvmPtrToElementType = elemenType.getPointerTo(); auto llvmArrayRankElementSizeType = LLVM::LLVMType::getArrayTy(getInt64Type(), rank); // Create a type // `!llvm<"{ `element-type`*, `element-type`*, i64, // [`rank` x i64], [`rank` x i64]}">`. return LLVM::LLVMType::getStructTy( llvmDialect, {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(), llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); } LLVM::LLVMType getFloatType() { return llvmFloatType; } LLVM::LLVMType getVoidType() { return llvmVoidType; } LLVM::LLVMType getPointerType() { return llvmPointerType; } LLVM::LLVMType getInt32Type() { return llvmInt32Type; } LLVM::LLVMType getInt64Type() { return llvmInt64Type; } LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; } LLVM::LLVMType getMemRef1DInt() { return llvmMemRef1DInt; } LLVM::LLVMType getMemRef2DInt() { return llvmMemRef2DInt; } LLVM::LLVMType getMemRef3DInt() { return llvmMemRef3DInt; } /// Creates a LLVM global for the given `name`. Value createEntryPointNameConstant(StringRef name, Location loc, OpBuilder &builder); /// Declares all needed runtime functions. void declareVulkanFunctions(Location loc); /// Checks whether the given LLVM::CallOp is a vulkan launch call op. bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); } /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call /// op. bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) { return (callOp.callee() && callOp.callee().getValue() == kCInterfaceVulkanLaunch && callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands); } /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan /// runtime calls. void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); /// Creates call to `bindMemRef` for each memref operand. void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, Value vulkanRuntime); /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`. LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type); /// Returns a string representation from the given `type`. StringRef stringifyType(LLVM::LLVMType type) { if (type.isFloatTy()) return "Float"; if (type.isIntegerTy()) return "Int"; llvm_unreachable("unsupported type"); } public: void runOnOperation() override; private: LLVM::LLVMDialect *llvmDialect; LLVM::LLVMType llvmFloatType; LLVM::LLVMType llvmVoidType; LLVM::LLVMType llvmPointerType; LLVM::LLVMType llvmInt32Type; LLVM::LLVMType llvmInt64Type; LLVM::LLVMType llvmMemRef1DFloat; LLVM::LLVMType llvmMemRef2DFloat; LLVM::LLVMType llvmMemRef3DFloat; LLVM::LLVMType llvmMemRef1DInt; LLVM::LLVMType llvmMemRef2DInt; LLVM::LLVMType llvmMemRef3DInt; // TODO: Use an associative array to support multiple vulkan launch calls. std::pair spirvAttributes; /// The number of vulkan launch configuration operands, placed at the leading /// positions of the operand list. static constexpr unsigned kVulkanLaunchNumConfigOperands = 3; }; } // anonymous namespace void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { initializeCachedTypes(); // Collect SPIR-V attributes such as `spirv_blob` and // `spirv_entry_point_name`. getOperation().walk([this](LLVM::CallOp op) { if (isVulkanLaunchCallOp(op)) collectSPIRVAttributes(op); }); // Convert vulkan launch call op into a sequence of Vulkan runtime calls. getOperation().walk([this](LLVM::CallOp op) { if (isCInterfaceVulkanLaunchCallOp(op)) translateVulkanLaunchCall(op); }); } void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( LLVM::CallOp vulkanLaunchCallOp) { // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes // for the given vulkan launch call. auto spirvBlobAttr = vulkanLaunchCallOp.getAttrOfType(kSPIRVBlobAttrName); if (!spirvBlobAttr) { vulkanLaunchCallOp.emitError() << "missing " << kSPIRVBlobAttrName << " attribute"; return signalPassFailure(); } auto spirvEntryPointNameAttr = vulkanLaunchCallOp.getAttrOfType(kSPIRVEntryPointAttrName); if (!spirvEntryPointNameAttr) { vulkanLaunchCallOp.emitError() << "missing " << kSPIRVEntryPointAttrName << " attribute"; return signalPassFailure(); } spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); } void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) { if (cInterfaceVulkanLaunchCallOp.getNumOperands() == kVulkanLaunchNumConfigOperands) return; OpBuilder builder(cInterfaceVulkanLaunchCallOp); Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); // Create LLVM constant for the descriptor set index. // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` // pass does. Value descriptorSet = builder.create( loc, getInt32Type(), builder.getI32IntegerAttr(0)); for (auto en : llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front( kVulkanLaunchNumConfigOperands))) { // Create LLVM constant for the descriptor binding index. Value descriptorBinding = builder.create( loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); auto ptrToMemRefDescriptor = en.value(); uint32_t rank = 0; LLVM::LLVMType type; if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) { cInterfaceVulkanLaunchCallOp.emitError() << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); return signalPassFailure(); } auto symbolName = llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); // Create call to `bindMemRef`. builder.create( loc, ArrayRef{getVoidType()}, builder.getSymbolRefAttr( StringRef(symbolName.data(), symbolName.size())), ArrayRef{vulkanRuntime, descriptorSet, descriptorBinding, ptrToMemRefDescriptor}); } } LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) { auto llvmPtrDescriptorTy = ptrToMemRefDescriptor.getType().dyn_cast(); if (!llvmPtrDescriptorTy) return failure(); auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); // template // struct { // Elem *allocated; // Elem *aligned; // int64_t offset; // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) return failure(); type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy(); if (llvmDescriptorTy.getStructNumElements() == 3) { rank = 0; return success(); } rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); return success(); } void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { ModuleOp module = getOperation(); OpBuilder builder(module.getBody()->getTerminator()); if (!module.lookupSymbol(kSetEntryPoint)) { builder.create( loc, kSetEntryPoint, LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType(), getPointerType()}, /*isVarArg=*/false)); } if (!module.lookupSymbol(kSetNumWorkGroups)) { builder.create( loc, kSetNumWorkGroups, LLVM::LLVMType::getFunctionTy( getVoidType(), {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, /*isVarArg=*/false)); } if (!module.lookupSymbol(kSetBinaryShader)) { builder.create( loc, kSetBinaryShader, LLVM::LLVMType::getFunctionTy( getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, /*isVarArg=*/false)); } if (!module.lookupSymbol(kRunOnVulkan)) { builder.create( loc, kRunOnVulkan, LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, /*isVarArg=*/false)); } #define CREATE_VULKAN_BIND_FUNC(MemRefType) \ if (!module.lookupSymbol(kBind##MemRefType)) { \ builder.create( \ loc, kBind##MemRefType, \ LLVM::LLVMType::getFunctionTy(getVoidType(), \ {getPointerType(), getInt32Type(), \ getInt32Type(), \ get##MemRefType().getPointerTo()}, \ /*isVarArg=*/false)); \ } CREATE_VULKAN_BIND_FUNC(MemRef1DFloat); CREATE_VULKAN_BIND_FUNC(MemRef2DFloat); CREATE_VULKAN_BIND_FUNC(MemRef3DFloat); CREATE_VULKAN_BIND_FUNC(MemRef1DInt); CREATE_VULKAN_BIND_FUNC(MemRef2DInt); CREATE_VULKAN_BIND_FUNC(MemRef3DInt); if (!module.lookupSymbol(kInitVulkan)) { builder.create( loc, kInitVulkan, LLVM::LLVMType::getFunctionTy(getPointerType(), {}, /*isVarArg=*/false)); } if (!module.lookupSymbol(kDeinitVulkan)) { builder.create( loc, kDeinitVulkan, LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, /*isVarArg=*/false)); } } Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( StringRef name, Location loc, OpBuilder &builder) { SmallString<16> shaderName(name.begin(), name.end()); // Append `\0` to follow C style string given that LLVM::createGlobalString() // won't handle this directly for us. shaderName.push_back('\0'); std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); return LLVM::createGlobalString(loc, builder, entryPointGlobalName, shaderName, LLVM::Linkage::Internal, getLLVMDialect()); } void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( LLVM::CallOp cInterfaceVulkanLaunchCallOp) { OpBuilder builder(cInterfaceVulkanLaunchCallOp); Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); // Create call to `initVulkan`. auto initVulkanCall = builder.create( loc, ArrayRef{getPointerType()}, builder.getSymbolRefAttr(kInitVulkan), ArrayRef{}); // The result of `initVulkan` function is a pointer to Vulkan runtime, we // need to pass that pointer to each Vulkan runtime call. auto vulkanRuntime = initVulkanCall.getResult(0); // Create LLVM global with SPIR-V binary data, so we can pass a pointer with // that data to runtime call. Value ptrToSPIRVBinary = LLVM::createGlobalString( loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), LLVM::Linkage::Internal, getLLVMDialect()); // Create LLVM constant for the size of SPIR-V binary shader. Value binarySize = builder.create( loc, getInt32Type(), builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); // Create call to `bindMemRef` for each memref operand. createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime); // Create call to `setBinaryShader` runtime function with the given pointer to // SPIR-V binary and binary size. builder.create( loc, ArrayRef{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader), ArrayRef{vulkanRuntime, ptrToSPIRVBinary, binarySize}); // Create LLVM global with entry point name. Value entryPointName = createEntryPointNameConstant( spirvAttributes.second.getValue(), loc, builder); // Create call to `setEntryPoint` runtime function with the given pointer to // entry point name. builder.create(loc, ArrayRef{getVoidType()}, builder.getSymbolRefAttr(kSetEntryPoint), ArrayRef{vulkanRuntime, entryPointName}); // Create number of local workgroup for each dimension. builder.create( loc, ArrayRef{getVoidType()}, builder.getSymbolRefAttr(kSetNumWorkGroups), ArrayRef{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), cInterfaceVulkanLaunchCallOp.getOperand(1), cInterfaceVulkanLaunchCallOp.getOperand(2)}); // Create call to `runOnVulkan` runtime function. builder.create(loc, ArrayRef{getVoidType()}, builder.getSymbolRefAttr(kRunOnVulkan), ArrayRef{vulkanRuntime}); // Create call to 'deinitVulkan' runtime function. builder.create(loc, ArrayRef{getVoidType()}, builder.getSymbolRefAttr(kDeinitVulkan), ArrayRef{vulkanRuntime}); // Declare runtime functions. declareVulkanFunctions(loc); cInterfaceVulkanLaunchCallOp.erase(); } std::unique_ptr> mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { return std::make_unique(); }