1a062a3edSDenis Khalikov //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===//
2a062a3edSDenis Khalikov //
3a062a3edSDenis Khalikov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a062a3edSDenis Khalikov // See https://llvm.org/LICENSE.txt for license information.
5a062a3edSDenis Khalikov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a062a3edSDenis Khalikov //
7a062a3edSDenis Khalikov //===----------------------------------------------------------------------===//
8a062a3edSDenis Khalikov //
91090a830SDenis Khalikov // This file implements a pass to convert vulkan launch call into a sequence of
10a062a3edSDenis Khalikov // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
11a062a3edSDenis Khalikov // don't expose separate external functions in IR for each of them, instead we
12a062a3edSDenis Khalikov // expose a few external functions to wrapper libraries which manages Vulkan
13a062a3edSDenis Khalikov // runtime.
14a062a3edSDenis Khalikov //
15a062a3edSDenis Khalikov //===----------------------------------------------------------------------===//
16a062a3edSDenis Khalikov 
171834ad4aSRiver Riddle #include "../PassDetail.h"
18a062a3edSDenis Khalikov #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19a062a3edSDenis Khalikov #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20a062a3edSDenis Khalikov #include "mlir/IR/Attributes.h"
21a062a3edSDenis Khalikov #include "mlir/IR/Builders.h"
2265fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
23a062a3edSDenis Khalikov 
24896ee361SDenis Khalikov #include "llvm/ADT/SmallString.h"
258f4ab8c7SDenis Khalikov #include "llvm/Support/FormatVariadic.h"
26a062a3edSDenis Khalikov 
27a062a3edSDenis Khalikov using namespace mlir;
28a062a3edSDenis Khalikov 
29bfb2ce02SDenis Khalikov static constexpr const char *kCInterfaceVulkanLaunch =
30bfb2ce02SDenis Khalikov     "_mlir_ciface_vulkanLaunch";
311090a830SDenis Khalikov static constexpr const char *kDeinitVulkan = "deinitVulkan";
321090a830SDenis Khalikov static constexpr const char *kRunOnVulkan = "runOnVulkan";
331090a830SDenis Khalikov static constexpr const char *kInitVulkan = "initVulkan";
34a062a3edSDenis Khalikov static constexpr const char *kSetBinaryShader = "setBinaryShader";
35a062a3edSDenis Khalikov static constexpr const char *kSetEntryPoint = "setEntryPoint";
36a062a3edSDenis Khalikov static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
37a062a3edSDenis Khalikov static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
381090a830SDenis Khalikov static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
391090a830SDenis Khalikov static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
401090a830SDenis Khalikov static constexpr const char *kVulkanLaunch = "vulkanLaunch";
41a062a3edSDenis Khalikov 
42a062a3edSDenis Khalikov namespace {
43a062a3edSDenis Khalikov 
44bfb2ce02SDenis Khalikov /// A pass to convert vulkan launch call op into a sequence of Vulkan
451090a830SDenis Khalikov /// runtime calls in the following order:
46a062a3edSDenis Khalikov ///
471090a830SDenis Khalikov /// * initVulkan           -- initializes vulkan runtime
48bfb2ce02SDenis Khalikov /// * bindMemRef           -- binds memref
49a062a3edSDenis Khalikov /// * setBinaryShader      -- sets the binary shader data
50a062a3edSDenis Khalikov /// * setEntryPoint        -- sets the entry point name
51a062a3edSDenis Khalikov /// * setNumWorkGroups     -- sets the number of a local workgroups
52a062a3edSDenis Khalikov /// * runOnVulkan          -- runs vulkan runtime
531090a830SDenis Khalikov /// * deinitVulkan         -- deinitializes vulkan runtime
54a062a3edSDenis Khalikov ///
551090a830SDenis Khalikov class VulkanLaunchFuncToVulkanCallsPass
561834ad4aSRiver Riddle     : public ConvertVulkanLaunchFuncToVulkanCallsBase<
571834ad4aSRiver Riddle           VulkanLaunchFuncToVulkanCallsPass> {
58a062a3edSDenis Khalikov private:
initializeCachedTypes()59a062a3edSDenis Khalikov   void initializeCachedTypes() {
60dd5165a9SAlex Zinenko     llvmFloatType = Float32Type::get(&getContext());
617ed9cfc7SAlex Zinenko     llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
622230bf99SAlex Zinenko     llvmPointerType =
632230bf99SAlex Zinenko         LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
642230bf99SAlex Zinenko     llvmInt32Type = IntegerType::get(&getContext(), 32);
652230bf99SAlex Zinenko     llvmInt64Type = IntegerType::get(&getContext(), 64);
66bfb2ce02SDenis Khalikov   }
67bfb2ce02SDenis Khalikov 
getMemRefType(uint32_t rank,Type elemenType)68c69c9e0fSAlex Zinenko   Type getMemRefType(uint32_t rank, Type elemenType) {
69bfb2ce02SDenis Khalikov     // According to the MLIR doc memref argument is converted into a
70bfb2ce02SDenis Khalikov     // pointer-to-struct argument of type:
71bfb2ce02SDenis Khalikov     // template <typename Elem, size_t Rank>
72bfb2ce02SDenis Khalikov     // struct {
73bfb2ce02SDenis Khalikov     //   Elem *allocated;
74bfb2ce02SDenis Khalikov     //   Elem *aligned;
75bfb2ce02SDenis Khalikov     //   int64_t offset;
76bfb2ce02SDenis Khalikov     //   int64_t sizes[Rank]; // omitted when rank == 0
77bfb2ce02SDenis Khalikov     //   int64_t strides[Rank]; // omitted when rank == 0
78bfb2ce02SDenis Khalikov     // };
798de43b92SAlex Zinenko     auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
800718e3aeSDenis Khalikov     auto llvmArrayRankElementSizeType =
817ed9cfc7SAlex Zinenko         LLVM::LLVMArrayType::get(getInt64Type(), rank);
82bfb2ce02SDenis Khalikov 
830718e3aeSDenis Khalikov     // Create a type
841009177dSDenis Khalikov     // `!llvm<"{ `element-type`*, `element-type`*, i64,
851009177dSDenis Khalikov     // [`rank` x i64], [`rank` x i64]}">`.
867ed9cfc7SAlex Zinenko     return LLVM::LLVMStructType::getLiteral(
875446ec85SAlex Zinenko         &getContext(),
881009177dSDenis Khalikov         {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
890718e3aeSDenis Khalikov          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
90a062a3edSDenis Khalikov   }
91a062a3edSDenis Khalikov 
getVoidType()92c69c9e0fSAlex Zinenko   Type getVoidType() { return llvmVoidType; }
getPointerType()93c69c9e0fSAlex Zinenko   Type getPointerType() { return llvmPointerType; }
getInt32Type()94c69c9e0fSAlex Zinenko   Type getInt32Type() { return llvmInt32Type; }
getInt64Type()95c69c9e0fSAlex Zinenko   Type getInt64Type() { return llvmInt64Type; }
96a062a3edSDenis Khalikov 
9706b90586SKazuaki Ishizaki   /// Creates an LLVM global for the given `name`.
98a062a3edSDenis Khalikov   Value createEntryPointNameConstant(StringRef name, Location loc,
99a062a3edSDenis Khalikov                                      OpBuilder &builder);
100a062a3edSDenis Khalikov 
101a062a3edSDenis Khalikov   /// Declares all needed runtime functions.
102a062a3edSDenis Khalikov   void declareVulkanFunctions(Location loc);
103a062a3edSDenis Khalikov 
1041090a830SDenis Khalikov   /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
isVulkanLaunchCallOp(LLVM::CallOp callOp)1051090a830SDenis Khalikov   bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
106*6d5fc1e3SKazu Hirata     return (callOp.getCallee() && *callOp.getCallee() == kVulkanLaunch &&
107a48f0a3cSDenis Khalikov             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
108bfb2ce02SDenis Khalikov   }
109bfb2ce02SDenis Khalikov 
110bfb2ce02SDenis Khalikov   /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call
111bfb2ce02SDenis Khalikov   /// op.
isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp)112bfb2ce02SDenis Khalikov   bool isCInterfaceVulkanLaunchCallOp(LLVM::CallOp callOp) {
113cfb72fd3SJacques Pienaar     return (callOp.getCallee() &&
114*6d5fc1e3SKazu Hirata             *callOp.getCallee() == kCInterfaceVulkanLaunch &&
115a48f0a3cSDenis Khalikov             callOp.getNumOperands() >= kVulkanLaunchNumConfigOperands);
1161090a830SDenis Khalikov   }
1171090a830SDenis Khalikov 
1181090a830SDenis Khalikov   /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
1191090a830SDenis Khalikov   /// runtime calls.
1201090a830SDenis Khalikov   void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
1211090a830SDenis Khalikov 
122bfb2ce02SDenis Khalikov   /// Creates call to `bindMemRef` for each memref operand.
123bfb2ce02SDenis Khalikov   void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp,
124bfb2ce02SDenis Khalikov                              Value vulkanRuntime);
125bfb2ce02SDenis Khalikov 
126bfb2ce02SDenis Khalikov   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
127bfb2ce02SDenis Khalikov   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
128a062a3edSDenis Khalikov 
1291009177dSDenis Khalikov   /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
1301009177dSDenis Khalikov   LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
131c69c9e0fSAlex Zinenko                                         uint32_t &rank, Type &type);
1321009177dSDenis Khalikov 
1331009177dSDenis Khalikov   /// Returns a string representation from the given `type`.
stringifyType(Type type)134c69c9e0fSAlex Zinenko   StringRef stringifyType(Type type) {
135dd5165a9SAlex Zinenko     if (type.isa<Float32Type>())
1361009177dSDenis Khalikov       return "Float";
137dd5165a9SAlex Zinenko     if (type.isa<Float16Type>())
138afd43a7aSThomas Raoux       return "Half";
1392230bf99SAlex Zinenko     if (auto intType = type.dyn_cast<IntegerType>()) {
1402230bf99SAlex Zinenko       if (intType.getWidth() == 32)
141afd43a7aSThomas Raoux         return "Int32";
1422230bf99SAlex Zinenko       if (intType.getWidth() == 16)
143afd43a7aSThomas Raoux         return "Int16";
1442230bf99SAlex Zinenko       if (intType.getWidth() == 8)
145afd43a7aSThomas Raoux         return "Int8";
1468de43b92SAlex Zinenko     }
1471009177dSDenis Khalikov 
1481009177dSDenis Khalikov     llvm_unreachable("unsupported type");
1491009177dSDenis Khalikov   }
1508f4ab8c7SDenis Khalikov 
151a062a3edSDenis Khalikov public:
152722f909fSRiver Riddle   void runOnOperation() override;
153a062a3edSDenis Khalikov 
154a062a3edSDenis Khalikov private:
155c69c9e0fSAlex Zinenko   Type llvmFloatType;
156c69c9e0fSAlex Zinenko   Type llvmVoidType;
157c69c9e0fSAlex Zinenko   Type llvmPointerType;
158c69c9e0fSAlex Zinenko   Type llvmInt32Type;
159c69c9e0fSAlex Zinenko   Type llvmInt64Type;
1601090a830SDenis Khalikov 
161bfb2ce02SDenis Khalikov   // TODO: Use an associative array to support multiple vulkan launch calls.
162bfb2ce02SDenis Khalikov   std::pair<StringAttr, StringAttr> spirvAttributes;
163a48f0a3cSDenis Khalikov   /// The number of vulkan launch configuration operands, placed at the leading
164a48f0a3cSDenis Khalikov   /// positions of the operand list.
165a48f0a3cSDenis Khalikov   static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
166a062a3edSDenis Khalikov };
167a062a3edSDenis Khalikov 
168be0a7e9fSMehdi Amini } // namespace
169a062a3edSDenis Khalikov 
runOnOperation()170722f909fSRiver Riddle void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
171a062a3edSDenis Khalikov   initializeCachedTypes();
172bfb2ce02SDenis Khalikov 
173bfb2ce02SDenis Khalikov   // Collect SPIR-V attributes such as `spirv_blob` and
174bfb2ce02SDenis Khalikov   // `spirv_entry_point_name`.
175722f909fSRiver Riddle   getOperation().walk([this](LLVM::CallOp op) {
1761090a830SDenis Khalikov     if (isVulkanLaunchCallOp(op))
177bfb2ce02SDenis Khalikov       collectSPIRVAttributes(op);
178bfb2ce02SDenis Khalikov   });
179bfb2ce02SDenis Khalikov 
180bfb2ce02SDenis Khalikov   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
181722f909fSRiver Riddle   getOperation().walk([this](LLVM::CallOp op) {
182bfb2ce02SDenis Khalikov     if (isCInterfaceVulkanLaunchCallOp(op))
1831090a830SDenis Khalikov       translateVulkanLaunchCall(op);
1841090a830SDenis Khalikov   });
185a062a3edSDenis Khalikov }
186a062a3edSDenis Khalikov 
collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp)187bfb2ce02SDenis Khalikov void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
188bfb2ce02SDenis Khalikov     LLVM::CallOp vulkanLaunchCallOp) {
189bfb2ce02SDenis Khalikov   // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
190bfb2ce02SDenis Khalikov   // for the given vulkan launch call.
191bfb2ce02SDenis Khalikov   auto spirvBlobAttr =
1920bf4a82aSChristian Sigg       vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
193bfb2ce02SDenis Khalikov   if (!spirvBlobAttr) {
194bfb2ce02SDenis Khalikov     vulkanLaunchCallOp.emitError()
195bfb2ce02SDenis Khalikov         << "missing " << kSPIRVBlobAttrName << " attribute";
196bfb2ce02SDenis Khalikov     return signalPassFailure();
197bfb2ce02SDenis Khalikov   }
198bfb2ce02SDenis Khalikov 
199bfb2ce02SDenis Khalikov   auto spirvEntryPointNameAttr =
2000bf4a82aSChristian Sigg       vulkanLaunchCallOp->getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
201bfb2ce02SDenis Khalikov   if (!spirvEntryPointNameAttr) {
202bfb2ce02SDenis Khalikov     vulkanLaunchCallOp.emitError()
203bfb2ce02SDenis Khalikov         << "missing " << kSPIRVEntryPointAttrName << " attribute";
204bfb2ce02SDenis Khalikov     return signalPassFailure();
205bfb2ce02SDenis Khalikov   }
206bfb2ce02SDenis Khalikov 
207bfb2ce02SDenis Khalikov   spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
208bfb2ce02SDenis Khalikov }
209bfb2ce02SDenis Khalikov 
createBindMemRefCalls(LLVM::CallOp cInterfaceVulkanLaunchCallOp,Value vulkanRuntime)210bfb2ce02SDenis Khalikov void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
211bfb2ce02SDenis Khalikov     LLVM::CallOp cInterfaceVulkanLaunchCallOp, Value vulkanRuntime) {
212bfb2ce02SDenis Khalikov   if (cInterfaceVulkanLaunchCallOp.getNumOperands() ==
213a48f0a3cSDenis Khalikov       kVulkanLaunchNumConfigOperands)
2141090a830SDenis Khalikov     return;
215bfb2ce02SDenis Khalikov   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
216bfb2ce02SDenis Khalikov   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
2171090a830SDenis Khalikov 
2181090a830SDenis Khalikov   // Create LLVM constant for the descriptor set index.
219bfb2ce02SDenis Khalikov   // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
2201090a830SDenis Khalikov   // pass does.
2211090a830SDenis Khalikov   Value descriptorSet = builder.create<LLVM::ConstantOp>(
2221090a830SDenis Khalikov       loc, getInt32Type(), builder.getI32IntegerAttr(0));
2231090a830SDenis Khalikov 
224e4853be2SMehdi Amini   for (const auto &en :
225bfb2ce02SDenis Khalikov        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
226a48f0a3cSDenis Khalikov            kVulkanLaunchNumConfigOperands))) {
2271090a830SDenis Khalikov     // Create LLVM constant for the descriptor binding index.
2281090a830SDenis Khalikov     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
229bfb2ce02SDenis Khalikov         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
2308f4ab8c7SDenis Khalikov 
2318f4ab8c7SDenis Khalikov     auto ptrToMemRefDescriptor = en.value();
2328f4ab8c7SDenis Khalikov     uint32_t rank = 0;
233c69c9e0fSAlex Zinenko     Type type;
2341009177dSDenis Khalikov     if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
2358f4ab8c7SDenis Khalikov       cInterfaceVulkanLaunchCallOp.emitError()
2368f4ab8c7SDenis Khalikov           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
2378f4ab8c7SDenis Khalikov       return signalPassFailure();
2388f4ab8c7SDenis Khalikov     }
2398f4ab8c7SDenis Khalikov 
2401009177dSDenis Khalikov     auto symbolName =
2411009177dSDenis Khalikov         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
242afd43a7aSThomas Raoux     // Special case for fp16 type. Since it is not a supported type in C we use
243afd43a7aSThomas Raoux     // int16_t and bitcast the descriptor.
244dd5165a9SAlex Zinenko     if (type.isa<Float16Type>()) {
2452230bf99SAlex Zinenko       auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
246afd43a7aSThomas Raoux       ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
2478de43b92SAlex Zinenko           loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
248afd43a7aSThomas Raoux     }
249bfb2ce02SDenis Khalikov     // Create call to `bindMemRef`.
2501090a830SDenis Khalikov     builder.create<LLVM::CallOp>(
251faf1c224SChris Lattner         loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()),
25208e4f078SRahul Joshi         ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
2538f4ab8c7SDenis Khalikov                    ptrToMemRefDescriptor});
2541090a830SDenis Khalikov   }
2551090a830SDenis Khalikov }
2561090a830SDenis Khalikov 
deduceMemRefRankAndType(Value ptrToMemRefDescriptor,uint32_t & rank,Type & type)2571009177dSDenis Khalikov LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
258c69c9e0fSAlex Zinenko     Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
2598f4ab8c7SDenis Khalikov   auto llvmPtrDescriptorTy =
2608de43b92SAlex Zinenko       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
2618f4ab8c7SDenis Khalikov   if (!llvmPtrDescriptorTy)
2628f4ab8c7SDenis Khalikov     return failure();
2638f4ab8c7SDenis Khalikov 
2648de43b92SAlex Zinenko   auto llvmDescriptorTy =
2658de43b92SAlex Zinenko       llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
2668f4ab8c7SDenis Khalikov   // template <typename Elem, size_t Rank>
2678f4ab8c7SDenis Khalikov   // struct {
2688f4ab8c7SDenis Khalikov   //   Elem *allocated;
2698f4ab8c7SDenis Khalikov   //   Elem *aligned;
2708f4ab8c7SDenis Khalikov   //   int64_t offset;
2718f4ab8c7SDenis Khalikov   //   int64_t sizes[Rank]; // omitted when rank == 0
2728f4ab8c7SDenis Khalikov   //   int64_t strides[Rank]; // omitted when rank == 0
2738f4ab8c7SDenis Khalikov   // };
2748de43b92SAlex Zinenko   if (!llvmDescriptorTy)
2758f4ab8c7SDenis Khalikov     return failure();
2761009177dSDenis Khalikov 
2778de43b92SAlex Zinenko   type = llvmDescriptorTy.getBody()[0]
2788de43b92SAlex Zinenko              .cast<LLVM::LLVMPointerType>()
2798de43b92SAlex Zinenko              .getElementType();
2808de43b92SAlex Zinenko   if (llvmDescriptorTy.getBody().size() == 3) {
2818f4ab8c7SDenis Khalikov     rank = 0;
2828f4ab8c7SDenis Khalikov     return success();
2838f4ab8c7SDenis Khalikov   }
2848de43b92SAlex Zinenko   rank = llvmDescriptorTy.getBody()[3]
2858de43b92SAlex Zinenko              .cast<LLVM::LLVMArrayType>()
2868de43b92SAlex Zinenko              .getNumElements();
2878f4ab8c7SDenis Khalikov   return success();
2888f4ab8c7SDenis Khalikov }
2898f4ab8c7SDenis Khalikov 
declareVulkanFunctions(Location loc)2901090a830SDenis Khalikov void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
291722f909fSRiver Riddle   ModuleOp module = getOperation();
292973ddb7dSMehdi Amini   auto builder = OpBuilder::atBlockEnd(module.getBody());
293a062a3edSDenis Khalikov 
294a062a3edSDenis Khalikov   if (!module.lookupSymbol(kSetEntryPoint)) {
295a062a3edSDenis Khalikov     builder.create<LLVM::LLVMFuncOp>(
296a062a3edSDenis Khalikov         loc, kSetEntryPoint,
2977ed9cfc7SAlex Zinenko         LLVM::LLVMFunctionType::get(getVoidType(),
2987ed9cfc7SAlex Zinenko                                     {getPointerType(), getPointerType()}));
299a062a3edSDenis Khalikov   }
300a062a3edSDenis Khalikov 
301a062a3edSDenis Khalikov   if (!module.lookupSymbol(kSetNumWorkGroups)) {
302a062a3edSDenis Khalikov     builder.create<LLVM::LLVMFuncOp>(
303a062a3edSDenis Khalikov         loc, kSetNumWorkGroups,
3047ed9cfc7SAlex Zinenko         LLVM::LLVMFunctionType::get(getVoidType(),
3057ed9cfc7SAlex Zinenko                                     {getPointerType(), getInt64Type(),
3067ed9cfc7SAlex Zinenko                                      getInt64Type(), getInt64Type()}));
307a062a3edSDenis Khalikov   }
308a062a3edSDenis Khalikov 
309a062a3edSDenis Khalikov   if (!module.lookupSymbol(kSetBinaryShader)) {
310a062a3edSDenis Khalikov     builder.create<LLVM::LLVMFuncOp>(
311a062a3edSDenis Khalikov         loc, kSetBinaryShader,
3127ed9cfc7SAlex Zinenko         LLVM::LLVMFunctionType::get(
3137ed9cfc7SAlex Zinenko             getVoidType(),
3147ed9cfc7SAlex Zinenko             {getPointerType(), getPointerType(), getInt32Type()}));
315a062a3edSDenis Khalikov   }
316a062a3edSDenis Khalikov 
317a062a3edSDenis Khalikov   if (!module.lookupSymbol(kRunOnVulkan)) {
318a062a3edSDenis Khalikov     builder.create<LLVM::LLVMFuncOp>(
319a062a3edSDenis Khalikov         loc, kRunOnVulkan,
3207ed9cfc7SAlex Zinenko         LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
3211090a830SDenis Khalikov   }
3221090a830SDenis Khalikov 
323afd43a7aSThomas Raoux   for (unsigned i = 1; i <= 3; i++) {
324dd5165a9SAlex Zinenko     SmallVector<Type, 5> types{
325dd5165a9SAlex Zinenko         Float32Type::get(&getContext()), IntegerType::get(&getContext(), 32),
326dd5165a9SAlex Zinenko         IntegerType::get(&getContext(), 16), IntegerType::get(&getContext(), 8),
327dd5165a9SAlex Zinenko         Float16Type::get(&getContext())};
3287ed9cfc7SAlex Zinenko     for (auto type : types) {
329afd43a7aSThomas Raoux       std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
330afd43a7aSThomas Raoux                            std::string(stringifyType(type));
331dd5165a9SAlex Zinenko       if (type.isa<Float16Type>())
3322230bf99SAlex Zinenko         type = IntegerType::get(&getContext(), 16);
333afd43a7aSThomas Raoux       if (!module.lookupSymbol(fnName)) {
3347ed9cfc7SAlex Zinenko         auto fnType = LLVM::LLVMFunctionType::get(
335afd43a7aSThomas Raoux             getVoidType(),
336afd43a7aSThomas Raoux             {getPointerType(), getInt32Type(), getInt32Type(),
3378de43b92SAlex Zinenko              LLVM::LLVMPointerType::get(getMemRefType(i, type))},
338afd43a7aSThomas Raoux             /*isVarArg=*/false);
339afd43a7aSThomas Raoux         builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
3401090a830SDenis Khalikov       }
341afd43a7aSThomas Raoux     }
342afd43a7aSThomas Raoux   }
3430718e3aeSDenis Khalikov 
3441090a830SDenis Khalikov   if (!module.lookupSymbol(kInitVulkan)) {
3451090a830SDenis Khalikov     builder.create<LLVM::LLVMFuncOp>(
3467ed9cfc7SAlex Zinenko         loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {}));
3471090a830SDenis Khalikov   }
3481090a830SDenis Khalikov 
3491090a830SDenis Khalikov   if (!module.lookupSymbol(kDeinitVulkan)) {
3501090a830SDenis Khalikov     builder.create<LLVM::LLVMFuncOp>(
3511090a830SDenis Khalikov         loc, kDeinitVulkan,
3527ed9cfc7SAlex Zinenko         LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()}));
353a062a3edSDenis Khalikov   }
354a062a3edSDenis Khalikov }
355a062a3edSDenis Khalikov 
createEntryPointNameConstant(StringRef name,Location loc,OpBuilder & builder)3561090a830SDenis Khalikov Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
357a062a3edSDenis Khalikov     StringRef name, Location loc, OpBuilder &builder) {
358896ee361SDenis Khalikov   SmallString<16> shaderName(name.begin(), name.end());
359a062a3edSDenis Khalikov   // Append `\0` to follow C style string given that LLVM::createGlobalString()
360a062a3edSDenis Khalikov   // won't handle this directly for us.
361a062a3edSDenis Khalikov   shaderName.push_back('\0');
362a062a3edSDenis Khalikov 
363896ee361SDenis Khalikov   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
364896ee361SDenis Khalikov   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
3655446ec85SAlex Zinenko                                   shaderName, LLVM::Linkage::Internal);
366a062a3edSDenis Khalikov }
367a062a3edSDenis Khalikov 
translateVulkanLaunchCall(LLVM::CallOp cInterfaceVulkanLaunchCallOp)3681090a830SDenis Khalikov void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
369bfb2ce02SDenis Khalikov     LLVM::CallOp cInterfaceVulkanLaunchCallOp) {
370bfb2ce02SDenis Khalikov   OpBuilder builder(cInterfaceVulkanLaunchCallOp);
371bfb2ce02SDenis Khalikov   Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
3721090a830SDenis Khalikov   // Create call to `initVulkan`.
3731090a830SDenis Khalikov   auto initVulkanCall = builder.create<LLVM::CallOp>(
374faf1c224SChris Lattner       loc, TypeRange{getPointerType()}, kInitVulkan);
3751090a830SDenis Khalikov   // The result of `initVulkan` function is a pointer to Vulkan runtime, we
3761090a830SDenis Khalikov   // need to pass that pointer to each Vulkan runtime call.
3771090a830SDenis Khalikov   auto vulkanRuntime = initVulkanCall.getResult(0);
378a062a3edSDenis Khalikov 
379a062a3edSDenis Khalikov   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
380a062a3edSDenis Khalikov   // that data to runtime call.
381a062a3edSDenis Khalikov   Value ptrToSPIRVBinary = LLVM::createGlobalString(
382bfb2ce02SDenis Khalikov       loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
3835446ec85SAlex Zinenko       LLVM::Linkage::Internal);
3841090a830SDenis Khalikov 
385a062a3edSDenis Khalikov   // Create LLVM constant for the size of SPIR-V binary shader.
386a062a3edSDenis Khalikov   Value binarySize = builder.create<LLVM::ConstantOp>(
3871090a830SDenis Khalikov       loc, getInt32Type(),
388bfb2ce02SDenis Khalikov       builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
3891090a830SDenis Khalikov 
390bfb2ce02SDenis Khalikov   // Create call to `bindMemRef` for each memref operand.
391bfb2ce02SDenis Khalikov   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
3921090a830SDenis Khalikov 
393a062a3edSDenis Khalikov   // Create call to `setBinaryShader` runtime function with the given pointer to
394a062a3edSDenis Khalikov   // SPIR-V binary and binary size.
3951090a830SDenis Khalikov   builder.create<LLVM::CallOp>(
396faf1c224SChris Lattner       loc, TypeRange(), kSetBinaryShader,
39708e4f078SRahul Joshi       ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
398a062a3edSDenis Khalikov   // Create LLVM global with entry point name.
399bfb2ce02SDenis Khalikov   Value entryPointName = createEntryPointNameConstant(
400bfb2ce02SDenis Khalikov       spirvAttributes.second.getValue(), loc, builder);
401a062a3edSDenis Khalikov   // Create call to `setEntryPoint` runtime function with the given pointer to
402a062a3edSDenis Khalikov   // entry point name.
403faf1c224SChris Lattner   builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
40408e4f078SRahul Joshi                                ValueRange{vulkanRuntime, entryPointName});
405a062a3edSDenis Khalikov 
406a062a3edSDenis Khalikov   // Create number of local workgroup for each dimension.
407a062a3edSDenis Khalikov   builder.create<LLVM::CallOp>(
408faf1c224SChris Lattner       loc, TypeRange(), kSetNumWorkGroups,
40908e4f078SRahul Joshi       ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
410bfb2ce02SDenis Khalikov                  cInterfaceVulkanLaunchCallOp.getOperand(1),
411bfb2ce02SDenis Khalikov                  cInterfaceVulkanLaunchCallOp.getOperand(2)});
412a062a3edSDenis Khalikov 
413a062a3edSDenis Khalikov   // Create call to `runOnVulkan` runtime function.
414faf1c224SChris Lattner   builder.create<LLVM::CallOp>(loc, TypeRange(), kRunOnVulkan,
41508e4f078SRahul Joshi                                ValueRange{vulkanRuntime});
4161090a830SDenis Khalikov 
4171090a830SDenis Khalikov   // Create call to 'deinitVulkan' runtime function.
418faf1c224SChris Lattner   builder.create<LLVM::CallOp>(loc, TypeRange(), kDeinitVulkan,
41908e4f078SRahul Joshi                                ValueRange{vulkanRuntime});
420a062a3edSDenis Khalikov 
421a062a3edSDenis Khalikov   // Declare runtime functions.
422a062a3edSDenis Khalikov   declareVulkanFunctions(loc);
423a062a3edSDenis Khalikov 
424bfb2ce02SDenis Khalikov   cInterfaceVulkanLaunchCallOp.erase();
425a062a3edSDenis Khalikov }
426a062a3edSDenis Khalikov 
42780aca1eaSRiver Riddle std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createConvertVulkanLaunchFuncToVulkanCallsPass()4281090a830SDenis Khalikov mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
4291090a830SDenis Khalikov   return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
430a062a3edSDenis Khalikov }
431