1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
18 #include "mlir/Dialect/GPU/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
21 #include "mlir/Dialect/SPIRV/Serialization.h"
22 #include "mlir/Dialect/StandardOps/Ops.h"
23 #include "mlir/IR/Attributes.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/Function.h"
26 #include "mlir/IR/Module.h"
27 #include "mlir/IR/StandardTypes.h"
28 #include "mlir/Pass/Pass.h"
29 
30 #include "llvm/Support/FormatVariadic.h"
31 
32 using namespace mlir;
33 
34 static constexpr const char *kSetBinaryShader = "setBinaryShader";
35 static constexpr const char *kSetEntryPoint = "setEntryPoint";
36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
37 static constexpr const char *kRunOnVulkan = "runOnVulkan";
38 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
39 
40 namespace {
41 
42 /// A pass to convert gpu.launch_func operation into a sequence of Vulkan
43 /// runtime calls.
44 ///
45 /// * setBinaryShader      -- sets the binary shader data
46 /// * setEntryPoint        -- sets the entry point name
47 /// * setNumWorkGroups     -- sets the number of a local workgroups
48 /// * runOnVulkan          -- runs vulkan runtime
49 ///
50 class GpuLaunchFuncToVulkanCalssPass
51     : public ModulePass<GpuLaunchFuncToVulkanCalssPass> {
52 private:
53   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
54 
55   llvm::LLVMContext &getLLVMContext() {
56     return getLLVMDialect()->getLLVMContext();
57   }
58 
59   void initializeCachedTypes() {
60     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
61     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
62     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
63     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
64   }
65 
66   LLVM::LLVMType getVoidType() { return llvmVoidType; }
67   LLVM::LLVMType getPointerType() { return llvmPointerType; }
68   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
69 
70   /// Creates a SPIR-V binary shader from the given `module` using
71   /// `spirv::serialize` function.
72   LogicalResult createBinaryShader(ModuleOp module,
73                                    std::vector<char> &binaryShader);
74 
75   /// Creates a LLVM global for the given `name`.
76   Value createEntryPointNameConstant(StringRef name, Location loc,
77                                      OpBuilder &builder);
78 
79   /// Creates a LLVM constant for each dimension of local workgroup and
80   /// populates the given `numWorkGroups`.
81   LogicalResult createNumWorkGroups(Location loc, OpBuilder &builder,
82                                     mlir::gpu::LaunchFuncOp launchOp,
83                                     SmallVector<Value, 3> &numWorkGroups);
84 
85   /// Declares all needed runtime functions.
86   void declareVulkanFunctions(Location loc);
87 
88   /// Translates the given `launcOp` op to the sequence of Vulkan runtime calls
89   void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
90 
91 public:
92   void runOnModule() override;
93 
94 private:
95   LLVM::LLVMDialect *llvmDialect;
96   LLVM::LLVMType llvmVoidType;
97   LLVM::LLVMType llvmPointerType;
98   LLVM::LLVMType llvmInt32Type;
99 };
100 
101 } // anonymous namespace
102 
103 void GpuLaunchFuncToVulkanCalssPass::runOnModule() {
104   initializeCachedTypes();
105 
106   getModule().walk(
107       [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
108 
109   // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
110   for (auto gpuModule :
111        llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
112     gpuModule.erase();
113 
114   for (auto spirvModule :
115        llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
116     spirvModule.erase();
117 }
118 
119 void GpuLaunchFuncToVulkanCalssPass::declareVulkanFunctions(Location loc) {
120   ModuleOp module = getModule();
121   OpBuilder builder(module.getBody()->getTerminator());
122 
123   if (!module.lookupSymbol(kSetEntryPoint)) {
124     builder.create<LLVM::LLVMFuncOp>(
125         loc, kSetEntryPoint,
126         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
127                                       /*isVarArg=*/false));
128   }
129 
130   if (!module.lookupSymbol(kSetNumWorkGroups)) {
131     builder.create<LLVM::LLVMFuncOp>(
132         loc, kSetNumWorkGroups,
133         LLVM::LLVMType::getFunctionTy(
134             getVoidType(), {getInt32Type(), getInt32Type(), getInt32Type()},
135             /*isVarArg=*/false));
136   }
137 
138   if (!module.lookupSymbol(kSetBinaryShader)) {
139     builder.create<LLVM::LLVMFuncOp>(
140         loc, kSetBinaryShader,
141         LLVM::LLVMType::getFunctionTy(getVoidType(),
142                                       {getPointerType(), getInt32Type()},
143                                       /*isVarArg=*/false));
144   }
145 
146   if (!module.lookupSymbol(kRunOnVulkan)) {
147     builder.create<LLVM::LLVMFuncOp>(
148         loc, kRunOnVulkan,
149         LLVM::LLVMType::getFunctionTy(getVoidType(), {},
150                                       /*isVarArg=*/false));
151   }
152 }
153 
154 Value GpuLaunchFuncToVulkanCalssPass::createEntryPointNameConstant(
155     StringRef name, Location loc, OpBuilder &builder) {
156   std::vector<char> shaderName(name.begin(), name.end());
157   // Append `\0` to follow C style string given that LLVM::createGlobalString()
158   // won't handle this directly for us.
159   shaderName.push_back('\0');
160 
161   std::string entryPointGlobalName =
162       std::string(llvm::formatv("{0}_spv_entry_point_name", name));
163   return LLVM::createGlobalString(
164       loc, builder, entryPointGlobalName,
165       StringRef(shaderName.data(), shaderName.size()), LLVM::Linkage::Internal,
166       getLLVMDialect());
167 }
168 
169 LogicalResult GpuLaunchFuncToVulkanCalssPass::createBinaryShader(
170     ModuleOp module, std::vector<char> &binaryShader) {
171   bool done = false;
172   SmallVector<uint32_t, 0> binary;
173   for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
174     if (done) {
175       spirvModule.emitError("should only contain one 'spv.module' op");
176       return failure();
177     }
178     done = true;
179     if (failed(spirv::serialize(spirvModule, binary))) {
180       return failure();
181     }
182   }
183 
184   binaryShader.resize(binary.size() * sizeof(uint32_t));
185   std::memcpy(binaryShader.data(), reinterpret_cast<char *>(binary.data()),
186               binaryShader.size());
187   return success();
188 }
189 
190 LogicalResult GpuLaunchFuncToVulkanCalssPass::createNumWorkGroups(
191     Location loc, OpBuilder &builder, mlir::gpu::LaunchFuncOp launchOp,
192     SmallVector<Value, 3> &numWorkGroups) {
193   for (auto index : llvm::seq(0, 3)) {
194     auto numWorkGroupDimConstant = dyn_cast_or_null<ConstantOp>(
195         launchOp.getOperand(index).getDefiningOp());
196 
197     if (!numWorkGroupDimConstant) {
198       return failure();
199     }
200 
201     auto numWorkGroupDimValue =
202         numWorkGroupDimConstant.getValue().cast<IntegerAttr>().getInt();
203     numWorkGroups.push_back(builder.create<LLVM::ConstantOp>(
204         loc, getInt32Type(), builder.getI32IntegerAttr(numWorkGroupDimValue)));
205   }
206 
207   return success();
208 }
209 
210 // Translates gpu launch op to the sequence of Vulkan runtime calls.
211 void GpuLaunchFuncToVulkanCalssPass::translateGpuLaunchCalls(
212     mlir::gpu::LaunchFuncOp launchOp) {
213   ModuleOp module = getModule();
214   OpBuilder builder(launchOp);
215   Location loc = launchOp.getLoc();
216 
217   // Serialize `spirv::Module` into binary form.
218   std::vector<char> binary;
219   if (failed(
220           GpuLaunchFuncToVulkanCalssPass::createBinaryShader(module, binary))) {
221     return signalPassFailure();
222   }
223 
224   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
225   // that data to runtime call.
226   Value ptrToSPIRVBinary = LLVM::createGlobalString(
227       loc, builder, kSPIRVBinary, StringRef(binary.data(), binary.size()),
228       LLVM::Linkage::Internal, getLLVMDialect());
229   // Create LLVM constant for the size of SPIR-V binary shader.
230   Value binarySize = builder.create<LLVM::ConstantOp>(
231       loc, getInt32Type(), builder.getI32IntegerAttr(binary.size()));
232   // Create call to `setBinaryShader` runtime function with the given pointer to
233   // SPIR-V binary and binary size.
234   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
235                                builder.getSymbolRefAttr(kSetBinaryShader),
236                                ArrayRef<Value>{ptrToSPIRVBinary, binarySize});
237 
238   // Create LLVM global with entry point name.
239   Value entryPointName =
240       createEntryPointNameConstant(launchOp.kernel(), loc, builder);
241   // Create call to `setEntryPoint` runtime function with the given pointer to
242   // entry point name.
243   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
244                                builder.getSymbolRefAttr(kSetEntryPoint),
245                                ArrayRef<Value>{entryPointName});
246 
247   // Create number of local workgroup for each dimension.
248   SmallVector<Value, 3> numWorkGroups;
249   if (failed(createNumWorkGroups(loc, builder, launchOp, numWorkGroups))) {
250     return signalPassFailure();
251   }
252 
253   // Create call `setNumWorkGroups` runtime function with the given numbers of
254   // local workgroup.
255   builder.create<LLVM::CallOp>(
256       loc, ArrayRef<Type>{getVoidType()},
257       builder.getSymbolRefAttr(kSetNumWorkGroups),
258       ArrayRef<Value>{numWorkGroups[0], numWorkGroups[1], numWorkGroups[2]});
259 
260   // Create call to `runOnVulkan` runtime function.
261   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
262                                builder.getSymbolRefAttr(kRunOnVulkan),
263                                ArrayRef<Value>{});
264 
265   // Declare runtime functions.
266   declareVulkanFunctions(loc);
267 
268   launchOp.erase();
269 }
270 
271 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
272 mlir::createConvertGpuLaunchFuncToVulkanCallsPass() {
273   return std::make_unique<GpuLaunchFuncToVulkanCalssPass>();
274 }
275 
276 static PassRegistration<GpuLaunchFuncToVulkanCalssPass>
277     pass("launch-func-to-vulkan",
278          "Convert gpu.launch_func op to Vulkan runtime calls");
279