1 //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert vulkan launch call into a sequence of
10 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11 // don't expose separate external functions in IR for each of them, instead we
12 // expose a few external functions to wrapper libraries which manages Vulkan
13 // runtime.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "../PassDetail.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Support/FormatVariadic.h"
26
27 using namespace mlir;
28
29 static constexpr const char *kCInterfaceVulkanLaunch =
30 "_mlir_ciface_vulkanLaunch";
31 static constexpr const char *kDeinitVulkan = "deinitVulkan";
32 static constexpr const char *kRunOnVulkan = "runOnVulkan";
33 static constexpr const char *kInitVulkan = "initVulkan";
34 static constexpr const char *kSetBinaryShader = "setBinaryShader";
35 static constexpr const char *kSetEntryPoint = "setEntryPoint";
36 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
37 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
38 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
39 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
40 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
41
42 namespace {
43
44 /// A pass to convert vulkan launch call op into a sequence of Vulkan
45 /// runtime calls in the following order:
46 ///
47 /// * initVulkan -- initializes vulkan runtime
48 /// * bindMemRef -- binds memref
49 /// * setBinaryShader -- sets the binary shader data
50 /// * setEntryPoint -- sets the entry point name
51 /// * setNumWorkGroups -- sets the number of a local workgroups
52 /// * runOnVulkan -- runs vulkan runtime
53 /// * deinitVulkan -- deinitializes vulkan runtime
54 ///
55 class VulkanLaunchFuncToVulkanCallsPass
56 : public ConvertVulkanLaunchFuncToVulkanCallsBase<
57 VulkanLaunchFuncToVulkanCallsPass> {
58 private:
initializeCachedTypes()59 void initializeCachedTypes() {
60 llvmFloatType = Float32Type::get(&getContext());
61 llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
62 llvmPointerType =
63 LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
64 llvmInt32Type = IntegerType::get(&getContext(), 32);
65 llvmInt64Type = IntegerType::get(&getContext(), 64);
66 }
67
getMemRefType(uint32_t rank,Type elemenType)68 Type getMemRefType(uint32_t rank, Type elemenType) {
69 // According to the MLIR doc memref argument is converted into a
70 // pointer-to-struct argument of type:
71 // template <typename Elem, size_t Rank>
72 // struct {
73 // Elem *allocated;
74 // Elem *aligned;
75 // int64_t offset;
76 // int64_t sizes[Rank]; // omitted when rank == 0
77 // int64_t strides[Rank]; // omitted when rank == 0
78 // };
79 auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
80 auto llvmArrayRankElementSizeType =
81 LLVM::LLVMArrayType::get(getInt64Type(), rank);
82
83 // Create a type
84 // `!llvm<"{ `element-type`*, `element-type`*, i64,
85 // [`rank` x i64], [`rank` x i64]}">`.
86 return LLVM::LLVMStructType::getLiteral(
87 &getContext(),
88 {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
89 llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
90 }
91
getVoidType()92 Type getVoidType() { return llvmVoidType; }
getPointerType()93 Type getPointerType() { return llvmPointerType; }
getInt32Type()94 Type getInt32Type() { return llvmInt32Type; }
getInt64Type()95 Type getInt64Type() { return llvmInt64Type; }
96
97 /// Creates an LLVM global for the given `name`.
98 Value createEntryPointNameConstant(StringRef name, Location loc,
99 OpBuilder &builder);
100
101 /// Declares all needed runtime functions.
102 void declareVulkanFunctions(Location loc);
103
104 /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
isVulkanLaunchCallOp(LLVM::CallOp callOp)105 bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
106 return (callOp.getCallee() && *callOp.getCallee() == kVulkanLaunch &&
107 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
108 }
109
110 /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
111 /// op.
isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp)112 bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
113 return (callOp.getCallee() &&
114 *callOp.getCallee() == kCInterfaceVulkanLaunch &&
115 callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
116 }
117
118 /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
119 /// runtime calls.
120 void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
121
122 /// Creates call to `bindMemRef` for each memref operand.
123 void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
124 Value vulkanRuntime);
125
126 /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
127 void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
128
129 /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
130 LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
131 uint32_t &rank, Type &type);
132
133 /// Returns a string representation from the given `type`.
stringifyType(Type type)134 StringRef stringifyType(Type type) {
135 if (type.isa<Float32Type>())
136 return "Float";
137 if (type.isa<Float16Type>())
138 return "Half";
139 if (auto intType = type.dyn_cast<IntegerType>()) {
140 if (intType.getWidth() == 32)
141 return "Int32";
142 if (intType.getWidth() == 16)
143 return "Int16";
144 if (intType.getWidth() == 8)
145 return "Int8";
146 }
147
148 llvm_unreachable("unsupported type");
149 }
150
151 public:
152 void runOnOperation() override;
153
154 private:
155 Type llvmFloatType;
156 Type llvmVoidType;
157 Type llvmPointerType;
158 Type llvmInt32Type;
159 Type llvmInt64Type;
160
161 // TODO: Use an associative array to support multiple vulkan launch calls.
162 std::pair<StringAttr, StringAttr> spirvAttributes;
163 /// The number of vulkan launch configuration operands, placed at the leading
164 /// positions of the operand list.
165 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
166 };
167
168 } // namespace
169
runOnOperation()170 void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
171 initializeCachedTypes();
172
173 // Collect SPIR-V attributes such as `spirv_blob` and
174 // `spirv_entry_point_name`.
175 getOperation().walk([this](LLVM::CallOp op) {
176 if (isVulkanLaunchCallOp(op))
177 collectSPIRVAttributes(op);
178 });
179
180 // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
181 getOperation().walk([this](LLVM::CallOp op) {
182 if (isCInterfaceVulkanLaunchCallOp(op))
183 translateVulkanLaunchCall(op);
184 });
185 }
186
collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp)187 void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
188 LLVM::CallOp vulkanLaunchCallOp) {
189 // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
190 // for the given vulkan launch call.
191 auto spirvBlobAttr =
192 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
193 if (!spirvBlobAttr) {
194 vulkanLaunchCallOp.emitError()
195 << "missing " << kSPIRVBlobAttrName << " attribute";
196 return signalPassFailure();
197 }
198
199 auto spirvEntryPointNameAttr =
200 vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
201 if (!spirvEntryPointNameAttr) {
202 vulkanLaunchCallOp.emitError()
203 << "missing " << kSPIRVEntryPointAttrName << " attribute";
204 return signalPassFailure();
205 }
206
207 spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
208 }
209
createBindMemRefCalls(LLVM::CallOp cInterfaceVulkanLaunchCallOp,Value vulkanRuntime)210 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
211 LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
212 if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
213 kVulkanLaunchNumConfigOperands)
214 return;
215 OpBuilder builder(cInterfaceVulkanLaunchCallOp);
216 Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
217
218 // Create LLVM constant for the descriptor set index.
219 // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
220 // pass does.
221 Value descriptorSet = builder.create<LLVM::ConstantOp>(
222 loc, getInt32Type(), builder.getI32IntegerAttr(0));
223
224 for (const auto &en :
225 llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
226 kVulkanLaunchNumConfigOperands))) {
227 // Create LLVM constant for the descriptor binding index.
228 Value descriptorBinding = builder.create<LLVM::ConstantOp>(
229 loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
230
231 auto ptrToMemRefDescriptor = en.value();
232 uint32_t rank = 0;
233 Type type;
234 if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
235 cInterfaceVulkanLaunchCallOp.emitError()
236 << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
237 return signalPassFailure();
238 }
239
240 auto symbolName =
241 llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
242 // Special case for fp16 type. Since it is not a supported type in C we use
243 // int16_t and bitcast the descriptor.
244 if (type.isa<Float16Type>()) {
245 auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
246 ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
247 loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
248 }
249 // Create call to `bindMemRef`.
250 builder.create<LLVM::CallOp>(
251 loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()),
252 ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
253 ptrToMemRefDescriptor});
254 }
255 }
256
deduceMemRefRankAndType(Value ptrToMemRefDescriptor,uint32_t & rank,Type & type)257 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
258 Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
259 auto llvmPtrDescriptorTy =
260 ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
261 if (!llvmPtrDescriptorTy)
262 return failure();
263
264 auto llvmDescriptorTy =
265 llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
266 // template <typename Elem, size_t Rank>
267 // struct {
268 // Elem *allocated;
269 // Elem *aligned;
270 // int64_t offset;
271 // int64_t sizes[Rank]; // omitted when rank == 0
272 // int64_t strides[Rank]; // omitted when rank == 0
273 // };
274 if (!llvmDescriptorTy)
275 return failure();
276
277 type = llvmDescriptorTy.getBody()[0]
278 .cast<LLVM::LLVMPointerType>()
279 .getElementType();
280 if (llvmDescriptorTy.getBody().size() == 3) {
281 rank = 0;
282 return success();
283 }
284 rank = llvmDescriptorTy.getBody()[3]
285 .cast<LLVM::LLVMArrayType>()
286 .getNumElements();
287 return success();
288 }
289
declareVulkanFunctions(Location loc)290 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
291 ModuleOp module = getOperation();
292 auto builder = OpBuilder::atBlockEnd(module.getBody());
293
294 if (!module.lookupSymbol(kSetEntryPoint)) {
295 builder.create<LLVM::LLVMFuncOp>(
296 loc, kSetEntryPoint,
297 LLVM::LLVMFunctionType::get(getVoidType(),
298 {getPointerType(), getPointerType()}));
299 }
300
301 if (!module.lookupSymbol(kSetNumWorkGroups)) {
302 builder.create<LLVM::LLVMFuncOp>(
303 loc, kSetNumWorkGroups,
304 LLVM::LLVMFunctionType::get(getVoidType(),
305 {getPointerType(), getInt64Type(),
306 getInt64Type(), getInt64Type()}));
307 }
308
309 if (!module.lookupSymbol(kSetBinaryShader)) {
310 builder.create<LLVM::LLVMFuncOp>(
311 loc, kSetBinaryShader,
312 LLVM::LLVMFunctionType::get(
313 getVoidType(),
314 {getPointerType(), getPointerType(), getInt32Type()}));
315 }
316
317 if (!module.lookupSymbol(kRunOnVulkan)) {
318 builder.create<LLVM::LLVMFuncOp>(
319 loc, kRunOnVulkan,
320 LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
321 }
322
323 for (unsigned i = 1; i <= 3; i++) {
324 SmallVector<Type, 5> types{
325 Float32Type::get(&getContext()), IntegerType::get(&getContext(), 32),
326 IntegerType::get(&getContext(), 16), IntegerType::get(&getContext(), 8),
327 Float16Type::get(&getContext())};
328 for (auto type : types) {
329 std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
330 std::string(stringifyType(type));
331 if (type.isa<Float16Type>())
332 type = IntegerType::get(&getContext(), 16);
333 if (!module.lookupSymbol(fnName)) {
334 auto fnType = LLVM::LLVMFunctionType::get(
335 getVoidType(),
336 {getPointerType(), getInt32Type(), getInt32Type(),
337 LLVM::LLVMPointerType::get(getMemRefType(i, type))},
338 /*isVarArg=*/false);
339 builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
340 }
341 }
342 }
343
344 if (!module.lookupSymbol(kInitVulkan)) {
345 builder.create<LLVM::LLVMFuncOp>(
346 loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {}));
347 }
348
349 if (!module.lookupSymbol(kDeinitVulkan)) {
350 builder.create<LLVM::LLVMFuncOp>(
351 loc, kDeinitVulkan,
352 LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
353 }
354 }
355
createEntryPointNameConstant(StringRef name,Location loc,OpBuilder & builder)356 Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
357 StringRef name, Location loc, OpBuilder &builder) {
358 SmallString<16> shaderName(name.begin(), name.end());
359 // Append `\0` to follow C style string given that LLVM::createGlobalString()
360 // won't handle this directly for us.
361 shaderName.push_back('\0');
362
363 std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
364 return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
365 shaderName, LLVM::Linkage::Internal);
366 }
367
translateVulkanLaunchCall(LLVM::CallOp cInterfaceVulkanLaunchCallOp)368 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
369 LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
370 OpBuilder builder(cInterfaceVulkanLaunchCallOp);
371 Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
372 // Create call to `initVulkan`.
373 auto initVulkanCall = builder.create<LLVM::CallOp>(
374 loc, TypeRange{getPointerType()}, kInitVulkan);
375 // The result of `initVulkan` function is a pointer to Vulkan runtime, we
376 // need to pass that pointer to each Vulkan runtime call.
377 auto vulkanRuntime = initVulkanCall.getResult(0);
378
379 // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
380 // that data to runtime call.
381 Value ptrToSPIRVBinary = LLVM::createGlobalString(
382 loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
383 LLVM::Linkage::Internal);
384
385 // Create LLVM constant for the size of SPIR-V binary shader.
386 Value binarySize = builder.create<LLVM::ConstantOp>(
387 loc, getInt32Type(),
388 builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
389
390 // Create call to `bindMemRef` for each memref operand.
391 createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
392
393 // Create call to `setBinaryShader` runtime function with the given pointer to
394 // SPIR-V binary and binary size.
395 builder.create<LLVM::CallOp>(
396 loc, TypeRange(), kSetBinaryShader,
397 ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
398 // Create LLVM global with entry point name.
399 Value entryPointName = createEntryPointNameConstant(
400 spirvAttributes.second.getValue(), loc, builder);
401 // Create call to `setEntryPoint` runtime function with the given pointer to
402 // entry point name.
403 builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
404 ValueRange{vulkanRuntime, entryPointName});
405
406 // Create number of local workgroup for each dimension.
407 builder.create<LLVM::CallOp>(
408 loc, TypeRange(), kSetNumWorkGroups,
409 ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
410 cInterfaceVulkanLaunchCallOp.getOperand(1),
411 cInterfaceVulkanLaunchCallOp.getOperand(2)});
412
413 // Create call to `runOnVulkan` runtime function.
414 builder.create<LLVM::CallOp>(loc, TypeRange(), kRunOnVulkan,
415 ValueRange{vulkanRuntime});
416
417 // Create call to 'deinitVulkan' runtime function.
418 builder.create<LLVM::CallOp>(loc, TypeRange(), kDeinitVulkan,
419 ValueRange{vulkanRuntime});
420
421 // Declare runtime functions.
422 declareVulkanFunctions(loc);
423
424 cInterfaceVulkanLaunchCallOp.erase();
425 }
426
427 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createConvertVulkanLaunchFuncToVulkanCallsPass()428 mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
429 return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
430 }
431