1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert gpu.launch_func op into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 18 #include "mlir/Dialect/GPU/GPUDialect.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 21 #include "mlir/Dialect/SPIRV/Serialization.h" 22 #include "mlir/Dialect/StandardOps/Ops.h" 23 #include "mlir/IR/Attributes.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/Function.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/StandardTypes.h" 28 #include "mlir/Pass/Pass.h" 29 30 #include "llvm/Support/FormatVariadic.h" 31 32 using namespace mlir; 33 34 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 35 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 37 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 38 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 39 40 namespace { 41 42 /// A pass to convert gpu.launch_func operation into a sequence of Vulkan 43 /// runtime calls. 44 /// 45 /// * setBinaryShader -- sets the binary shader data 46 /// * setEntryPoint -- sets the entry point name 47 /// * setNumWorkGroups -- sets the number of a local workgroups 48 /// * runOnVulkan -- runs vulkan runtime 49 /// 50 class GpuLaunchFuncToVulkanCalssPass 51 : public ModulePass<GpuLaunchFuncToVulkanCalssPass> { 52 private: 53 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 54 55 llvm::LLVMContext &getLLVMContext() { 56 return getLLVMDialect()->getLLVMContext(); 57 } 58 59 void initializeCachedTypes() { 60 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 61 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 62 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 63 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 64 } 65 66 LLVM::LLVMType getVoidType() { return llvmVoidType; } 67 LLVM::LLVMType getPointerType() { return llvmPointerType; } 68 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 69 70 /// Creates a SPIR-V binary shader from the given `module` using 71 /// `spirv::serialize` function. 72 LogicalResult createBinaryShader(ModuleOp module, 73 std::vector<char> &binaryShader); 74 75 /// Creates a LLVM global for the given `name`. 76 Value createEntryPointNameConstant(StringRef name, Location loc, 77 OpBuilder &builder); 78 79 /// Creates a LLVM constant for each dimension of local workgroup and 80 /// populates the given `numWorkGroups`. 81 LogicalResult createNumWorkGroups(Location loc, OpBuilder &builder, 82 mlir::gpu::LaunchFuncOp launchOp, 83 SmallVector<Value, 3> &numWorkGroups); 84 85 /// Declares all needed runtime functions. 86 void declareVulkanFunctions(Location loc); 87 88 /// Translates the given `launcOp` op to the sequence of Vulkan runtime calls 89 void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp); 90 91 public: 92 void runOnModule() override; 93 94 private: 95 LLVM::LLVMDialect *llvmDialect; 96 LLVM::LLVMType llvmVoidType; 97 LLVM::LLVMType llvmPointerType; 98 LLVM::LLVMType llvmInt32Type; 99 }; 100 101 } // anonymous namespace 102 103 void GpuLaunchFuncToVulkanCalssPass::runOnModule() { 104 initializeCachedTypes(); 105 106 getModule().walk( 107 [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); }); 108 109 // Erase `gpu::GPUModuleOp` and `spirv::Module` operations. 110 for (auto gpuModule : 111 llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>())) 112 gpuModule.erase(); 113 114 for (auto spirvModule : 115 llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>())) 116 spirvModule.erase(); 117 } 118 119 void GpuLaunchFuncToVulkanCalssPass::declareVulkanFunctions(Location loc) { 120 ModuleOp module = getModule(); 121 OpBuilder builder(module.getBody()->getTerminator()); 122 123 if (!module.lookupSymbol(kSetEntryPoint)) { 124 builder.create<LLVM::LLVMFuncOp>( 125 loc, kSetEntryPoint, 126 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 127 /*isVarArg=*/false)); 128 } 129 130 if (!module.lookupSymbol(kSetNumWorkGroups)) { 131 builder.create<LLVM::LLVMFuncOp>( 132 loc, kSetNumWorkGroups, 133 LLVM::LLVMType::getFunctionTy( 134 getVoidType(), {getInt32Type(), getInt32Type(), getInt32Type()}, 135 /*isVarArg=*/false)); 136 } 137 138 if (!module.lookupSymbol(kSetBinaryShader)) { 139 builder.create<LLVM::LLVMFuncOp>( 140 loc, kSetBinaryShader, 141 LLVM::LLVMType::getFunctionTy(getVoidType(), 142 {getPointerType(), getInt32Type()}, 143 /*isVarArg=*/false)); 144 } 145 146 if (!module.lookupSymbol(kRunOnVulkan)) { 147 builder.create<LLVM::LLVMFuncOp>( 148 loc, kRunOnVulkan, 149 LLVM::LLVMType::getFunctionTy(getVoidType(), {}, 150 /*isVarArg=*/false)); 151 } 152 } 153 154 Value GpuLaunchFuncToVulkanCalssPass::createEntryPointNameConstant( 155 StringRef name, Location loc, OpBuilder &builder) { 156 std::vector<char> shaderName(name.begin(), name.end()); 157 // Append `\0` to follow C style string given that LLVM::createGlobalString() 158 // won't handle this directly for us. 159 shaderName.push_back('\0'); 160 161 std::string entryPointGlobalName = 162 std::string(llvm::formatv("{0}_spv_entry_point_name", name)); 163 return LLVM::createGlobalString( 164 loc, builder, entryPointGlobalName, 165 StringRef(shaderName.data(), shaderName.size()), LLVM::Linkage::Internal, 166 getLLVMDialect()); 167 } 168 169 LogicalResult GpuLaunchFuncToVulkanCalssPass::createBinaryShader( 170 ModuleOp module, std::vector<char> &binaryShader) { 171 bool done = false; 172 SmallVector<uint32_t, 0> binary; 173 for (auto spirvModule : module.getOps<spirv::ModuleOp>()) { 174 if (done) { 175 spirvModule.emitError("should only contain one 'spv.module' op"); 176 return failure(); 177 } 178 done = true; 179 if (failed(spirv::serialize(spirvModule, binary))) { 180 return failure(); 181 } 182 } 183 184 binaryShader.resize(binary.size() * sizeof(uint32_t)); 185 std::memcpy(binaryShader.data(), reinterpret_cast<char *>(binary.data()), 186 binaryShader.size()); 187 return success(); 188 } 189 190 LogicalResult GpuLaunchFuncToVulkanCalssPass::createNumWorkGroups( 191 Location loc, OpBuilder &builder, mlir::gpu::LaunchFuncOp launchOp, 192 SmallVector<Value, 3> &numWorkGroups) { 193 for (auto index : llvm::seq(0, 3)) { 194 auto numWorkGroupDimConstant = dyn_cast_or_null<ConstantOp>( 195 launchOp.getOperand(index).getDefiningOp()); 196 197 if (!numWorkGroupDimConstant) { 198 return failure(); 199 } 200 201 auto numWorkGroupDimValue = 202 numWorkGroupDimConstant.getValue().cast<IntegerAttr>().getInt(); 203 numWorkGroups.push_back(builder.create<LLVM::ConstantOp>( 204 loc, getInt32Type(), builder.getI32IntegerAttr(numWorkGroupDimValue))); 205 } 206 207 return success(); 208 } 209 210 // Translates gpu launch op to the sequence of Vulkan runtime calls. 211 void GpuLaunchFuncToVulkanCalssPass::translateGpuLaunchCalls( 212 mlir::gpu::LaunchFuncOp launchOp) { 213 ModuleOp module = getModule(); 214 OpBuilder builder(launchOp); 215 Location loc = launchOp.getLoc(); 216 217 // Serialize `spirv::Module` into binary form. 218 std::vector<char> binary; 219 if (failed( 220 GpuLaunchFuncToVulkanCalssPass::createBinaryShader(module, binary))) { 221 return signalPassFailure(); 222 } 223 224 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 225 // that data to runtime call. 226 Value ptrToSPIRVBinary = LLVM::createGlobalString( 227 loc, builder, kSPIRVBinary, StringRef(binary.data(), binary.size()), 228 LLVM::Linkage::Internal, getLLVMDialect()); 229 // Create LLVM constant for the size of SPIR-V binary shader. 230 Value binarySize = builder.create<LLVM::ConstantOp>( 231 loc, getInt32Type(), builder.getI32IntegerAttr(binary.size())); 232 // Create call to `setBinaryShader` runtime function with the given pointer to 233 // SPIR-V binary and binary size. 234 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 235 builder.getSymbolRefAttr(kSetBinaryShader), 236 ArrayRef<Value>{ptrToSPIRVBinary, binarySize}); 237 238 // Create LLVM global with entry point name. 239 Value entryPointName = 240 createEntryPointNameConstant(launchOp.kernel(), loc, builder); 241 // Create call to `setEntryPoint` runtime function with the given pointer to 242 // entry point name. 243 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 244 builder.getSymbolRefAttr(kSetEntryPoint), 245 ArrayRef<Value>{entryPointName}); 246 247 // Create number of local workgroup for each dimension. 248 SmallVector<Value, 3> numWorkGroups; 249 if (failed(createNumWorkGroups(loc, builder, launchOp, numWorkGroups))) { 250 return signalPassFailure(); 251 } 252 253 // Create call `setNumWorkGroups` runtime function with the given numbers of 254 // local workgroup. 255 builder.create<LLVM::CallOp>( 256 loc, ArrayRef<Type>{getVoidType()}, 257 builder.getSymbolRefAttr(kSetNumWorkGroups), 258 ArrayRef<Value>{numWorkGroups[0], numWorkGroups[1], numWorkGroups[2]}); 259 260 // Create call to `runOnVulkan` runtime function. 261 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 262 builder.getSymbolRefAttr(kRunOnVulkan), 263 ArrayRef<Value>{}); 264 265 // Declare runtime functions. 266 declareVulkanFunctions(loc); 267 268 launchOp.erase(); 269 } 270 271 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> 272 mlir::createConvertGpuLaunchFuncToVulkanCallsPass() { 273 return std::make_unique<GpuLaunchFuncToVulkanCalssPass>(); 274 } 275 276 static PassRegistration<GpuLaunchFuncToVulkanCalssPass> 277 pass("launch-func-to-vulkan", 278 "Convert gpu.launch_func op to Vulkan runtime calls"); 279