1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/IR/Module.h"
24 
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 using namespace mlir;
29 
30 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
31 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
32 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat";
33 static constexpr const char *kBindMemRef1DInt = "bindMemRef1DInt";
34 static constexpr const char *kBindMemRef2DInt = "bindMemRef2DInt";
35 static constexpr const char *kBindMemRef3DInt = "bindMemRef3DInt";
36 static constexpr const char *kCInterfaceVulkanLaunch =
37     "_mlir_ciface_vulkanLaunch";
38 static constexpr const char *kDeinitVulkan = "deinitVulkan";
39 static constexpr const char *kRunOnVulkan = "runOnVulkan";
40 static constexpr const char *kInitVulkan = "initVulkan";
41 static constexpr const char *kSetBinaryShader = "setBinaryShader";
42 static constexpr const char *kSetEntryPoint = "setEntryPoint";
43 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
44 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
45 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
46 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
47 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
48 
49 namespace {
50 
51 /// A pass to convert vulkan launch call op into a sequence of Vulkan
52 /// runtime calls in the following order:
53 ///
54 /// * initVulkan           -- initializes vulkan runtime
55 /// * bindMemRef           -- binds memref
56 /// * setBinaryShader      -- sets the binary shader data
57 /// * setEntryPoint        -- sets the entry point name
58 /// * setNumWorkGroups     -- sets the number of a local workgroups
59 /// * runOnVulkan          -- runs vulkan runtime
60 /// * deinitVulkan         -- deinitializes vulkan runtime
61 ///
62 class VulkanLaunchFuncToVulkanCallsPass
63     : public ConvertVulkanLaunchFuncToVulkanCallsBase<
64           VulkanLaunchFuncToVulkanCallsPass> {
65 private:
66   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
67 
68   llvm::LLVMContext &getLLVMContext() {
69     return getLLVMDialect()->getLLVMContext();
70   }
71 
72   void initializeCachedTypes() {
73     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
74     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
75     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
76     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
77     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
78     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
79     llvmMemRef1DFloat = getMemRefType(1, llvmFloatType);
80     llvmMemRef2DFloat = getMemRefType(2, llvmFloatType);
81     llvmMemRef3DFloat = getMemRefType(3, llvmFloatType);
82     llvmMemRef1DInt = getMemRefType(1, llvmInt32Type);
83     llvmMemRef2DInt = getMemRefType(2, llvmInt32Type);
84     llvmMemRef3DInt = getMemRefType(3, llvmInt32Type);
85   }
86 
87   LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
88     // According to the MLIR doc memref argument is converted into a
89     // pointer-to-struct argument of type:
90     // template <typename Elem, size_t Rank>
91     // struct {
92     //   Elem *allocated;
93     //   Elem *aligned;
94     //   int64_t offset;
95     //   int64_t sizes[Rank]; // omitted when rank == 0
96     //   int64_t strides[Rank]; // omitted when rank == 0
97     // };
98     auto llvmPtrToElementType = elemenType.getPointerTo();
99     auto llvmArrayRankElementSizeType =
100         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
101 
102     // Create a type
103     // `!llvm<"{ `element-type`*, `element-type`*, i64,
104     // [`rank` x i64], [`rank` x i64]}">`.
105     return LLVM::LLVMType::getStructTy(
106         llvmDialect,
107         {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
108          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
109   }
110 
111   LLVM::LLVMType getFloatType() { return llvmFloatType; }
112   LLVM::LLVMType getVoidType() { return llvmVoidType; }
113   LLVM::LLVMType getPointerType() { return llvmPointerType; }
114   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
115   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
116   LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
117   LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
118   LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; }
119   LLVM::LLVMType getMemRef1DInt() { return llvmMemRef1DInt; }
120   LLVM::LLVMType getMemRef2DInt() { return llvmMemRef2DInt; }
121   LLVM::LLVMType getMemRef3DInt() { return llvmMemRef3DInt; }
122 
123   /// Creates a LLVM global for the given `name`.
124   Value createEntryPointNameConstant(StringRef name, Location loc,
125                                      OpBuilder &builder);
126 
127   /// Declares all needed runtime functions.
128   void declareVulkanFunctions(Location loc);
129 
130   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
131   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
132     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
133             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
134   }
135 
136   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
137   /// op.
138   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
139     return (callOp.callee() &&
140             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
141             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
142   }
143 
144   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
145   /// runtime calls.
146   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
147 
148   /// Creates call to `bindMemRef` for each memref operand.
149   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
150                              Value vulkanRuntime);
151 
152   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
153   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
154 
155   /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
156   LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
157                                         uint32_t &rank, LLVM::LLVMType &type);
158 
159   /// Returns a string representation from the given `type`.
160   StringRef stringifyType(LLVM::LLVMType type) {
161     if (type.isFloatTy())
162       return "Float";
163     if (type.isIntegerTy())
164       return "Int";
165 
166     llvm_unreachable("unsupported type");
167   }
168 
169 public:
170   void runOnOperation() override;
171 
172 private:
173   LLVM::LLVMDialect *llvmDialect;
174   LLVM::LLVMType llvmFloatType;
175   LLVM::LLVMType llvmVoidType;
176   LLVM::LLVMType llvmPointerType;
177   LLVM::LLVMType llvmInt32Type;
178   LLVM::LLVMType llvmInt64Type;
179   LLVM::LLVMType llvmMemRef1DFloat;
180   LLVM::LLVMType llvmMemRef2DFloat;
181   LLVM::LLVMType llvmMemRef3DFloat;
182   LLVM::LLVMType llvmMemRef1DInt;
183   LLVM::LLVMType llvmMemRef2DInt;
184   LLVM::LLVMType llvmMemRef3DInt;
185 
186   // TODO: Use an associative array to support multiple vulkan launch calls.
187   std::pair<StringAttr, StringAttr> spirvAttributes;
188   /// The number of vulkan launch configuration operands, placed at the leading
189   /// positions of the operand list.
190   static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
191 };
192 
193 } // anonymous namespace
194 
195 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
196   initializeCachedTypes();
197 
198   // Collect SPIR-V attributes such as `spirv_blob` and
199   // `spirv_entry_point_name`.
200   getOperation().walk([this](LLVM::CallOp op) {
201     if (isVulkanLaunchCallOp(op))
202       collectSPIRVAttributes(op);
203   });
204 
205   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
206   getOperation().walk([this](LLVM::CallOp op) {
207     if (isCInterfaceVulkanLaunchCallOp(op))
208       translateVulkanLaunchCall(op);
209   });
210 }
211 
212 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
213     LLVM::CallOp vulkanLaunchCallOp) {
214   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
215   // for the given vulkan launch call.
216   auto spirvBlobAttr =
217       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
218   if (!spirvBlobAttr) {
219     vulkanLaunchCallOp.emitError()
220         << "missing " << kSPIRVBlobAttrName << " attribute";
221     return signalPassFailure();
222   }
223 
224   auto spirvEntryPointNameAttr =
225       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
226   if (!spirvEntryPointNameAttr) {
227     vulkanLaunchCallOp.emitError()
228         << "missing " << kSPIRVEntryPointAttrName << " attribute";
229     return signalPassFailure();
230   }
231 
232   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
233 }
234 
235 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
236     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
237   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
238       kVulkanLaunchNumConfigOperands)
239     return;
240   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
241   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
242 
243   // Create LLVM constant for the descriptor set index.
244   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
245   // pass does.
246   Value descriptorSet = builder.create<LLVM::ConstantOp>(
247       loc, getInt32Type(), builder.getI32IntegerAttr(0));
248 
249   for (auto en :
250        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
251            kVulkanLaunchNumConfigOperands))) {
252     // Create LLVM constant for the descriptor binding index.
253     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
254         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
255 
256     auto ptrToMemRefDescriptor = en.value();
257     uint32_t rank = 0;
258     LLVM::LLVMType type;
259     if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
260       cInterfaceVulkanLaunchCallOp.emitError()
261           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
262       return signalPassFailure();
263     }
264 
265     auto symbolName =
266         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
267     // Create call to `bindMemRef`.
268     builder.create<LLVM::CallOp>(
269         loc, ArrayRef<Type>{getVoidType()},
270         builder.getSymbolRefAttr(
271             StringRef(symbolName.data(), symbolName.size())),
272         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
273                         ptrToMemRefDescriptor});
274   }
275 }
276 
277 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
278     Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
279   auto llvmPtrDescriptorTy =
280       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
281   if (!llvmPtrDescriptorTy)
282     return failure();
283 
284   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
285   // template <typename Elem, size_t Rank>
286   // struct {
287   //   Elem *allocated;
288   //   Elem *aligned;
289   //   int64_t offset;
290   //   int64_t sizes[Rank]; // omitted when rank == 0
291   //   int64_t strides[Rank]; // omitted when rank == 0
292   // };
293   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
294     return failure();
295 
296   type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
297   if (llvmDescriptorTy.getStructNumElements() == 3) {
298     rank = 0;
299     return success();
300   }
301   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
302   return success();
303 }
304 
305 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
306   ModuleOp module = getOperation();
307   OpBuilder builder(module.getBody()->getTerminator());
308 
309   if (!module.lookupSymbol(kSetEntryPoint)) {
310     builder.create<LLVM::LLVMFuncOp>(
311         loc, kSetEntryPoint,
312         LLVM::LLVMType::getFunctionTy(getVoidType(),
313                                       {getPointerType(), getPointerType()},
314                                       /*isVarArg=*/false));
315   }
316 
317   if (!module.lookupSymbol(kSetNumWorkGroups)) {
318     builder.create<LLVM::LLVMFuncOp>(
319         loc, kSetNumWorkGroups,
320         LLVM::LLVMType::getFunctionTy(
321             getVoidType(),
322             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
323             /*isVarArg=*/false));
324   }
325 
326   if (!module.lookupSymbol(kSetBinaryShader)) {
327     builder.create<LLVM::LLVMFuncOp>(
328         loc, kSetBinaryShader,
329         LLVM::LLVMType::getFunctionTy(
330             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
331             /*isVarArg=*/false));
332   }
333 
334   if (!module.lookupSymbol(kRunOnVulkan)) {
335     builder.create<LLVM::LLVMFuncOp>(
336         loc, kRunOnVulkan,
337         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
338                                       /*isVarArg=*/false));
339   }
340 
341 #define CREATE_VULKAN_BIND_FUNC(MemRefType)                                    \
342   if (!module.lookupSymbol(kBind##MemRefType)) {                               \
343     builder.create<LLVM::LLVMFuncOp>(                                          \
344         loc, kBind##MemRefType,                                                \
345         LLVM::LLVMType::getFunctionTy(getVoidType(),                           \
346                                       {getPointerType(), getInt32Type(),       \
347                                        getInt32Type(),                         \
348                                        get##MemRefType().getPointerTo()},      \
349                                       /*isVarArg=*/false));                    \
350   }
351 
352   CREATE_VULKAN_BIND_FUNC(MemRef1DFloat);
353   CREATE_VULKAN_BIND_FUNC(MemRef2DFloat);
354   CREATE_VULKAN_BIND_FUNC(MemRef3DFloat);
355   CREATE_VULKAN_BIND_FUNC(MemRef1DInt);
356   CREATE_VULKAN_BIND_FUNC(MemRef2DInt);
357   CREATE_VULKAN_BIND_FUNC(MemRef3DInt);
358 
359   if (!module.lookupSymbol(kInitVulkan)) {
360     builder.create<LLVM::LLVMFuncOp>(
361         loc, kInitVulkan,
362         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
363                                       /*isVarArg=*/false));
364   }
365 
366   if (!module.lookupSymbol(kDeinitVulkan)) {
367     builder.create<LLVM::LLVMFuncOp>(
368         loc, kDeinitVulkan,
369         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
370                                       /*isVarArg=*/false));
371   }
372 }
373 
374 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
375     StringRef name, Location loc, OpBuilder &builder) {
376   SmallString<16> shaderName(name.begin(), name.end());
377   // Append `\0` to follow C style string given that LLVM::createGlobalString()
378   // won't handle this directly for us.
379   shaderName.push_back('\0');
380 
381   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
382   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
383                                   shaderName, LLVM::Linkage::Internal,
384                                   getLLVMDialect());
385 }
386 
387 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
388     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
389   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
390   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
391   // Create call to `initVulkan`.
392   auto initVulkanCall = builder.create<LLVM::CallOp>(
393       loc, ArrayRef<Type>{getPointerType()},
394       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
395   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
396   // need to pass that pointer to each Vulkan runtime call.
397   auto vulkanRuntime = initVulkanCall.getResult(0);
398 
399   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
400   // that data to runtime call.
401   Value ptrToSPIRVBinary = LLVM::createGlobalString(
402       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
403       LLVM::Linkage::Internal, getLLVMDialect());
404 
405   // Create LLVM constant for the size of SPIR-V binary shader.
406   Value binarySize = builder.create<LLVM::ConstantOp>(
407       loc, getInt32Type(),
408       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
409 
410   // Create call to `bindMemRef` for each memref operand.
411   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
412 
413   // Create call to `setBinaryShader` runtime function with the given pointer to
414   // SPIR-V binary and binary size.
415   builder.create<LLVM::CallOp>(
416       loc, ArrayRef<Type>{getVoidType()},
417       builder.getSymbolRefAttr(kSetBinaryShader),
418       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
419   // Create LLVM global with entry point name.
420   Value entryPointName = createEntryPointNameConstant(
421       spirvAttributes.second.getValue(), loc, builder);
422   // Create call to `setEntryPoint` runtime function with the given pointer to
423   // entry point name.
424   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
425                                builder.getSymbolRefAttr(kSetEntryPoint),
426                                ArrayRef<Value>{vulkanRuntime, entryPointName});
427 
428   // Create number of local workgroup for each dimension.
429   builder.create<LLVM::CallOp>(
430       loc, ArrayRef<Type>{getVoidType()},
431       builder.getSymbolRefAttr(kSetNumWorkGroups),
432       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
433                       cInterfaceVulkanLaunchCallOp.getOperand(1),
434                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
435 
436   // Create call to `runOnVulkan` runtime function.
437   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
438                                builder.getSymbolRefAttr(kRunOnVulkan),
439                                ArrayRef<Value>{vulkanRuntime});
440 
441   // Create call to 'deinitVulkan' runtime function.
442   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
443                                builder.getSymbolRefAttr(kDeinitVulkan),
444                                ArrayRef<Value>{vulkanRuntime});
445 
446   // Declare runtime functions.
447   declareVulkanFunctions(loc);
448 
449   cInterfaceVulkanLaunchCallOp.erase();
450 }
451 
452 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
453 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
454   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
455 }
456