1a062a3edSDenis Khalikov //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2a062a3edSDenis Khalikov //
3a062a3edSDenis Khalikov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a062a3edSDenis Khalikov // See https://llvm.org/LICENSE.txt for license information.
5a062a3edSDenis Khalikov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a062a3edSDenis Khalikov //
7a062a3edSDenis Khalikov //===----------------------------------------------------------------------===//
8a062a3edSDenis Khalikov //
91090a830SDenis Khalikov // This file implements a pass to convert vulkan launch call into a sequence of
10a062a3edSDenis Khalikov // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11a062a3edSDenis Khalikov // don't expose separate external functions in IR for each of them, instead we
12a062a3edSDenis Khalikov // expose a few external functions to wrapper libraries which manages Vulkan
13a062a3edSDenis Khalikov // runtime.
14a062a3edSDenis Khalikov //
15a062a3edSDenis Khalikov //===----------------------------------------------------------------------===//
16a062a3edSDenis Khalikov
171834ad4aSRiver Riddle #include "../PassDetail.h"
18a062a3edSDenis Khalikov #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19a062a3edSDenis Khalikov #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20a062a3edSDenis Khalikov #include "mlir/IR/Attributes.h"
21a062a3edSDenis Khalikov #include "mlir/IR/Builders.h"
2265fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
23a062a3edSDenis Khalikov
24896ee361SDenis Khalikov #include "llvm/ADT/SmallString.h"
258f4ab8c7SDenis Khalikov #include "llvm/Support/FormatVariadic.h"
26a062a3edSDenis Khalikov
27a062a3edSDenis Khalikov using namespace mlir;
28a062a3edSDenis Khalikov
29bfb2ce02SDenis Khalikov static constexpr const char *kCInterfaceVulkanLaunch =
30bfb2ce02SDenis Khalikov "_mlir_ciface_vulkanLaunch";
311090a830SDenis Khalikov static constexpr const char *kDeinitVulkan = "deinitVulkan";
321090a830SDenis Khalikov static constexpr const char *kRunOnVulkan = "runOnVulkan";
331090a830SDenis Khalikov static constexpr const char *kInitVulkan = "initVulkan";
34a062a3edSDenis Khalikov static constexpr const char *kSetBinaryShader = "setBinaryShader";
35a062a3edSDenis Khalikov static constexpr const char *kSetEntryPoint = "setEntryPoint";
36a062a3edSDenis Khalikov static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
37a062a3edSDenis Khalikov static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
381090a830SDenis Khalikov static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
391090a830SDenis Khalikov static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
401090a830SDenis Khalikov static constexpr const char *kVulkanLaunch = "vulkanLaunch";
41a062a3edSDenis Khalikov
42a062a3edSDenis Khalikov namespace {
43a062a3edSDenis Khalikov
44bfb2ce02SDenis Khalikov /// A pass to convert vulkan launch call op into a sequence of Vulkan
451090a830SDenis Khalikov /// runtime calls in the following order:
46a062a3edSDenis Khalikov ///
471090a830SDenis Khalikov /// * initVulkan -- initializes vulkan runtime
48bfb2ce02SDenis Khalikov /// * bindMemRef -- binds memref
49a062a3edSDenis Khalikov /// * setBinaryShader -- sets the binary shader data
50a062a3edSDenis Khalikov /// * setEntryPoint -- sets the entry point name
51a062a3edSDenis Khalikov /// * setNumWorkGroups -- sets the number of a local workgroups
52a062a3edSDenis Khalikov /// * runOnVulkan -- runs vulkan runtime
531090a830SDenis Khalikov /// * deinitVulkan -- deinitializes vulkan runtime
54a062a3edSDenis Khalikov ///
551090a830SDenis Khalikov class VulkanLaunchFuncToVulkanCallsPass
561834ad4aSRiver Riddle : public ConvertVulkanLaunchFuncToVulkanCallsBase<
571834ad4aSRiver Riddle VulkanLaunchFuncToVulkanCallsPass> {
58a062a3edSDenis Khalikov private:
initializeCachedTypes()59a062a3edSDenis Khalikov void initializeCachedTypes() {
60dd5165a9SAlex Zinenko llvmFloatType = Float32Type::get(&getContext());
617ed9cfc7SAlex Zinenko llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
622230bf99SAlex Zinenko llvmPointerType =
632230bf99SAlex Zinenko LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
642230bf99SAlex Zinenko llvmInt32Type = IntegerType::get(&getContext(), 32);
652230bf99SAlex Zinenko llvmInt64Type = IntegerType::get(&getContext(), 64);
66bfb2ce02SDenis Khalikov }
67bfb2ce02SDenis Khalikov
getMemRefType(uint32_t rank,Type elemenType)68c69c9e0fSAlex Zinenko Type getMemRefType(uint32_t rank, Type elemenType) {
69bfb2ce02SDenis Khalikov // According to the MLIR doc memref argument is converted into a
70bfb2ce02SDenis Khalikov // pointer-to-struct argument of type:
71bfb2ce02SDenis Khalikov // template <typename Elem, size_t Rank>
72bfb2ce02SDenis Khalikov // struct {
73bfb2ce02SDenis Khalikov // Elem *allocated;
74bfb2ce02SDenis Khalikov // Elem *aligned;
75bfb2ce02SDenis Khalikov // int64_t offset;
76bfb2ce02SDenis Khalikov // int64_t sizes[Rank]; // omitted when rank == 0
77bfb2ce02SDenis Khalikov // int64_t strides[Rank]; // omitted when rank == 0
78bfb2ce02SDenis Khalikov // };
798de43b92SAlex Zinenko auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
800718e3aeSDenis Khalikov auto llvmArrayRankElementSizeType =
817ed9cfc7SAlex Zinenko LLVM::LLVMArrayType::get(getInt64Type(), rank);
82bfb2ce02SDenis Khalikov
830718e3aeSDenis Khalikov // Create a type
841009177dSDenis Khalikov // `!llvm<"{ `element-type`*, `element-type`*, i64,
851009177dSDenis Khalikov // [`rank` x i64], [`rank` x i64]}">`.
867ed9cfc7SAlex Zinenko return LLVM::LLVMStructType::getLiteral(
875446ec85SAlex Zinenko &getContext(),
881009177dSDenis Khalikov {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
890718e3aeSDenis Khalikov llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
90a062a3edSDenis Khalikov }
91a062a3edSDenis Khalikov
getVoidType()92c69c9e0fSAlex Zinenko Type getVoidType() { return llvmVoidType; }
getPointerType()93c69c9e0fSAlex Zinenko Type getPointerType() { return llvmPointerType; }
getInt32Type()94c69c9e0fSAlex Zinenko Type getInt32Type() { return llvmInt32Type; }
getInt64Type()95c69c9e0fSAlex Zinenko Type getInt64Type() { return llvmInt64Type; }
96a062a3edSDenis Khalikov
9706b90586SKazuaki Ishizaki /// Creates an LLVM global for the given `name`.
98a062a3edSDenis Khalikov Value createEntryPointNameConstant(StringRef name, Location loc,
99a062a3edSDenis Khalikov OpBuilder &builder);
100a062a3edSDenis Khalikov
101a062a3edSDenis Khalikov /// Declares all needed runtime functions.
102a062a3edSDenis Khalikov void declareVulkanFunctions(Location loc);
103a062a3edSDenis Khalikov
1041090a830SDenis Khalikov /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
isVulkanLaunchCallOp(LLVM::CallOp callOp)1051090a830SDenis Khalikov bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
106*6d5fc1e3SKazu Hirata return (callOp.getCallee() && *callOp.getCallee() == kVulkanLaunch &&
107a48f0a3cSDenis Khalikov callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
108bfb2ce02SDenis Khalikov }
109bfb2ce02SDenis Khalikov
110bfb2ce02SDenis Khalikov /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
111bfb2ce02SDenis Khalikov /// op.
isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp)112bfb2ce02SDenis Khalikov bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
113cfb72fd3SJacques Pienaar return (callOp.getCallee() &&
114*6d5fc1e3SKazu Hirata *callOp.getCallee() == kCInterfaceVulkanLaunch &&
115a48f0a3cSDenis Khalikov callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
1161090a830SDenis Khalikov }
1171090a830SDenis Khalikov
1181090a830SDenis Khalikov /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
1191090a830SDenis Khalikov /// runtime calls.
1201090a830SDenis Khalikov void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
1211090a830SDenis Khalikov
122bfb2ce02SDenis Khalikov /// Creates call to `bindMemRef` for each memref operand.
123bfb2ce02SDenis Khalikov void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
124bfb2ce02SDenis Khalikov Value vulkanRuntime);
125bfb2ce02SDenis Khalikov
126bfb2ce02SDenis Khalikov /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
127bfb2ce02SDenis Khalikov void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
128a062a3edSDenis Khalikov
1291009177dSDenis Khalikov /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
1301009177dSDenis Khalikov LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
131c69c9e0fSAlex Zinenko uint32_t &rank, Type &type);
1321009177dSDenis Khalikov
1331009177dSDenis Khalikov /// Returns a string representation from the given `type`.
stringifyType(Type type)134c69c9e0fSAlex Zinenko StringRef stringifyType(Type type) {
135dd5165a9SAlex Zinenko if (type.isa<Float32Type>())
1361009177dSDenis Khalikov return "Float";
137dd5165a9SAlex Zinenko if (type.isa<Float16Type>())
138afd43a7aSThomas Raoux return "Half";
1392230bf99SAlex Zinenko if (auto intType = type.dyn_cast<IntegerType>()) {
1402230bf99SAlex Zinenko if (intType.getWidth() == 32)
141afd43a7aSThomas Raoux return "Int32";
1422230bf99SAlex Zinenko if (intType.getWidth() == 16)
143afd43a7aSThomas Raoux return "Int16";
1442230bf99SAlex Zinenko if (intType.getWidth() == 8)
145afd43a7aSThomas Raoux return "Int8";
1468de43b92SAlex Zinenko }
1471009177dSDenis Khalikov
1481009177dSDenis Khalikov llvm_unreachable("unsupported type");
1491009177dSDenis Khalikov }
1508f4ab8c7SDenis Khalikov
151a062a3edSDenis Khalikov public:
152722f909fSRiver Riddle void runOnOperation() override;
153a062a3edSDenis Khalikov
154a062a3edSDenis Khalikov private:
155c69c9e0fSAlex Zinenko Type llvmFloatType;
156c69c9e0fSAlex Zinenko Type llvmVoidType;
157c69c9e0fSAlex Zinenko Type llvmPointerType;
158c69c9e0fSAlex Zinenko Type llvmInt32Type;
159c69c9e0fSAlex Zinenko Type llvmInt64Type;
1601090a830SDenis Khalikov
161bfb2ce02SDenis Khalikov // TODO: Use an associative array to support multiple vulkan launch calls.
162bfb2ce02SDenis Khalikov std::pair<StringAttr, StringAttr> spirvAttributes;
163a48f0a3cSDenis Khalikov /// The number of vulkan launch configuration operands, placed at the leading
164a48f0a3cSDenis Khalikov /// positions of the operand list.
165a48f0a3cSDenis Khalikov static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
166a062a3edSDenis Khalikov };
167a062a3edSDenis Khalikov
168be0a7e9fSMehdi Amini } // namespace
169a062a3edSDenis Khalikov
runOnOperation()170722f909fSRiver Riddle void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
171a062a3edSDenis Khalikov initializeCachedTypes();
172bfb2ce02SDenis Khalikov
173bfb2ce02SDenis Khalikov // Collect SPIR-V attributes such as `spirv_blob` and
174bfb2ce02SDenis Khalikov // `spirv_entry_point_name`.
175722f909fSRiver Riddle getOperation().walk([this](LLVM::CallOp op) {
1761090a830SDenis Khalikov if (isVulkanLaunchCallOp(op))
177bfb2ce02SDenis Khalikov collectSPIRVAttributes(op);
178bfb2ce02SDenis Khalikov });
179bfb2ce02SDenis Khalikov
180bfb2ce02SDenis Khalikov // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
181722f909fSRiver Riddle getOperation().walk([this](LLVM::CallOp op) {
182bfb2ce02SDenis Khalikov if (isCInterfaceVulkanLaunchCallOp(op))
1831090a830SDenis Khalikov translateVulkanLaunchCall(op);
1841090a830SDenis Khalikov });
185a062a3edSDenis Khalikov }
186a062a3edSDenis Khalikov
collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp)187bfb2ce02SDenis Khalikov void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
188bfb2ce02SDenis Khalikov LLVM::CallOp vulkanLaunchCallOp) {
189bfb2ce02SDenis Khalikov // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
190bfb2ce02SDenis Khalikov // for the given vulkan launch call.
191bfb2ce02SDenis Khalikov auto spirvBlobAttr =
1920bf4a82aSChristian Sigg vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
193bfb2ce02SDenis Khalikov if (!spirvBlobAttr) {
194bfb2ce02SDenis Khalikov vulkanLaunchCallOp.emitError()
195bfb2ce02SDenis Khalikov << "missing " << kSPIRVBlobAttrName << " attribute";
196bfb2ce02SDenis Khalikov return signalPassFailure();
197bfb2ce02SDenis Khalikov }
198bfb2ce02SDenis Khalikov
199bfb2ce02SDenis Khalikov auto spirvEntryPointNameAttr =
2000bf4a82aSChristian Sigg vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
201bfb2ce02SDenis Khalikov if (!spirvEntryPointNameAttr) {
202bfb2ce02SDenis Khalikov vulkanLaunchCallOp.emitError()
203bfb2ce02SDenis Khalikov << "missing " << kSPIRVEntryPointAttrName << " attribute";
204bfb2ce02SDenis Khalikov return signalPassFailure();
205bfb2ce02SDenis Khalikov }
206bfb2ce02SDenis Khalikov
207bfb2ce02SDenis Khalikov spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
208bfb2ce02SDenis Khalikov }
209bfb2ce02SDenis Khalikov
createBindMemRefCalls(LLVM::CallOp cInterfaceVulkanLaunchCallOp,Value vulkanRuntime)210bfb2ce02SDenis Khalikov void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
211bfb2ce02SDenis Khalikov LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
212bfb2ce02SDenis Khalikov if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
213a48f0a3cSDenis Khalikov kVulkanLaunchNumConfigOperands)
2141090a830SDenis Khalikov return;
215bfb2ce02SDenis Khalikov OpBuilder builder(cInterfaceVulkanLaunchCallOp);
216bfb2ce02SDenis Khalikov Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
2171090a830SDenis Khalikov
2181090a830SDenis Khalikov // Create LLVM constant for the descriptor set index.
219bfb2ce02SDenis Khalikov // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
2201090a830SDenis Khalikov // pass does.
2211090a830SDenis Khalikov Value descriptorSet = builder.create<LLVM::ConstantOp>(
2221090a830SDenis Khalikov loc, getInt32Type(), builder.getI32IntegerAttr(0));
2231090a830SDenis Khalikov
224e4853be2SMehdi Amini for (const auto &en :
225bfb2ce02SDenis Khalikov llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
226a48f0a3cSDenis Khalikov kVulkanLaunchNumConfigOperands))) {
2271090a830SDenis Khalikov // Create LLVM constant for the descriptor binding index.
2281090a830SDenis Khalikov Value descriptorBinding = builder.create<LLVM::ConstantOp>(
229bfb2ce02SDenis Khalikov loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
2308f4ab8c7SDenis Khalikov
2318f4ab8c7SDenis Khalikov auto ptrToMemRefDescriptor = en.value();
2328f4ab8c7SDenis Khalikov uint32_t rank = 0;
233c69c9e0fSAlex Zinenko Type type;
2341009177dSDenis Khalikov if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
2358f4ab8c7SDenis Khalikov cInterfaceVulkanLaunchCallOp.emitError()
2368f4ab8c7SDenis Khalikov << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
2378f4ab8c7SDenis Khalikov return signalPassFailure();
2388f4ab8c7SDenis Khalikov }
2398f4ab8c7SDenis Khalikov
2401009177dSDenis Khalikov auto symbolName =
2411009177dSDenis Khalikov llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
242afd43a7aSThomas Raoux // Special case for fp16 type. Since it is not a supported type in C we use
243afd43a7aSThomas Raoux // int16_t and bitcast the descriptor.
244dd5165a9SAlex Zinenko if (type.isa<Float16Type>()) {
2452230bf99SAlex Zinenko auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
246afd43a7aSThomas Raoux ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
2478de43b92SAlex Zinenko loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
248afd43a7aSThomas Raoux }
249bfb2ce02SDenis Khalikov // Create call to `bindMemRef`.
2501090a830SDenis Khalikov builder.create<LLVM::CallOp>(
251faf1c224SChris Lattner loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()),
25208e4f078SRahul Joshi ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
2538f4ab8c7SDenis Khalikov ptrToMemRefDescriptor});
2541090a830SDenis Khalikov }
2551090a830SDenis Khalikov }
2561090a830SDenis Khalikov
deduceMemRefRankAndType(Value ptrToMemRefDescriptor,uint32_t & rank,Type & type)2571009177dSDenis Khalikov LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
258c69c9e0fSAlex Zinenko Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
2598f4ab8c7SDenis Khalikov auto llvmPtrDescriptorTy =
2608de43b92SAlex Zinenko ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
2618f4ab8c7SDenis Khalikov if (!llvmPtrDescriptorTy)
2628f4ab8c7SDenis Khalikov return failure();
2638f4ab8c7SDenis Khalikov
2648de43b92SAlex Zinenko auto llvmDescriptorTy =
2658de43b92SAlex Zinenko llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
2668f4ab8c7SDenis Khalikov // template <typename Elem, size_t Rank>
2678f4ab8c7SDenis Khalikov // struct {
2688f4ab8c7SDenis Khalikov // Elem *allocated;
2698f4ab8c7SDenis Khalikov // Elem *aligned;
2708f4ab8c7SDenis Khalikov // int64_t offset;
2718f4ab8c7SDenis Khalikov // int64_t sizes[Rank]; // omitted when rank == 0
2728f4ab8c7SDenis Khalikov // int64_t strides[Rank]; // omitted when rank == 0
2738f4ab8c7SDenis Khalikov // };
2748de43b92SAlex Zinenko if (!llvmDescriptorTy)
2758f4ab8c7SDenis Khalikov return failure();
2761009177dSDenis Khalikov
2778de43b92SAlex Zinenko type = llvmDescriptorTy.getBody()[0]
2788de43b92SAlex Zinenko .cast<LLVM::LLVMPointerType>()
2798de43b92SAlex Zinenko .getElementType();
2808de43b92SAlex Zinenko if (llvmDescriptorTy.getBody().size() == 3) {
2818f4ab8c7SDenis Khalikov rank = 0;
2828f4ab8c7SDenis Khalikov return success();
2838f4ab8c7SDenis Khalikov }
2848de43b92SAlex Zinenko rank = llvmDescriptorTy.getBody()[3]
2858de43b92SAlex Zinenko .cast<LLVM::LLVMArrayType>()
2868de43b92SAlex Zinenko .getNumElements();
2878f4ab8c7SDenis Khalikov return success();
2888f4ab8c7SDenis Khalikov }
2898f4ab8c7SDenis Khalikov
declareVulkanFunctions(Location loc)2901090a830SDenis Khalikov void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
291722f909fSRiver Riddle ModuleOp module = getOperation();
292973ddb7dSMehdi Amini auto builder = OpBuilder::atBlockEnd(module.getBody());
293a062a3edSDenis Khalikov
294a062a3edSDenis Khalikov if (!module.lookupSymbol(kSetEntryPoint)) {
295a062a3edSDenis Khalikov builder.create<LLVM::LLVMFuncOp>(
296a062a3edSDenis Khalikov loc, kSetEntryPoint,
2977ed9cfc7SAlex Zinenko LLVM::LLVMFunctionType::get(getVoidType(),
2987ed9cfc7SAlex Zinenko {getPointerType(), getPointerType()}));
299a062a3edSDenis Khalikov }
300a062a3edSDenis Khalikov
301a062a3edSDenis Khalikov if (!module.lookupSymbol(kSetNumWorkGroups)) {
302a062a3edSDenis Khalikov builder.create<LLVM::LLVMFuncOp>(
303a062a3edSDenis Khalikov loc, kSetNumWorkGroups,
3047ed9cfc7SAlex Zinenko LLVM::LLVMFunctionType::get(getVoidType(),
3057ed9cfc7SAlex Zinenko {getPointerType(), getInt64Type(),
3067ed9cfc7SAlex Zinenko getInt64Type(), getInt64Type()}));
307a062a3edSDenis Khalikov }
308a062a3edSDenis Khalikov
309a062a3edSDenis Khalikov if (!module.lookupSymbol(kSetBinaryShader)) {
310a062a3edSDenis Khalikov builder.create<LLVM::LLVMFuncOp>(
311a062a3edSDenis Khalikov loc, kSetBinaryShader,
3127ed9cfc7SAlex Zinenko LLVM::LLVMFunctionType::get(
3137ed9cfc7SAlex Zinenko getVoidType(),
3147ed9cfc7SAlex Zinenko {getPointerType(), getPointerType(), getInt32Type()}));
315a062a3edSDenis Khalikov }
316a062a3edSDenis Khalikov
317a062a3edSDenis Khalikov if (!module.lookupSymbol(kRunOnVulkan)) {
318a062a3edSDenis Khalikov builder.create<LLVM::LLVMFuncOp>(
319a062a3edSDenis Khalikov loc, kRunOnVulkan,
3207ed9cfc7SAlex Zinenko LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
3211090a830SDenis Khalikov }
3221090a830SDenis Khalikov
323afd43a7aSThomas Raoux for (unsigned i = 1; i <= 3; i++) {
324dd5165a9SAlex Zinenko SmallVector<Type, 5> types{
325dd5165a9SAlex Zinenko Float32Type::get(&getContext()), IntegerType::get(&getContext(), 32),
326dd5165a9SAlex Zinenko IntegerType::get(&getContext(), 16), IntegerType::get(&getContext(), 8),
327dd5165a9SAlex Zinenko Float16Type::get(&getContext())};
3287ed9cfc7SAlex Zinenko for (auto type : types) {
329afd43a7aSThomas Raoux std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
330afd43a7aSThomas Raoux std::string(stringifyType(type));
331dd5165a9SAlex Zinenko if (type.isa<Float16Type>())
3322230bf99SAlex Zinenko type = IntegerType::get(&getContext(), 16);
333afd43a7aSThomas Raoux if (!module.lookupSymbol(fnName)) {
3347ed9cfc7SAlex Zinenko auto fnType = LLVM::LLVMFunctionType::get(
335afd43a7aSThomas Raoux getVoidType(),
336afd43a7aSThomas Raoux {getPointerType(), getInt32Type(), getInt32Type(),
3378de43b92SAlex Zinenko LLVM::LLVMPointerType::get(getMemRefType(i, type))},
338afd43a7aSThomas Raoux /*isVarArg=*/false);
339afd43a7aSThomas Raoux builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
3401090a830SDenis Khalikov }
341afd43a7aSThomas Raoux }
342afd43a7aSThomas Raoux }
3430718e3aeSDenis Khalikov
3441090a830SDenis Khalikov if (!module.lookupSymbol(kInitVulkan)) {
3451090a830SDenis Khalikov builder.create<LLVM::LLVMFuncOp>(
3467ed9cfc7SAlex Zinenko loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {}));
3471090a830SDenis Khalikov }
3481090a830SDenis Khalikov
3491090a830SDenis Khalikov if (!module.lookupSymbol(kDeinitVulkan)) {
3501090a830SDenis Khalikov builder.create<LLVM::LLVMFuncOp>(
3511090a830SDenis Khalikov loc, kDeinitVulkan,
3527ed9cfc7SAlex Zinenko LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
353a062a3edSDenis Khalikov }
354a062a3edSDenis Khalikov }
355a062a3edSDenis Khalikov
createEntryPointNameConstant(StringRef name,Location loc,OpBuilder & builder)3561090a830SDenis Khalikov Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
357a062a3edSDenis Khalikov StringRef name, Location loc, OpBuilder &builder) {
358896ee361SDenis Khalikov SmallString<16> shaderName(name.begin(), name.end());
359a062a3edSDenis Khalikov // Append `\0` to follow C style string given that LLVM::createGlobalString()
360a062a3edSDenis Khalikov // won't handle this directly for us.
361a062a3edSDenis Khalikov shaderName.push_back('\0');
362a062a3edSDenis Khalikov
363896ee361SDenis Khalikov std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
364896ee361SDenis Khalikov return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
3655446ec85SAlex Zinenko shaderName, LLVM::Linkage::Internal);
366a062a3edSDenis Khalikov }
367a062a3edSDenis Khalikov
translateVulkanLaunchCall(LLVM::CallOp cInterfaceVulkanLaunchCallOp)3681090a830SDenis Khalikov void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
369bfb2ce02SDenis Khalikov LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
370bfb2ce02SDenis Khalikov OpBuilder builder(cInterfaceVulkanLaunchCallOp);
371bfb2ce02SDenis Khalikov Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
3721090a830SDenis Khalikov // Create call to `initVulkan`.
3731090a830SDenis Khalikov auto initVulkanCall = builder.create<LLVM::CallOp>(
374faf1c224SChris Lattner loc, TypeRange{getPointerType()}, kInitVulkan);
3751090a830SDenis Khalikov // The result of `initVulkan` function is a pointer to Vulkan runtime, we
3761090a830SDenis Khalikov // need to pass that pointer to each Vulkan runtime call.
3771090a830SDenis Khalikov auto vulkanRuntime = initVulkanCall.getResult(0);
378a062a3edSDenis Khalikov
379a062a3edSDenis Khalikov // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
380a062a3edSDenis Khalikov // that data to runtime call.
381a062a3edSDenis Khalikov Value ptrToSPIRVBinary = LLVM::createGlobalString(
382bfb2ce02SDenis Khalikov loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
3835446ec85SAlex Zinenko LLVM::Linkage::Internal);
3841090a830SDenis Khalikov
385a062a3edSDenis Khalikov // Create LLVM constant for the size of SPIR-V binary shader.
386a062a3edSDenis Khalikov Value binarySize = builder.create<LLVM::ConstantOp>(
3871090a830SDenis Khalikov loc, getInt32Type(),
388bfb2ce02SDenis Khalikov builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
3891090a830SDenis Khalikov
390bfb2ce02SDenis Khalikov // Create call to `bindMemRef` for each memref operand.
391bfb2ce02SDenis Khalikov createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
3921090a830SDenis Khalikov
393a062a3edSDenis Khalikov // Create call to `setBinaryShader` runtime function with the given pointer to
394a062a3edSDenis Khalikov // SPIR-V binary and binary size.
3951090a830SDenis Khalikov builder.create<LLVM::CallOp>(
396faf1c224SChris Lattner loc, TypeRange(), kSetBinaryShader,
39708e4f078SRahul Joshi ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
398a062a3edSDenis Khalikov // Create LLVM global with entry point name.
399bfb2ce02SDenis Khalikov Value entryPointName = createEntryPointNameConstant(
400bfb2ce02SDenis Khalikov spirvAttributes.second.getValue(), loc, builder);
401a062a3edSDenis Khalikov // Create call to `setEntryPoint` runtime function with the given pointer to
402a062a3edSDenis Khalikov // entry point name.
403faf1c224SChris Lattner builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
40408e4f078SRahul Joshi ValueRange{vulkanRuntime, entryPointName});
405a062a3edSDenis Khalikov
406a062a3edSDenis Khalikov // Create number of local workgroup for each dimension.
407a062a3edSDenis Khalikov builder.create<LLVM::CallOp>(
408faf1c224SChris Lattner loc, TypeRange(), kSetNumWorkGroups,
40908e4f078SRahul Joshi ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
410bfb2ce02SDenis Khalikov cInterfaceVulkanLaunchCallOp.getOperand(1),
411bfb2ce02SDenis Khalikov cInterfaceVulkanLaunchCallOp.getOperand(2)});
412a062a3edSDenis Khalikov
413a062a3edSDenis Khalikov // Create call to `runOnVulkan` runtime function.
414faf1c224SChris Lattner builder.create<LLVM::CallOp>(loc, TypeRange(), kRunOnVulkan,
41508e4f078SRahul Joshi ValueRange{vulkanRuntime});
4161090a830SDenis Khalikov
4171090a830SDenis Khalikov // Create call to 'deinitVulkan' runtime function.
418faf1c224SChris Lattner builder.create<LLVM::CallOp>(loc, TypeRange(), kDeinitVulkan,
41908e4f078SRahul Joshi ValueRange{vulkanRuntime});
420a062a3edSDenis Khalikov
421a062a3edSDenis Khalikov // Declare runtime functions.
422a062a3edSDenis Khalikov declareVulkanFunctions(loc);
423a062a3edSDenis Khalikov
424bfb2ce02SDenis Khalikov cInterfaceVulkanLaunchCallOp.erase();
425a062a3edSDenis Khalikov }
426a062a3edSDenis Khalikov
42780aca1eaSRiver Riddle std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createConvertVulkanLaunchFuncToVulkanCallsPass()4281090a830SDenis Khalikov mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
4291090a830SDenis Khalikov return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
430a062a3edSDenis Khalikov }
431