1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/IR/Module.h"
24 
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 using namespace mlir;
29 
30 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
31 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
32 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat";
33 static constexpr const char *kCInterfaceVulkanLaunch =
34     "_mlir_ciface_vulkanLaunch";
35 static constexpr const char *kDeinitVulkan = "deinitVulkan";
36 static constexpr const char *kRunOnVulkan = "runOnVulkan";
37 static constexpr const char *kInitVulkan = "initVulkan";
38 static constexpr const char *kSetBinaryShader = "setBinaryShader";
39 static constexpr const char *kSetEntryPoint = "setEntryPoint";
40 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
41 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
42 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
43 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
44 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
45 
46 namespace {
47 
48 /// A pass to convert vulkan launch call op into a sequence of Vulkan
49 /// runtime calls in the following order:
50 ///
51 /// * initVulkan           -- initializes vulkan runtime
52 /// * bindMemRef           -- binds memref
53 /// * setBinaryShader      -- sets the binary shader data
54 /// * setEntryPoint        -- sets the entry point name
55 /// * setNumWorkGroups     -- sets the number of a local workgroups
56 /// * runOnVulkan          -- runs vulkan runtime
57 /// * deinitVulkan         -- deinitializes vulkan runtime
58 ///
59 class VulkanLaunchFuncToVulkanCallsPass
60     : public ConvertVulkanLaunchFuncToVulkanCallsBase<
61           VulkanLaunchFuncToVulkanCallsPass> {
62 private:
63   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
64 
65   llvm::LLVMContext &getLLVMContext() {
66     return getLLVMDialect()->getLLVMContext();
67   }
68 
69   void initializeCachedTypes() {
70     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
71     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
72     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
73     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
74     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
75     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
76     llvmMemRef1DFloat = getMemRefType(1);
77     llvmMemRef2DFloat = getMemRefType(2);
78     llvmMemRef3DFloat = getMemRefType(3);
79   }
80 
81   LLVM::LLVMType getMemRefType(uint32_t rank) {
82     // According to the MLIR doc memref argument is converted into a
83     // pointer-to-struct argument of type:
84     // template <typename Elem, size_t Rank>
85     // struct {
86     //   Elem *allocated;
87     //   Elem *aligned;
88     //   int64_t offset;
89     //   int64_t sizes[Rank]; // omitted when rank == 0
90     //   int64_t strides[Rank]; // omitted when rank == 0
91     // };
92     auto llvmPtrToFloatType = getFloatType().getPointerTo();
93     auto llvmArrayRankElementSizeType =
94         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
95 
96     // Create a type
97     // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`.
98     return LLVM::LLVMType::getStructTy(
99         llvmDialect,
100         {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
101          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
102   }
103 
104   LLVM::LLVMType getFloatType() { return llvmFloatType; }
105   LLVM::LLVMType getVoidType() { return llvmVoidType; }
106   LLVM::LLVMType getPointerType() { return llvmPointerType; }
107   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
108   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
109   LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
110   LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
111   LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; }
112 
113   /// Creates a LLVM global for the given `name`.
114   Value createEntryPointNameConstant(StringRef name, Location loc,
115                                      OpBuilder &builder);
116 
117   /// Declares all needed runtime functions.
118   void declareVulkanFunctions(Location loc);
119 
120   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
121   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
122     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
123             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
124   }
125 
126   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
127   /// op.
128   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
129     return (callOp.callee() &&
130             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
131             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
132   }
133 
134   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
135   /// runtime calls.
136   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
137 
138   /// Creates call to `bindMemRef` for each memref operand.
139   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
140                              Value vulkanRuntime);
141 
142   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
143   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
144 
145   /// Deduces a rank from the given 'ptrToMemRefDescriptor`.
146   LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
147 
148 public:
149   void runOnOperation() override;
150 
151 private:
152   LLVM::LLVMDialect *llvmDialect;
153   LLVM::LLVMType llvmFloatType;
154   LLVM::LLVMType llvmVoidType;
155   LLVM::LLVMType llvmPointerType;
156   LLVM::LLVMType llvmInt32Type;
157   LLVM::LLVMType llvmInt64Type;
158   LLVM::LLVMType llvmMemRef1DFloat;
159   LLVM::LLVMType llvmMemRef2DFloat;
160   LLVM::LLVMType llvmMemRef3DFloat;
161 
162   // TODO: Use an associative array to support multiple vulkan launch calls.
163   std::pair<StringAttr, StringAttr> spirvAttributes;
164   /// The number of vulkan launch configuration operands, placed at the leading
165   /// positions of the operand list.
166   static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
167 };
168 
169 } // anonymous namespace
170 
171 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
172   initializeCachedTypes();
173 
174   // Collect SPIR-V attributes such as `spirv_blob` and
175   // `spirv_entry_point_name`.
176   getOperation().walk([this](LLVM::CallOp op) {
177     if (isVulkanLaunchCallOp(op))
178       collectSPIRVAttributes(op);
179   });
180 
181   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
182   getOperation().walk([this](LLVM::CallOp op) {
183     if (isCInterfaceVulkanLaunchCallOp(op))
184       translateVulkanLaunchCall(op);
185   });
186 }
187 
188 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
189     LLVM::CallOp vulkanLaunchCallOp) {
190   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
191   // for the given vulkan launch call.
192   auto spirvBlobAttr =
193       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
194   if (!spirvBlobAttr) {
195     vulkanLaunchCallOp.emitError()
196         << "missing " << kSPIRVBlobAttrName << " attribute";
197     return signalPassFailure();
198   }
199 
200   auto spirvEntryPointNameAttr =
201       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
202   if (!spirvEntryPointNameAttr) {
203     vulkanLaunchCallOp.emitError()
204         << "missing " << kSPIRVEntryPointAttrName << " attribute";
205     return signalPassFailure();
206   }
207 
208   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
209 }
210 
211 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
212     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
213   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
214       kVulkanLaunchNumConfigOperands)
215     return;
216   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
217   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
218 
219   // Create LLVM constant for the descriptor set index.
220   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
221   // pass does.
222   Value descriptorSet = builder.create<LLVM::ConstantOp>(
223       loc, getInt32Type(), builder.getI32IntegerAttr(0));
224 
225   for (auto en :
226        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
227            kVulkanLaunchNumConfigOperands))) {
228     // Create LLVM constant for the descriptor binding index.
229     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
230         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
231 
232     auto ptrToMemRefDescriptor = en.value();
233     uint32_t rank = 0;
234     if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
235       cInterfaceVulkanLaunchCallOp.emitError()
236           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
237       return signalPassFailure();
238     }
239 
240     auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str();
241     // Create call to `bindMemRef`.
242     builder.create<LLVM::CallOp>(
243         loc, ArrayRef<Type>{getVoidType()},
244         builder.getSymbolRefAttr(
245             StringRef(symbolName.data(), symbolName.size())),
246         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
247                         ptrToMemRefDescriptor});
248   }
249 }
250 
251 LogicalResult
252 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
253                                                     uint32_t &rank) {
254   auto llvmPtrDescriptorTy =
255       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
256   if (!llvmPtrDescriptorTy)
257     return failure();
258 
259   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
260   // template <typename Elem, size_t Rank>
261   // struct {
262   //   Elem *allocated;
263   //   Elem *aligned;
264   //   int64_t offset;
265   //   int64_t sizes[Rank]; // omitted when rank == 0
266   //   int64_t strides[Rank]; // omitted when rank == 0
267   // };
268   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
269     return failure();
270   if (llvmDescriptorTy.getStructNumElements() == 3) {
271     rank = 0;
272     return success();
273   }
274 
275   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
276   return success();
277 }
278 
279 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
280   ModuleOp module = getOperation();
281   OpBuilder builder(module.getBody()->getTerminator());
282 
283   if (!module.lookupSymbol(kSetEntryPoint)) {
284     builder.create<LLVM::LLVMFuncOp>(
285         loc, kSetEntryPoint,
286         LLVM::LLVMType::getFunctionTy(getVoidType(),
287                                       {getPointerType(), getPointerType()},
288                                       /*isVarArg=*/false));
289   }
290 
291   if (!module.lookupSymbol(kSetNumWorkGroups)) {
292     builder.create<LLVM::LLVMFuncOp>(
293         loc, kSetNumWorkGroups,
294         LLVM::LLVMType::getFunctionTy(
295             getVoidType(),
296             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
297             /*isVarArg=*/false));
298   }
299 
300   if (!module.lookupSymbol(kSetBinaryShader)) {
301     builder.create<LLVM::LLVMFuncOp>(
302         loc, kSetBinaryShader,
303         LLVM::LLVMType::getFunctionTy(
304             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
305             /*isVarArg=*/false));
306   }
307 
308   if (!module.lookupSymbol(kRunOnVulkan)) {
309     builder.create<LLVM::LLVMFuncOp>(
310         loc, kRunOnVulkan,
311         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
312                                       /*isVarArg=*/false));
313   }
314 
315   if (!module.lookupSymbol(kBindMemRef1DFloat)) {
316     builder.create<LLVM::LLVMFuncOp>(
317         loc, kBindMemRef1DFloat,
318         LLVM::LLVMType::getFunctionTy(getVoidType(),
319                                       {getPointerType(), getInt32Type(),
320                                        getInt32Type(),
321                                        getMemRef1DFloat().getPointerTo()},
322                                       /*isVarArg=*/false));
323   }
324 
325   if (!module.lookupSymbol(kBindMemRef2DFloat)) {
326     builder.create<LLVM::LLVMFuncOp>(
327         loc, kBindMemRef2DFloat,
328         LLVM::LLVMType::getFunctionTy(getVoidType(),
329                                       {getPointerType(), getInt32Type(),
330                                        getInt32Type(),
331                                        getMemRef2DFloat().getPointerTo()},
332                                       /*isVarArg=*/false));
333   }
334 
335   if (!module.lookupSymbol(kBindMemRef3DFloat)) {
336     builder.create<LLVM::LLVMFuncOp>(
337         loc, kBindMemRef3DFloat,
338         LLVM::LLVMType::getFunctionTy(getVoidType(),
339                                       {getPointerType(), getInt32Type(),
340                                        getInt32Type(),
341                                        getMemRef3DFloat().getPointerTo()},
342                                       /*isVarArg=*/false));
343   }
344 
345   if (!module.lookupSymbol(kInitVulkan)) {
346     builder.create<LLVM::LLVMFuncOp>(
347         loc, kInitVulkan,
348         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
349                                       /*isVarArg=*/false));
350   }
351 
352   if (!module.lookupSymbol(kDeinitVulkan)) {
353     builder.create<LLVM::LLVMFuncOp>(
354         loc, kDeinitVulkan,
355         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
356                                       /*isVarArg=*/false));
357   }
358 }
359 
360 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
361     StringRef name, Location loc, OpBuilder &builder) {
362   SmallString<16> shaderName(name.begin(), name.end());
363   // Append `\0` to follow C style string given that LLVM::createGlobalString()
364   // won't handle this directly for us.
365   shaderName.push_back('\0');
366 
367   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
368   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
369                                   shaderName, LLVM::Linkage::Internal,
370                                   getLLVMDialect());
371 }
372 
373 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
374     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
375   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
376   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
377   // Create call to `initVulkan`.
378   auto initVulkanCall = builder.create<LLVM::CallOp>(
379       loc, ArrayRef<Type>{getPointerType()},
380       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
381   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
382   // need to pass that pointer to each Vulkan runtime call.
383   auto vulkanRuntime = initVulkanCall.getResult(0);
384 
385   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
386   // that data to runtime call.
387   Value ptrToSPIRVBinary = LLVM::createGlobalString(
388       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
389       LLVM::Linkage::Internal, getLLVMDialect());
390 
391   // Create LLVM constant for the size of SPIR-V binary shader.
392   Value binarySize = builder.create<LLVM::ConstantOp>(
393       loc, getInt32Type(),
394       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
395 
396   // Create call to `bindMemRef` for each memref operand.
397   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
398 
399   // Create call to `setBinaryShader` runtime function with the given pointer to
400   // SPIR-V binary and binary size.
401   builder.create<LLVM::CallOp>(
402       loc, ArrayRef<Type>{getVoidType()},
403       builder.getSymbolRefAttr(kSetBinaryShader),
404       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
405   // Create LLVM global with entry point name.
406   Value entryPointName = createEntryPointNameConstant(
407       spirvAttributes.second.getValue(), loc, builder);
408   // Create call to `setEntryPoint` runtime function with the given pointer to
409   // entry point name.
410   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
411                                builder.getSymbolRefAttr(kSetEntryPoint),
412                                ArrayRef<Value>{vulkanRuntime, entryPointName});
413 
414   // Create number of local workgroup for each dimension.
415   builder.create<LLVM::CallOp>(
416       loc, ArrayRef<Type>{getVoidType()},
417       builder.getSymbolRefAttr(kSetNumWorkGroups),
418       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
419                       cInterfaceVulkanLaunchCallOp.getOperand(1),
420                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
421 
422   // Create call to `runOnVulkan` runtime function.
423   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
424                                builder.getSymbolRefAttr(kRunOnVulkan),
425                                ArrayRef<Value>{vulkanRuntime});
426 
427   // Create call to 'deinitVulkan' runtime function.
428   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
429                                builder.getSymbolRefAttr(kDeinitVulkan),
430                                ArrayRef<Value>{vulkanRuntime});
431 
432   // Declare runtime functions.
433   declareVulkanFunctions(loc);
434 
435   cInterfaceVulkanLaunchCallOp.erase();
436 }
437 
438 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
439 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
440   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
441 }
442