1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
18 #include "mlir/Dialect/GPU/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
21 #include "mlir/Dialect/SPIRV/Serialization.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/IR/Attributes.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/Function.h"
26 #include "mlir/IR/Module.h"
27 #include "mlir/IR/StandardTypes.h"
28 #include "mlir/Pass/Pass.h"
29 
30 #include "llvm/ADT/SmallString.h"
31 
32 using namespace mlir;
33 
34 static constexpr const char *kSetBinaryShader = "setBinaryShader";
35 static constexpr const char *kSetEntryPoint = "setEntryPoint";
36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
37 static constexpr const char *kRunOnVulkan = "runOnVulkan";
38 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
39 
40 namespace {
41 
42 /// A pass to convert gpu.launch_func operation into a sequence of Vulkan
43 /// runtime calls.
44 ///
45 /// * setBinaryShader      -- sets the binary shader data
46 /// * setEntryPoint        -- sets the entry point name
47 /// * setNumWorkGroups     -- sets the number of a local workgroups
48 /// * runOnVulkan          -- runs vulkan runtime
49 ///
50 class GpuLaunchFuncToVulkanCalssPass
51     : public ModulePass<GpuLaunchFuncToVulkanCalssPass> {
52 private:
53   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
54 
55   llvm::LLVMContext &getLLVMContext() {
56     return getLLVMDialect()->getLLVMContext();
57   }
58 
59   void initializeCachedTypes() {
60     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
61     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
62     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
63     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
64   }
65 
66   LLVM::LLVMType getVoidType() { return llvmVoidType; }
67   LLVM::LLVMType getPointerType() { return llvmPointerType; }
68   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
69 
70   /// Creates a SPIR-V binary shader from the given `module` using
71   /// `spirv::serialize` function.
72   LogicalResult createBinaryShader(ModuleOp module,
73                                    std::vector<char> &binaryShader);
74 
75   /// Creates a LLVM global for the given `name`.
76   Value createEntryPointNameConstant(StringRef name, Location loc,
77                                      OpBuilder &builder);
78 
79   /// Creates a LLVM constant for each dimension of local workgroup and
80   /// populates the given `numWorkGroups`.
81   LogicalResult createNumWorkGroups(Location loc, OpBuilder &builder,
82                                     mlir::gpu::LaunchFuncOp launchOp,
83                                     SmallVectorImpl<Value> &numWorkGroups);
84 
85   /// Declares all needed runtime functions.
86   void declareVulkanFunctions(Location loc);
87 
88   /// Translates the given `launcOp` op to the sequence of Vulkan runtime calls
89   void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
90 
91 public:
92   void runOnModule() override;
93 
94 private:
95   LLVM::LLVMDialect *llvmDialect;
96   LLVM::LLVMType llvmVoidType;
97   LLVM::LLVMType llvmPointerType;
98   LLVM::LLVMType llvmInt32Type;
99 };
100 
101 } // anonymous namespace
102 
103 void GpuLaunchFuncToVulkanCalssPass::runOnModule() {
104   initializeCachedTypes();
105 
106   getModule().walk(
107       [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
108 
109   // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
110   for (auto gpuModule :
111        llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
112     gpuModule.erase();
113 
114   for (auto spirvModule :
115        llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
116     spirvModule.erase();
117 }
118 
119 void GpuLaunchFuncToVulkanCalssPass::declareVulkanFunctions(Location loc) {
120   ModuleOp module = getModule();
121   OpBuilder builder(module.getBody()->getTerminator());
122 
123   if (!module.lookupSymbol(kSetEntryPoint)) {
124     builder.create<LLVM::LLVMFuncOp>(
125         loc, kSetEntryPoint,
126         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
127                                       /*isVarArg=*/false));
128   }
129 
130   if (!module.lookupSymbol(kSetNumWorkGroups)) {
131     builder.create<LLVM::LLVMFuncOp>(
132         loc, kSetNumWorkGroups,
133         LLVM::LLVMType::getFunctionTy(
134             getVoidType(), {getInt32Type(), getInt32Type(), getInt32Type()},
135             /*isVarArg=*/false));
136   }
137 
138   if (!module.lookupSymbol(kSetBinaryShader)) {
139     builder.create<LLVM::LLVMFuncOp>(
140         loc, kSetBinaryShader,
141         LLVM::LLVMType::getFunctionTy(getVoidType(),
142                                       {getPointerType(), getInt32Type()},
143                                       /*isVarArg=*/false));
144   }
145 
146   if (!module.lookupSymbol(kRunOnVulkan)) {
147     builder.create<LLVM::LLVMFuncOp>(
148         loc, kRunOnVulkan,
149         LLVM::LLVMType::getFunctionTy(getVoidType(), {},
150                                       /*isVarArg=*/false));
151   }
152 }
153 
154 Value GpuLaunchFuncToVulkanCalssPass::createEntryPointNameConstant(
155     StringRef name, Location loc, OpBuilder &builder) {
156   SmallString<16> shaderName(name.begin(), name.end());
157   // Append `\0` to follow C style string given that LLVM::createGlobalString()
158   // won't handle this directly for us.
159   shaderName.push_back('\0');
160 
161   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
162   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
163                                   shaderName, LLVM::Linkage::Internal,
164                                   getLLVMDialect());
165 }
166 
167 LogicalResult GpuLaunchFuncToVulkanCalssPass::createBinaryShader(
168     ModuleOp module, std::vector<char> &binaryShader) {
169   bool done = false;
170   SmallVector<uint32_t, 0> binary;
171   for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
172     if (done)
173       return spirvModule.emitError("should only contain one 'spv.module' op");
174     done = true;
175 
176     if (failed(spirv::serialize(spirvModule, binary)))
177       return failure();
178   }
179 
180   binaryShader.resize(binary.size() * sizeof(uint32_t));
181   std::memcpy(binaryShader.data(), reinterpret_cast<char *>(binary.data()),
182               binaryShader.size());
183   return success();
184 }
185 
186 LogicalResult GpuLaunchFuncToVulkanCalssPass::createNumWorkGroups(
187     Location loc, OpBuilder &builder, mlir::gpu::LaunchFuncOp launchOp,
188     SmallVectorImpl<Value> &numWorkGroups) {
189   for (auto index : llvm::seq(0, 3)) {
190     auto numWorkGroupDimConstant = dyn_cast_or_null<ConstantOp>(
191         launchOp.getOperand(index).getDefiningOp());
192 
193     if (!numWorkGroupDimConstant)
194       return failure();
195 
196     auto numWorkGroupDimValue =
197         numWorkGroupDimConstant.getValue().cast<IntegerAttr>().getInt();
198     numWorkGroups.push_back(builder.create<LLVM::ConstantOp>(
199         loc, getInt32Type(), builder.getI32IntegerAttr(numWorkGroupDimValue)));
200   }
201 
202   return success();
203 }
204 
205 void GpuLaunchFuncToVulkanCalssPass::translateGpuLaunchCalls(
206     mlir::gpu::LaunchFuncOp launchOp) {
207   ModuleOp module = getModule();
208   OpBuilder builder(launchOp);
209   Location loc = launchOp.getLoc();
210 
211   // Serialize `spirv::Module` into binary form.
212   std::vector<char> binary;
213   if (failed(
214           GpuLaunchFuncToVulkanCalssPass::createBinaryShader(module, binary)))
215     return signalPassFailure();
216 
217   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
218   // that data to runtime call.
219   Value ptrToSPIRVBinary = LLVM::createGlobalString(
220       loc, builder, kSPIRVBinary, StringRef(binary.data(), binary.size()),
221       LLVM::Linkage::Internal, getLLVMDialect());
222   // Create LLVM constant for the size of SPIR-V binary shader.
223   Value binarySize = builder.create<LLVM::ConstantOp>(
224       loc, getInt32Type(), builder.getI32IntegerAttr(binary.size()));
225   // Create call to `setBinaryShader` runtime function with the given pointer to
226   // SPIR-V binary and binary size.
227   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
228                                builder.getSymbolRefAttr(kSetBinaryShader),
229                                ArrayRef<Value>{ptrToSPIRVBinary, binarySize});
230 
231   // Create LLVM global with entry point name.
232   Value entryPointName =
233       createEntryPointNameConstant(launchOp.kernel(), loc, builder);
234   // Create call to `setEntryPoint` runtime function with the given pointer to
235   // entry point name.
236   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
237                                builder.getSymbolRefAttr(kSetEntryPoint),
238                                ArrayRef<Value>{entryPointName});
239 
240   // Create number of local workgroup for each dimension.
241   SmallVector<Value, 3> numWorkGroups;
242   if (failed(createNumWorkGroups(loc, builder, launchOp, numWorkGroups)))
243     return signalPassFailure();
244 
245   // Create call `setNumWorkGroups` runtime function with the given numbers of
246   // local workgroup.
247   builder.create<LLVM::CallOp>(
248       loc, ArrayRef<Type>{getVoidType()},
249       builder.getSymbolRefAttr(kSetNumWorkGroups),
250       ArrayRef<Value>{numWorkGroups[0], numWorkGroups[1], numWorkGroups[2]});
251 
252   // Create call to `runOnVulkan` runtime function.
253   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
254                                builder.getSymbolRefAttr(kRunOnVulkan),
255                                ArrayRef<Value>{});
256 
257   // Declare runtime functions.
258   declareVulkanFunctions(loc);
259 
260   launchOp.erase();
261 }
262 
263 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
264 mlir::createConvertGpuLaunchFuncToVulkanCallsPass() {
265   return std::make_unique<GpuLaunchFuncToVulkanCalssPass>();
266 }
267 
268 static PassRegistration<GpuLaunchFuncToVulkanCalssPass>
269     pass("launch-func-to-vulkan",
270          "Convert gpu.launch_func op to Vulkan runtime calls");
271