1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/IR/Module.h"
24 
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 using namespace mlir;
29 
30 static constexpr const char *kCInterfaceVulkanLaunch =
31     "_mlir_ciface_vulkanLaunch";
32 static constexpr const char *kDeinitVulkan = "deinitVulkan";
33 static constexpr const char *kRunOnVulkan = "runOnVulkan";
34 static constexpr const char *kInitVulkan = "initVulkan";
35 static constexpr const char *kSetBinaryShader = "setBinaryShader";
36 static constexpr const char *kSetEntryPoint = "setEntryPoint";
37 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
38 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
39 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
40 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
41 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
42 
43 namespace {
44 
45 /// A pass to convert vulkan launch call op into a sequence of Vulkan
46 /// runtime calls in the following order:
47 ///
48 /// * initVulkan           -- initializes vulkan runtime
49 /// * bindMemRef           -- binds memref
50 /// * setBinaryShader      -- sets the binary shader data
51 /// * setEntryPoint        -- sets the entry point name
52 /// * setNumWorkGroups     -- sets the number of a local workgroups
53 /// * runOnVulkan          -- runs vulkan runtime
54 /// * deinitVulkan         -- deinitializes vulkan runtime
55 ///
56 class VulkanLaunchFuncToVulkanCallsPass
57     : public ConvertVulkanLaunchFuncToVulkanCallsBase<
58           VulkanLaunchFuncToVulkanCallsPass> {
59 private:
60   void initializeCachedTypes() {
61     llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext());
62     llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext());
63     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext());
64     llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext());
65     llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext());
66   }
67 
68   LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
69     // According to the MLIR doc memref argument is converted into a
70     // pointer-to-struct argument of type:
71     // template <typename Elem, size_t Rank>
72     // struct {
73     //   Elem *allocated;
74     //   Elem *aligned;
75     //   int64_t offset;
76     //   int64_t sizes[Rank]; // omitted when rank == 0
77     //   int64_t strides[Rank]; // omitted when rank == 0
78     // };
79     auto llvmPtrToElementType = elemenType.getPointerTo();
80     auto llvmArrayRankElementSizeType =
81         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
82 
83     // Create a type
84     // `!llvm<"{ `element-type`*, `element-type`*, i64,
85     // [`rank` x i64], [`rank` x i64]}">`.
86     return LLVM::LLVMType::getStructTy(
87         &getContext(),
88         {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
89          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
90   }
91 
92   LLVM::LLVMType getVoidType() { return llvmVoidType; }
93   LLVM::LLVMType getPointerType() { return llvmPointerType; }
94   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
95   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
96 
97   /// Creates an LLVM global for the given `name`.
98   Value createEntryPointNameConstant(StringRef name, Location loc,
99                                      OpBuilder &builder);
100 
101   /// Declares all needed runtime functions.
102   void declareVulkanFunctions(Location loc);
103 
104   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
105   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
106     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
107             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
108   }
109 
110   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
111   /// op.
112   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
113     return (callOp.callee() &&
114             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
115             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
116   }
117 
118   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
119   /// runtime calls.
120   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
121 
122   /// Creates call to `bindMemRef` for each memref operand.
123   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
124                              Value vulkanRuntime);
125 
126   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
127   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
128 
129   /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
130   LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
131                                         uint32_t &rank, LLVM::LLVMType &type);
132 
133   /// Returns a string representation from the given `type`.
134   StringRef stringifyType(LLVM::LLVMType type) {
135     if (type.isFloatTy())
136       return "Float";
137     if (type.isHalfTy())
138       return "Half";
139     if (type.isIntegerTy(32))
140       return "Int32";
141     if (type.isIntegerTy(16))
142       return "Int16";
143     if (type.isIntegerTy(8))
144       return "Int8";
145 
146     llvm_unreachable("unsupported type");
147   }
148 
149 public:
150   void runOnOperation() override;
151 
152 private:
153   LLVM::LLVMType llvmFloatType;
154   LLVM::LLVMType llvmVoidType;
155   LLVM::LLVMType llvmPointerType;
156   LLVM::LLVMType llvmInt32Type;
157   LLVM::LLVMType llvmInt64Type;
158 
159   // TODO: Use an associative array to support multiple vulkan launch calls.
160   std::pair<StringAttr, StringAttr> spirvAttributes;
161   /// The number of vulkan launch configuration operands, placed at the leading
162   /// positions of the operand list.
163   static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
164 };
165 
166 } // anonymous namespace
167 
168 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
169   initializeCachedTypes();
170 
171   // Collect SPIR-V attributes such as `spirv_blob` and
172   // `spirv_entry_point_name`.
173   getOperation().walk([this](LLVM::CallOp op) {
174     if (isVulkanLaunchCallOp(op))
175       collectSPIRVAttributes(op);
176   });
177 
178   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
179   getOperation().walk([this](LLVM::CallOp op) {
180     if (isCInterfaceVulkanLaunchCallOp(op))
181       translateVulkanLaunchCall(op);
182   });
183 }
184 
185 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
186     LLVM::CallOp vulkanLaunchCallOp) {
187   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
188   // for the given vulkan launch call.
189   auto spirvBlobAttr =
190       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
191   if (!spirvBlobAttr) {
192     vulkanLaunchCallOp.emitError()
193         << "missing " << kSPIRVBlobAttrName << " attribute";
194     return signalPassFailure();
195   }
196 
197   auto spirvEntryPointNameAttr =
198       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
199   if (!spirvEntryPointNameAttr) {
200     vulkanLaunchCallOp.emitError()
201         << "missing " << kSPIRVEntryPointAttrName << " attribute";
202     return signalPassFailure();
203   }
204 
205   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
206 }
207 
208 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
209     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
210   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
211       kVulkanLaunchNumConfigOperands)
212     return;
213   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
214   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
215 
216   // Create LLVM constant for the descriptor set index.
217   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
218   // pass does.
219   Value descriptorSet = builder.create<LLVM::ConstantOp>(
220       loc, getInt32Type(), builder.getI32IntegerAttr(0));
221 
222   for (auto en :
223        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
224            kVulkanLaunchNumConfigOperands))) {
225     // Create LLVM constant for the descriptor binding index.
226     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
227         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
228 
229     auto ptrToMemRefDescriptor = en.value();
230     uint32_t rank = 0;
231     LLVM::LLVMType type;
232     if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
233       cInterfaceVulkanLaunchCallOp.emitError()
234           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
235       return signalPassFailure();
236     }
237 
238     auto symbolName =
239         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
240     // Special case for fp16 type. Since it is not a supported type in C we use
241     // int16_t and bitcast the descriptor.
242     if (type.isHalfTy()) {
243       auto memRefTy =
244           getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext()));
245       ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
246           loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
247     }
248     // Create call to `bindMemRef`.
249     builder.create<LLVM::CallOp>(
250         loc, ArrayRef<Type>{getVoidType()},
251         builder.getSymbolRefAttr(
252             StringRef(symbolName.data(), symbolName.size())),
253         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
254                         ptrToMemRefDescriptor});
255   }
256 }
257 
258 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
259     Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
260   auto llvmPtrDescriptorTy =
261       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
262   if (!llvmPtrDescriptorTy)
263     return failure();
264 
265   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
266   // template <typename Elem, size_t Rank>
267   // struct {
268   //   Elem *allocated;
269   //   Elem *aligned;
270   //   int64_t offset;
271   //   int64_t sizes[Rank]; // omitted when rank == 0
272   //   int64_t strides[Rank]; // omitted when rank == 0
273   // };
274   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
275     return failure();
276 
277   type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
278   if (llvmDescriptorTy.getStructNumElements() == 3) {
279     rank = 0;
280     return success();
281   }
282   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
283   return success();
284 }
285 
286 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
287   ModuleOp module = getOperation();
288   OpBuilder builder(module.getBody()->getTerminator());
289 
290   if (!module.lookupSymbol(kSetEntryPoint)) {
291     builder.create<LLVM::LLVMFuncOp>(
292         loc, kSetEntryPoint,
293         LLVM::LLVMType::getFunctionTy(getVoidType(),
294                                       {getPointerType(), getPointerType()},
295                                       /*isVarArg=*/false));
296   }
297 
298   if (!module.lookupSymbol(kSetNumWorkGroups)) {
299     builder.create<LLVM::LLVMFuncOp>(
300         loc, kSetNumWorkGroups,
301         LLVM::LLVMType::getFunctionTy(
302             getVoidType(),
303             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
304             /*isVarArg=*/false));
305   }
306 
307   if (!module.lookupSymbol(kSetBinaryShader)) {
308     builder.create<LLVM::LLVMFuncOp>(
309         loc, kSetBinaryShader,
310         LLVM::LLVMType::getFunctionTy(
311             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
312             /*isVarArg=*/false));
313   }
314 
315   if (!module.lookupSymbol(kRunOnVulkan)) {
316     builder.create<LLVM::LLVMFuncOp>(
317         loc, kRunOnVulkan,
318         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
319                                       /*isVarArg=*/false));
320   }
321 
322   for (unsigned i = 1; i <= 3; i++) {
323     for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()),
324                                 LLVM::LLVMType::getInt32Ty(&getContext()),
325                                 LLVM::LLVMType::getInt16Ty(&getContext()),
326                                 LLVM::LLVMType::getInt8Ty(&getContext()),
327                                 LLVM::LLVMType::getHalfTy(&getContext())}) {
328       std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
329                            std::string(stringifyType(type));
330       if (type.isHalfTy())
331         type = LLVM::LLVMType::getInt16Ty(&getContext());
332       if (!module.lookupSymbol(fnName)) {
333         auto fnType = LLVM::LLVMType::getFunctionTy(
334             getVoidType(),
335             {getPointerType(), getInt32Type(), getInt32Type(),
336              getMemRefType(i, type).getPointerTo()},
337             /*isVarArg=*/false);
338         builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
339       }
340     }
341   }
342 
343   if (!module.lookupSymbol(kInitVulkan)) {
344     builder.create<LLVM::LLVMFuncOp>(
345         loc, kInitVulkan,
346         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
347                                       /*isVarArg=*/false));
348   }
349 
350   if (!module.lookupSymbol(kDeinitVulkan)) {
351     builder.create<LLVM::LLVMFuncOp>(
352         loc, kDeinitVulkan,
353         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
354                                       /*isVarArg=*/false));
355   }
356 }
357 
358 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
359     StringRef name, Location loc, OpBuilder &builder) {
360   SmallString<16> shaderName(name.begin(), name.end());
361   // Append `\0` to follow C style string given that LLVM::createGlobalString()
362   // won't handle this directly for us.
363   shaderName.push_back('\0');
364 
365   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
366   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
367                                   shaderName, LLVM::Linkage::Internal);
368 }
369 
370 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
371     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
372   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
373   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
374   // Create call to `initVulkan`.
375   auto initVulkanCall = builder.create<LLVM::CallOp>(
376       loc, ArrayRef<Type>{getPointerType()},
377       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
378   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
379   // need to pass that pointer to each Vulkan runtime call.
380   auto vulkanRuntime = initVulkanCall.getResult(0);
381 
382   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
383   // that data to runtime call.
384   Value ptrToSPIRVBinary = LLVM::createGlobalString(
385       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
386       LLVM::Linkage::Internal);
387 
388   // Create LLVM constant for the size of SPIR-V binary shader.
389   Value binarySize = builder.create<LLVM::ConstantOp>(
390       loc, getInt32Type(),
391       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
392 
393   // Create call to `bindMemRef` for each memref operand.
394   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
395 
396   // Create call to `setBinaryShader` runtime function with the given pointer to
397   // SPIR-V binary and binary size.
398   builder.create<LLVM::CallOp>(
399       loc, ArrayRef<Type>{getVoidType()},
400       builder.getSymbolRefAttr(kSetBinaryShader),
401       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
402   // Create LLVM global with entry point name.
403   Value entryPointName = createEntryPointNameConstant(
404       spirvAttributes.second.getValue(), loc, builder);
405   // Create call to `setEntryPoint` runtime function with the given pointer to
406   // entry point name.
407   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
408                                builder.getSymbolRefAttr(kSetEntryPoint),
409                                ArrayRef<Value>{vulkanRuntime, entryPointName});
410 
411   // Create number of local workgroup for each dimension.
412   builder.create<LLVM::CallOp>(
413       loc, ArrayRef<Type>{getVoidType()},
414       builder.getSymbolRefAttr(kSetNumWorkGroups),
415       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
416                       cInterfaceVulkanLaunchCallOp.getOperand(1),
417                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
418 
419   // Create call to `runOnVulkan` runtime function.
420   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
421                                builder.getSymbolRefAttr(kRunOnVulkan),
422                                ArrayRef<Value>{vulkanRuntime});
423 
424   // Create call to 'deinitVulkan' runtime function.
425   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
426                                builder.getSymbolRefAttr(kDeinitVulkan),
427                                ArrayRef<Value>{vulkanRuntime});
428 
429   // Declare runtime functions.
430   declareVulkanFunctions(loc);
431 
432   cInterfaceVulkanLaunchCallOp.erase();
433 }
434 
435 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
436 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
437   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
438 }
439