1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides a library for running a module on a Vulkan device. 10 // Implements a Vulkan runtime. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "VulkanRuntime.h" 15 16 #include <chrono> 17 #include <cstring> 18 // TODO: It's generally bad to access stdout/stderr in a library. 19 // Figure out a better way for error reporting. 20 #include <iomanip> 21 #include <iostream> 22 23 inline void emitVulkanError(const char *api, VkResult error) { 24 std::cerr << " failed with error code " << error << " when executing " << api; 25 } 26 27 #define RETURN_ON_VULKAN_ERROR(result, api) \ 28 if ((result) != VK_SUCCESS) { \ 29 emitVulkanError(api, (result)); \ 30 return failure(); \ 31 } 32 33 using namespace mlir; 34 35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) { 36 numWorkGroups = numberWorkGroups; 37 } 38 39 void VulkanRuntime::setResourceStorageClassBindingMap( 40 const ResourceStorageClassBindingMap &stClassData) { 41 resourceStorageClassData = stClassData; 42 } 43 44 void VulkanRuntime::setResourceData( 45 const DescriptorSetIndex desIndex, const BindingIndex bindIndex, 46 const VulkanHostMemoryBuffer &hostMemBuffer) { 47 resourceData[desIndex][bindIndex] = hostMemBuffer; 48 resourceStorageClassData[desIndex][bindIndex] = 49 SPIRVStorageClass::StorageBuffer; 50 } 51 52 void VulkanRuntime::setEntryPoint(const char *entryPointName) { 53 entryPoint = entryPointName; 54 } 55 56 void VulkanRuntime::setResourceData(const ResourceData &resData) { 57 resourceData = resData; 58 } 59 60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) { 61 binary = shader; 62 binarySize = size; 63 } 64 65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType( 66 SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) { 67 switch (storageClass) { 68 case SPIRVStorageClass::StorageBuffer: 69 descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; 70 break; 71 case SPIRVStorageClass::Uniform: 72 descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; 73 break; 74 } 75 return success(); 76 } 77 78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag( 79 SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) { 80 switch (storageClass) { 81 case SPIRVStorageClass::StorageBuffer: 82 bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; 83 break; 84 case SPIRVStorageClass::Uniform: 85 bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; 86 break; 87 } 88 return success(); 89 } 90 91 LogicalResult VulkanRuntime::countDeviceMemorySize() { 92 for (const auto &resourceDataMapPair : resourceData) { 93 const auto &resourceDataMap = resourceDataMapPair.second; 94 for (const auto &resourceDataBindingPair : resourceDataMap) { 95 if (resourceDataBindingPair.second.size) { 96 memorySize += resourceDataBindingPair.second.size; 97 } else { 98 std::cerr << "expected buffer size greater than zero for resource data"; 99 return failure(); 100 } 101 } 102 } 103 return success(); 104 } 105 106 LogicalResult VulkanRuntime::initRuntime() { 107 if (resourceData.empty()) { 108 std::cerr << "Vulkan runtime needs at least one resource"; 109 return failure(); 110 } 111 if (!binarySize || !binary) { 112 std::cerr << "binary shader size must be greater than zero"; 113 return failure(); 114 } 115 if (failed(countDeviceMemorySize())) { 116 return failure(); 117 } 118 return success(); 119 } 120 121 LogicalResult VulkanRuntime::destroy() { 122 // According to Vulkan spec: 123 // "To ensure that no work is active on the device, vkDeviceWaitIdle can be 124 // used to gate the destruction of the device. Prior to destroying a device, 125 // an application is responsible for destroying/freeing any Vulkan objects 126 // that were created using that device as the first parameter of the 127 // corresponding vkCreate* or vkAllocate* command." 128 RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle"); 129 130 // Free and destroy. 131 vkFreeCommandBuffers(device, commandPool, commandBuffers.size(), 132 commandBuffers.data()); 133 vkDestroyQueryPool(device, queryPool, nullptr); 134 vkDestroyCommandPool(device, commandPool, nullptr); 135 vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(), 136 descriptorSets.data()); 137 vkDestroyDescriptorPool(device, descriptorPool, nullptr); 138 vkDestroyPipeline(device, pipeline, nullptr); 139 vkDestroyPipelineLayout(device, pipelineLayout, nullptr); 140 for (auto &descriptorSetLayout : descriptorSetLayouts) { 141 vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr); 142 } 143 vkDestroyShaderModule(device, shaderModule, nullptr); 144 145 // For each descriptor set. 146 for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 147 auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 148 // For each descriptor binding. 149 for (auto &memoryBuffer : deviceMemoryBuffers) { 150 vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr); 151 vkFreeMemory(device, memoryBuffer.hostMemory, nullptr); 152 vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr); 153 vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr); 154 } 155 } 156 157 vkDestroyDevice(device, nullptr); 158 vkDestroyInstance(instance, nullptr); 159 return success(); 160 } 161 162 LogicalResult VulkanRuntime::run() { 163 // Create logical device, shader module and memory buffers. 164 if (failed(createInstance()) || failed(createDevice()) || 165 failed(createMemoryBuffers()) || failed(createShaderModule())) { 166 return failure(); 167 } 168 169 // Descriptor bindings divided into sets. Each descriptor binding 170 // must have a layout binding attached into a descriptor set layout. 171 // Each layout set must be binded into a pipeline layout. 172 initDescriptorSetLayoutBindingMap(); 173 if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) || 174 // Each descriptor set must be allocated from a descriptor pool. 175 failed(createComputePipeline()) || failed(createDescriptorPool()) || 176 failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) || 177 // Create command buffer. 178 failed(createCommandPool()) || failed(createQueryPool()) || 179 failed(createComputeCommandBuffer())) { 180 return failure(); 181 } 182 183 // Get working queue. 184 vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue); 185 186 if (failed(copyResource(/*deviceToHost=*/false))) 187 return failure(); 188 189 auto submitStart = std::chrono::high_resolution_clock::now(); 190 // Submit command buffer into the queue. 191 if (failed(submitCommandBuffersToQueue())) 192 return failure(); 193 auto submitEnd = std::chrono::high_resolution_clock::now(); 194 195 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle"); 196 auto execEnd = std::chrono::high_resolution_clock::now(); 197 198 auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>( 199 submitEnd - submitStart); 200 auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>( 201 execEnd - submitEnd); 202 203 if (queryPool != VK_NULL_HANDLE) { 204 uint64_t timestamps[2]; 205 RETURN_ON_VULKAN_ERROR( 206 vkGetQueryPoolResults( 207 device, queryPool, /*firstQuery=*/0, /*queryCount=*/2, 208 /*dataSize=*/sizeof(timestamps), 209 /*pData=*/reinterpret_cast<void *>(timestamps), 210 /*stride=*/sizeof(uint64_t), 211 VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT), 212 "vkGetQueryPoolResults"); 213 float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000; 214 std::cout << "Compute shader execution time: " << std::setprecision(3) 215 << microsec << "us\n"; 216 } 217 218 std::cout << "Command buffer submit time: " << submitDuration.count() 219 << "us\nWait idle time: " << execDuration.count() << "us\n"; 220 221 return success(); 222 } 223 224 LogicalResult VulkanRuntime::createInstance() { 225 VkApplicationInfo applicationInfo = {}; 226 applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; 227 applicationInfo.pNext = nullptr; 228 applicationInfo.pApplicationName = "MLIR Vulkan runtime"; 229 applicationInfo.applicationVersion = 0; 230 applicationInfo.pEngineName = "mlir"; 231 applicationInfo.engineVersion = 0; 232 applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0); 233 234 VkInstanceCreateInfo instanceCreateInfo = {}; 235 instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; 236 instanceCreateInfo.pNext = nullptr; 237 instanceCreateInfo.flags = 0; 238 instanceCreateInfo.pApplicationInfo = &applicationInfo; 239 instanceCreateInfo.enabledLayerCount = 0; 240 instanceCreateInfo.ppEnabledLayerNames = nullptr; 241 instanceCreateInfo.enabledExtensionCount = 0; 242 instanceCreateInfo.ppEnabledExtensionNames = nullptr; 243 244 RETURN_ON_VULKAN_ERROR( 245 vkCreateInstance(&instanceCreateInfo, nullptr, &instance), 246 "vkCreateInstance"); 247 return success(); 248 } 249 250 LogicalResult VulkanRuntime::createDevice() { 251 uint32_t physicalDeviceCount = 0; 252 RETURN_ON_VULKAN_ERROR( 253 vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, nullptr), 254 "vkEnumeratePhysicalDevices"); 255 256 std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount); 257 RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance, 258 &physicalDeviceCount, 259 physicalDevices.data()), 260 "vkEnumeratePhysicalDevices"); 261 262 RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE, 263 "physicalDeviceCount"); 264 265 // TODO: find the best device. 266 physicalDevice = physicalDevices.front(); 267 if (failed(getBestComputeQueue())) 268 return failure(); 269 270 const float queuePriority = 1.0f; 271 VkDeviceQueueCreateInfo deviceQueueCreateInfo = {}; 272 deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; 273 deviceQueueCreateInfo.pNext = nullptr; 274 deviceQueueCreateInfo.flags = 0; 275 deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex; 276 deviceQueueCreateInfo.queueCount = 1; 277 deviceQueueCreateInfo.pQueuePriorities = &queuePriority; 278 279 // Structure specifying parameters of a newly created device. 280 VkDeviceCreateInfo deviceCreateInfo = {}; 281 deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; 282 deviceCreateInfo.pNext = nullptr; 283 deviceCreateInfo.flags = 0; 284 deviceCreateInfo.queueCreateInfoCount = 1; 285 deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo; 286 deviceCreateInfo.enabledLayerCount = 0; 287 deviceCreateInfo.ppEnabledLayerNames = nullptr; 288 deviceCreateInfo.enabledExtensionCount = 0; 289 deviceCreateInfo.ppEnabledExtensionNames = nullptr; 290 deviceCreateInfo.pEnabledFeatures = nullptr; 291 292 RETURN_ON_VULKAN_ERROR( 293 vkCreateDevice(physicalDevice, &deviceCreateInfo, nullptr, &device), 294 "vkCreateDevice"); 295 296 VkPhysicalDeviceMemoryProperties properties = {}; 297 vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties); 298 299 // Try to find memory type with following properties: 300 // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated 301 // with this type can be mapped for host access using vkMapMemory; 302 // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache 303 // management commands vkFlushMappedMemoryRanges and 304 // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the 305 // device or make device writes visible to the host, respectively. 306 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) { 307 if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT & 308 properties.memoryTypes[i].propertyFlags) && 309 (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT & 310 properties.memoryTypes[i].propertyFlags) && 311 (memorySize <= 312 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) { 313 hostMemoryTypeIndex = i; 314 break; 315 } 316 } 317 318 // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be 319 // used on the device. This will allow better performance access for GPU with 320 // on device memory. 321 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) { 322 if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT & 323 properties.memoryTypes[i].propertyFlags) && 324 (memorySize <= 325 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) { 326 deviceMemoryTypeIndex = i; 327 break; 328 } 329 } 330 331 RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES || 332 deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES) 333 ? VK_INCOMPLETE 334 : VK_SUCCESS, 335 "invalid memoryTypeIndex"); 336 return success(); 337 } 338 339 LogicalResult VulkanRuntime::getBestComputeQueue() { 340 uint32_t queueFamilyPropertiesCount = 0; 341 vkGetPhysicalDeviceQueueFamilyProperties( 342 physicalDevice, &queueFamilyPropertiesCount, nullptr); 343 344 std::vector<VkQueueFamilyProperties> familyProperties( 345 queueFamilyPropertiesCount); 346 vkGetPhysicalDeviceQueueFamilyProperties( 347 physicalDevice, &queueFamilyPropertiesCount, familyProperties.data()); 348 349 // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support 350 // compute operations. Try to find a compute-only queue first if possible. 351 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { 352 auto flags = familyProperties[i].queueFlags; 353 if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) { 354 queueFamilyIndex = i; 355 queueFamilyProperties = familyProperties[i]; 356 return success(); 357 } 358 } 359 360 // Otherwise use a queue that can also support graphics. 361 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { 362 auto flags = familyProperties[i].queueFlags; 363 if ((flags & VK_QUEUE_COMPUTE_BIT)) { 364 queueFamilyIndex = i; 365 queueFamilyProperties = familyProperties[i]; 366 return success(); 367 } 368 } 369 370 std::cerr << "cannot find valid queue"; 371 return failure(); 372 } 373 374 LogicalResult VulkanRuntime::createMemoryBuffers() { 375 // For each descriptor set. 376 for (const auto &resourceDataMapPair : resourceData) { 377 std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers; 378 const auto descriptorSetIndex = resourceDataMapPair.first; 379 const auto &resourceDataMap = resourceDataMapPair.second; 380 381 // For each descriptor binding. 382 for (const auto &resourceDataBindingPair : resourceDataMap) { 383 // Create device memory buffer. 384 VulkanDeviceMemoryBuffer memoryBuffer; 385 memoryBuffer.bindingIndex = resourceDataBindingPair.first; 386 VkDescriptorType descriptorType = {}; 387 VkBufferUsageFlagBits bufferUsage = {}; 388 389 // Check that descriptor set has storage class map. 390 const auto resourceStorageClassMapIt = 391 resourceStorageClassData.find(descriptorSetIndex); 392 if (resourceStorageClassMapIt == resourceStorageClassData.end()) { 393 std::cerr 394 << "cannot find storage class for resource in descriptor set: " 395 << descriptorSetIndex; 396 return failure(); 397 } 398 399 // Check that specific descriptor binding has storage class. 400 const auto &resourceStorageClassMap = resourceStorageClassMapIt->second; 401 const auto resourceStorageClassIt = 402 resourceStorageClassMap.find(resourceDataBindingPair.first); 403 if (resourceStorageClassIt == resourceStorageClassMap.end()) { 404 std::cerr 405 << "cannot find storage class for resource with descriptor index: " 406 << resourceDataBindingPair.first; 407 return failure(); 408 } 409 410 const auto resourceStorageClassBinding = resourceStorageClassIt->second; 411 if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding, 412 descriptorType)) || 413 failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding, 414 bufferUsage))) { 415 std::cerr << "storage class for resource with descriptor binding: " 416 << resourceDataBindingPair.first 417 << " in the descriptor set: " << descriptorSetIndex 418 << " is not supported "; 419 return failure(); 420 } 421 422 // Set descriptor type for the specific device memory buffer. 423 memoryBuffer.descriptorType = descriptorType; 424 const auto bufferSize = resourceDataBindingPair.second.size; 425 memoryBuffer.bufferSize = bufferSize; 426 // Specify memory allocation info. 427 VkMemoryAllocateInfo memoryAllocateInfo = {}; 428 memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; 429 memoryAllocateInfo.pNext = nullptr; 430 memoryAllocateInfo.allocationSize = bufferSize; 431 memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex; 432 433 // Allocate device memory. 434 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 435 nullptr, 436 &memoryBuffer.hostMemory), 437 "vkAllocateMemory"); 438 memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex; 439 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 440 nullptr, 441 &memoryBuffer.deviceMemory), 442 "vkAllocateMemory"); 443 void *payload; 444 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0, 445 bufferSize, 0, 446 reinterpret_cast<void **>(&payload)), 447 "vkMapMemory"); 448 449 // Copy host memory into the mapped area. 450 std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize); 451 vkUnmapMemory(device, memoryBuffer.hostMemory); 452 453 VkBufferCreateInfo bufferCreateInfo = {}; 454 bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; 455 bufferCreateInfo.pNext = nullptr; 456 bufferCreateInfo.flags = 0; 457 bufferCreateInfo.size = bufferSize; 458 bufferCreateInfo.usage = bufferUsage | VK_BUFFER_USAGE_TRANSFER_DST_BIT | 459 VK_BUFFER_USAGE_TRANSFER_SRC_BIT; 460 bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; 461 bufferCreateInfo.queueFamilyIndexCount = 1; 462 bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex; 463 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr, 464 &memoryBuffer.hostBuffer), 465 "vkCreateBuffer"); 466 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr, 467 &memoryBuffer.deviceBuffer), 468 "vkCreateBuffer"); 469 470 // Bind buffer and device memory. 471 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer, 472 memoryBuffer.hostMemory, 0), 473 "vkBindBufferMemory"); 474 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, 475 memoryBuffer.deviceBuffer, 476 memoryBuffer.deviceMemory, 0), 477 "vkBindBufferMemory"); 478 479 // Update buffer info. 480 memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer; 481 memoryBuffer.bufferInfo.offset = 0; 482 memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE; 483 deviceMemoryBuffers.push_back(memoryBuffer); 484 } 485 486 // Associate device memory buffers with a descriptor set. 487 deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers; 488 } 489 return success(); 490 } 491 492 LogicalResult VulkanRuntime::copyResource(bool deviceToHost) { 493 VkCommandBufferAllocateInfo commandBufferAllocateInfo = { 494 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, 495 nullptr, 496 commandPool, 497 VK_COMMAND_BUFFER_LEVEL_PRIMARY, 498 1, 499 }; 500 VkCommandBuffer commandBuffer; 501 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device, 502 &commandBufferAllocateInfo, 503 &commandBuffer), 504 "vkAllocateCommandBuffers"); 505 506 VkCommandBufferBeginInfo commandBufferBeginInfo = { 507 VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, 508 nullptr, 509 0, 510 nullptr, 511 }; 512 RETURN_ON_VULKAN_ERROR( 513 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo), 514 "vkBeginCommandBuffer"); 515 516 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 517 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings; 518 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 519 for (const auto &memBuffer : deviceMemoryBuffers) { 520 VkBufferCopy copy = {0, 0, memBuffer.bufferSize}; 521 if (deviceToHost) 522 vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer, 523 memBuffer.hostBuffer, 1, ©); 524 else 525 vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer, 526 memBuffer.deviceBuffer, 1, ©); 527 } 528 } 529 530 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer), 531 "vkEndCommandBuffer"); 532 VkSubmitInfo submitInfo = { 533 VK_STRUCTURE_TYPE_SUBMIT_INFO, 534 nullptr, 535 0, 536 nullptr, 537 nullptr, 538 1, 539 &commandBuffer, 540 0, 541 nullptr, 542 }; 543 submitInfo.pCommandBuffers = &commandBuffer; 544 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE), 545 "vkQueueSubmit"); 546 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle"); 547 548 vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer); 549 return success(); 550 } 551 552 LogicalResult VulkanRuntime::createShaderModule() { 553 VkShaderModuleCreateInfo shaderModuleCreateInfo = {}; 554 shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; 555 shaderModuleCreateInfo.pNext = nullptr; 556 shaderModuleCreateInfo.flags = 0; 557 // Set size in bytes. 558 shaderModuleCreateInfo.codeSize = binarySize; 559 // Set pointer to the binary shader. 560 shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary); 561 RETURN_ON_VULKAN_ERROR(vkCreateShaderModule(device, &shaderModuleCreateInfo, 562 nullptr, &shaderModule), 563 "vkCreateShaderModule"); 564 return success(); 565 } 566 567 void VulkanRuntime::initDescriptorSetLayoutBindingMap() { 568 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 569 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings; 570 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 571 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; 572 573 // Create a layout binding for each descriptor. 574 for (const auto &memBuffer : deviceMemoryBuffers) { 575 VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {}; 576 descriptorSetLayoutBinding.binding = memBuffer.bindingIndex; 577 descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType; 578 descriptorSetLayoutBinding.descriptorCount = 1; 579 descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; 580 descriptorSetLayoutBinding.pImmutableSamplers = nullptr; 581 descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding); 582 } 583 descriptorSetLayoutBindingMap[descriptorSetIndex] = 584 descriptorSetLayoutBindings; 585 } 586 } 587 588 LogicalResult VulkanRuntime::createDescriptorSetLayout() { 589 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 590 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; 591 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 592 // Each descriptor in a descriptor set must be the same type. 593 VkDescriptorType descriptorType = 594 deviceMemoryBuffers.front().descriptorType; 595 const uint32_t descriptorSize = deviceMemoryBuffers.size(); 596 const auto descriptorSetLayoutBindingIt = 597 descriptorSetLayoutBindingMap.find(descriptorSetIndex); 598 599 if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) { 600 std::cerr << "cannot find layout bindings for the set with number: " 601 << descriptorSetIndex; 602 return failure(); 603 } 604 605 const auto &descriptorSetLayoutBindings = 606 descriptorSetLayoutBindingIt->second; 607 // Create descriptor set layout. 608 VkDescriptorSetLayout descriptorSetLayout = {}; 609 VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {}; 610 611 descriptorSetLayoutCreateInfo.sType = 612 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; 613 descriptorSetLayoutCreateInfo.pNext = nullptr; 614 descriptorSetLayoutCreateInfo.flags = 0; 615 // Amount of descriptor bindings in a layout set. 616 descriptorSetLayoutCreateInfo.bindingCount = 617 descriptorSetLayoutBindings.size(); 618 descriptorSetLayoutCreateInfo.pBindings = 619 descriptorSetLayoutBindings.data(); 620 RETURN_ON_VULKAN_ERROR( 621 vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 622 nullptr, &descriptorSetLayout), 623 "vkCreateDescriptorSetLayout"); 624 625 descriptorSetLayouts.push_back(descriptorSetLayout); 626 descriptorSetInfoPool.push_back( 627 {descriptorSetIndex, descriptorSize, descriptorType}); 628 } 629 return success(); 630 } 631 632 LogicalResult VulkanRuntime::createPipelineLayout() { 633 // Associate descriptor sets with a pipeline layout. 634 VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {}; 635 pipelineLayoutCreateInfo.sType = 636 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; 637 pipelineLayoutCreateInfo.pNext = nullptr; 638 pipelineLayoutCreateInfo.flags = 0; 639 pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size(); 640 pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data(); 641 pipelineLayoutCreateInfo.pushConstantRangeCount = 0; 642 pipelineLayoutCreateInfo.pPushConstantRanges = nullptr; 643 RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device, 644 &pipelineLayoutCreateInfo, 645 nullptr, &pipelineLayout), 646 "vkCreatePipelineLayout"); 647 return success(); 648 } 649 650 LogicalResult VulkanRuntime::createComputePipeline() { 651 VkPipelineShaderStageCreateInfo stageInfo = {}; 652 stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; 653 stageInfo.pNext = nullptr; 654 stageInfo.flags = 0; 655 stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT; 656 stageInfo.module = shaderModule; 657 // Set entry point. 658 stageInfo.pName = entryPoint; 659 stageInfo.pSpecializationInfo = nullptr; 660 661 VkComputePipelineCreateInfo computePipelineCreateInfo = {}; 662 computePipelineCreateInfo.sType = 663 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; 664 computePipelineCreateInfo.pNext = nullptr; 665 computePipelineCreateInfo.flags = 0; 666 computePipelineCreateInfo.stage = stageInfo; 667 computePipelineCreateInfo.layout = pipelineLayout; 668 computePipelineCreateInfo.basePipelineHandle = nullptr; 669 computePipelineCreateInfo.basePipelineIndex = 0; 670 RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, nullptr, 1, 671 &computePipelineCreateInfo, 672 nullptr, &pipeline), 673 "vkCreateComputePipelines"); 674 return success(); 675 } 676 677 LogicalResult VulkanRuntime::createDescriptorPool() { 678 std::vector<VkDescriptorPoolSize> descriptorPoolSizes; 679 for (const auto &descriptorSetInfo : descriptorSetInfoPool) { 680 // For each descriptor set populate descriptor pool size. 681 VkDescriptorPoolSize descriptorPoolSize = {}; 682 descriptorPoolSize.type = descriptorSetInfo.descriptorType; 683 descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize; 684 descriptorPoolSizes.push_back(descriptorPoolSize); 685 } 686 687 VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {}; 688 descriptorPoolCreateInfo.sType = 689 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; 690 descriptorPoolCreateInfo.pNext = nullptr; 691 descriptorPoolCreateInfo.flags = 0; 692 descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size(); 693 descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size(); 694 descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data(); 695 RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device, 696 &descriptorPoolCreateInfo, 697 nullptr, &descriptorPool), 698 "vkCreateDescriptorPool"); 699 return success(); 700 } 701 702 LogicalResult VulkanRuntime::allocateDescriptorSets() { 703 VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {}; 704 // Size of descriptor sets and descriptor layout sets is the same. 705 descriptorSets.resize(descriptorSetLayouts.size()); 706 descriptorSetAllocateInfo.sType = 707 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; 708 descriptorSetAllocateInfo.pNext = nullptr; 709 descriptorSetAllocateInfo.descriptorPool = descriptorPool; 710 descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size(); 711 descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data(); 712 RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device, 713 &descriptorSetAllocateInfo, 714 descriptorSets.data()), 715 "vkAllocateDescriptorSets"); 716 return success(); 717 } 718 719 LogicalResult VulkanRuntime::setWriteDescriptors() { 720 if (descriptorSets.size() != descriptorSetInfoPool.size()) { 721 std::cerr << "Each descriptor set must have descriptor set information"; 722 return failure(); 723 } 724 // For each descriptor set. 725 auto descriptorSetIt = descriptorSets.begin(); 726 // Each descriptor set is associated with descriptor set info. 727 for (const auto &descriptorSetInfo : descriptorSetInfoPool) { 728 // For each device memory buffer in the descriptor set. 729 const auto &deviceMemoryBuffers = 730 deviceMemoryBufferMap[descriptorSetInfo.descriptorSet]; 731 for (const auto &memoryBuffer : deviceMemoryBuffers) { 732 // Structure describing descriptor sets to write to. 733 VkWriteDescriptorSet wSet = {}; 734 wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; 735 wSet.pNext = nullptr; 736 // Descriptor set. 737 wSet.dstSet = *descriptorSetIt; 738 wSet.dstBinding = memoryBuffer.bindingIndex; 739 wSet.dstArrayElement = 0; 740 wSet.descriptorCount = 1; 741 wSet.descriptorType = memoryBuffer.descriptorType; 742 wSet.pImageInfo = nullptr; 743 wSet.pBufferInfo = &memoryBuffer.bufferInfo; 744 wSet.pTexelBufferView = nullptr; 745 vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr); 746 } 747 // Increment descriptor set iterator. 748 ++descriptorSetIt; 749 } 750 return success(); 751 } 752 753 LogicalResult VulkanRuntime::createCommandPool() { 754 VkCommandPoolCreateInfo commandPoolCreateInfo = {}; 755 commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; 756 commandPoolCreateInfo.pNext = nullptr; 757 commandPoolCreateInfo.flags = 0; 758 commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex; 759 RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo, 760 /*pAllocator=*/nullptr, 761 &commandPool), 762 "vkCreateCommandPool"); 763 return success(); 764 } 765 766 LogicalResult VulkanRuntime::createQueryPool() { 767 // Return directly if timestamp query is not supported. 768 if (queueFamilyProperties.timestampValidBits == 0) 769 return success(); 770 771 // Get timestamp period for this physical device. 772 VkPhysicalDeviceProperties deviceProperties = {}; 773 vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties); 774 timestampPeriod = deviceProperties.limits.timestampPeriod; 775 776 // Create query pool. 777 VkQueryPoolCreateInfo queryPoolCreateInfo = {}; 778 queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO; 779 queryPoolCreateInfo.pNext = nullptr; 780 queryPoolCreateInfo.flags = 0; 781 queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP; 782 queryPoolCreateInfo.queryCount = 2; 783 queryPoolCreateInfo.pipelineStatistics = 0; 784 RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo, 785 /*pAllocator=*/nullptr, &queryPool), 786 "vkCreateQueryPool"); 787 788 return success(); 789 } 790 791 LogicalResult VulkanRuntime::createComputeCommandBuffer() { 792 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {}; 793 commandBufferAllocateInfo.sType = 794 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; 795 commandBufferAllocateInfo.pNext = nullptr; 796 commandBufferAllocateInfo.commandPool = commandPool; 797 commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; 798 commandBufferAllocateInfo.commandBufferCount = 1; 799 800 VkCommandBuffer commandBuffer; 801 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device, 802 &commandBufferAllocateInfo, 803 &commandBuffer), 804 "vkAllocateCommandBuffers"); 805 806 VkCommandBufferBeginInfo commandBufferBeginInfo = {}; 807 commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; 808 commandBufferBeginInfo.pNext = nullptr; 809 commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; 810 commandBufferBeginInfo.pInheritanceInfo = nullptr; 811 812 // Commands begin. 813 RETURN_ON_VULKAN_ERROR( 814 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo), 815 "vkBeginCommandBuffer"); 816 817 if (queryPool != VK_NULL_HANDLE) 818 vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2); 819 820 vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); 821 vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, 822 pipelineLayout, 0, descriptorSets.size(), 823 descriptorSets.data(), 0, nullptr); 824 // Get a timestamp before invoking the compute shader. 825 if (queryPool != VK_NULL_HANDLE) 826 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, 827 queryPool, 0); 828 vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y, 829 numWorkGroups.z); 830 // Get another timestamp after invoking the compute shader. 831 if (queryPool != VK_NULL_HANDLE) 832 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, 833 queryPool, 1); 834 835 // Commands end. 836 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer), 837 "vkEndCommandBuffer"); 838 839 commandBuffers.push_back(commandBuffer); 840 return success(); 841 } 842 843 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() { 844 VkSubmitInfo submitInfo = {}; 845 submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; 846 submitInfo.pNext = nullptr; 847 submitInfo.waitSemaphoreCount = 0; 848 submitInfo.pWaitSemaphores = nullptr; 849 submitInfo.pWaitDstStageMask = nullptr; 850 submitInfo.commandBufferCount = commandBuffers.size(); 851 submitInfo.pCommandBuffers = commandBuffers.data(); 852 submitInfo.signalSemaphoreCount = 0; 853 submitInfo.pSignalSemaphores = nullptr; 854 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, nullptr), 855 "vkQueueSubmit"); 856 return success(); 857 } 858 859 LogicalResult VulkanRuntime::updateHostMemoryBuffers() { 860 // First copy back the data to the staging buffer. 861 (void)copyResource(/*deviceToHost=*/true); 862 863 // For each descriptor set. 864 for (auto &resourceDataMapPair : resourceData) { 865 auto &resourceDataMap = resourceDataMapPair.second; 866 auto &deviceMemoryBuffers = 867 deviceMemoryBufferMap[resourceDataMapPair.first]; 868 // For each device memory buffer in the set. 869 for (auto &deviceMemoryBuffer : deviceMemoryBuffers) { 870 if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) { 871 void *payload; 872 auto &hostMemoryBuffer = 873 resourceDataMap[deviceMemoryBuffer.bindingIndex]; 874 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, 875 deviceMemoryBuffer.hostMemory, 0, 876 hostMemoryBuffer.size, 0, 877 reinterpret_cast<void **>(&payload)), 878 "vkMapMemory"); 879 std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size); 880 vkUnmapMemory(device, deviceMemoryBuffer.hostMemory); 881 } 882 } 883 } 884 return success(); 885 } 886