1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Function.h"
22 #include "mlir/IR/Module.h"
23 #include "mlir/Pass/Pass.h"
24 
25 #include "llvm/ADT/SmallString.h"
26 
27 using namespace mlir;
28 
29 static constexpr const char *kBindResource = "bindResource";
30 static constexpr const char *kDeinitVulkan = "deinitVulkan";
31 static constexpr const char *kRunOnVulkan = "runOnVulkan";
32 static constexpr const char *kInitVulkan = "initVulkan";
33 static constexpr const char *kSetBinaryShader = "setBinaryShader";
34 static constexpr const char *kSetEntryPoint = "setEntryPoint";
35 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
36 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
37 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
38 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
39 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
40 
41 namespace {
42 
43 /// A pass to convert vulkan launch func into a sequence of Vulkan
44 /// runtime calls in the following order:
45 ///
46 /// * initVulkan           -- initializes vulkan runtime
47 /// * bindResource         -- binds resource
48 /// * setBinaryShader      -- sets the binary shader data
49 /// * setEntryPoint        -- sets the entry point name
50 /// * setNumWorkGroups     -- sets the number of a local workgroups
51 /// * runOnVulkan          -- runs vulkan runtime
52 /// * deinitVulkan         -- deinitializes vulkan runtime
53 ///
54 class VulkanLaunchFuncToVulkanCallsPass
55     : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
56 private:
57   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
58 
59   llvm::LLVMContext &getLLVMContext() {
60     return getLLVMDialect()->getLLVMContext();
61   }
62 
63   void initializeCachedTypes() {
64     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
65     llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
66     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
67     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
68     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
69     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
70   }
71 
72   LLVM::LLVMType getFloatType() { return llvmFloatType; }
73   LLVM::LLVMType getVoidType() { return llvmVoidType; }
74   LLVM::LLVMType getPointerType() { return llvmPointerType; }
75   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
76   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
77 
78   /// Creates a LLVM global for the given `name`.
79   Value createEntryPointNameConstant(StringRef name, Location loc,
80                                      OpBuilder &builder);
81 
82   /// Declares all needed runtime functions.
83   void declareVulkanFunctions(Location loc);
84 
85   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
86   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
87     return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
88             callOp.getNumOperands() >= 6);
89   }
90 
91   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
92   /// runtime calls.
93   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
94 
95   /// Creates call to `bindResource` for each resource operand.
96   void createBindResourceCalls(LLVM::CallOp vulkanLaunchCallOp,
97                                Value vulkanRuntiem);
98 
99 public:
100   void runOnModule() override;
101 
102 private:
103   LLVM::LLVMDialect *llvmDialect;
104   LLVM::LLVMType llvmFloatType;
105   LLVM::LLVMType llvmVoidType;
106   LLVM::LLVMType llvmPointerType;
107   LLVM::LLVMType llvmInt32Type;
108   LLVM::LLVMType llvmInt64Type;
109 };
110 
111 /// Represents operand adaptor for vulkan launch call operation, to simplify an
112 /// access to the lowered memref.
113 // TODO: We should use 'emit-c-wrappers' option to lower memref type:
114 // https://mlir.llvm.org/docs/ConversionToLLVMDialect/#c-compatible-wrapper-emission.
115 struct VulkanLaunchOpOperandAdaptor {
116   VulkanLaunchOpOperandAdaptor(ArrayRef<Value> values) { operands = values; }
117   VulkanLaunchOpOperandAdaptor(const VulkanLaunchOpOperandAdaptor &) = delete;
118   VulkanLaunchOpOperandAdaptor
119   operator=(const VulkanLaunchOpOperandAdaptor &) = delete;
120 
121   /// Returns a tuple with a pointer to the memory and the size for the index-th
122   /// resource.
123   std::tuple<Value, Value> getResourceDescriptor1D(uint32_t index) {
124     assert(index < getResourceCount1D());
125     // 1D memref calling convention according to "ConversionToLLVMDialect.md":
126     // 0. Allocated pointer.
127     // 1. Aligned pointer.
128     // 2. Offset.
129     // 3. Size in dim 0.
130     // 4. Stride in dim 0.
131     auto offset = numConfigOps + index * loweredMemRefNumOps1D;
132     return std::make_tuple(operands[offset], operands[offset + 3]);
133   }
134 
135   /// Returns the number of resources assuming all operands lowered from
136   /// 1D memref.
137   uint32_t getResourceCount1D() {
138     return (operands.size() - numConfigOps) / loweredMemRefNumOps1D;
139   }
140 
141 private:
142   /// The number of operands of lowered 1D memref.
143   static constexpr const uint32_t loweredMemRefNumOps1D = 5;
144   /// The number of the first config operands.
145   static constexpr const uint32_t numConfigOps = 6;
146   ArrayRef<Value> operands;
147 };
148 
149 } // anonymous namespace
150 
151 void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
152   initializeCachedTypes();
153   getModule().walk([this](LLVM::CallOp op) {
154     if (isVulkanLaunchCallOp(op))
155       translateVulkanLaunchCall(op);
156   });
157 }
158 
159 void VulkanLaunchFuncToVulkanCallsPass::createBindResourceCalls(
160     LLVM::CallOp vulkanLaunchCallOp, Value vulkanRuntime) {
161   if (vulkanLaunchCallOp.getNumOperands() == 6)
162     return;
163   OpBuilder builder(vulkanLaunchCallOp);
164   Location loc = vulkanLaunchCallOp.getLoc();
165 
166   // Create LLVM constant for the descriptor set index.
167   // Bind all resources to the `0` descriptor set, the same way as `GPUToSPIRV`
168   // pass does.
169   Value descriptorSet = builder.create<LLVM::ConstantOp>(
170       loc, getInt32Type(), builder.getI32IntegerAttr(0));
171 
172   auto operands = SmallVector<Value, 32>{vulkanLaunchCallOp.getOperands()};
173   VulkanLaunchOpOperandAdaptor vkLaunchOperandAdaptor(operands);
174 
175   for (auto resourceIdx :
176        llvm::seq<uint32_t>(0, vkLaunchOperandAdaptor.getResourceCount1D())) {
177     // Create LLVM constant for the descriptor binding index.
178     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
179         loc, getInt32Type(), builder.getI32IntegerAttr(resourceIdx));
180     // Get a pointer to the memory and size of that memory.
181     auto resourceDescriptor =
182         vkLaunchOperandAdaptor.getResourceDescriptor1D(resourceIdx);
183     // Create call to `bindResource`.
184     builder.create<LLVM::CallOp>(
185         loc, ArrayRef<Type>{getVoidType()},
186         builder.getSymbolRefAttr(kBindResource),
187         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
188                         // Pointer to the memory.
189                         std::get<0>(resourceDescriptor),
190                         // Size of the memory.
191                         std::get<1>(resourceDescriptor)});
192   }
193 }
194 
195 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
196   ModuleOp module = getModule();
197   OpBuilder builder(module.getBody()->getTerminator());
198 
199   if (!module.lookupSymbol(kSetEntryPoint)) {
200     builder.create<LLVM::LLVMFuncOp>(
201         loc, kSetEntryPoint,
202         LLVM::LLVMType::getFunctionTy(getVoidType(),
203                                       {getPointerType(), getPointerType()},
204                                       /*isVarArg=*/false));
205   }
206 
207   if (!module.lookupSymbol(kSetNumWorkGroups)) {
208     builder.create<LLVM::LLVMFuncOp>(
209         loc, kSetNumWorkGroups,
210         LLVM::LLVMType::getFunctionTy(
211             getVoidType(),
212             {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
213             /*isVarArg=*/false));
214   }
215 
216   if (!module.lookupSymbol(kSetBinaryShader)) {
217     builder.create<LLVM::LLVMFuncOp>(
218         loc, kSetBinaryShader,
219         LLVM::LLVMType::getFunctionTy(
220             getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
221             /*isVarArg=*/false));
222   }
223 
224   if (!module.lookupSymbol(kRunOnVulkan)) {
225     builder.create<LLVM::LLVMFuncOp>(
226         loc, kRunOnVulkan,
227         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
228                                       /*isVarArg=*/false));
229   }
230 
231   if (!module.lookupSymbol(kBindResource)) {
232     builder.create<LLVM::LLVMFuncOp>(
233         loc, kBindResource,
234         LLVM::LLVMType::getFunctionTy(
235             getVoidType(),
236             {getPointerType(), getInt32Type(), getInt32Type(),
237              getFloatType().getPointerTo(), getInt64Type()},
238             /*isVarArg=*/false));
239   }
240 
241   if (!module.lookupSymbol(kInitVulkan)) {
242     builder.create<LLVM::LLVMFuncOp>(
243         loc, kInitVulkan,
244         LLVM::LLVMType::getFunctionTy(getPointerType(), {},
245                                       /*isVarArg=*/false));
246   }
247 
248   if (!module.lookupSymbol(kDeinitVulkan)) {
249     builder.create<LLVM::LLVMFuncOp>(
250         loc, kDeinitVulkan,
251         LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
252                                       /*isVarArg=*/false));
253   }
254 }
255 
256 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
257     StringRef name, Location loc, OpBuilder &builder) {
258   SmallString<16> shaderName(name.begin(), name.end());
259   // Append `\0` to follow C style string given that LLVM::createGlobalString()
260   // won't handle this directly for us.
261   shaderName.push_back('\0');
262 
263   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
264   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
265                                   shaderName, LLVM::Linkage::Internal,
266                                   getLLVMDialect());
267 }
268 
269 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
270     LLVM::CallOp vulkanLaunchCallOp) {
271   OpBuilder builder(vulkanLaunchCallOp);
272   Location loc = vulkanLaunchCallOp.getLoc();
273 
274   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
275   // for the given vulkan launch call.
276   auto spirvBlobAttr =
277       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
278   if (!spirvBlobAttr) {
279     vulkanLaunchCallOp.emitError()
280         << "missing " << kSPIRVBlobAttrName << " attribute";
281     return signalPassFailure();
282   }
283 
284   auto entryPointNameAttr =
285       vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
286   if (!entryPointNameAttr) {
287     vulkanLaunchCallOp.emitError()
288         << "missing " << kSPIRVEntryPointAttrName << " attribute";
289     return signalPassFailure();
290   }
291 
292   // Create call to `initVulkan`.
293   auto initVulkanCall = builder.create<LLVM::CallOp>(
294       loc, ArrayRef<Type>{getPointerType()},
295       builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
296   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
297   // need to pass that pointer to each Vulkan runtime call.
298   auto vulkanRuntime = initVulkanCall.getResult(0);
299 
300   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
301   // that data to runtime call.
302   Value ptrToSPIRVBinary = LLVM::createGlobalString(
303       loc, builder, kSPIRVBinary, spirvBlobAttr.getValue(),
304       LLVM::Linkage::Internal, getLLVMDialect());
305 
306   // Create LLVM constant for the size of SPIR-V binary shader.
307   Value binarySize = builder.create<LLVM::ConstantOp>(
308       loc, getInt32Type(),
309       builder.getI32IntegerAttr(spirvBlobAttr.getValue().size()));
310 
311   // Create call to `bindResource` for each resource operand.
312   createBindResourceCalls(vulkanLaunchCallOp, vulkanRuntime);
313 
314   // Create call to `setBinaryShader` runtime function with the given pointer to
315   // SPIR-V binary and binary size.
316   builder.create<LLVM::CallOp>(
317       loc, ArrayRef<Type>{getVoidType()},
318       builder.getSymbolRefAttr(kSetBinaryShader),
319       ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
320   // Create LLVM global with entry point name.
321   Value entryPointName =
322       createEntryPointNameConstant(entryPointNameAttr.getValue(), loc, builder);
323   // Create call to `setEntryPoint` runtime function with the given pointer to
324   // entry point name.
325   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
326                                builder.getSymbolRefAttr(kSetEntryPoint),
327                                ArrayRef<Value>{vulkanRuntime, entryPointName});
328 
329   // Create number of local workgroup for each dimension.
330   builder.create<LLVM::CallOp>(
331       loc, ArrayRef<Type>{getVoidType()},
332       builder.getSymbolRefAttr(kSetNumWorkGroups),
333       ArrayRef<Value>{vulkanRuntime, vulkanLaunchCallOp.getOperand(0),
334                       vulkanLaunchCallOp.getOperand(1),
335                       vulkanLaunchCallOp.getOperand(2)});
336 
337   // Create call to `runOnVulkan` runtime function.
338   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
339                                builder.getSymbolRefAttr(kRunOnVulkan),
340                                ArrayRef<Value>{vulkanRuntime});
341 
342   // Create call to 'deinitVulkan' runtime function.
343   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
344                                builder.getSymbolRefAttr(kDeinitVulkan),
345                                ArrayRef<Value>{vulkanRuntime});
346 
347   // Declare runtime functions.
348   declareVulkanFunctions(loc);
349 
350   vulkanLaunchCallOp.erase();
351 }
352 
353 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
354 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
355   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
356 }
357 
358 static PassRegistration<VulkanLaunchFuncToVulkanCallsPass>
359     pass("launch-func-to-vulkan",
360          "Convert vulkanLaunch external call to Vulkan runtime external calls");
361