1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/GPU/GPUDialect.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/Function.h"
24 #include "mlir/IR/Module.h"
25 
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 using namespace mlir;
30 
31 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
32 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
33 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat";
34 static constexpr const char *kCInterfaceVulkanLaunch =
35     "_mlir_ciface_vulkanLaunch";
36 static constexpr const char *kDeinitVulkan = "deinitVulkan";
37 static constexpr const char *kRunOnVulkan = "runOnVulkan";
38 static constexpr const char *kInitVulkan = "initVulkan";
39 static constexpr const char *kSetBinaryShader = "setBinaryShader";
40 static constexpr const char *kSetEntryPoint = "setEntryPoint";
41 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
42 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
43 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
44 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
45 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
46 
47 namespace {
48 
49 /// A pass to convert vulkan launch call op into a sequence of Vulkan
50 /// runtime calls in the following order:
51 ///
52 /// * initVulkan           -- initializes vulkan runtime
53 /// * bindMemRef           -- binds memref
54 /// * setBinaryShader      -- sets the binary shader data
55 /// * setEntryPoint        -- sets the entry point name
56 /// * setNumWorkGroups     -- sets the number of a local workgroups
57 /// * runOnVulkan          -- runs vulkan runtime
58 /// * deinitVulkan         -- deinitializes vulkan runtime
59 ///
60 class VulkanLaunchFuncToVulkanCallsPass
61     : public ConvertVulkanLaunchFuncToVulkanCallsBase<
62           VulkanLaunchFuncToVulkanCallsPass> {
63 private:
64   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
65 
66   llvm::LLVMContext &getLLVMContext() {
67     return getLLVMDialect()->getLLVMContext();
68   }
69 
70   void initializeCachedTypes() {
71     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
72     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
73     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
74     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
75     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
76     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
77     llvmMemRef1DFloat = getMemRefType(1);
78     llvmMemRef2DFloat = getMemRefType(2);
79     llvmMemRef3DFloat = getMemRefType(3);
80   }
81 
82   LLVM::LLVMType getMemRefType(uint32_t rank) {
83     // According to the MLIR doc memref argument is converted into a
84     // pointer-to-struct argument of type:
85     // template <typename Elem, size_t Rank>
86     // struct {
87     //   Elem *allocated;
88     //   Elem *aligned;
89     //   int64_t offset;
90     //   int64_t sizes[Rank]; // omitted when rank == 0
91     //   int64_t strides[Rank]; // omitted when rank == 0
92     // };
93     auto llvmPtrToFloatType = getFloatType().getPointerTo();
94     auto llvmArrayRankElementSizeType =
95         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
96 
97     // Create a type
98     // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`.
99     return LLVM::LLVMType::getStructTy(
100         llvmDialect,
101         {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
102          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
103   }
104 
105   LLVM::LLVMType getFloatType() { return llvmFloatType; }
106   LLVM::LLVMType getVoidType() { return llvmVoidType; }
107   LLVM::LLVMType getPointerType() { return llvmPointerType; }
108   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
109   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
110   LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
111   LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
112   LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; }
113 
114   /// Creates a LLVM global for the given `name`.
115   Value createEntryPointNameConstant(StringRef name, Location loc,
116                                      OpBuilder &builder);
117 
118   /// Declares all needed runtime functions.
119   void declareVulkanFunctions(Location loc);
120 
121   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
122   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
123     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
124             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
125   }
126 
127   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
128   /// op.
129   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
130     return (callOp.callee() &&
131             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
132             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
133   }
134 
135   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
136   /// runtime calls.
137   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
138 
139   /// Creates call to `bindMemRef` for each memref operand.
140   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
141                              Value vulkanRuntime);
142 
143   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
144   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
145 
146   /// Deduces a rank from the given 'ptrToMemRefDescriptor`.
147   LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
148 
149 public:
150   void runOnOperation() override;
151 
152 private:
153   LLVM::LLVMDialect *llvmDialect;
154   LLVM::LLVMType llvmFloatType;
155   LLVM::LLVMType llvmVoidType;
156   LLVM::LLVMType llvmPointerType;
157   LLVM::LLVMType llvmInt32Type;
158   LLVM::LLVMType llvmInt64Type;
159   LLVM::LLVMType llvmMemRef1DFloat;
160   LLVM::LLVMType llvmMemRef2DFloat;
161   LLVM::LLVMType llvmMemRef3DFloat;
162 
163   // TODO: Use an associative array to support multiple vulkan launch calls.
164   std::pair<StringAttr, StringAttr> spirvAttributes;
165 };
166 
167 } // anonymous namespace
168 
169 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
170   initializeCachedTypes();
171 
172   // Collect SPIR-V attributes such as `spirv_blob` and
173   // `spirv_entry_point_name`.
174   getOperation().walk([this](LLVM::CallOp op) {
175     if (isVulkanLaunchCallOp(op))
176       collectSPIRVAttributes(op);
177   });
178 
179   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
180   getOperation().walk([this](LLVM::CallOp op) {
181     if (isCInterfaceVulkanLaunchCallOp(op))
182       translateVulkanLaunchCall(op);
183   });
184 }
185 
186 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
187     LLVM::CallOp vulkanLaunchCallOp) {
188   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
189   // for the given vulkan launch call.
190   auto spirvBlobAttr =
191       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
192   if (!spirvBlobAttr) {
193     vulkanLaunchCallOp.emitError()
194         << "missing " << kSPIRVBlobAttrName << " attribute";
195     return signalPassFailure();
196   }
197 
198   auto spirvEntryPointNameAttr =
199       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
200   if (!spirvEntryPointNameAttr) {
201     vulkanLaunchCallOp.emitError()
202         << "missing " << kSPIRVEntryPointAttrName << " attribute";
203     return signalPassFailure();
204   }
205 
206   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
207 }
208 
209 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
210     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
211   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
212       gpu::LaunchOp::kNumConfigOperands)
213     return;
214   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
215   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
216 
217   // Create LLVM constant for the descriptor set index.
218   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
219   // pass does.
220   Value descriptorSet = builder.create<LLVM::ConstantOp>(
221       loc, getInt32Type(), builder.getI32IntegerAttr(0));
222 
223   for (auto en :
224        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
225            gpu::LaunchOp::kNumConfigOperands))) {
226     // Create LLVM constant for the descriptor binding index.
227     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
228         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
229 
230     auto ptrToMemRefDescriptor = en.value();
231     uint32_t rank = 0;
232     if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
233       cInterfaceVulkanLaunchCallOp.emitError()
234           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
235       return signalPassFailure();
236     }
237 
238     auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str();
239     // Create call to `bindMemRef`.
240     builder.create<LLVM::CallOp>(
241         loc, ArrayRef<Type>{getVoidType()},
242         builder.getSymbolRefAttr(
243             StringRef(symbolName.data(), symbolName.size())),
244         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
245                         ptrToMemRefDescriptor});
246   }
247 }
248 
249 LogicalResult
250 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
251                                                     uint32_t &rank) {
252   auto llvmPtrDescriptorTy =
253       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
254   if (!llvmPtrDescriptorTy)
255     return failure();
256 
257   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
258   // template <typename Elem, size_t Rank>
259   // struct {
260   //   Elem *allocated;
261   //   Elem *aligned;
262   //   int64_t offset;
263   //   int64_t sizes[Rank]; // omitted when rank == 0
264   //   int64_t strides[Rank]; // omitted when rank == 0
265   // };
266   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
267     return failure();
268   if (llvmDescriptorTy.getStructNumElements() == 3) {
269     rank = 0;
270     return success();
271   }
272 
273   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
274   return success();
275 }
276 
277 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
278   ModuleOp module = getOperation();
279   OpBuilder builder(module.getBody()->getTerminator());
280 
281   if (!module.lookupSymbol(kSetEntryPoint)) {
282     builder.create<LLVM::LLVMFuncOp>(
283         loc, kSetEntryPoint,
284         LLVM::LLVMType::getFunctionTy(getVoidType(),
285                                       {getPointerType(), getPointerType()},
286                                       /*isVarArg=*/false));
287   }
288 
289   if (!module.lookupSymbol(kSetNumWorkGroups)) {
290     builder.create<LLVM::LLVMFuncOp>(
291         loc, kSetNumWorkGroups,
292         LLVM::LLVMType::getFunctionTy(
293             getVoidType(),
294             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
295             /*isVarArg=*/false));
296   }
297 
298   if (!module.lookupSymbol(kSetBinaryShader)) {
299     builder.create<LLVM::LLVMFuncOp>(
300         loc, kSetBinaryShader,
301         LLVM::LLVMType::getFunctionTy(
302             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
303             /*isVarArg=*/false));
304   }
305 
306   if (!module.lookupSymbol(kRunOnVulkan)) {
307     builder.create<LLVM::LLVMFuncOp>(
308         loc, kRunOnVulkan,
309         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
310                                       /*isVarArg=*/false));
311   }
312 
313   if (!module.lookupSymbol(kBindMemRef1DFloat)) {
314     builder.create<LLVM::LLVMFuncOp>(
315         loc, kBindMemRef1DFloat,
316         LLVM::LLVMType::getFunctionTy(getVoidType(),
317                                       {getPointerType(), getInt32Type(),
318                                        getInt32Type(),
319                                        getMemRef1DFloat().getPointerTo()},
320                                       /*isVarArg=*/false));
321   }
322 
323   if (!module.lookupSymbol(kBindMemRef2DFloat)) {
324     builder.create<LLVM::LLVMFuncOp>(
325         loc, kBindMemRef2DFloat,
326         LLVM::LLVMType::getFunctionTy(getVoidType(),
327                                       {getPointerType(), getInt32Type(),
328                                        getInt32Type(),
329                                        getMemRef2DFloat().getPointerTo()},
330                                       /*isVarArg=*/false));
331   }
332 
333   if (!module.lookupSymbol(kBindMemRef3DFloat)) {
334     builder.create<LLVM::LLVMFuncOp>(
335         loc, kBindMemRef3DFloat,
336         LLVM::LLVMType::getFunctionTy(getVoidType(),
337                                       {getPointerType(), getInt32Type(),
338                                        getInt32Type(),
339                                        getMemRef3DFloat().getPointerTo()},
340                                       /*isVarArg=*/false));
341   }
342 
343   if (!module.lookupSymbol(kInitVulkan)) {
344     builder.create<LLVM::LLVMFuncOp>(
345         loc, kInitVulkan,
346         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
347                                       /*isVarArg=*/false));
348   }
349 
350   if (!module.lookupSymbol(kDeinitVulkan)) {
351     builder.create<LLVM::LLVMFuncOp>(
352         loc, kDeinitVulkan,
353         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
354                                       /*isVarArg=*/false));
355   }
356 }
357 
358 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
359     StringRef name, Location loc, OpBuilder &builder) {
360   SmallString<16> shaderName(name.begin(), name.end());
361   // Append `\0` to follow C style string given that LLVM::createGlobalString()
362   // won't handle this directly for us.
363   shaderName.push_back('\0');
364 
365   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
366   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
367                                   shaderName, LLVM::Linkage::Internal,
368                                   getLLVMDialect());
369 }
370 
371 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
372     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
373   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
374   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
375   // Create call to `initVulkan`.
376   auto initVulkanCall = builder.create<LLVM::CallOp>(
377       loc, ArrayRef<Type>{getPointerType()},
378       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
379   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
380   // need to pass that pointer to each Vulkan runtime call.
381   auto vulkanRuntime = initVulkanCall.getResult(0);
382 
383   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
384   // that data to runtime call.
385   Value ptrToSPIRVBinary = LLVM::createGlobalString(
386       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
387       LLVM::Linkage::Internal, getLLVMDialect());
388 
389   // Create LLVM constant for the size of SPIR-V binary shader.
390   Value binarySize = builder.create<LLVM::ConstantOp>(
391       loc, getInt32Type(),
392       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
393 
394   // Create call to `bindMemRef` for each memref operand.
395   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
396 
397   // Create call to `setBinaryShader` runtime function with the given pointer to
398   // SPIR-V binary and binary size.
399   builder.create<LLVM::CallOp>(
400       loc, ArrayRef<Type>{getVoidType()},
401       builder.getSymbolRefAttr(kSetBinaryShader),
402       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
403   // Create LLVM global with entry point name.
404   Value entryPointName = createEntryPointNameConstant(
405       spirvAttributes.second.getValue(), loc, builder);
406   // Create call to `setEntryPoint` runtime function with the given pointer to
407   // entry point name.
408   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
409                                builder.getSymbolRefAttr(kSetEntryPoint),
410                                ArrayRef<Value>{vulkanRuntime, entryPointName});
411 
412   // Create number of local workgroup for each dimension.
413   builder.create<LLVM::CallOp>(
414       loc, ArrayRef<Type>{getVoidType()},
415       builder.getSymbolRefAttr(kSetNumWorkGroups),
416       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
417                       cInterfaceVulkanLaunchCallOp.getOperand(1),
418                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
419 
420   // Create call to `runOnVulkan` runtime function.
421   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
422                                builder.getSymbolRefAttr(kRunOnVulkan),
423                                ArrayRef<Value>{vulkanRuntime});
424 
425   // Create call to 'deinitVulkan' runtime function.
426   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
427                                builder.getSymbolRefAttr(kDeinitVulkan),
428                                ArrayRef<Value>{vulkanRuntime});
429 
430   // Declare runtime functions.
431   declareVulkanFunctions(loc);
432 
433   cInterfaceVulkanLaunchCallOp.erase();
434 }
435 
436 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
437 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
438   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
439 }
440