1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
18 #include "mlir/Dialect/GPU/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/Pass/Pass.h"
25 
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 using namespace mlir;
30 
31 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
32 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
33 static constexpr const char *kCInterfaceVulkanLaunch =
34     "_mlir_ciface_vulkanLaunch";
35 static constexpr const char *kDeinitVulkan = "deinitVulkan";
36 static constexpr const char *kRunOnVulkan = "runOnVulkan";
37 static constexpr const char *kInitVulkan = "initVulkan";
38 static constexpr const char *kSetBinaryShader = "setBinaryShader";
39 static constexpr const char *kSetEntryPoint = "setEntryPoint";
40 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
41 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
42 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
43 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
44 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
45 
46 namespace {
47 
48 /// A pass to convert vulkan launch call op into a sequence of Vulkan
49 /// runtime calls in the following order:
50 ///
51 /// * initVulkan           -- initializes vulkan runtime
52 /// * bindMemRef           -- binds memref
53 /// * setBinaryShader      -- sets the binary shader data
54 /// * setEntryPoint        -- sets the entry point name
55 /// * setNumWorkGroups     -- sets the number of a local workgroups
56 /// * runOnVulkan          -- runs vulkan runtime
57 /// * deinitVulkan         -- deinitializes vulkan runtime
58 ///
59 class VulkanLaunchFuncToVulkanCallsPass
60     : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
61 private:
62   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
63 
64   llvm::LLVMContext &getLLVMContext() {
65     return getLLVMDialect()->getLLVMContext();
66   }
67 
68   void initializeCachedTypes() {
69     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
70     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
71     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
72     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
73     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
74     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
75     initializeMemRefTypes();
76   }
77 
78   void initializeMemRefTypes() {
79     // According to the MLIR doc memref argument is converted into a
80     // pointer-to-struct argument of type:
81     // template <typename Elem, size_t Rank>
82     // struct {
83     //   Elem *allocated;
84     //   Elem *aligned;
85     //   int64_t offset;
86     //   int64_t sizes[Rank]; // omitted when rank == 0
87     //   int64_t strides[Rank]; // omitted when rank == 0
88     // };
89     auto llvmPtrToFloatType = getFloatType().getPointerTo();
90     auto llvmArrayOneElementSizeType =
91         LLVM::LLVMType::getArrayTy(getInt64Type(), 1);
92     auto llvmArrayTwoElementSizeType =
93         LLVM::LLVMType::getArrayTy(getInt64Type(), 2);
94 
95     // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`.
96     llvmMemRef1DFloat = LLVM::LLVMType::getStructTy(
97         llvmDialect,
98         {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
99          llvmArrayOneElementSizeType, llvmArrayOneElementSizeType});
100 
101     // Create a type `!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64]}">`.
102     llvmMemRef2DFloat = LLVM::LLVMType::getStructTy(
103         llvmDialect,
104         {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
105          llvmArrayTwoElementSizeType, llvmArrayTwoElementSizeType});
106   }
107 
108   LLVM::LLVMType getFloatType() { return llvmFloatType; }
109   LLVM::LLVMType getVoidType() { return llvmVoidType; }
110   LLVM::LLVMType getPointerType() { return llvmPointerType; }
111   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
112   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
113   LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
114   LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
115 
116   /// Creates a LLVM global for the given `name`.
117   Value createEntryPointNameConstant(StringRef name, Location loc,
118                                      OpBuilder &builder);
119 
120   /// Declares all needed runtime functions.
121   void declareVulkanFunctions(Location loc);
122 
123   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
124   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
125     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
126             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
127   }
128 
129   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
130   /// op.
131   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
132     return (callOp.callee() &&
133             callOp.callee().getValue() == kCInterfaceVulkanLaunch &&
134             callOp.getNumOperands() >= gpu::LaunchOp::kNumConfigOperands);
135   }
136 
137   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
138   /// runtime calls.
139   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
140 
141   /// Creates call to `bindMemRef` for each memref operand.
142   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
143                              Value vulkanRuntime);
144 
145   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
146   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
147 
148   /// Deduces a rank from the given 'ptrToMemRefDescriptor`.
149   LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
150 
151 public:
152   void runOnModule() override;
153 
154 private:
155   LLVM::LLVMDialect *llvmDialect;
156   LLVM::LLVMType llvmFloatType;
157   LLVM::LLVMType llvmVoidType;
158   LLVM::LLVMType llvmPointerType;
159   LLVM::LLVMType llvmInt32Type;
160   LLVM::LLVMType llvmInt64Type;
161   LLVM::LLVMType llvmMemRef1DFloat;
162   LLVM::LLVMType llvmMemRef2DFloat;
163 
164   // TODO: Use an associative array to support multiple vulkan launch calls.
165   std::pair<StringAttr, StringAttr> spirvAttributes;
166 };
167 
168 } // anonymous namespace
169 
170 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
171   initializeCachedTypes();
172 
173   // Collect SPIR-V attributes such as `spirv_blob` and
174   // `spirv_entry_point_name`.
175   getModule().walk([this](LLVM::CallOp op) {
176     if (isVulkanLaunchCallOp(op))
177       collectSPIRVAttributes(op);
178   });
179 
180   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
181   getModule().walk([this](LLVM::CallOp op) {
182     if (isCInterfaceVulkanLaunchCallOp(op))
183       translateVulkanLaunchCall(op);
184   });
185 }
186 
187 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
188     LLVM::CallOp vulkanLaunchCallOp) {
189   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
190   // for the given vulkan launch call.
191   auto spirvBlobAttr =
192       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
193   if (!spirvBlobAttr) {
194     vulkanLaunchCallOp.emitError()
195         << "missing " << kSPIRVBlobAttrName << " attribute";
196     return signalPassFailure();
197   }
198 
199   auto spirvEntryPointNameAttr =
200       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
201   if (!spirvEntryPointNameAttr) {
202     vulkanLaunchCallOp.emitError()
203         << "missing " << kSPIRVEntryPointAttrName << " attribute";
204     return signalPassFailure();
205   }
206 
207   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
208 }
209 
210 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
211     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
212   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
213       gpu::LaunchOp::kNumConfigOperands)
214     return;
215   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
216   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
217 
218   // Create LLVM constant for the descriptor set index.
219   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
220   // pass does.
221   Value descriptorSet = builder.create<LLVM::ConstantOp>(
222       loc, getInt32Type(), builder.getI32IntegerAttr(0));
223 
224   for (auto en :
225        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
226            gpu::LaunchOp::kNumConfigOperands))) {
227     // Create LLVM constant for the descriptor binding index.
228     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
229         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
230 
231     auto ptrToMemRefDescriptor = en.value();
232     uint32_t rank = 0;
233     if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
234       cInterfaceVulkanLaunchCallOp.emitError()
235           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
236       return signalPassFailure();
237     }
238 
239     auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str();
240     // Create call to `bindMemRef`.
241     builder.create<LLVM::CallOp>(
242         loc, ArrayRef<Type>{getVoidType()},
243         builder.getSymbolRefAttr(
244             StringRef(symbolName.data(), symbolName.size())),
245         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
246                         ptrToMemRefDescriptor});
247   }
248 }
249 
250 LogicalResult
251 VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
252                                                     uint32_t &rank) {
253   auto llvmPtrDescriptorTy =
254       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
255   if (!llvmPtrDescriptorTy)
256     return failure();
257 
258   auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
259   // template <typename Elem, size_t Rank>
260   // struct {
261   //   Elem *allocated;
262   //   Elem *aligned;
263   //   int64_t offset;
264   //   int64_t sizes[Rank]; // omitted when rank == 0
265   //   int64_t strides[Rank]; // omitted when rank == 0
266   // };
267   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
268     return failure();
269   if (llvmDescriptorTy.getStructNumElements() == 3) {
270     rank = 0;
271     return success();
272   }
273 
274   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
275   return success();
276 }
277 
278 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
279   ModuleOp module = getModule();
280   OpBuilder builder(module.getBody()->getTerminator());
281 
282   if (!module.lookupSymbol(kSetEntryPoint)) {
283     builder.create<LLVM::LLVMFuncOp>(
284         loc, kSetEntryPoint,
285         LLVM::LLVMType::getFunctionTy(getVoidType(),
286                                       {getPointerType(), getPointerType()},
287                                       /*isVarArg=*/false));
288   }
289 
290   if (!module.lookupSymbol(kSetNumWorkGroups)) {
291     builder.create<LLVM::LLVMFuncOp>(
292         loc, kSetNumWorkGroups,
293         LLVM::LLVMType::getFunctionTy(
294             getVoidType(),
295             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
296             /*isVarArg=*/false));
297   }
298 
299   if (!module.lookupSymbol(kSetBinaryShader)) {
300     builder.create<LLVM::LLVMFuncOp>(
301         loc, kSetBinaryShader,
302         LLVM::LLVMType::getFunctionTy(
303             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
304             /*isVarArg=*/false));
305   }
306 
307   if (!module.lookupSymbol(kRunOnVulkan)) {
308     builder.create<LLVM::LLVMFuncOp>(
309         loc, kRunOnVulkan,
310         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
311                                       /*isVarArg=*/false));
312   }
313 
314   if (!module.lookupSymbol(kBindMemRef1DFloat)) {
315     builder.create<LLVM::LLVMFuncOp>(
316         loc, kBindMemRef1DFloat,
317         LLVM::LLVMType::getFunctionTy(getVoidType(),
318                                       {getPointerType(), getInt32Type(),
319                                        getInt32Type(),
320                                        getMemRef1DFloat().getPointerTo()},
321                                       /*isVarArg=*/false));
322   }
323 
324   if (!module.lookupSymbol(kBindMemRef2DFloat)) {
325     builder.create<LLVM::LLVMFuncOp>(
326         loc, kBindMemRef2DFloat,
327         LLVM::LLVMType::getFunctionTy(getVoidType(),
328                                       {getPointerType(), getInt32Type(),
329                                        getInt32Type(),
330                                        getMemRef2DFloat().getPointerTo()},
331                                       /*isVarArg=*/false));
332   }
333 
334   if (!module.lookupSymbol(kInitVulkan)) {
335     builder.create<LLVM::LLVMFuncOp>(
336         loc, kInitVulkan,
337         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
338                                       /*isVarArg=*/false));
339   }
340 
341   if (!module.lookupSymbol(kDeinitVulkan)) {
342     builder.create<LLVM::LLVMFuncOp>(
343         loc, kDeinitVulkan,
344         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
345                                       /*isVarArg=*/false));
346   }
347 }
348 
349 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
350     StringRef name, Location loc, OpBuilder &builder) {
351   SmallString<16> shaderName(name.begin(), name.end());
352   // Append `\0` to follow C style string given that LLVM::createGlobalString()
353   // won't handle this directly for us.
354   shaderName.push_back('\0');
355 
356   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
357   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
358                                   shaderName, LLVM::Linkage::Internal,
359                                   getLLVMDialect());
360 }
361 
362 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
363     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
364   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
365   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
366   // Create call to `initVulkan`.
367   auto initVulkanCall = builder.create<LLVM::CallOp>(
368       loc, ArrayRef<Type>{getPointerType()},
369       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
370   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
371   // need to pass that pointer to each Vulkan runtime call.
372   auto vulkanRuntime = initVulkanCall.getResult(0);
373 
374   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
375   // that data to runtime call.
376   Value ptrToSPIRVBinary = LLVM::createGlobalString(
377       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
378       LLVM::Linkage::Internal, getLLVMDialect());
379 
380   // Create LLVM constant for the size of SPIR-V binary shader.
381   Value binarySize = builder.create<LLVM::ConstantOp>(
382       loc, getInt32Type(),
383       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
384 
385   // Create call to `bindMemRef` for each memref operand.
386   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
387 
388   // Create call to `setBinaryShader` runtime function with the given pointer to
389   // SPIR-V binary and binary size.
390   builder.create<LLVM::CallOp>(
391       loc, ArrayRef<Type>{getVoidType()},
392       builder.getSymbolRefAttr(kSetBinaryShader),
393       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
394   // Create LLVM global with entry point name.
395   Value entryPointName = createEntryPointNameConstant(
396       spirvAttributes.second.getValue(), loc, builder);
397   // Create call to `setEntryPoint` runtime function with the given pointer to
398   // entry point name.
399   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
400                                builder.getSymbolRefAttr(kSetEntryPoint),
401                                ArrayRef<Value>{vulkanRuntime, entryPointName});
402 
403   // Create number of local workgroup for each dimension.
404   builder.create<LLVM::CallOp>(
405       loc, ArrayRef<Type>{getVoidType()},
406       builder.getSymbolRefAttr(kSetNumWorkGroups),
407       ArrayRef<Value>{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
408                       cInterfaceVulkanLaunchCallOp.getOperand(1),
409                       cInterfaceVulkanLaunchCallOp.getOperand(2)});
410 
411   // Create call to `runOnVulkan` runtime function.
412   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
413                                builder.getSymbolRefAttr(kRunOnVulkan),
414                                ArrayRef<Value>{vulkanRuntime});
415 
416   // Create call to 'deinitVulkan' runtime function.
417   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
418                                builder.getSymbolRefAttr(kDeinitVulkan),
419                                ArrayRef<Value>{vulkanRuntime});
420 
421   // Declare runtime functions.
422   declareVulkanFunctions(loc);
423 
424   cInterfaceVulkanLaunchCallOp.erase();
425 }
426 
427 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
428 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
429   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
430 }
431 
432 static PassRegistration<VulkanLaunchFuncToVulkanCallsPass>
433     pass("launch-func-to-vulkan",
434          "Convert vulkanLaunch external call to Vulkan runtime external calls");
435