1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
18 #include "mlir/Dialect/GPU/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/Pass/Pass.h"
25 
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 using namespace mlir;
30 
31 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
32 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
33 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat";
34 static constexpr const char *kCInterfaceVulkanLaunch =
35     "_mlir_ciface_vulkanLaunch";
36 static constexpr const char *kDeinitVulkan = "deinitVulkan";
37 static constexpr const char *kRunOnVulkan = "runOnVulkan";
38 static constexpr const char *kInitVulkan = "initVulkan";
39 static constexpr const char *kSetBinaryShader = "setBinaryShader";
40 static constexpr const char *kSetEntryPoint = "setEntryPoint";
41 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
42 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
43 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
44 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
45 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
46 
47 namespace {
48 
49 /// A pass to convert vulkan launch call op into a sequence of Vulkan
50 /// runtime calls in the following order:
51 ///
52 /// * initVulkan           -- initializes vulkan runtime
53 /// * bindMemRef           -- binds memref
54 /// * setBinaryShader      -- sets the binary shader data
55 /// * setEntryPoint        -- sets the entry point name
56 /// * setNumWorkGroups     -- sets the number of a local workgroups
57 /// * runOnVulkan          -- runs vulkan runtime
58 /// * deinitVulkan         -- deinitializes vulkan runtime
59 ///
60 class VulkanLaunchFuncToVulkanCallsPass
61     : public PassWrapper<VulkanLaunchFuncToVulkanCallsPass,
62                          OperationPass<ModuleOp>> {
63 private:
64 /// Include the generated pass utilities.
65 #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls
66 #include "mlir/Conversion/Passes.h.inc"
67 
68   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
69 
70   llvm::LLVMContext &getLLVMContext() {
71     return getLLVMDialect()->getLLVMContext();
72   }
73 
74   void initializeCachedTypes() {
75     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
76     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
77     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
78     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
79     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
80     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
81     llvmMemRef1DFloat = getMemRefType(1);
82     llvmMemRef2DFloat = getMemRefType(2);
83     llvmMemRef3DFloat = getMemRefType(3);
84   }
85 
86   LLVM::LLVMType getMemRefType(uint32_t rank) {
87     // According to the MLIR doc memref argument is converted into a
88     // pointer-to-struct argument of type:
89     // template <typename Elem, size_t Rank>
90     // struct {
91     //   Elem *allocated;
92     //   Elem *aligned;
93     //   int64_t offset;
94     //   int64_t sizes[Rank]; // omitted when rank == 0
95     //   int64_t strides[Rank]; // omitted when rank == 0
96     // };
97     auto llvmPtrToFloatType = getFloatType().getPointerTo();
98     auto llvmArrayRankElementSizeType =
99         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
100 
101     // Create a type
102     // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`.
103     return LLVM::LLVMType::getStructTy(
104         llvmDialect,
105         {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
106          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
107   }
108 
109   LLVM::LLVMType getFloatType() { return llvmFloatType; }
110   LLVM::LLVMType getVoidType() { return llvmVoidType; }
111   LLVM::LLVMType getPointerType() { return llvmPointerType; }
112   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
113   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
114   LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
115   LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
116   LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; }
117 
118   /// Creates a LLVM global for the given `name`.
119   Value createEntryPointNameConstant(StringRef name, Location loc,
120                                      OpBuilder &builder);
121 
122   /// Declares all needed runtime functions.
123   void declareVulkanFunctions(Location loc);
124 
125   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
126   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
127     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
128             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
129   }
130 
131   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
132   /// op.
133   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
134     return (callOp.callee() &&
135             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
136             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
137   }
138 
139   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
140   /// runtime calls.
141   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
142 
143   /// Creates call to `bindMemRef` for each memref operand.
144   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
145                              Value vulkanRuntime);
146 
147   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
148   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
149 
150   /// Deduces a rank from the given 'ptrToMemRefDescriptor`.
151   LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
152 
153 public:
154   void runOnOperation() override;
155 
156 private:
157   LLVM::LLVMDialect *llvmDialect;
158   LLVM::LLVMType llvmFloatType;
159   LLVM::LLVMType llvmVoidType;
160   LLVM::LLVMType llvmPointerType;
161   LLVM::LLVMType llvmInt32Type;
162   LLVM::LLVMType llvmInt64Type;
163   LLVM::LLVMType llvmMemRef1DFloat;
164   LLVM::LLVMType llvmMemRef2DFloat;
165   LLVM::LLVMType llvmMemRef3DFloat;
166 
167   // TODO: Use an associative array to support multiple vulkan launch calls.
168   std::pair<StringAttr, StringAttr> spirvAttributes;
169 };
170 
171 } // anonymous namespace
172 
173 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
174   initializeCachedTypes();
175 
176   // Collect SPIR-V attributes such as `spirv_blob` and
177   // `spirv_entry_point_name`.
178   getOperation().walk([this](LLVM::CallOp op) {
179     if (isVulkanLaunchCallOp(op))
180       collectSPIRVAttributes(op);
181   });
182 
183   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
184   getOperation().walk([this](LLVM::CallOp op) {
185     if (isCInterfaceVulkanLaunchCallOp(op))
186       translateVulkanLaunchCall(op);
187   });
188 }
189 
190 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
191     LLVM::CallOp vulkanLaunchCallOp) {
192   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
193   // for the given vulkan launch call.
194   auto spirvBlobAttr =
195       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
196   if (!spirvBlobAttr) {
197     vulkanLaunchCallOp.emitError()
198         << "missing " << kSPIRVBlobAttrName << " attribute";
199     return signalPassFailure();
200   }
201 
202   auto spirvEntryPointNameAttr =
203       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
204   if (!spirvEntryPointNameAttr) {
205     vulkanLaunchCallOp.emitError()
206         << "missing " << kSPIRVEntryPointAttrName << " attribute";
207     return signalPassFailure();
208   }
209 
210   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
211 }
212 
213 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
214     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
215   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
216       gpu::LaunchOp::kNumConfigOperands)
217     return;
218   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
219   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
220 
221   // Create LLVM constant for the descriptor set index.
222   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
223   // pass does.
224   Value descriptorSet = builder.create<LLVM::ConstantOp>(
225       loc, getInt32Type(), builder.getI32IntegerAttr(0));
226 
227   for (auto en :
228        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
229            gpu::LaunchOp::kNumConfigOperands))) {
230     // Create LLVM constant for the descriptor binding index.
231     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
232         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
233 
234     auto ptrToMemRefDescriptor = en.value();
235     uint32_t rank = 0;
236     if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
237       cInterfaceVulkanLaunchCallOp.emitError()
238           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
239       return signalPassFailure();
240     }
241 
242     auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str();
243     // Create call to `bindMemRef`.
244     builder.create<LLVM::CallOp>(
245         loc, ArrayRef<Type>{getVoidType()},
246         builder.getSymbolRefAttr(
247             StringRef(symbolName.data(), symbolName.size())),
248         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
249                         ptrToMemRefDescriptor});
250   }
251 }
252 
253 LogicalResult
254 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
255                                                     uint32_t &rank) {
256   auto llvmPtrDescriptorTy =
257       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
258   if (!llvmPtrDescriptorTy)
259     return failure();
260 
261   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
262   // template <typename Elem, size_t Rank>
263   // struct {
264   //   Elem *allocated;
265   //   Elem *aligned;
266   //   int64_t offset;
267   //   int64_t sizes[Rank]; // omitted when rank == 0
268   //   int64_t strides[Rank]; // omitted when rank == 0
269   // };
270   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
271     return failure();
272   if (llvmDescriptorTy.getStructNumElements() == 3) {
273     rank = 0;
274     return success();
275   }
276 
277   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
278   return success();
279 }
280 
281 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
282   ModuleOp module = getOperation();
283   OpBuilder builder(module.getBody()->getTerminator());
284 
285   if (!module.lookupSymbol(kSetEntryPoint)) {
286     builder.create<LLVM::LLVMFuncOp>(
287         loc, kSetEntryPoint,
288         LLVM::LLVMType::getFunctionTy(getVoidType(),
289                                       {getPointerType(), getPointerType()},
290                                       /*isVarArg=*/false));
291   }
292 
293   if (!module.lookupSymbol(kSetNumWorkGroups)) {
294     builder.create<LLVM::LLVMFuncOp>(
295         loc, kSetNumWorkGroups,
296         LLVM::LLVMType::getFunctionTy(
297             getVoidType(),
298             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
299             /*isVarArg=*/false));
300   }
301 
302   if (!module.lookupSymbol(kSetBinaryShader)) {
303     builder.create<LLVM::LLVMFuncOp>(
304         loc, kSetBinaryShader,
305         LLVM::LLVMType::getFunctionTy(
306             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
307             /*isVarArg=*/false));
308   }
309 
310   if (!module.lookupSymbol(kRunOnVulkan)) {
311     builder.create<LLVM::LLVMFuncOp>(
312         loc, kRunOnVulkan,
313         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
314                                       /*isVarArg=*/false));
315   }
316 
317   if (!module.lookupSymbol(kBindMemRef1DFloat)) {
318     builder.create<LLVM::LLVMFuncOp>(
319         loc, kBindMemRef1DFloat,
320         LLVM::LLVMType::getFunctionTy(getVoidType(),
321                                       {getPointerType(), getInt32Type(),
322                                        getInt32Type(),
323                                        getMemRef1DFloat().getPointerTo()},
324                                       /*isVarArg=*/false));
325   }
326 
327   if (!module.lookupSymbol(kBindMemRef2DFloat)) {
328     builder.create<LLVM::LLVMFuncOp>(
329         loc, kBindMemRef2DFloat,
330         LLVM::LLVMType::getFunctionTy(getVoidType(),
331                                       {getPointerType(), getInt32Type(),
332                                        getInt32Type(),
333                                        getMemRef2DFloat().getPointerTo()},
334                                       /*isVarArg=*/false));
335   }
336 
337   if (!module.lookupSymbol(kBindMemRef3DFloat)) {
338     builder.create<LLVM::LLVMFuncOp>(
339         loc, kBindMemRef3DFloat,
340         LLVM::LLVMType::getFunctionTy(getVoidType(),
341                                       {getPointerType(), getInt32Type(),
342                                        getInt32Type(),
343                                        getMemRef3DFloat().getPointerTo()},
344                                       /*isVarArg=*/false));
345   }
346 
347   if (!module.lookupSymbol(kInitVulkan)) {
348     builder.create<LLVM::LLVMFuncOp>(
349         loc, kInitVulkan,
350         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
351                                       /*isVarArg=*/false));
352   }
353 
354   if (!module.lookupSymbol(kDeinitVulkan)) {
355     builder.create<LLVM::LLVMFuncOp>(
356         loc, kDeinitVulkan,
357         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
358                                       /*isVarArg=*/false));
359   }
360 }
361 
362 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
363     StringRef name, Location loc, OpBuilder &builder) {
364   SmallString<16> shaderName(name.begin(), name.end());
365   // Append `\0` to follow C style string given that LLVM::createGlobalString()
366   // won't handle this directly for us.
367   shaderName.push_back('\0');
368 
369   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
370   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
371                                   shaderName, LLVM::Linkage::Internal,
372                                   getLLVMDialect());
373 }
374 
375 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
376     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
377   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
378   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
379   // Create call to `initVulkan`.
380   auto initVulkanCall = builder.create<LLVM::CallOp>(
381       loc, ArrayRef<Type>{getPointerType()},
382       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
383   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
384   // need to pass that pointer to each Vulkan runtime call.
385   auto vulkanRuntime = initVulkanCall.getResult(0);
386 
387   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
388   // that data to runtime call.
389   Value ptrToSPIRVBinary = LLVM::createGlobalString(
390       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
391       LLVM::Linkage::Internal, getLLVMDialect());
392 
393   // Create LLVM constant for the size of SPIR-V binary shader.
394   Value binarySize = builder.create<LLVM::ConstantOp>(
395       loc, getInt32Type(),
396       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
397 
398   // Create call to `bindMemRef` for each memref operand.
399   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
400 
401   // Create call to `setBinaryShader` runtime function with the given pointer to
402   // SPIR-V binary and binary size.
403   builder.create<LLVM::CallOp>(
404       loc, ArrayRef<Type>{getVoidType()},
405       builder.getSymbolRefAttr(kSetBinaryShader),
406       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
407   // Create LLVM global with entry point name.
408   Value entryPointName = createEntryPointNameConstant(
409       spirvAttributes.second.getValue(), loc, builder);
410   // Create call to `setEntryPoint` runtime function with the given pointer to
411   // entry point name.
412   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
413                                builder.getSymbolRefAttr(kSetEntryPoint),
414                                ArrayRef<Value>{vulkanRuntime, entryPointName});
415 
416   // Create number of local workgroup for each dimension.
417   builder.create<LLVM::CallOp>(
418       loc, ArrayRef<Type>{getVoidType()},
419       builder.getSymbolRefAttr(kSetNumWorkGroups),
420       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
421                       cInterfaceVulkanLaunchCallOp.getOperand(1),
422                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
423 
424   // Create call to `runOnVulkan` runtime function.
425   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
426                                builder.getSymbolRefAttr(kRunOnVulkan),
427                                ArrayRef<Value>{vulkanRuntime});
428 
429   // Create call to 'deinitVulkan' runtime function.
430   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
431                                builder.getSymbolRefAttr(kDeinitVulkan),
432                                ArrayRef<Value>{vulkanRuntime});
433 
434   // Declare runtime functions.
435   declareVulkanFunctions(loc);
436 
437   cInterfaceVulkanLaunchCallOp.erase();
438 }
439 
440 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
441 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
442   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
443 }
444