1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/IR/Module.h"
24 
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 using namespace mlir;
29 
30 static constexpr const char *kCInterfaceVulkanLaunch =
31     "_mlir_ciface_vulkanLaunch";
32 static constexpr const char *kDeinitVulkan = "deinitVulkan";
33 static constexpr const char *kRunOnVulkan = "runOnVulkan";
34 static constexpr const char *kInitVulkan = "initVulkan";
35 static constexpr const char *kSetBinaryShader = "setBinaryShader";
36 static constexpr const char *kSetEntryPoint = "setEntryPoint";
37 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
38 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
39 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
40 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
41 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
42 
43 namespace {
44 
45 /// A pass to convert vulkan launch call op into a sequence of Vulkan
46 /// runtime calls in the following order:
47 ///
48 /// * initVulkan           -- initializes vulkan runtime
49 /// * bindMemRef           -- binds memref
50 /// * setBinaryShader      -- sets the binary shader data
51 /// * setEntryPoint        -- sets the entry point name
52 /// * setNumWorkGroups     -- sets the number of a local workgroups
53 /// * runOnVulkan          -- runs vulkan runtime
54 /// * deinitVulkan         -- deinitializes vulkan runtime
55 ///
56 class VulkanLaunchFuncToVulkanCallsPass
57     : public ConvertVulkanLaunchFuncToVulkanCallsBase<
58           VulkanLaunchFuncToVulkanCallsPass> {
59 private:
60   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
61 
62   void initializeCachedTypes() {
63     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
64     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
65     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
66     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
67     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
68     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
69   }
70 
71   LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
72     // According to the MLIR doc memref argument is converted into a
73     // pointer-to-struct argument of type:
74     // template <typename Elem, size_t Rank>
75     // struct {
76     //   Elem *allocated;
77     //   Elem *aligned;
78     //   int64_t offset;
79     //   int64_t sizes[Rank]; // omitted when rank == 0
80     //   int64_t strides[Rank]; // omitted when rank == 0
81     // };
82     auto llvmPtrToElementType = elemenType.getPointerTo();
83     auto llvmArrayRankElementSizeType =
84         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
85 
86     // Create a type
87     // `!llvm<"{ `element-type`*, `element-type`*, i64,
88     // [`rank` x i64], [`rank` x i64]}">`.
89     return LLVM::LLVMType::getStructTy(
90         llvmDialect,
91         {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
92          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
93   }
94 
95   LLVM::LLVMType getVoidType() { return llvmVoidType; }
96   LLVM::LLVMType getPointerType() { return llvmPointerType; }
97   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
98   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
99 
100   /// Creates an LLVM global for the given `name`.
101   Value createEntryPointNameConstant(StringRef name, Location loc,
102                                      OpBuilder &builder);
103 
104   /// Declares all needed runtime functions.
105   void declareVulkanFunctions(Location loc);
106 
107   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
108   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
109     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
110             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
111   }
112 
113   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
114   /// op.
115   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
116     return (callOp.callee() &&
117             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
118             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
119   }
120 
121   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
122   /// runtime calls.
123   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
124 
125   /// Creates call to `bindMemRef` for each memref operand.
126   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
127                              Value vulkanRuntime);
128 
129   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
130   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
131 
132   /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
133   LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
134                                         uint32_t &rank, LLVM::LLVMType &type);
135 
136   /// Returns a string representation from the given `type`.
137   StringRef stringifyType(LLVM::LLVMType type) {
138     if (type.isFloatTy())
139       return "Float";
140     if (type.isHalfTy())
141       return "Half";
142     if (type.isIntegerTy(32))
143       return "Int32";
144     if (type.isIntegerTy(16))
145       return "Int16";
146     if (type.isIntegerTy(8))
147       return "Int8";
148 
149     llvm_unreachable("unsupported type");
150   }
151 
152 public:
153   void runOnOperation() override;
154 
155 private:
156   LLVM::LLVMDialect *llvmDialect;
157   LLVM::LLVMType llvmFloatType;
158   LLVM::LLVMType llvmVoidType;
159   LLVM::LLVMType llvmPointerType;
160   LLVM::LLVMType llvmInt32Type;
161   LLVM::LLVMType llvmInt64Type;
162 
163   // TODO: Use an associative array to support multiple vulkan launch calls.
164   std::pair<StringAttr, StringAttr> spirvAttributes;
165   /// The number of vulkan launch configuration operands, placed at the leading
166   /// positions of the operand list.
167   static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
168 };
169 
170 } // anonymous namespace
171 
172 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
173   initializeCachedTypes();
174 
175   // Collect SPIR-V attributes such as `spirv_blob` and
176   // `spirv_entry_point_name`.
177   getOperation().walk([this](LLVM::CallOp op) {
178     if (isVulkanLaunchCallOp(op))
179       collectSPIRVAttributes(op);
180   });
181 
182   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
183   getOperation().walk([this](LLVM::CallOp op) {
184     if (isCInterfaceVulkanLaunchCallOp(op))
185       translateVulkanLaunchCall(op);
186   });
187 }
188 
189 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
190     LLVM::CallOp vulkanLaunchCallOp) {
191   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
192   // for the given vulkan launch call.
193   auto spirvBlobAttr =
194       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
195   if (!spirvBlobAttr) {
196     vulkanLaunchCallOp.emitError()
197         << "missing " << kSPIRVBlobAttrName << " attribute";
198     return signalPassFailure();
199   }
200 
201   auto spirvEntryPointNameAttr =
202       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
203   if (!spirvEntryPointNameAttr) {
204     vulkanLaunchCallOp.emitError()
205         << "missing " << kSPIRVEntryPointAttrName << " attribute";
206     return signalPassFailure();
207   }
208 
209   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
210 }
211 
212 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
213     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
214   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
215       kVulkanLaunchNumConfigOperands)
216     return;
217   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
218   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
219 
220   // Create LLVM constant for the descriptor set index.
221   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
222   // pass does.
223   Value descriptorSet = builder.create<LLVM::ConstantOp>(
224       loc, getInt32Type(), builder.getI32IntegerAttr(0));
225 
226   for (auto en :
227        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
228            kVulkanLaunchNumConfigOperands))) {
229     // Create LLVM constant for the descriptor binding index.
230     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
231         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
232 
233     auto ptrToMemRefDescriptor = en.value();
234     uint32_t rank = 0;
235     LLVM::LLVMType type;
236     if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
237       cInterfaceVulkanLaunchCallOp.emitError()
238           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
239       return signalPassFailure();
240     }
241 
242     auto symbolName =
243         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
244     // Special case for fp16 type. Since it is not a supported type in C we use
245     // int16_t and bitcast the descriptor.
246     if (type.isHalfTy()) {
247       auto memRefTy =
248           getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect));
249       ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
250           loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
251     }
252     // Create call to `bindMemRef`.
253     builder.create<LLVM::CallOp>(
254         loc, ArrayRef<Type>{getVoidType()},
255         builder.getSymbolRefAttr(
256             StringRef(symbolName.data(), symbolName.size())),
257         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
258                         ptrToMemRefDescriptor});
259   }
260 }
261 
262 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
263     Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
264   auto llvmPtrDescriptorTy =
265       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
266   if (!llvmPtrDescriptorTy)
267     return failure();
268 
269   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
270   // template <typename Elem, size_t Rank>
271   // struct {
272   //   Elem *allocated;
273   //   Elem *aligned;
274   //   int64_t offset;
275   //   int64_t sizes[Rank]; // omitted when rank == 0
276   //   int64_t strides[Rank]; // omitted when rank == 0
277   // };
278   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
279     return failure();
280 
281   type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
282   if (llvmDescriptorTy.getStructNumElements() == 3) {
283     rank = 0;
284     return success();
285   }
286   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
287   return success();
288 }
289 
290 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
291   ModuleOp module = getOperation();
292   OpBuilder builder(module.getBody()->getTerminator());
293 
294   if (!module.lookupSymbol(kSetEntryPoint)) {
295     builder.create<LLVM::LLVMFuncOp>(
296         loc, kSetEntryPoint,
297         LLVM::LLVMType::getFunctionTy(getVoidType(),
298                                       {getPointerType(), getPointerType()},
299                                       /*isVarArg=*/false));
300   }
301 
302   if (!module.lookupSymbol(kSetNumWorkGroups)) {
303     builder.create<LLVM::LLVMFuncOp>(
304         loc, kSetNumWorkGroups,
305         LLVM::LLVMType::getFunctionTy(
306             getVoidType(),
307             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
308             /*isVarArg=*/false));
309   }
310 
311   if (!module.lookupSymbol(kSetBinaryShader)) {
312     builder.create<LLVM::LLVMFuncOp>(
313         loc, kSetBinaryShader,
314         LLVM::LLVMType::getFunctionTy(
315             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
316             /*isVarArg=*/false));
317   }
318 
319   if (!module.lookupSymbol(kRunOnVulkan)) {
320     builder.create<LLVM::LLVMFuncOp>(
321         loc, kRunOnVulkan,
322         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
323                                       /*isVarArg=*/false));
324   }
325 
326   for (unsigned i = 1; i <= 3; i++) {
327     for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect),
328                                 LLVM::LLVMType::getInt32Ty(llvmDialect),
329                                 LLVM::LLVMType::getInt16Ty(llvmDialect),
330                                 LLVM::LLVMType::getInt8Ty(llvmDialect),
331                                 LLVM::LLVMType::getHalfTy(llvmDialect)}) {
332       std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
333                            std::string(stringifyType(type));
334       if (type.isHalfTy())
335         type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect));
336       if (!module.lookupSymbol(fnName)) {
337         auto fnType = LLVM::LLVMType::getFunctionTy(
338             getVoidType(),
339             {getPointerType(), getInt32Type(), getInt32Type(),
340              getMemRefType(i, type).getPointerTo()},
341             /*isVarArg=*/false);
342         builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
343       }
344     }
345   }
346 
347   if (!module.lookupSymbol(kInitVulkan)) {
348     builder.create<LLVM::LLVMFuncOp>(
349         loc, kInitVulkan,
350         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
351                                       /*isVarArg=*/false));
352   }
353 
354   if (!module.lookupSymbol(kDeinitVulkan)) {
355     builder.create<LLVM::LLVMFuncOp>(
356         loc, kDeinitVulkan,
357         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
358                                       /*isVarArg=*/false));
359   }
360 }
361 
362 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
363     StringRef name, Location loc, OpBuilder &builder) {
364   SmallString<16> shaderName(name.begin(), name.end());
365   // Append `\0` to follow C style string given that LLVM::createGlobalString()
366   // won't handle this directly for us.
367   shaderName.push_back('\0');
368 
369   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
370   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
371                                   shaderName, LLVM::Linkage::Internal,
372                                   getLLVMDialect());
373 }
374 
375 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
376     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
377   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
378   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
379   // Create call to `initVulkan`.
380   auto initVulkanCall = builder.create<LLVM::CallOp>(
381       loc, ArrayRef<Type>{getPointerType()},
382       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
383   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
384   // need to pass that pointer to each Vulkan runtime call.
385   auto vulkanRuntime = initVulkanCall.getResult(0);
386 
387   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
388   // that data to runtime call.
389   Value ptrToSPIRVBinary = LLVM::createGlobalString(
390       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
391       LLVM::Linkage::Internal, getLLVMDialect());
392 
393   // Create LLVM constant for the size of SPIR-V binary shader.
394   Value binarySize = builder.create<LLVM::ConstantOp>(
395       loc, getInt32Type(),
396       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
397 
398   // Create call to `bindMemRef` for each memref operand.
399   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
400 
401   // Create call to `setBinaryShader` runtime function with the given pointer to
402   // SPIR-V binary and binary size.
403   builder.create<LLVM::CallOp>(
404       loc, ArrayRef<Type>{getVoidType()},
405       builder.getSymbolRefAttr(kSetBinaryShader),
406       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
407   // Create LLVM global with entry point name.
408   Value entryPointName = createEntryPointNameConstant(
409       spirvAttributes.second.getValue(), loc, builder);
410   // Create call to `setEntryPoint` runtime function with the given pointer to
411   // entry point name.
412   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
413                                builder.getSymbolRefAttr(kSetEntryPoint),
414                                ArrayRef<Value>{vulkanRuntime, entryPointName});
415 
416   // Create number of local workgroup for each dimension.
417   builder.create<LLVM::CallOp>(
418       loc, ArrayRef<Type>{getVoidType()},
419       builder.getSymbolRefAttr(kSetNumWorkGroups),
420       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
421                       cInterfaceVulkanLaunchCallOp.getOperand(1),
422                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
423 
424   // Create call to `runOnVulkan` runtime function.
425   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
426                                builder.getSymbolRefAttr(kRunOnVulkan),
427                                ArrayRef<Value>{vulkanRuntime});
428 
429   // Create call to 'deinitVulkan' runtime function.
430   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
431                                builder.getSymbolRefAttr(kDeinitVulkan),
432                                ArrayRef<Value>{vulkanRuntime});
433 
434   // Declare runtime functions.
435   declareVulkanFunctions(loc);
436 
437   cInterfaceVulkanLaunchCallOp.erase();
438 }
439 
440 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
441 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
442   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
443 }
444