1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/IR/Module.h"
24 
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 using namespace mlir;
29 
30 static constexpr const char *kCInterfaceVulkanLaunch =
31     "_mlir_ciface_vulkanLaunch";
32 static constexpr const char *kDeinitVulkan = "deinitVulkan";
33 static constexpr const char *kRunOnVulkan = "runOnVulkan";
34 static constexpr const char *kInitVulkan = "initVulkan";
35 static constexpr const char *kSetBinaryShader = "setBinaryShader";
36 static constexpr const char *kSetEntryPoint = "setEntryPoint";
37 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
38 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
39 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
40 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
41 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
42 
43 namespace {
44 
45 /// A pass to convert vulkan launch call op into a sequence of Vulkan
46 /// runtime calls in the following order:
47 ///
48 /// * initVulkan           -- initializes vulkan runtime
49 /// * bindMemRef           -- binds memref
50 /// * setBinaryShader      -- sets the binary shader data
51 /// * setEntryPoint        -- sets the entry point name
52 /// * setNumWorkGroups     -- sets the number of a local workgroups
53 /// * runOnVulkan          -- runs vulkan runtime
54 /// * deinitVulkan         -- deinitializes vulkan runtime
55 ///
56 class VulkanLaunchFuncToVulkanCallsPass
57     : public ConvertVulkanLaunchFuncToVulkanCallsBase<
58           VulkanLaunchFuncToVulkanCallsPass> {
59 private:
60   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
61 
62   llvm::LLVMContext &getLLVMContext() {
63     return getLLVMDialect()->getLLVMContext();
64   }
65 
66   void initializeCachedTypes() {
67     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
68     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
69     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
70     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
71     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
72     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
73   }
74 
75   LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
76     // According to the MLIR doc memref argument is converted into a
77     // pointer-to-struct argument of type:
78     // template <typename Elem, size_t Rank>
79     // struct {
80     //   Elem *allocated;
81     //   Elem *aligned;
82     //   int64_t offset;
83     //   int64_t sizes[Rank]; // omitted when rank == 0
84     //   int64_t strides[Rank]; // omitted when rank == 0
85     // };
86     auto llvmPtrToElementType = elemenType.getPointerTo();
87     auto llvmArrayRankElementSizeType =
88         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
89 
90     // Create a type
91     // `!llvm<"{ `element-type`*, `element-type`*, i64,
92     // [`rank` x i64], [`rank` x i64]}">`.
93     return LLVM::LLVMType::getStructTy(
94         llvmDialect,
95         {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
96          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
97   }
98 
99   LLVM::LLVMType getVoidType() { return llvmVoidType; }
100   LLVM::LLVMType getPointerType() { return llvmPointerType; }
101   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
102   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
103 
104   /// Creates a LLVM global for the given `name`.
105   Value createEntryPointNameConstant(StringRef name, Location loc,
106                                      OpBuilder &builder);
107 
108   /// Declares all needed runtime functions.
109   void declareVulkanFunctions(Location loc);
110 
111   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
112   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
113     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
114             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
115   }
116 
117   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
118   /// op.
119   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
120     return (callOp.callee() &&
121             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
122             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
123   }
124 
125   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
126   /// runtime calls.
127   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
128 
129   /// Creates call to `bindMemRef` for each memref operand.
130   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
131                              Value vulkanRuntime);
132 
133   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
134   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
135 
136   /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
137   LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
138                                         uint32_t &rank, LLVM::LLVMType &type);
139 
140   /// Returns a string representation from the given `type`.
141   StringRef stringifyType(LLVM::LLVMType type) {
142     if (type.isFloatTy())
143       return "Float";
144     if (type.isHalfTy())
145       return "Half";
146     if (type.isIntegerTy(32))
147       return "Int32";
148     if (type.isIntegerTy(16))
149       return "Int16";
150     if (type.isIntegerTy(8))
151       return "Int8";
152 
153     llvm_unreachable("unsupported type");
154   }
155 
156 public:
157   void runOnOperation() override;
158 
159 private:
160   LLVM::LLVMDialect *llvmDialect;
161   LLVM::LLVMType llvmFloatType;
162   LLVM::LLVMType llvmVoidType;
163   LLVM::LLVMType llvmPointerType;
164   LLVM::LLVMType llvmInt32Type;
165   LLVM::LLVMType llvmInt64Type;
166 
167   // TODO: Use an associative array to support multiple vulkan launch calls.
168   std::pair<StringAttr, StringAttr> spirvAttributes;
169   /// The number of vulkan launch configuration operands, placed at the leading
170   /// positions of the operand list.
171   static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
172 };
173 
174 } // anonymous namespace
175 
176 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
177   initializeCachedTypes();
178 
179   // Collect SPIR-V attributes such as `spirv_blob` and
180   // `spirv_entry_point_name`.
181   getOperation().walk([this](LLVM::CallOp op) {
182     if (isVulkanLaunchCallOp(op))
183       collectSPIRVAttributes(op);
184   });
185 
186   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
187   getOperation().walk([this](LLVM::CallOp op) {
188     if (isCInterfaceVulkanLaunchCallOp(op))
189       translateVulkanLaunchCall(op);
190   });
191 }
192 
193 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
194     LLVM::CallOp vulkanLaunchCallOp) {
195   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
196   // for the given vulkan launch call.
197   auto spirvBlobAttr =
198       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
199   if (!spirvBlobAttr) {
200     vulkanLaunchCallOp.emitError()
201         << "missing " << kSPIRVBlobAttrName << " attribute";
202     return signalPassFailure();
203   }
204 
205   auto spirvEntryPointNameAttr =
206       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
207   if (!spirvEntryPointNameAttr) {
208     vulkanLaunchCallOp.emitError()
209         << "missing " << kSPIRVEntryPointAttrName << " attribute";
210     return signalPassFailure();
211   }
212 
213   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
214 }
215 
216 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
217     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
218   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
219       kVulkanLaunchNumConfigOperands)
220     return;
221   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
222   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
223 
224   // Create LLVM constant for the descriptor set index.
225   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
226   // pass does.
227   Value descriptorSet = builder.create<LLVM::ConstantOp>(
228       loc, getInt32Type(), builder.getI32IntegerAttr(0));
229 
230   for (auto en :
231        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
232            kVulkanLaunchNumConfigOperands))) {
233     // Create LLVM constant for the descriptor binding index.
234     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
235         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
236 
237     auto ptrToMemRefDescriptor = en.value();
238     uint32_t rank = 0;
239     LLVM::LLVMType type;
240     if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
241       cInterfaceVulkanLaunchCallOp.emitError()
242           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
243       return signalPassFailure();
244     }
245 
246     auto symbolName =
247         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
248     // Special case for fp16 type. Since it is not a supported type in C we use
249     // int16_t and bitcast the descriptor.
250     if (type.isHalfTy()) {
251       auto memRefTy =
252           getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect));
253       ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
254           loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
255     }
256     // Create call to `bindMemRef`.
257     builder.create<LLVM::CallOp>(
258         loc, ArrayRef<Type>{getVoidType()},
259         builder.getSymbolRefAttr(
260             StringRef(symbolName.data(), symbolName.size())),
261         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
262                         ptrToMemRefDescriptor});
263   }
264 }
265 
266 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
267     Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
268   auto llvmPtrDescriptorTy =
269       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
270   if (!llvmPtrDescriptorTy)
271     return failure();
272 
273   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
274   // template <typename Elem, size_t Rank>
275   // struct {
276   //   Elem *allocated;
277   //   Elem *aligned;
278   //   int64_t offset;
279   //   int64_t sizes[Rank]; // omitted when rank == 0
280   //   int64_t strides[Rank]; // omitted when rank == 0
281   // };
282   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
283     return failure();
284 
285   type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
286   if (llvmDescriptorTy.getStructNumElements() == 3) {
287     rank = 0;
288     return success();
289   }
290   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
291   return success();
292 }
293 
294 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
295   ModuleOp module = getOperation();
296   OpBuilder builder(module.getBody()->getTerminator());
297 
298   if (!module.lookupSymbol(kSetEntryPoint)) {
299     builder.create<LLVM::LLVMFuncOp>(
300         loc, kSetEntryPoint,
301         LLVM::LLVMType::getFunctionTy(getVoidType(),
302                                       {getPointerType(), getPointerType()},
303                                       /*isVarArg=*/false));
304   }
305 
306   if (!module.lookupSymbol(kSetNumWorkGroups)) {
307     builder.create<LLVM::LLVMFuncOp>(
308         loc, kSetNumWorkGroups,
309         LLVM::LLVMType::getFunctionTy(
310             getVoidType(),
311             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
312             /*isVarArg=*/false));
313   }
314 
315   if (!module.lookupSymbol(kSetBinaryShader)) {
316     builder.create<LLVM::LLVMFuncOp>(
317         loc, kSetBinaryShader,
318         LLVM::LLVMType::getFunctionTy(
319             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
320             /*isVarArg=*/false));
321   }
322 
323   if (!module.lookupSymbol(kRunOnVulkan)) {
324     builder.create<LLVM::LLVMFuncOp>(
325         loc, kRunOnVulkan,
326         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
327                                       /*isVarArg=*/false));
328   }
329 
330   for (unsigned i = 1; i <= 3; i++) {
331     for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect),
332                                 LLVM::LLVMType::getInt32Ty(llvmDialect),
333                                 LLVM::LLVMType::getInt16Ty(llvmDialect),
334                                 LLVM::LLVMType::getInt8Ty(llvmDialect),
335                                 LLVM::LLVMType::getHalfTy(llvmDialect)}) {
336       std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
337                            std::string(stringifyType(type));
338       if (type.isHalfTy())
339         type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect));
340       if (!module.lookupSymbol(fnName)) {
341         auto fnType = LLVM::LLVMType::getFunctionTy(
342             getVoidType(),
343             {getPointerType(), getInt32Type(), getInt32Type(),
344              getMemRefType(i, type).getPointerTo()},
345             /*isVarArg=*/false);
346         builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
347       }
348     }
349   }
350 
351   if (!module.lookupSymbol(kInitVulkan)) {
352     builder.create<LLVM::LLVMFuncOp>(
353         loc, kInitVulkan,
354         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
355                                       /*isVarArg=*/false));
356   }
357 
358   if (!module.lookupSymbol(kDeinitVulkan)) {
359     builder.create<LLVM::LLVMFuncOp>(
360         loc, kDeinitVulkan,
361         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
362                                       /*isVarArg=*/false));
363   }
364 }
365 
366 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
367     StringRef name, Location loc, OpBuilder &builder) {
368   SmallString<16> shaderName(name.begin(), name.end());
369   // Append `\0` to follow C style string given that LLVM::createGlobalString()
370   // won't handle this directly for us.
371   shaderName.push_back('\0');
372 
373   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
374   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
375                                   shaderName, LLVM::Linkage::Internal,
376                                   getLLVMDialect());
377 }
378 
379 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
380     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
381   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
382   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
383   // Create call to `initVulkan`.
384   auto initVulkanCall = builder.create<LLVM::CallOp>(
385       loc, ArrayRef<Type>{getPointerType()},
386       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
387   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
388   // need to pass that pointer to each Vulkan runtime call.
389   auto vulkanRuntime = initVulkanCall.getResult(0);
390 
391   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
392   // that data to runtime call.
393   Value ptrToSPIRVBinary = LLVM::createGlobalString(
394       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
395       LLVM::Linkage::Internal, getLLVMDialect());
396 
397   // Create LLVM constant for the size of SPIR-V binary shader.
398   Value binarySize = builder.create<LLVM::ConstantOp>(
399       loc, getInt32Type(),
400       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
401 
402   // Create call to `bindMemRef` for each memref operand.
403   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
404 
405   // Create call to `setBinaryShader` runtime function with the given pointer to
406   // SPIR-V binary and binary size.
407   builder.create<LLVM::CallOp>(
408       loc, ArrayRef<Type>{getVoidType()},
409       builder.getSymbolRefAttr(kSetBinaryShader),
410       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
411   // Create LLVM global with entry point name.
412   Value entryPointName = createEntryPointNameConstant(
413       spirvAttributes.second.getValue(), loc, builder);
414   // Create call to `setEntryPoint` runtime function with the given pointer to
415   // entry point name.
416   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
417                                builder.getSymbolRefAttr(kSetEntryPoint),
418                                ArrayRef<Value>{vulkanRuntime, entryPointName});
419 
420   // Create number of local workgroup for each dimension.
421   builder.create<LLVM::CallOp>(
422       loc, ArrayRef<Type>{getVoidType()},
423       builder.getSymbolRefAttr(kSetNumWorkGroups),
424       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
425                       cInterfaceVulkanLaunchCallOp.getOperand(1),
426                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
427 
428   // Create call to `runOnVulkan` runtime function.
429   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
430                                builder.getSymbolRefAttr(kRunOnVulkan),
431                                ArrayRef<Value>{vulkanRuntime});
432 
433   // Create call to 'deinitVulkan' runtime function.
434   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
435                                builder.getSymbolRefAttr(kDeinitVulkan),
436                                ArrayRef<Value>{vulkanRuntime});
437 
438   // Declare runtime functions.
439   declareVulkanFunctions(loc);
440 
441   cInterfaceVulkanLaunchCallOp.erase();
442 }
443 
444 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
445 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
446   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
447 }
448