1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert gpu.launch_func op into a sequence of 10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we 11 // don't expose separate external functions in IR for each of them, instead we 12 // expose a few external functions to wrapper libraries which manages Vulkan 13 // runtime. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 18 #include "mlir/Dialect/GPU/GPUDialect.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 21 #include "mlir/Dialect/SPIRV/Serialization.h" 22 #include "mlir/Dialect/StandardOps/Ops.h" 23 #include "mlir/IR/Attributes.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/Function.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/StandardTypes.h" 28 #include "mlir/Pass/Pass.h" 29 30 #include "llvm/ADT/SmallString.h" 31 32 using namespace mlir; 33 34 static constexpr const char *kSetBinaryShader = "setBinaryShader"; 35 static constexpr const char *kSetEntryPoint = "setEntryPoint"; 36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; 37 static constexpr const char *kRunOnVulkan = "runOnVulkan"; 38 static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; 39 40 namespace { 41 42 /// A pass to convert gpu.launch_func operation into a sequence of Vulkan 43 /// runtime calls. 44 /// 45 /// * setBinaryShader -- sets the binary shader data 46 /// * setEntryPoint -- sets the entry point name 47 /// * setNumWorkGroups -- sets the number of a local workgroups 48 /// * runOnVulkan -- runs vulkan runtime 49 /// 50 class GpuLaunchFuncToVulkanCalssPass 51 : public ModulePass<GpuLaunchFuncToVulkanCalssPass> { 52 private: 53 LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } 54 55 llvm::LLVMContext &getLLVMContext() { 56 return getLLVMDialect()->getLLVMContext(); 57 } 58 59 void initializeCachedTypes() { 60 llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); 61 llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); 62 llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); 63 llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); 64 } 65 66 LLVM::LLVMType getVoidType() { return llvmVoidType; } 67 LLVM::LLVMType getPointerType() { return llvmPointerType; } 68 LLVM::LLVMType getInt32Type() { return llvmInt32Type; } 69 70 /// Creates a SPIR-V binary shader from the given `module` using 71 /// `spirv::serialize` function. 72 LogicalResult createBinaryShader(ModuleOp module, 73 std::vector<char> &binaryShader); 74 75 /// Creates a LLVM global for the given `name`. 76 Value createEntryPointNameConstant(StringRef name, Location loc, 77 OpBuilder &builder); 78 79 /// Creates a LLVM constant for each dimension of local workgroup and 80 /// populates the given `numWorkGroups`. 81 LogicalResult createNumWorkGroups(Location loc, OpBuilder &builder, 82 mlir::gpu::LaunchFuncOp launchOp, 83 SmallVectorImpl<Value> &numWorkGroups); 84 85 /// Declares all needed runtime functions. 86 void declareVulkanFunctions(Location loc); 87 88 /// Translates the given `launcOp` op to the sequence of Vulkan runtime calls 89 void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp); 90 91 public: 92 void runOnModule() override; 93 94 private: 95 LLVM::LLVMDialect *llvmDialect; 96 LLVM::LLVMType llvmVoidType; 97 LLVM::LLVMType llvmPointerType; 98 LLVM::LLVMType llvmInt32Type; 99 }; 100 101 } // anonymous namespace 102 103 void GpuLaunchFuncToVulkanCalssPass::runOnModule() { 104 initializeCachedTypes(); 105 106 getModule().walk( 107 [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); }); 108 109 // Erase `gpu::GPUModuleOp` and `spirv::Module` operations. 110 for (auto gpuModule : 111 llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>())) 112 gpuModule.erase(); 113 114 for (auto spirvModule : 115 llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>())) 116 spirvModule.erase(); 117 } 118 119 void GpuLaunchFuncToVulkanCalssPass::declareVulkanFunctions(Location loc) { 120 ModuleOp module = getModule(); 121 OpBuilder builder(module.getBody()->getTerminator()); 122 123 if (!module.lookupSymbol(kSetEntryPoint)) { 124 builder.create<LLVM::LLVMFuncOp>( 125 loc, kSetEntryPoint, 126 LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, 127 /*isVarArg=*/false)); 128 } 129 130 if (!module.lookupSymbol(kSetNumWorkGroups)) { 131 builder.create<LLVM::LLVMFuncOp>( 132 loc, kSetNumWorkGroups, 133 LLVM::LLVMType::getFunctionTy( 134 getVoidType(), {getInt32Type(), getInt32Type(), getInt32Type()}, 135 /*isVarArg=*/false)); 136 } 137 138 if (!module.lookupSymbol(kSetBinaryShader)) { 139 builder.create<LLVM::LLVMFuncOp>( 140 loc, kSetBinaryShader, 141 LLVM::LLVMType::getFunctionTy(getVoidType(), 142 {getPointerType(), getInt32Type()}, 143 /*isVarArg=*/false)); 144 } 145 146 if (!module.lookupSymbol(kRunOnVulkan)) { 147 builder.create<LLVM::LLVMFuncOp>( 148 loc, kRunOnVulkan, 149 LLVM::LLVMType::getFunctionTy(getVoidType(), {}, 150 /*isVarArg=*/false)); 151 } 152 } 153 154 Value GpuLaunchFuncToVulkanCalssPass::createEntryPointNameConstant( 155 StringRef name, Location loc, OpBuilder &builder) { 156 SmallString<16> shaderName(name.begin(), name.end()); 157 // Append `\0` to follow C style string given that LLVM::createGlobalString() 158 // won't handle this directly for us. 159 shaderName.push_back('\0'); 160 161 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); 162 return LLVM::createGlobalString(loc, builder, entryPointGlobalName, 163 shaderName, LLVM::Linkage::Internal, 164 getLLVMDialect()); 165 } 166 167 LogicalResult GpuLaunchFuncToVulkanCalssPass::createBinaryShader( 168 ModuleOp module, std::vector<char> &binaryShader) { 169 bool done = false; 170 SmallVector<uint32_t, 0> binary; 171 for (auto spirvModule : module.getOps<spirv::ModuleOp>()) { 172 if (done) 173 return spirvModule.emitError("should only contain one 'spv.module' op"); 174 done = true; 175 176 if (failed(spirv::serialize(spirvModule, binary))) 177 return failure(); 178 } 179 180 binaryShader.resize(binary.size() * sizeof(uint32_t)); 181 std::memcpy(binaryShader.data(), reinterpret_cast<char *>(binary.data()), 182 binaryShader.size()); 183 return success(); 184 } 185 186 LogicalResult GpuLaunchFuncToVulkanCalssPass::createNumWorkGroups( 187 Location loc, OpBuilder &builder, mlir::gpu::LaunchFuncOp launchOp, 188 SmallVectorImpl<Value> &numWorkGroups) { 189 for (auto index : llvm::seq(0, 3)) { 190 auto numWorkGroupDimConstant = dyn_cast_or_null<ConstantOp>( 191 launchOp.getOperand(index).getDefiningOp()); 192 193 if (!numWorkGroupDimConstant) 194 return failure(); 195 196 auto numWorkGroupDimValue = 197 numWorkGroupDimConstant.getValue().cast<IntegerAttr>().getInt(); 198 numWorkGroups.push_back(builder.create<LLVM::ConstantOp>( 199 loc, getInt32Type(), builder.getI32IntegerAttr(numWorkGroupDimValue))); 200 } 201 202 return success(); 203 } 204 205 void GpuLaunchFuncToVulkanCalssPass::translateGpuLaunchCalls( 206 mlir::gpu::LaunchFuncOp launchOp) { 207 ModuleOp module = getModule(); 208 OpBuilder builder(launchOp); 209 Location loc = launchOp.getLoc(); 210 211 // Serialize `spirv::Module` into binary form. 212 std::vector<char> binary; 213 if (failed( 214 GpuLaunchFuncToVulkanCalssPass::createBinaryShader(module, binary))) 215 return signalPassFailure(); 216 217 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with 218 // that data to runtime call. 219 Value ptrToSPIRVBinary = LLVM::createGlobalString( 220 loc, builder, kSPIRVBinary, StringRef(binary.data(), binary.size()), 221 LLVM::Linkage::Internal, getLLVMDialect()); 222 // Create LLVM constant for the size of SPIR-V binary shader. 223 Value binarySize = builder.create<LLVM::ConstantOp>( 224 loc, getInt32Type(), builder.getI32IntegerAttr(binary.size())); 225 // Create call to `setBinaryShader` runtime function with the given pointer to 226 // SPIR-V binary and binary size. 227 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 228 builder.getSymbolRefAttr(kSetBinaryShader), 229 ArrayRef<Value>{ptrToSPIRVBinary, binarySize}); 230 231 // Create LLVM global with entry point name. 232 Value entryPointName = 233 createEntryPointNameConstant(launchOp.kernel(), loc, builder); 234 // Create call to `setEntryPoint` runtime function with the given pointer to 235 // entry point name. 236 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 237 builder.getSymbolRefAttr(kSetEntryPoint), 238 ArrayRef<Value>{entryPointName}); 239 240 // Create number of local workgroup for each dimension. 241 SmallVector<Value, 3> numWorkGroups; 242 if (failed(createNumWorkGroups(loc, builder, launchOp, numWorkGroups))) 243 return signalPassFailure(); 244 245 // Create call `setNumWorkGroups` runtime function with the given numbers of 246 // local workgroup. 247 builder.create<LLVM::CallOp>( 248 loc, ArrayRef<Type>{getVoidType()}, 249 builder.getSymbolRefAttr(kSetNumWorkGroups), 250 ArrayRef<Value>{numWorkGroups[0], numWorkGroups[1], numWorkGroups[2]}); 251 252 // Create call to `runOnVulkan` runtime function. 253 builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, 254 builder.getSymbolRefAttr(kRunOnVulkan), 255 ArrayRef<Value>{}); 256 257 // Declare runtime functions. 258 declareVulkanFunctions(loc); 259 260 launchOp.erase(); 261 } 262 263 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> 264 mlir::createConvertGpuLaunchFuncToVulkanCallsPass() { 265 return std::make_unique<GpuLaunchFuncToVulkanCalssPass>(); 266 } 267 268 static PassRegistration<GpuLaunchFuncToVulkanCalssPass> 269 pass("launch-func-to-vulkan", 270 "Convert gpu.launch_func op to Vulkan runtime calls"); 271