1 //===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implements C runtime wrappers around the VulkanRuntime. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <iostream> 14 #include <mutex> 15 #include <numeric> 16 17 #include "VulkanRuntime.h" 18 19 // Explicitly export entry points to the vulkan-runtime-wrapper. 20 21 #ifdef _WIN32 22 #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport) 23 #else 24 #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default"))) 25 #endif // _WIN32 26 27 namespace { 28 29 class VulkanRuntimeManager { 30 public: 31 VulkanRuntimeManager() = default; 32 VulkanRuntimeManager(const VulkanRuntimeManager &) = delete; 33 VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete; 34 ~VulkanRuntimeManager() = default; 35 36 void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex, 37 const VulkanHostMemoryBuffer &memBuffer) { 38 std::lock_guard<std::mutex> lock(mutex); 39 vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer); 40 } 41 42 void setEntryPoint(const char *entryPoint) { 43 std::lock_guard<std::mutex> lock(mutex); 44 vulkanRuntime.setEntryPoint(entryPoint); 45 } 46 47 void setNumWorkGroups(NumWorkGroups numWorkGroups) { 48 std::lock_guard<std::mutex> lock(mutex); 49 vulkanRuntime.setNumWorkGroups(numWorkGroups); 50 } 51 52 void setShaderModule(uint8_t *shader, uint32_t size) { 53 std::lock_guard<std::mutex> lock(mutex); 54 vulkanRuntime.setShaderModule(shader, size); 55 } 56 57 void runOnVulkan() { 58 std::lock_guard<std::mutex> lock(mutex); 59 if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) || 60 failed(vulkanRuntime.updateHostMemoryBuffers()) || 61 failed(vulkanRuntime.destroy())) { 62 std::cerr << "runOnVulkan failed"; 63 } 64 } 65 66 private: 67 VulkanRuntime vulkanRuntime; 68 std::mutex mutex; 69 }; 70 71 } // namespace 72 73 template <typename T, int N> 74 struct MemRefDescriptor { 75 T *allocated; 76 T *aligned; 77 int64_t offset; 78 int64_t sizes[N]; 79 int64_t strides[N]; 80 }; 81 82 template <typename T, uint32_t S> 83 void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex, 84 BindingIndex bindIndex, MemRefDescriptor<T, S> *ptr) { 85 uint32_t size = sizeof(T); 86 for (unsigned i = 0; i < S; i++) 87 size *= ptr->sizes[i]; 88 VulkanHostMemoryBuffer memBuffer{ptr->allocated, size}; 89 reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) 90 ->setResourceData(setIndex, bindIndex, memBuffer); 91 } 92 93 extern "C" { 94 /// Initializes `VulkanRuntimeManager` and returns a pointer to it. 95 VULKAN_WRAPPER_SYMBOL_EXPORT void *initVulkan() { 96 return new VulkanRuntimeManager(); 97 } 98 99 /// Deinitializes `VulkanRuntimeManager` by the given pointer. 100 VULKAN_WRAPPER_SYMBOL_EXPORT void deinitVulkan(void *vkRuntimeManager) { 101 delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager); 102 } 103 104 VULKAN_WRAPPER_SYMBOL_EXPORT void runOnVulkan(void *vkRuntimeManager) { 105 reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan(); 106 } 107 108 VULKAN_WRAPPER_SYMBOL_EXPORT void setEntryPoint(void *vkRuntimeManager, 109 const char *entryPoint) { 110 reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) 111 ->setEntryPoint(entryPoint); 112 } 113 114 VULKAN_WRAPPER_SYMBOL_EXPORT void 115 setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y, uint32_t z) { 116 reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) 117 ->setNumWorkGroups({x, y, z}); 118 } 119 120 VULKAN_WRAPPER_SYMBOL_EXPORT void 121 setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) { 122 reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) 123 ->setShaderModule(shader, size); 124 } 125 126 /// Binds the given memref to the given descriptor set and descriptor 127 /// index. 128 #define DECLARE_BIND_MEMREF(size, type, typeName) \ 129 VULKAN_WRAPPER_SYMBOL_EXPORT void bindMemRef##size##D##typeName( \ 130 void *vkRuntimeManager, DescriptorSetIndex setIndex, \ 131 BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) { \ 132 bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr); \ 133 } 134 135 DECLARE_BIND_MEMREF(1, float, Float) 136 DECLARE_BIND_MEMREF(2, float, Float) 137 DECLARE_BIND_MEMREF(3, float, Float) 138 DECLARE_BIND_MEMREF(1, int32_t, Int32) 139 DECLARE_BIND_MEMREF(2, int32_t, Int32) 140 DECLARE_BIND_MEMREF(3, int32_t, Int32) 141 DECLARE_BIND_MEMREF(1, int16_t, Int16) 142 DECLARE_BIND_MEMREF(2, int16_t, Int16) 143 DECLARE_BIND_MEMREF(3, int16_t, Int16) 144 DECLARE_BIND_MEMREF(1, int8_t, Int8) 145 DECLARE_BIND_MEMREF(2, int8_t, Int8) 146 DECLARE_BIND_MEMREF(3, int8_t, Int8) 147 DECLARE_BIND_MEMREF(1, int16_t, Half) 148 DECLARE_BIND_MEMREF(2, int16_t, Half) 149 DECLARE_BIND_MEMREF(3, int16_t, Half) 150 151 /// Fills the given 1D float memref with the given float value. 152 VULKAN_WRAPPER_SYMBOL_EXPORT void 153 _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT 154 float value) { 155 std::fill_n(ptr->allocated, ptr->sizes[0], value); 156 } 157 158 /// Fills the given 2D float memref with the given float value. 159 VULKAN_WRAPPER_SYMBOL_EXPORT void 160 _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT 161 float value) { 162 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); 163 } 164 165 /// Fills the given 3D float memref with the given float value. 166 VULKAN_WRAPPER_SYMBOL_EXPORT void 167 _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT 168 float value) { 169 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], 170 value); 171 } 172 173 /// Fills the given 1D int memref with the given int value. 174 VULKAN_WRAPPER_SYMBOL_EXPORT void 175 _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT 176 int32_t value) { 177 std::fill_n(ptr->allocated, ptr->sizes[0], value); 178 } 179 180 /// Fills the given 2D int memref with the given int value. 181 VULKAN_WRAPPER_SYMBOL_EXPORT void 182 _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT 183 int32_t value) { 184 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); 185 } 186 187 /// Fills the given 3D int memref with the given int value. 188 VULKAN_WRAPPER_SYMBOL_EXPORT void 189 _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT 190 int32_t value) { 191 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], 192 value); 193 } 194 195 /// Fills the given 1D int memref with the given int8 value. 196 VULKAN_WRAPPER_SYMBOL_EXPORT void 197 _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT 198 int8_t value) { 199 std::fill_n(ptr->allocated, ptr->sizes[0], value); 200 } 201 202 /// Fills the given 2D int memref with the given int8 value. 203 VULKAN_WRAPPER_SYMBOL_EXPORT void 204 _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT 205 int8_t value) { 206 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); 207 } 208 209 /// Fills the given 3D int memref with the given int8 value. 210 VULKAN_WRAPPER_SYMBOL_EXPORT void 211 _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT 212 int8_t value) { 213 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], 214 value); 215 } 216 } 217