1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
18 #include "mlir/Dialect/GPU/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/Pass/Pass.h"
25 
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 using namespace mlir;
30 
31 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
32 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
33 static constexpr const char *kCInterfaceVulkanLaunch =
34     "_mlir_ciface_vulkanLaunch";
35 static constexpr const char *kDeinitVulkan = "deinitVulkan";
36 static constexpr const char *kRunOnVulkan = "runOnVulkan";
37 static constexpr const char *kInitVulkan = "initVulkan";
38 static constexpr const char *kSetBinaryShader = "setBinaryShader";
39 static constexpr const char *kSetEntryPoint = "setEntryPoint";
40 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
41 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
42 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
43 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
44 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
45 
46 namespace {
47 
48 /// A pass to convert vulkan launch call op into a sequence of Vulkan
49 /// runtime calls in the following order:
50 ///
51 /// * initVulkan           -- initializes vulkan runtime
52 /// * bindMemRef           -- binds memref
53 /// * setBinaryShader      -- sets the binary shader data
54 /// * setEntryPoint        -- sets the entry point name
55 /// * setNumWorkGroups     -- sets the number of a local workgroups
56 /// * runOnVulkan          -- runs vulkan runtime
57 /// * deinitVulkan         -- deinitializes vulkan runtime
58 ///
59 class VulkanLaunchFuncToVulkanCallsPass
60     : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
61 private:
62 /// Include the generated pass utilities.
63 #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls
64 #include "mlir/Conversion/Passes.h.inc"
65 
66   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
67 
68   llvm::LLVMContext &getLLVMContext() {
69     return getLLVMDialect()->getLLVMContext();
70   }
71 
72   void initializeCachedTypes() {
73     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
74     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
75     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
76     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
77     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
78     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
79     initializeMemRefTypes();
80   }
81 
82   void initializeMemRefTypes() {
83     // According to the MLIR doc memref argument is converted into a
84     // pointer-to-struct argument of type:
85     // template <typename Elem, size_t Rank>
86     // struct {
87     //   Elem *allocated;
88     //   Elem *aligned;
89     //   int64_t offset;
90     //   int64_t sizes[Rank]; // omitted when rank == 0
91     //   int64_t strides[Rank]; // omitted when rank == 0
92     // };
93     auto llvmPtrToFloatType = getFloatType().getPointerTo();
94     auto llvmArrayOneElementSizeType =
95         LLVM::LLVMType::getArrayTy(getInt64Type(), 1);
96     auto llvmArrayTwoElementSizeType =
97         LLVM::LLVMType::getArrayTy(getInt64Type(), 2);
98 
99     // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`.
100     llvmMemRef1DFloat = LLVM::LLVMType::getStructTy(
101         llvmDialect,
102         {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
103          llvmArrayOneElementSizeType, llvmArrayOneElementSizeType});
104 
105     // Create a type `!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64]}">`.
106     llvmMemRef2DFloat = LLVM::LLVMType::getStructTy(
107         llvmDialect,
108         {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
109          llvmArrayTwoElementSizeType, llvmArrayTwoElementSizeType});
110   }
111 
112   LLVM::LLVMType getFloatType() { return llvmFloatType; }
113   LLVM::LLVMType getVoidType() { return llvmVoidType; }
114   LLVM::LLVMType getPointerType() { return llvmPointerType; }
115   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
116   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
117   LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
118   LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
119 
120   /// Creates a LLVM global for the given `name`.
121   Value createEntryPointNameConstant(StringRef name, Location loc,
122                                      OpBuilder &builder);
123 
124   /// Declares all needed runtime functions.
125   void declareVulkanFunctions(Location loc);
126 
127   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
128   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
129     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
130             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
131   }
132 
133   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
134   /// op.
135   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
136     return (callOp.callee() &&
137             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
138             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
139   }
140 
141   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
142   /// runtime calls.
143   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
144 
145   /// Creates call to `bindMemRef` for each memref operand.
146   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
147                              Value vulkanRuntime);
148 
149   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
150   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
151 
152   /// Deduces a rank from the given 'ptrToMemRefDescriptor`.
153   LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
154 
155 public:
156   void runOnModule() override;
157 
158 private:
159   LLVM::LLVMDialect *llvmDialect;
160   LLVM::LLVMType llvmFloatType;
161   LLVM::LLVMType llvmVoidType;
162   LLVM::LLVMType llvmPointerType;
163   LLVM::LLVMType llvmInt32Type;
164   LLVM::LLVMType llvmInt64Type;
165   LLVM::LLVMType llvmMemRef1DFloat;
166   LLVM::LLVMType llvmMemRef2DFloat;
167 
168   // TODO: Use an associative array to support multiple vulkan launch calls.
169   std::pair<StringAttr, StringAttr> spirvAttributes;
170 };
171 
172 } // anonymous namespace
173 
174 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
175   initializeCachedTypes();
176 
177   // Collect SPIR-V attributes such as `spirv_blob` and
178   // `spirv_entry_point_name`.
179   getModule().walk([this](LLVM::CallOp op) {
180     if (isVulkanLaunchCallOp(op))
181       collectSPIRVAttributes(op);
182   });
183 
184   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
185   getModule().walk([this](LLVM::CallOp op) {
186     if (isCInterfaceVulkanLaunchCallOp(op))
187       translateVulkanLaunchCall(op);
188   });
189 }
190 
191 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
192     LLVM::CallOp vulkanLaunchCallOp) {
193   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
194   // for the given vulkan launch call.
195   auto spirvBlobAttr =
196       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
197   if (!spirvBlobAttr) {
198     vulkanLaunchCallOp.emitError()
199         << "missing " << kSPIRVBlobAttrName << " attribute";
200     return signalPassFailure();
201   }
202 
203   auto spirvEntryPointNameAttr =
204       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
205   if (!spirvEntryPointNameAttr) {
206     vulkanLaunchCallOp.emitError()
207         << "missing " << kSPIRVEntryPointAttrName << " attribute";
208     return signalPassFailure();
209   }
210 
211   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
212 }
213 
214 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
215     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
216   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
217       gpu::LaunchOp::kNumConfigOperands)
218     return;
219   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
220   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
221 
222   // Create LLVM constant for the descriptor set index.
223   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
224   // pass does.
225   Value descriptorSet = builder.create<LLVM::ConstantOp>(
226       loc, getInt32Type(), builder.getI32IntegerAttr(0));
227 
228   for (auto en :
229        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
230            gpu::LaunchOp::kNumConfigOperands))) {
231     // Create LLVM constant for the descriptor binding index.
232     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
233         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
234 
235     auto ptrToMemRefDescriptor = en.value();
236     uint32_t rank = 0;
237     if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
238       cInterfaceVulkanLaunchCallOp.emitError()
239           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
240       return signalPassFailure();
241     }
242 
243     auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str();
244     // Create call to `bindMemRef`.
245     builder.create<LLVM::CallOp>(
246         loc, ArrayRef<Type>{getVoidType()},
247         builder.getSymbolRefAttr(
248             StringRef(symbolName.data(), symbolName.size())),
249         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
250                         ptrToMemRefDescriptor});
251   }
252 }
253 
254 LogicalResult
255 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
256                                                     uint32_t &rank) {
257   auto llvmPtrDescriptorTy =
258       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
259   if (!llvmPtrDescriptorTy)
260     return failure();
261 
262   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
263   // template <typename Elem, size_t Rank>
264   // struct {
265   //   Elem *allocated;
266   //   Elem *aligned;
267   //   int64_t offset;
268   //   int64_t sizes[Rank]; // omitted when rank == 0
269   //   int64_t strides[Rank]; // omitted when rank == 0
270   // };
271   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
272     return failure();
273   if (llvmDescriptorTy.getStructNumElements() == 3) {
274     rank = 0;
275     return success();
276   }
277 
278   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
279   return success();
280 }
281 
282 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
283   ModuleOp module = getModule();
284   OpBuilder builder(module.getBody()->getTerminator());
285 
286   if (!module.lookupSymbol(kSetEntryPoint)) {
287     builder.create<LLVM::LLVMFuncOp>(
288         loc, kSetEntryPoint,
289         LLVM::LLVMType::getFunctionTy(getVoidType(),
290                                       {getPointerType(), getPointerType()},
291                                       /*isVarArg=*/false));
292   }
293 
294   if (!module.lookupSymbol(kSetNumWorkGroups)) {
295     builder.create<LLVM::LLVMFuncOp>(
296         loc, kSetNumWorkGroups,
297         LLVM::LLVMType::getFunctionTy(
298             getVoidType(),
299             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
300             /*isVarArg=*/false));
301   }
302 
303   if (!module.lookupSymbol(kSetBinaryShader)) {
304     builder.create<LLVM::LLVMFuncOp>(
305         loc, kSetBinaryShader,
306         LLVM::LLVMType::getFunctionTy(
307             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
308             /*isVarArg=*/false));
309   }
310 
311   if (!module.lookupSymbol(kRunOnVulkan)) {
312     builder.create<LLVM::LLVMFuncOp>(
313         loc, kRunOnVulkan,
314         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
315                                       /*isVarArg=*/false));
316   }
317 
318   if (!module.lookupSymbol(kBindMemRef1DFloat)) {
319     builder.create<LLVM::LLVMFuncOp>(
320         loc, kBindMemRef1DFloat,
321         LLVM::LLVMType::getFunctionTy(getVoidType(),
322                                       {getPointerType(), getInt32Type(),
323                                        getInt32Type(),
324                                        getMemRef1DFloat().getPointerTo()},
325                                       /*isVarArg=*/false));
326   }
327 
328   if (!module.lookupSymbol(kBindMemRef2DFloat)) {
329     builder.create<LLVM::LLVMFuncOp>(
330         loc, kBindMemRef2DFloat,
331         LLVM::LLVMType::getFunctionTy(getVoidType(),
332                                       {getPointerType(), getInt32Type(),
333                                        getInt32Type(),
334                                        getMemRef2DFloat().getPointerTo()},
335                                       /*isVarArg=*/false));
336   }
337 
338   if (!module.lookupSymbol(kInitVulkan)) {
339     builder.create<LLVM::LLVMFuncOp>(
340         loc, kInitVulkan,
341         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
342                                       /*isVarArg=*/false));
343   }
344 
345   if (!module.lookupSymbol(kDeinitVulkan)) {
346     builder.create<LLVM::LLVMFuncOp>(
347         loc, kDeinitVulkan,
348         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
349                                       /*isVarArg=*/false));
350   }
351 }
352 
353 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
354     StringRef name, Location loc, OpBuilder &builder) {
355   SmallString<16> shaderName(name.begin(), name.end());
356   // Append `\0` to follow C style string given that LLVM::createGlobalString()
357   // won't handle this directly for us.
358   shaderName.push_back('\0');
359 
360   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
361   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
362                                   shaderName, LLVM::Linkage::Internal,
363                                   getLLVMDialect());
364 }
365 
366 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
367     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
368   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
369   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
370   // Create call to `initVulkan`.
371   auto initVulkanCall = builder.create<LLVM::CallOp>(
372       loc, ArrayRef<Type>{getPointerType()},
373       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
374   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
375   // need to pass that pointer to each Vulkan runtime call.
376   auto vulkanRuntime = initVulkanCall.getResult(0);
377 
378   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
379   // that data to runtime call.
380   Value ptrToSPIRVBinary = LLVM::createGlobalString(
381       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
382       LLVM::Linkage::Internal, getLLVMDialect());
383 
384   // Create LLVM constant for the size of SPIR-V binary shader.
385   Value binarySize = builder.create<LLVM::ConstantOp>(
386       loc, getInt32Type(),
387       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
388 
389   // Create call to `bindMemRef` for each memref operand.
390   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
391 
392   // Create call to `setBinaryShader` runtime function with the given pointer to
393   // SPIR-V binary and binary size.
394   builder.create<LLVM::CallOp>(
395       loc, ArrayRef<Type>{getVoidType()},
396       builder.getSymbolRefAttr(kSetBinaryShader),
397       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
398   // Create LLVM global with entry point name.
399   Value entryPointName = createEntryPointNameConstant(
400       spirvAttributes.second.getValue(), loc, builder);
401   // Create call to `setEntryPoint` runtime function with the given pointer to
402   // entry point name.
403   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
404                                builder.getSymbolRefAttr(kSetEntryPoint),
405                                ArrayRef<Value>{vulkanRuntime, entryPointName});
406 
407   // Create number of local workgroup for each dimension.
408   builder.create<LLVM::CallOp>(
409       loc, ArrayRef<Type>{getVoidType()},
410       builder.getSymbolRefAttr(kSetNumWorkGroups),
411       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
412                       cInterfaceVulkanLaunchCallOp.getOperand(1),
413                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
414 
415   // Create call to `runOnVulkan` runtime function.
416   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
417                                builder.getSymbolRefAttr(kRunOnVulkan),
418                                ArrayRef<Value>{vulkanRuntime});
419 
420   // Create call to 'deinitVulkan' runtime function.
421   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
422                                builder.getSymbolRefAttr(kDeinitVulkan),
423                                ArrayRef<Value>{vulkanRuntime});
424 
425   // Declare runtime functions.
426   declareVulkanFunctions(loc);
427 
428   cInterfaceVulkanLaunchCallOp.erase();
429 }
430 
431 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
432 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
433   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
434 }
435