1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
18 #include "mlir/Dialect/GPU/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/Pass/Pass.h"
25 
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 using namespace mlir;
30 
31 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
32 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
33 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat";
34 static constexpr const char *kCInterfaceVulkanLaunch =
35     "_mlir_ciface_vulkanLaunch";
36 static constexpr const char *kDeinitVulkan = "deinitVulkan";
37 static constexpr const char *kRunOnVulkan = "runOnVulkan";
38 static constexpr const char *kInitVulkan = "initVulkan";
39 static constexpr const char *kSetBinaryShader = "setBinaryShader";
40 static constexpr const char *kSetEntryPoint = "setEntryPoint";
41 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
42 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
43 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
44 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
45 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
46 
47 namespace {
48 
49 /// A pass to convert vulkan launch call op into a sequence of Vulkan
50 /// runtime calls in the following order:
51 ///
52 /// * initVulkan           -- initializes vulkan runtime
53 /// * bindMemRef           -- binds memref
54 /// * setBinaryShader      -- sets the binary shader data
55 /// * setEntryPoint        -- sets the entry point name
56 /// * setNumWorkGroups     -- sets the number of a local workgroups
57 /// * runOnVulkan          -- runs vulkan runtime
58 /// * deinitVulkan         -- deinitializes vulkan runtime
59 ///
60 class VulkanLaunchFuncToVulkanCallsPass
61     : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
62 private:
63 /// Include the generated pass utilities.
64 #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls
65 #include "mlir/Conversion/Passes.h.inc"
66 
67   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
68 
69   llvm::LLVMContext &getLLVMContext() {
70     return getLLVMDialect()->getLLVMContext();
71   }
72 
73   void initializeCachedTypes() {
74     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
75     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
76     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
77     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
78     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
79     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
80     llvmMemRef1DFloat = getMemRefType(1);
81     llvmMemRef2DFloat = getMemRefType(2);
82     llvmMemRef3DFloat = getMemRefType(3);
83   }
84 
85   LLVM::LLVMType getMemRefType(uint32_t rank) {
86     // According to the MLIR doc memref argument is converted into a
87     // pointer-to-struct argument of type:
88     // template <typename Elem, size_t Rank>
89     // struct {
90     //   Elem *allocated;
91     //   Elem *aligned;
92     //   int64_t offset;
93     //   int64_t sizes[Rank]; // omitted when rank == 0
94     //   int64_t strides[Rank]; // omitted when rank == 0
95     // };
96     auto llvmPtrToFloatType = getFloatType().getPointerTo();
97     auto llvmArrayRankElementSizeType =
98         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
99 
100     // Create a type
101     // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`.
102     return LLVM::LLVMType::getStructTy(
103         llvmDialect,
104         {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
105          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
106   }
107 
108   LLVM::LLVMType getFloatType() { return llvmFloatType; }
109   LLVM::LLVMType getVoidType() { return llvmVoidType; }
110   LLVM::LLVMType getPointerType() { return llvmPointerType; }
111   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
112   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
113   LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
114   LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
115   LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; }
116 
117   /// Creates a LLVM global for the given `name`.
118   Value createEntryPointNameConstant(StringRef name, Location loc,
119                                      OpBuilder &builder);
120 
121   /// Declares all needed runtime functions.
122   void declareVulkanFunctions(Location loc);
123 
124   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
125   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
126     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
127             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
128   }
129 
130   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
131   /// op.
132   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
133     return (callOp.callee() &&
134             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
135             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
136   }
137 
138   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
139   /// runtime calls.
140   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
141 
142   /// Creates call to `bindMemRef` for each memref operand.
143   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
144                              Value vulkanRuntime);
145 
146   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
147   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
148 
149   /// Deduces a rank from the given 'ptrToMemRefDescriptor`.
150   LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
151 
152 public:
153   void runOnModule() override;
154 
155 private:
156   LLVM::LLVMDialect *llvmDialect;
157   LLVM::LLVMType llvmFloatType;
158   LLVM::LLVMType llvmVoidType;
159   LLVM::LLVMType llvmPointerType;
160   LLVM::LLVMType llvmInt32Type;
161   LLVM::LLVMType llvmInt64Type;
162   LLVM::LLVMType llvmMemRef1DFloat;
163   LLVM::LLVMType llvmMemRef2DFloat;
164   LLVM::LLVMType llvmMemRef3DFloat;
165 
166   // TODO: Use an associative array to support multiple vulkan launch calls.
167   std::pair<StringAttr, StringAttr> spirvAttributes;
168 };
169 
170 } // anonymous namespace
171 
172 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
173   initializeCachedTypes();
174 
175   // Collect SPIR-V attributes such as `spirv_blob` and
176   // `spirv_entry_point_name`.
177   getModule().walk([this](LLVM::CallOp op) {
178     if (isVulkanLaunchCallOp(op))
179       collectSPIRVAttributes(op);
180   });
181 
182   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
183   getModule().walk([this](LLVM::CallOp op) {
184     if (isCInterfaceVulkanLaunchCallOp(op))
185       translateVulkanLaunchCall(op);
186   });
187 }
188 
189 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
190     LLVM::CallOp vulkanLaunchCallOp) {
191   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
192   // for the given vulkan launch call.
193   auto spirvBlobAttr =
194       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
195   if (!spirvBlobAttr) {
196     vulkanLaunchCallOp.emitError()
197         << "missing " << kSPIRVBlobAttrName << " attribute";
198     return signalPassFailure();
199   }
200 
201   auto spirvEntryPointNameAttr =
202       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
203   if (!spirvEntryPointNameAttr) {
204     vulkanLaunchCallOp.emitError()
205         << "missing " << kSPIRVEntryPointAttrName << " attribute";
206     return signalPassFailure();
207   }
208 
209   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
210 }
211 
212 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
213     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
214   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
215       gpu::LaunchOp::kNumConfigOperands)
216     return;
217   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
218   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
219 
220   // Create LLVM constant for the descriptor set index.
221   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
222   // pass does.
223   Value descriptorSet = builder.create<LLVM::ConstantOp>(
224       loc, getInt32Type(), builder.getI32IntegerAttr(0));
225 
226   for (auto en :
227        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
228            gpu::LaunchOp::kNumConfigOperands))) {
229     // Create LLVM constant for the descriptor binding index.
230     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
231         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
232 
233     auto ptrToMemRefDescriptor = en.value();
234     uint32_t rank = 0;
235     if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
236       cInterfaceVulkanLaunchCallOp.emitError()
237           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
238       return signalPassFailure();
239     }
240 
241     auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str();
242     // Create call to `bindMemRef`.
243     builder.create<LLVM::CallOp>(
244         loc, ArrayRef<Type>{getVoidType()},
245         builder.getSymbolRefAttr(
246             StringRef(symbolName.data(), symbolName.size())),
247         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
248                         ptrToMemRefDescriptor});
249   }
250 }
251 
252 LogicalResult
253 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
254                                                     uint32_t &rank) {
255   auto llvmPtrDescriptorTy =
256       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
257   if (!llvmPtrDescriptorTy)
258     return failure();
259 
260   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
261   // template <typename Elem, size_t Rank>
262   // struct {
263   //   Elem *allocated;
264   //   Elem *aligned;
265   //   int64_t offset;
266   //   int64_t sizes[Rank]; // omitted when rank == 0
267   //   int64_t strides[Rank]; // omitted when rank == 0
268   // };
269   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
270     return failure();
271   if (llvmDescriptorTy.getStructNumElements() == 3) {
272     rank = 0;
273     return success();
274   }
275 
276   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
277   return success();
278 }
279 
280 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
281   ModuleOp module = getModule();
282   OpBuilder builder(module.getBody()->getTerminator());
283 
284   if (!module.lookupSymbol(kSetEntryPoint)) {
285     builder.create<LLVM::LLVMFuncOp>(
286         loc, kSetEntryPoint,
287         LLVM::LLVMType::getFunctionTy(getVoidType(),
288                                       {getPointerType(), getPointerType()},
289                                       /*isVarArg=*/false));
290   }
291 
292   if (!module.lookupSymbol(kSetNumWorkGroups)) {
293     builder.create<LLVM::LLVMFuncOp>(
294         loc, kSetNumWorkGroups,
295         LLVM::LLVMType::getFunctionTy(
296             getVoidType(),
297             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
298             /*isVarArg=*/false));
299   }
300 
301   if (!module.lookupSymbol(kSetBinaryShader)) {
302     builder.create<LLVM::LLVMFuncOp>(
303         loc, kSetBinaryShader,
304         LLVM::LLVMType::getFunctionTy(
305             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
306             /*isVarArg=*/false));
307   }
308 
309   if (!module.lookupSymbol(kRunOnVulkan)) {
310     builder.create<LLVM::LLVMFuncOp>(
311         loc, kRunOnVulkan,
312         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
313                                       /*isVarArg=*/false));
314   }
315 
316   if (!module.lookupSymbol(kBindMemRef1DFloat)) {
317     builder.create<LLVM::LLVMFuncOp>(
318         loc, kBindMemRef1DFloat,
319         LLVM::LLVMType::getFunctionTy(getVoidType(),
320                                       {getPointerType(), getInt32Type(),
321                                        getInt32Type(),
322                                        getMemRef1DFloat().getPointerTo()},
323                                       /*isVarArg=*/false));
324   }
325 
326   if (!module.lookupSymbol(kBindMemRef2DFloat)) {
327     builder.create<LLVM::LLVMFuncOp>(
328         loc, kBindMemRef2DFloat,
329         LLVM::LLVMType::getFunctionTy(getVoidType(),
330                                       {getPointerType(), getInt32Type(),
331                                        getInt32Type(),
332                                        getMemRef2DFloat().getPointerTo()},
333                                       /*isVarArg=*/false));
334   }
335 
336   if (!module.lookupSymbol(kBindMemRef3DFloat)) {
337     builder.create<LLVM::LLVMFuncOp>(
338         loc, kBindMemRef3DFloat,
339         LLVM::LLVMType::getFunctionTy(getVoidType(),
340                                       {getPointerType(), getInt32Type(),
341                                        getInt32Type(),
342                                        getMemRef3DFloat().getPointerTo()},
343                                       /*isVarArg=*/false));
344   }
345 
346   if (!module.lookupSymbol(kInitVulkan)) {
347     builder.create<LLVM::LLVMFuncOp>(
348         loc, kInitVulkan,
349         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
350                                       /*isVarArg=*/false));
351   }
352 
353   if (!module.lookupSymbol(kDeinitVulkan)) {
354     builder.create<LLVM::LLVMFuncOp>(
355         loc, kDeinitVulkan,
356         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
357                                       /*isVarArg=*/false));
358   }
359 }
360 
361 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
362     StringRef name, Location loc, OpBuilder &builder) {
363   SmallString<16> shaderName(name.begin(), name.end());
364   // Append `\0` to follow C style string given that LLVM::createGlobalString()
365   // won't handle this directly for us.
366   shaderName.push_back('\0');
367 
368   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
369   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
370                                   shaderName, LLVM::Linkage::Internal,
371                                   getLLVMDialect());
372 }
373 
374 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
375     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
376   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
377   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
378   // Create call to `initVulkan`.
379   auto initVulkanCall = builder.create<LLVM::CallOp>(
380       loc, ArrayRef<Type>{getPointerType()},
381       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
382   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
383   // need to pass that pointer to each Vulkan runtime call.
384   auto vulkanRuntime = initVulkanCall.getResult(0);
385 
386   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
387   // that data to runtime call.
388   Value ptrToSPIRVBinary = LLVM::createGlobalString(
389       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
390       LLVM::Linkage::Internal, getLLVMDialect());
391 
392   // Create LLVM constant for the size of SPIR-V binary shader.
393   Value binarySize = builder.create<LLVM::ConstantOp>(
394       loc, getInt32Type(),
395       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
396 
397   // Create call to `bindMemRef` for each memref operand.
398   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
399 
400   // Create call to `setBinaryShader` runtime function with the given pointer to
401   // SPIR-V binary and binary size.
402   builder.create<LLVM::CallOp>(
403       loc, ArrayRef<Type>{getVoidType()},
404       builder.getSymbolRefAttr(kSetBinaryShader),
405       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
406   // Create LLVM global with entry point name.
407   Value entryPointName = createEntryPointNameConstant(
408       spirvAttributes.second.getValue(), loc, builder);
409   // Create call to `setEntryPoint` runtime function with the given pointer to
410   // entry point name.
411   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
412                                builder.getSymbolRefAttr(kSetEntryPoint),
413                                ArrayRef<Value>{vulkanRuntime, entryPointName});
414 
415   // Create number of local workgroup for each dimension.
416   builder.create<LLVM::CallOp>(
417       loc, ArrayRef<Type>{getVoidType()},
418       builder.getSymbolRefAttr(kSetNumWorkGroups),
419       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
420                       cInterfaceVulkanLaunchCallOp.getOperand(1),
421                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
422 
423   // Create call to `runOnVulkan` runtime function.
424   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
425                                builder.getSymbolRefAttr(kRunOnVulkan),
426                                ArrayRef<Value>{vulkanRuntime});
427 
428   // Create call to 'deinitVulkan' runtime function.
429   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
430                                builder.getSymbolRefAttr(kDeinitVulkan),
431                                ArrayRef<Value>{vulkanRuntime});
432 
433   // Declare runtime functions.
434   declareVulkanFunctions(loc);
435 
436   cInterfaceVulkanLaunchCallOp.erase();
437 }
438 
439 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
440 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
441   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
442 }
443