1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
18 #include "mlir/Dialect/GPU/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/Pass/Pass.h"
25 
26 #include "llvm/ADT/SmallString.h"
27 
28 using namespace mlir;
29 
30 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
31 static constexpr const char *kCInterfaceVulkanLaunch =
32     "_mlir_ciface_vulkanLaunch";
33 static constexpr const char *kDeinitVulkan = "deinitVulkan";
34 static constexpr const char *kRunOnVulkan = "runOnVulkan";
35 static constexpr const char *kInitVulkan = "initVulkan";
36 static constexpr const char *kSetBinaryShader = "setBinaryShader";
37 static constexpr const char *kSetEntryPoint = "setEntryPoint";
38 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
39 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
40 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
41 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
42 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
43 
44 namespace {
45 
46 /// A pass to convert vulkan launch call op into a sequence of Vulkan
47 /// runtime calls in the following order:
48 ///
49 /// * initVulkan           -- initializes vulkan runtime
50 /// * bindMemRef           -- binds memref
51 /// * setBinaryShader      -- sets the binary shader data
52 /// * setEntryPoint        -- sets the entry point name
53 /// * setNumWorkGroups     -- sets the number of a local workgroups
54 /// * runOnVulkan          -- runs vulkan runtime
55 /// * deinitVulkan         -- deinitializes vulkan runtime
56 ///
57 class VulkanLaunchFuncToVulkanCallsPass
58     : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
59 private:
60   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
61 
62   llvm::LLVMContext &getLLVMContext() {
63     return getLLVMDialect()->getLLVMContext();
64   }
65 
66   void initializeCachedTypes() {
67     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
68     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
69     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
70     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
71     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
72     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
73     initializeMemRefTypes();
74   }
75 
76   void initializeMemRefTypes() {
77     // According to the MLIR doc memref argument is converted into a
78     // pointer-to-struct argument of type:
79     // template <typename Elem, size_t Rank>
80     // struct {
81     //   Elem *allocated;
82     //   Elem *aligned;
83     //   int64_t offset;
84     //   int64_t sizes[Rank]; // omitted when rank == 0
85     //   int64_t strides[Rank]; // omitted when rank == 0
86     // };
87     auto llvmPtrToFloatType = getFloatType().getPointerTo();
88     auto llvmArrayOneElementSizeType =
89         LLVM::LLVMType::getArrayTy(getInt64Type(), 1);
90 
91     // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`.
92     llvmMemRef1DFloat = LLVM::LLVMType::getStructTy(
93         llvmDialect,
94         {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
95          llvmArrayOneElementSizeType, llvmArrayOneElementSizeType});
96   }
97 
98   LLVM::LLVMType getFloatType() { return llvmFloatType; }
99   LLVM::LLVMType getVoidType() { return llvmVoidType; }
100   LLVM::LLVMType getPointerType() { return llvmPointerType; }
101   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
102   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
103   LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
104 
105   /// Creates a LLVM global for the given `name`.
106   Value createEntryPointNameConstant(StringRef name, Location loc,
107                                      OpBuilder &builder);
108 
109   /// Declares all needed runtime functions.
110   void declareVulkanFunctions(Location loc);
111 
112   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
113   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
114     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
115             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
116   }
117 
118   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
119   /// op.
120   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
121     return (callOp.callee() &&
122             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
123             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
124   }
125 
126   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
127   /// runtime calls.
128   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
129 
130   /// Creates call to `bindMemRef` for each memref operand.
131   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
132                              Value vulkanRuntime);
133 
134   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
135   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
136 
137 public:
138   void runOnModule() override;
139 
140 private:
141   LLVM::LLVMDialect *llvmDialect;
142   LLVM::LLVMType llvmFloatType;
143   LLVM::LLVMType llvmVoidType;
144   LLVM::LLVMType llvmPointerType;
145   LLVM::LLVMType llvmInt32Type;
146   LLVM::LLVMType llvmInt64Type;
147   LLVM::LLVMType llvmMemRef1DFloat;
148 
149   // TODO: Use an associative array to support multiple vulkan launch calls.
150   std::pair<StringAttr, StringAttr> spirvAttributes;
151 };
152 
153 } // anonymous namespace
154 
155 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
156   initializeCachedTypes();
157 
158   // Collect SPIR-V attributes such as `spirv_blob` and
159   // `spirv_entry_point_name`.
160   getModule().walk([this](LLVM::CallOp op) {
161     if (isVulkanLaunchCallOp(op))
162       collectSPIRVAttributes(op);
163   });
164 
165   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
166   getModule().walk([this](LLVM::CallOp op) {
167     if (isCInterfaceVulkanLaunchCallOp(op))
168       translateVulkanLaunchCall(op);
169   });
170 }
171 
172 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
173     LLVM::CallOp vulkanLaunchCallOp) {
174   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
175   // for the given vulkan launch call.
176   auto spirvBlobAttr =
177       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
178   if (!spirvBlobAttr) {
179     vulkanLaunchCallOp.emitError()
180         << "missing " << kSPIRVBlobAttrName << " attribute";
181     return signalPassFailure();
182   }
183 
184   auto spirvEntryPointNameAttr =
185       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
186   if (!spirvEntryPointNameAttr) {
187     vulkanLaunchCallOp.emitError()
188         << "missing " << kSPIRVEntryPointAttrName << " attribute";
189     return signalPassFailure();
190   }
191 
192   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
193 }
194 
195 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
196     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
197   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
198       gpu::LaunchOp::kNumConfigOperands)
199     return;
200   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
201   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
202 
203   // Create LLVM constant for the descriptor set index.
204   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
205   // pass does.
206   Value descriptorSet = builder.create<LLVM::ConstantOp>(
207       loc, getInt32Type(), builder.getI32IntegerAttr(0));
208 
209   for (auto en :
210        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
211            gpu::LaunchOp::kNumConfigOperands))) {
212     // Create LLVM constant for the descriptor binding index.
213     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
214         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
215     // Create call to `bindMemRef`.
216     builder.create<LLVM::CallOp>(
217         loc, ArrayRef<Type>{getVoidType()},
218         // TODO: Add support for memref with other ranks.
219         builder.getSymbolRefAttr(kBindMemRef1DFloat),
220         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
221                         en.value()});
222   }
223 }
224 
225 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
226   ModuleOp module = getModule();
227   OpBuilder builder(module.getBody()->getTerminator());
228 
229   if (!module.lookupSymbol(kSetEntryPoint)) {
230     builder.create<LLVM::LLVMFuncOp>(
231         loc, kSetEntryPoint,
232         LLVM::LLVMType::getFunctionTy(getVoidType(),
233                                       {getPointerType(), getPointerType()},
234                                       /*isVarArg=*/false));
235   }
236 
237   if (!module.lookupSymbol(kSetNumWorkGroups)) {
238     builder.create<LLVM::LLVMFuncOp>(
239         loc, kSetNumWorkGroups,
240         LLVM::LLVMType::getFunctionTy(
241             getVoidType(),
242             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
243             /*isVarArg=*/false));
244   }
245 
246   if (!module.lookupSymbol(kSetBinaryShader)) {
247     builder.create<LLVM::LLVMFuncOp>(
248         loc, kSetBinaryShader,
249         LLVM::LLVMType::getFunctionTy(
250             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
251             /*isVarArg=*/false));
252   }
253 
254   if (!module.lookupSymbol(kRunOnVulkan)) {
255     builder.create<LLVM::LLVMFuncOp>(
256         loc, kRunOnVulkan,
257         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
258                                       /*isVarArg=*/false));
259   }
260 
261   if (!module.lookupSymbol(kBindMemRef1DFloat)) {
262     builder.create<LLVM::LLVMFuncOp>(
263         loc, kBindMemRef1DFloat,
264         LLVM::LLVMType::getFunctionTy(getVoidType(),
265                                       {getPointerType(), getInt32Type(),
266                                        getInt32Type(),
267                                        getMemRef1DFloat().getPointerTo()},
268                                       /*isVarArg=*/false));
269   }
270 
271   if (!module.lookupSymbol(kInitVulkan)) {
272     builder.create<LLVM::LLVMFuncOp>(
273         loc, kInitVulkan,
274         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
275                                       /*isVarArg=*/false));
276   }
277 
278   if (!module.lookupSymbol(kDeinitVulkan)) {
279     builder.create<LLVM::LLVMFuncOp>(
280         loc, kDeinitVulkan,
281         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
282                                       /*isVarArg=*/false));
283   }
284 }
285 
286 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
287     StringRef name, Location loc, OpBuilder &builder) {
288   SmallString<16> shaderName(name.begin(), name.end());
289   // Append `\0` to follow C style string given that LLVM::createGlobalString()
290   // won't handle this directly for us.
291   shaderName.push_back('\0');
292 
293   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
294   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
295                                   shaderName, LLVM::Linkage::Internal,
296                                   getLLVMDialect());
297 }
298 
299 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
300     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
301   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
302   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
303   // Create call to `initVulkan`.
304   auto initVulkanCall = builder.create<LLVM::CallOp>(
305       loc, ArrayRef<Type>{getPointerType()},
306       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
307   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
308   // need to pass that pointer to each Vulkan runtime call.
309   auto vulkanRuntime = initVulkanCall.getResult(0);
310 
311   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
312   // that data to runtime call.
313   Value ptrToSPIRVBinary = LLVM::createGlobalString(
314       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
315       LLVM::Linkage::Internal, getLLVMDialect());
316 
317   // Create LLVM constant for the size of SPIR-V binary shader.
318   Value binarySize = builder.create<LLVM::ConstantOp>(
319       loc, getInt32Type(),
320       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
321 
322   // Create call to `bindMemRef` for each memref operand.
323   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
324 
325   // Create call to `setBinaryShader` runtime function with the given pointer to
326   // SPIR-V binary and binary size.
327   builder.create<LLVM::CallOp>(
328       loc, ArrayRef<Type>{getVoidType()},
329       builder.getSymbolRefAttr(kSetBinaryShader),
330       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
331   // Create LLVM global with entry point name.
332   Value entryPointName = createEntryPointNameConstant(
333       spirvAttributes.second.getValue(), loc, builder);
334   // Create call to `setEntryPoint` runtime function with the given pointer to
335   // entry point name.
336   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
337                                builder.getSymbolRefAttr(kSetEntryPoint),
338                                ArrayRef<Value>{vulkanRuntime, entryPointName});
339 
340   // Create number of local workgroup for each dimension.
341   builder.create<LLVM::CallOp>(
342       loc, ArrayRef<Type>{getVoidType()},
343       builder.getSymbolRefAttr(kSetNumWorkGroups),
344       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
345                       cInterfaceVulkanLaunchCallOp.getOperand(1),
346                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
347 
348   // Create call to `runOnVulkan` runtime function.
349   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
350                                builder.getSymbolRefAttr(kRunOnVulkan),
351                                ArrayRef<Value>{vulkanRuntime});
352 
353   // Create call to 'deinitVulkan' runtime function.
354   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
355                                builder.getSymbolRefAttr(kDeinitVulkan),
356                                ArrayRef<Value>{vulkanRuntime});
357 
358   // Declare runtime functions.
359   declareVulkanFunctions(loc);
360 
361   cInterfaceVulkanLaunchCallOp.erase();
362 }
363 
364 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
365 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
366   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
367 }
368 
369 static PassRegistration<VulkanLaunchFuncToVulkanCallsPass>
370     pass("launch-func-to-vulkan",
371          "Convert vulkanLaunch external call to Vulkan runtime external calls");
372