1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides a library for running a module on a Vulkan device. 10 // Implements a Vulkan runtime. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "VulkanRuntime.h" 15 16 #include <chrono> 17 #include <cstring> 18 // TODO: It's generally bad to access stdout/stderr in a library. 19 // Figure out a better way for error reporting. 20 #include <iomanip> 21 #include <iostream> 22 23 inline void emitVulkanError(const char *api, VkResult error) { 24 std::cerr << " failed with error code " << error << " when executing " << api; 25 } 26 27 #define RETURN_ON_VULKAN_ERROR(result, api) \ 28 if ((result) != VK_SUCCESS) { \ 29 emitVulkanError(api, (result)); \ 30 return failure(); \ 31 } 32 33 using namespace mlir; 34 35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) { 36 numWorkGroups = numberWorkGroups; 37 } 38 39 void VulkanRuntime::setResourceStorageClassBindingMap( 40 const ResourceStorageClassBindingMap &stClassData) { 41 resourceStorageClassData = stClassData; 42 } 43 44 void VulkanRuntime::setResourceData( 45 const DescriptorSetIndex desIndex, const BindingIndex bindIndex, 46 const VulkanHostMemoryBuffer &hostMemBuffer) { 47 resourceData[desIndex][bindIndex] = hostMemBuffer; 48 resourceStorageClassData[desIndex][bindIndex] = 49 SPIRVStorageClass::StorageBuffer; 50 } 51 52 void VulkanRuntime::setEntryPoint(const char *entryPointName) { 53 entryPoint = entryPointName; 54 } 55 56 void VulkanRuntime::setResourceData(const ResourceData &resData) { 57 resourceData = resData; 58 } 59 60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) { 61 binary = shader; 62 binarySize = size; 63 } 64 65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType( 66 SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) { 67 switch (storageClass) { 68 case SPIRVStorageClass::StorageBuffer: 69 descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; 70 break; 71 case SPIRVStorageClass::Uniform: 72 descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; 73 break; 74 } 75 return success(); 76 } 77 78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag( 79 SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) { 80 switch (storageClass) { 81 case SPIRVStorageClass::StorageBuffer: 82 bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; 83 break; 84 case SPIRVStorageClass::Uniform: 85 bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; 86 break; 87 } 88 return success(); 89 } 90 91 LogicalResult VulkanRuntime::countDeviceMemorySize() { 92 for (const auto &resourceDataMapPair : resourceData) { 93 const auto &resourceDataMap = resourceDataMapPair.second; 94 for (const auto &resourceDataBindingPair : resourceDataMap) { 95 if (resourceDataBindingPair.second.size) { 96 memorySize += resourceDataBindingPair.second.size; 97 } else { 98 std::cerr << "expected buffer size greater than zero for resource data"; 99 return failure(); 100 } 101 } 102 } 103 return success(); 104 } 105 106 LogicalResult VulkanRuntime::initRuntime() { 107 if (!resourceData.size()) { 108 std::cerr << "Vulkan runtime needs at least one resource"; 109 return failure(); 110 } 111 if (!binarySize || !binary) { 112 std::cerr << "binary shader size must be greater than zero"; 113 return failure(); 114 } 115 if (failed(countDeviceMemorySize())) { 116 return failure(); 117 } 118 return success(); 119 } 120 121 LogicalResult VulkanRuntime::destroy() { 122 // According to Vulkan spec: 123 // "To ensure that no work is active on the device, vkDeviceWaitIdle can be 124 // used to gate the destruction of the device. Prior to destroying a device, 125 // an application is responsible for destroying/freeing any Vulkan objects 126 // that were created using that device as the first parameter of the 127 // corresponding vkCreate* or vkAllocate* command." 128 RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle"); 129 130 // Free and destroy. 131 vkFreeCommandBuffers(device, commandPool, commandBuffers.size(), 132 commandBuffers.data()); 133 vkDestroyQueryPool(device, queryPool, nullptr); 134 vkDestroyCommandPool(device, commandPool, nullptr); 135 vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(), 136 descriptorSets.data()); 137 vkDestroyDescriptorPool(device, descriptorPool, nullptr); 138 vkDestroyPipeline(device, pipeline, nullptr); 139 vkDestroyPipelineLayout(device, pipelineLayout, nullptr); 140 for (auto &descriptorSetLayout : descriptorSetLayouts) { 141 vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr); 142 } 143 vkDestroyShaderModule(device, shaderModule, nullptr); 144 145 // For each descriptor set. 146 for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 147 auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 148 // For each descriptor binding. 149 for (auto &memoryBuffer : deviceMemoryBuffers) { 150 vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr); 151 vkFreeMemory(device, memoryBuffer.hostMemory, nullptr); 152 vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr); 153 vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr); 154 } 155 } 156 157 vkDestroyDevice(device, nullptr); 158 vkDestroyInstance(instance, nullptr); 159 return success(); 160 } 161 162 LogicalResult VulkanRuntime::run() { 163 // Create logical device, shader module and memory buffers. 164 if (failed(createInstance()) || failed(createDevice()) || 165 failed(createMemoryBuffers()) || failed(createShaderModule())) { 166 return failure(); 167 } 168 169 // Descriptor bindings divided into sets. Each descriptor binding 170 // must have a layout binding attached into a descriptor set layout. 171 // Each layout set must be binded into a pipeline layout. 172 initDescriptorSetLayoutBindingMap(); 173 if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) || 174 // Each descriptor set must be allocated from a descriptor pool. 175 failed(createComputePipeline()) || failed(createDescriptorPool()) || 176 failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) || 177 // Create command buffer. 178 failed(createCommandPool()) || failed(createQueryPool()) || 179 failed(createComputeCommandBuffer())) { 180 return failure(); 181 } 182 183 // Get working queue. 184 vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue); 185 186 if (failed(copyResource(/*deviceToHost=*/false))) 187 return failure(); 188 189 auto submitStart = std::chrono::high_resolution_clock::now(); 190 // Submit command buffer into the queue. 191 if (failed(submitCommandBuffersToQueue())) 192 return failure(); 193 auto submitEnd = std::chrono::high_resolution_clock::now(); 194 195 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle"); 196 auto execEnd = std::chrono::high_resolution_clock::now(); 197 198 auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>( 199 submitEnd - submitStart); 200 auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>( 201 execEnd - submitEnd); 202 203 if (queryPool != VK_NULL_HANDLE) { 204 uint64_t timestamps[2]; 205 RETURN_ON_VULKAN_ERROR( 206 vkGetQueryPoolResults( 207 device, queryPool, /*firstQuery=*/0, /*queryCount=*/2, 208 /*dataSize=*/sizeof(timestamps), 209 /*pData=*/reinterpret_cast<void *>(timestamps), 210 /*stride=*/sizeof(uint64_t), 211 VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT), 212 "vkGetQueryPoolResults"); 213 float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000; 214 std::cout << "Compute shader execution time: " << std::setprecision(3) 215 << microsec << "us\n"; 216 } 217 218 std::cout << "Command buffer submit time: " << submitDuration.count() 219 << "us\nWait idle time: " << execDuration.count() << "us\n"; 220 221 return success(); 222 } 223 224 LogicalResult VulkanRuntime::createInstance() { 225 VkApplicationInfo applicationInfo = {}; 226 applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; 227 applicationInfo.pNext = nullptr; 228 applicationInfo.pApplicationName = "MLIR Vulkan runtime"; 229 applicationInfo.applicationVersion = 0; 230 applicationInfo.pEngineName = "mlir"; 231 applicationInfo.engineVersion = 0; 232 applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0); 233 234 VkInstanceCreateInfo instanceCreateInfo = {}; 235 instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; 236 instanceCreateInfo.pNext = nullptr; 237 instanceCreateInfo.flags = 0; 238 instanceCreateInfo.pApplicationInfo = &applicationInfo; 239 instanceCreateInfo.enabledLayerCount = 0; 240 instanceCreateInfo.ppEnabledLayerNames = 0; 241 instanceCreateInfo.enabledExtensionCount = 0; 242 instanceCreateInfo.ppEnabledExtensionNames = 0; 243 244 RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance), 245 "vkCreateInstance"); 246 return success(); 247 } 248 249 LogicalResult VulkanRuntime::createDevice() { 250 uint32_t physicalDeviceCount = 0; 251 RETURN_ON_VULKAN_ERROR( 252 vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0), 253 "vkEnumeratePhysicalDevices"); 254 255 std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount); 256 RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance, 257 &physicalDeviceCount, 258 physicalDevices.data()), 259 "vkEnumeratePhysicalDevices"); 260 261 RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE, 262 "physicalDeviceCount"); 263 264 // TODO: find the best device. 265 physicalDevice = physicalDevices.front(); 266 if (failed(getBestComputeQueue())) 267 return failure(); 268 269 const float queuePriority = 1.0f; 270 VkDeviceQueueCreateInfo deviceQueueCreateInfo = {}; 271 deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; 272 deviceQueueCreateInfo.pNext = nullptr; 273 deviceQueueCreateInfo.flags = 0; 274 deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex; 275 deviceQueueCreateInfo.queueCount = 1; 276 deviceQueueCreateInfo.pQueuePriorities = &queuePriority; 277 278 // Structure specifying parameters of a newly created device. 279 VkDeviceCreateInfo deviceCreateInfo = {}; 280 deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; 281 deviceCreateInfo.pNext = nullptr; 282 deviceCreateInfo.flags = 0; 283 deviceCreateInfo.queueCreateInfoCount = 1; 284 deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo; 285 deviceCreateInfo.enabledLayerCount = 0; 286 deviceCreateInfo.ppEnabledLayerNames = nullptr; 287 deviceCreateInfo.enabledExtensionCount = 0; 288 deviceCreateInfo.ppEnabledExtensionNames = nullptr; 289 deviceCreateInfo.pEnabledFeatures = nullptr; 290 291 RETURN_ON_VULKAN_ERROR( 292 vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device), 293 "vkCreateDevice"); 294 295 VkPhysicalDeviceMemoryProperties properties = {}; 296 vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties); 297 298 // Try to find memory type with following properties: 299 // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated 300 // with this type can be mapped for host access using vkMapMemory; 301 // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache 302 // management commands vkFlushMappedMemoryRanges and 303 // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the 304 // device or make device writes visible to the host, respectively. 305 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) { 306 if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT & 307 properties.memoryTypes[i].propertyFlags) && 308 (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT & 309 properties.memoryTypes[i].propertyFlags) && 310 (memorySize <= 311 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) { 312 hostMemoryTypeIndex = i; 313 break; 314 } 315 } 316 317 // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be 318 // used on the device. This will allow better performance access for GPU with 319 // on device memory. 320 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) { 321 if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT & 322 properties.memoryTypes[i].propertyFlags) && 323 (memorySize <= 324 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) { 325 deviceMemoryTypeIndex = i; 326 break; 327 } 328 } 329 330 RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES || 331 deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES) 332 ? VK_INCOMPLETE 333 : VK_SUCCESS, 334 "invalid memoryTypeIndex"); 335 return success(); 336 } 337 338 LogicalResult VulkanRuntime::getBestComputeQueue() { 339 uint32_t queueFamilyPropertiesCount = 0; 340 vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, 341 &queueFamilyPropertiesCount, 0); 342 343 std::vector<VkQueueFamilyProperties> familyProperties( 344 queueFamilyPropertiesCount); 345 vkGetPhysicalDeviceQueueFamilyProperties( 346 physicalDevice, &queueFamilyPropertiesCount, familyProperties.data()); 347 348 // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support 349 // compute operations. Try to find a compute-only queue first if possible. 350 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { 351 auto flags = familyProperties[i].queueFlags; 352 if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) { 353 queueFamilyIndex = i; 354 queueFamilyProperties = familyProperties[i]; 355 return success(); 356 } 357 } 358 359 // Otherwise use a queue that can also support graphics. 360 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { 361 auto flags = familyProperties[i].queueFlags; 362 if ((flags & VK_QUEUE_COMPUTE_BIT)) { 363 queueFamilyIndex = i; 364 queueFamilyProperties = familyProperties[i]; 365 return success(); 366 } 367 } 368 369 std::cerr << "cannot find valid queue"; 370 return failure(); 371 } 372 373 LogicalResult VulkanRuntime::createMemoryBuffers() { 374 // For each descriptor set. 375 for (const auto &resourceDataMapPair : resourceData) { 376 std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers; 377 const auto descriptorSetIndex = resourceDataMapPair.first; 378 const auto &resourceDataMap = resourceDataMapPair.second; 379 380 // For each descriptor binding. 381 for (const auto &resourceDataBindingPair : resourceDataMap) { 382 // Create device memory buffer. 383 VulkanDeviceMemoryBuffer memoryBuffer; 384 memoryBuffer.bindingIndex = resourceDataBindingPair.first; 385 VkDescriptorType descriptorType = {}; 386 VkBufferUsageFlagBits bufferUsage = {}; 387 388 // Check that descriptor set has storage class map. 389 const auto resourceStorageClassMapIt = 390 resourceStorageClassData.find(descriptorSetIndex); 391 if (resourceStorageClassMapIt == resourceStorageClassData.end()) { 392 std::cerr 393 << "cannot find storage class for resource in descriptor set: " 394 << descriptorSetIndex; 395 return failure(); 396 } 397 398 // Check that specific descriptor binding has storage class. 399 const auto &resourceStorageClassMap = resourceStorageClassMapIt->second; 400 const auto resourceStorageClassIt = 401 resourceStorageClassMap.find(resourceDataBindingPair.first); 402 if (resourceStorageClassIt == resourceStorageClassMap.end()) { 403 std::cerr 404 << "cannot find storage class for resource with descriptor index: " 405 << resourceDataBindingPair.first; 406 return failure(); 407 } 408 409 const auto resourceStorageClassBinding = resourceStorageClassIt->second; 410 if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding, 411 descriptorType)) || 412 failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding, 413 bufferUsage))) { 414 std::cerr << "storage class for resource with descriptor binding: " 415 << resourceDataBindingPair.first 416 << " in the descriptor set: " << descriptorSetIndex 417 << " is not supported "; 418 return failure(); 419 } 420 421 // Set descriptor type for the specific device memory buffer. 422 memoryBuffer.descriptorType = descriptorType; 423 const auto bufferSize = resourceDataBindingPair.second.size; 424 memoryBuffer.bufferSize = bufferSize; 425 // Specify memory allocation info. 426 VkMemoryAllocateInfo memoryAllocateInfo = {}; 427 memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; 428 memoryAllocateInfo.pNext = nullptr; 429 memoryAllocateInfo.allocationSize = bufferSize; 430 memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex; 431 432 // Allocate device memory. 433 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0, 434 &memoryBuffer.hostMemory), 435 "vkAllocateMemory"); 436 memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex; 437 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0, 438 &memoryBuffer.deviceMemory), 439 "vkAllocateMemory"); 440 void *payload; 441 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0, 442 bufferSize, 0, 443 reinterpret_cast<void **>(&payload)), 444 "vkMapMemory"); 445 446 // Copy host memory into the mapped area. 447 std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize); 448 vkUnmapMemory(device, memoryBuffer.hostMemory); 449 450 VkBufferCreateInfo bufferCreateInfo = {}; 451 bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; 452 bufferCreateInfo.pNext = nullptr; 453 bufferCreateInfo.flags = 0; 454 bufferCreateInfo.size = bufferSize; 455 bufferCreateInfo.usage = bufferUsage; 456 bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; 457 bufferCreateInfo.queueFamilyIndexCount = 1; 458 bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex; 459 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, 0, 460 &memoryBuffer.hostBuffer), 461 "vkCreateBuffer"); 462 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, 0, 463 &memoryBuffer.deviceBuffer), 464 "vkCreateBuffer"); 465 466 // Bind buffer and device memory. 467 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer, 468 memoryBuffer.hostMemory, 0), 469 "vkBindBufferMemory"); 470 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, 471 memoryBuffer.deviceBuffer, 472 memoryBuffer.deviceMemory, 0), 473 "vkBindBufferMemory"); 474 475 // Update buffer info. 476 memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer; 477 memoryBuffer.bufferInfo.offset = 0; 478 memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE; 479 deviceMemoryBuffers.push_back(memoryBuffer); 480 } 481 482 // Associate device memory buffers with a descriptor set. 483 deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers; 484 } 485 return success(); 486 } 487 488 LogicalResult VulkanRuntime::copyResource(bool deviceToHost) { 489 VkCommandBufferAllocateInfo commandBufferAllocateInfo = { 490 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, 491 NULL, 492 commandPool, 493 VK_COMMAND_BUFFER_LEVEL_PRIMARY, 494 1, 495 }; 496 VkCommandBuffer commandBuffer; 497 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device, 498 &commandBufferAllocateInfo, 499 &commandBuffer), 500 "vkAllocateCommandBuffers"); 501 502 VkCommandBufferBeginInfo commandBufferBeginInfo = { 503 VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, 504 NULL, 505 0, 506 NULL, 507 }; 508 RETURN_ON_VULKAN_ERROR( 509 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo), 510 "vkBeginCommandBuffer"); 511 512 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 513 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings; 514 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 515 for (const auto &memBuffer : deviceMemoryBuffers) { 516 VkBufferCopy copy = {0, 0, memBuffer.bufferSize}; 517 if (deviceToHost) 518 vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer, 519 memBuffer.hostBuffer, 1, ©); 520 else 521 vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer, 522 memBuffer.deviceBuffer, 1, ©); 523 } 524 } 525 526 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer), 527 "vkEndCommandBuffer"); 528 VkSubmitInfo submitInfo = { 529 VK_STRUCTURE_TYPE_SUBMIT_INFO, 530 NULL, 531 0, 532 NULL, 533 NULL, 534 1, 535 &commandBuffer, 536 0, 537 NULL, 538 }; 539 submitInfo.pCommandBuffers = &commandBuffer; 540 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE), 541 "vkQueueSubmit"); 542 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle"); 543 544 vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer); 545 return success(); 546 } 547 548 LogicalResult VulkanRuntime::createShaderModule() { 549 VkShaderModuleCreateInfo shaderModuleCreateInfo = {}; 550 shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; 551 shaderModuleCreateInfo.pNext = nullptr; 552 shaderModuleCreateInfo.flags = 0; 553 // Set size in bytes. 554 shaderModuleCreateInfo.codeSize = binarySize; 555 // Set pointer to the binary shader. 556 shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary); 557 RETURN_ON_VULKAN_ERROR( 558 vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule), 559 "vkCreateShaderModule"); 560 return success(); 561 } 562 563 void VulkanRuntime::initDescriptorSetLayoutBindingMap() { 564 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 565 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings; 566 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 567 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; 568 569 // Create a layout binding for each descriptor. 570 for (const auto &memBuffer : deviceMemoryBuffers) { 571 VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {}; 572 descriptorSetLayoutBinding.binding = memBuffer.bindingIndex; 573 descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType; 574 descriptorSetLayoutBinding.descriptorCount = 1; 575 descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; 576 descriptorSetLayoutBinding.pImmutableSamplers = 0; 577 descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding); 578 } 579 descriptorSetLayoutBindingMap[descriptorSetIndex] = 580 descriptorSetLayoutBindings; 581 } 582 } 583 584 LogicalResult VulkanRuntime::createDescriptorSetLayout() { 585 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 586 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; 587 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 588 // Each descriptor in a descriptor set must be the same type. 589 VkDescriptorType descriptorType = 590 deviceMemoryBuffers.front().descriptorType; 591 const uint32_t descriptorSize = deviceMemoryBuffers.size(); 592 const auto descriptorSetLayoutBindingIt = 593 descriptorSetLayoutBindingMap.find(descriptorSetIndex); 594 595 if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) { 596 std::cerr << "cannot find layout bindings for the set with number: " 597 << descriptorSetIndex; 598 return failure(); 599 } 600 601 const auto &descriptorSetLayoutBindings = 602 descriptorSetLayoutBindingIt->second; 603 // Create descriptor set layout. 604 VkDescriptorSetLayout descriptorSetLayout = {}; 605 VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {}; 606 607 descriptorSetLayoutCreateInfo.sType = 608 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; 609 descriptorSetLayoutCreateInfo.pNext = nullptr; 610 descriptorSetLayoutCreateInfo.flags = 0; 611 // Amount of descriptor bindings in a layout set. 612 descriptorSetLayoutCreateInfo.bindingCount = 613 descriptorSetLayoutBindings.size(); 614 descriptorSetLayoutCreateInfo.pBindings = 615 descriptorSetLayoutBindings.data(); 616 RETURN_ON_VULKAN_ERROR( 617 vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0, 618 &descriptorSetLayout), 619 "vkCreateDescriptorSetLayout"); 620 621 descriptorSetLayouts.push_back(descriptorSetLayout); 622 descriptorSetInfoPool.push_back( 623 {descriptorSetIndex, descriptorSize, descriptorType}); 624 } 625 return success(); 626 } 627 628 LogicalResult VulkanRuntime::createPipelineLayout() { 629 // Associate descriptor sets with a pipeline layout. 630 VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {}; 631 pipelineLayoutCreateInfo.sType = 632 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; 633 pipelineLayoutCreateInfo.pNext = nullptr; 634 pipelineLayoutCreateInfo.flags = 0; 635 pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size(); 636 pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data(); 637 pipelineLayoutCreateInfo.pushConstantRangeCount = 0; 638 pipelineLayoutCreateInfo.pPushConstantRanges = 0; 639 RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device, 640 &pipelineLayoutCreateInfo, 0, 641 &pipelineLayout), 642 "vkCreatePipelineLayout"); 643 return success(); 644 } 645 646 LogicalResult VulkanRuntime::createComputePipeline() { 647 VkPipelineShaderStageCreateInfo stageInfo = {}; 648 stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; 649 stageInfo.pNext = nullptr; 650 stageInfo.flags = 0; 651 stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT; 652 stageInfo.module = shaderModule; 653 // Set entry point. 654 stageInfo.pName = entryPoint; 655 stageInfo.pSpecializationInfo = 0; 656 657 VkComputePipelineCreateInfo computePipelineCreateInfo = {}; 658 computePipelineCreateInfo.sType = 659 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; 660 computePipelineCreateInfo.pNext = nullptr; 661 computePipelineCreateInfo.flags = 0; 662 computePipelineCreateInfo.stage = stageInfo; 663 computePipelineCreateInfo.layout = pipelineLayout; 664 computePipelineCreateInfo.basePipelineHandle = 0; 665 computePipelineCreateInfo.basePipelineIndex = 0; 666 RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1, 667 &computePipelineCreateInfo, 0, 668 &pipeline), 669 "vkCreateComputePipelines"); 670 return success(); 671 } 672 673 LogicalResult VulkanRuntime::createDescriptorPool() { 674 std::vector<VkDescriptorPoolSize> descriptorPoolSizes; 675 for (const auto &descriptorSetInfo : descriptorSetInfoPool) { 676 // For each descriptor set populate descriptor pool size. 677 VkDescriptorPoolSize descriptorPoolSize = {}; 678 descriptorPoolSize.type = descriptorSetInfo.descriptorType; 679 descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize; 680 descriptorPoolSizes.push_back(descriptorPoolSize); 681 } 682 683 VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {}; 684 descriptorPoolCreateInfo.sType = 685 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; 686 descriptorPoolCreateInfo.pNext = nullptr; 687 descriptorPoolCreateInfo.flags = 0; 688 descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size(); 689 descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size(); 690 descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data(); 691 RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device, 692 &descriptorPoolCreateInfo, 0, 693 &descriptorPool), 694 "vkCreateDescriptorPool"); 695 return success(); 696 } 697 698 LogicalResult VulkanRuntime::allocateDescriptorSets() { 699 VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {}; 700 // Size of descriptor sets and descriptor layout sets is the same. 701 descriptorSets.resize(descriptorSetLayouts.size()); 702 descriptorSetAllocateInfo.sType = 703 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; 704 descriptorSetAllocateInfo.pNext = nullptr; 705 descriptorSetAllocateInfo.descriptorPool = descriptorPool; 706 descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size(); 707 descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data(); 708 RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device, 709 &descriptorSetAllocateInfo, 710 descriptorSets.data()), 711 "vkAllocateDescriptorSets"); 712 return success(); 713 } 714 715 LogicalResult VulkanRuntime::setWriteDescriptors() { 716 if (descriptorSets.size() != descriptorSetInfoPool.size()) { 717 std::cerr << "Each descriptor set must have descriptor set information"; 718 return failure(); 719 } 720 // For each descriptor set. 721 auto descriptorSetIt = descriptorSets.begin(); 722 // Each descriptor set is associated with descriptor set info. 723 for (const auto &descriptorSetInfo : descriptorSetInfoPool) { 724 // For each device memory buffer in the descriptor set. 725 const auto &deviceMemoryBuffers = 726 deviceMemoryBufferMap[descriptorSetInfo.descriptorSet]; 727 for (const auto &memoryBuffer : deviceMemoryBuffers) { 728 // Structure describing descriptor sets to write to. 729 VkWriteDescriptorSet wSet = {}; 730 wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; 731 wSet.pNext = nullptr; 732 // Descriptor set. 733 wSet.dstSet = *descriptorSetIt; 734 wSet.dstBinding = memoryBuffer.bindingIndex; 735 wSet.dstArrayElement = 0; 736 wSet.descriptorCount = 1; 737 wSet.descriptorType = memoryBuffer.descriptorType; 738 wSet.pImageInfo = nullptr; 739 wSet.pBufferInfo = &memoryBuffer.bufferInfo; 740 wSet.pTexelBufferView = nullptr; 741 vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr); 742 } 743 // Increment descriptor set iterator. 744 ++descriptorSetIt; 745 } 746 return success(); 747 } 748 749 LogicalResult VulkanRuntime::createCommandPool() { 750 VkCommandPoolCreateInfo commandPoolCreateInfo = {}; 751 commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; 752 commandPoolCreateInfo.pNext = nullptr; 753 commandPoolCreateInfo.flags = 0; 754 commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex; 755 RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo, 756 /*pAllocator=*/nullptr, 757 &commandPool), 758 "vkCreateCommandPool"); 759 return success(); 760 } 761 762 LogicalResult VulkanRuntime::createQueryPool() { 763 // Return directly if timestamp query is not supported. 764 if (queueFamilyProperties.timestampValidBits == 0) 765 return success(); 766 767 // Get timestamp period for this physical device. 768 VkPhysicalDeviceProperties deviceProperties = {}; 769 vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties); 770 timestampPeriod = deviceProperties.limits.timestampPeriod; 771 772 // Create query pool. 773 VkQueryPoolCreateInfo queryPoolCreateInfo = {}; 774 queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO; 775 queryPoolCreateInfo.pNext = nullptr; 776 queryPoolCreateInfo.flags = 0; 777 queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP; 778 queryPoolCreateInfo.queryCount = 2; 779 queryPoolCreateInfo.pipelineStatistics = 0; 780 RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo, 781 /*pAllocator=*/nullptr, &queryPool), 782 "vkCreateQueryPool"); 783 784 return success(); 785 } 786 787 LogicalResult VulkanRuntime::createComputeCommandBuffer() { 788 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {}; 789 commandBufferAllocateInfo.sType = 790 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; 791 commandBufferAllocateInfo.pNext = nullptr; 792 commandBufferAllocateInfo.commandPool = commandPool; 793 commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; 794 commandBufferAllocateInfo.commandBufferCount = 1; 795 796 VkCommandBuffer commandBuffer; 797 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device, 798 &commandBufferAllocateInfo, 799 &commandBuffer), 800 "vkAllocateCommandBuffers"); 801 802 VkCommandBufferBeginInfo commandBufferBeginInfo = {}; 803 commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; 804 commandBufferBeginInfo.pNext = nullptr; 805 commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; 806 commandBufferBeginInfo.pInheritanceInfo = nullptr; 807 808 // Commands begin. 809 RETURN_ON_VULKAN_ERROR( 810 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo), 811 "vkBeginCommandBuffer"); 812 813 if (queryPool != VK_NULL_HANDLE) 814 vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2); 815 816 vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); 817 vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, 818 pipelineLayout, 0, descriptorSets.size(), 819 descriptorSets.data(), 0, 0); 820 // Get a timestamp before invoking the compute shader. 821 if (queryPool != VK_NULL_HANDLE) 822 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, 823 queryPool, 0); 824 vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y, 825 numWorkGroups.z); 826 // Get another timestamp after invoking the compute shader. 827 if (queryPool != VK_NULL_HANDLE) 828 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, 829 queryPool, 1); 830 831 // Commands end. 832 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer), 833 "vkEndCommandBuffer"); 834 835 commandBuffers.push_back(commandBuffer); 836 return success(); 837 } 838 839 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() { 840 VkSubmitInfo submitInfo = {}; 841 submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; 842 submitInfo.pNext = nullptr; 843 submitInfo.waitSemaphoreCount = 0; 844 submitInfo.pWaitSemaphores = 0; 845 submitInfo.pWaitDstStageMask = 0; 846 submitInfo.commandBufferCount = commandBuffers.size(); 847 submitInfo.pCommandBuffers = commandBuffers.data(); 848 submitInfo.signalSemaphoreCount = 0; 849 submitInfo.pSignalSemaphores = nullptr; 850 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0), 851 "vkQueueSubmit"); 852 return success(); 853 } 854 855 LogicalResult VulkanRuntime::updateHostMemoryBuffers() { 856 // First copy back the data to the staging buffer. 857 copyResource(/*deviceToHost=*/true); 858 859 // For each descriptor set. 860 for (auto &resourceDataMapPair : resourceData) { 861 auto &resourceDataMap = resourceDataMapPair.second; 862 auto &deviceMemoryBuffers = 863 deviceMemoryBufferMap[resourceDataMapPair.first]; 864 // For each device memory buffer in the set. 865 for (auto &deviceMemoryBuffer : deviceMemoryBuffers) { 866 if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) { 867 void *payload; 868 auto &hostMemoryBuffer = 869 resourceDataMap[deviceMemoryBuffer.bindingIndex]; 870 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, 871 deviceMemoryBuffer.hostMemory, 0, 872 hostMemoryBuffer.size, 0, 873 reinterpret_cast<void **>(&payload)), 874 "vkMapMemory"); 875 std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size); 876 vkUnmapMemory(device, deviceMemoryBuffer.hostMemory); 877 } 878 } 879 } 880 return success(); 881 } 882