1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides a library for running a module on a Vulkan device. 10 // Implements a Vulkan runtime. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "VulkanRuntime.h" 15 16 #include <chrono> 17 #include <cstring> 18 // TODO(antiagainst): It's generally bad to access stdout/stderr in a library. 19 // Figure out a better way for error reporting. 20 #include <iomanip> 21 #include <iostream> 22 23 inline void emitVulkanError(const char *api, VkResult error) { 24 std::cerr << " failed with error code " << error << " when executing " << api; 25 } 26 27 #define RETURN_ON_VULKAN_ERROR(result, api) \ 28 if ((result) != VK_SUCCESS) { \ 29 emitVulkanError(api, (result)); \ 30 return failure(); \ 31 } 32 33 using namespace mlir; 34 35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) { 36 numWorkGroups = numberWorkGroups; 37 } 38 39 void VulkanRuntime::setResourceStorageClassBindingMap( 40 const ResourceStorageClassBindingMap &stClassData) { 41 resourceStorageClassData = stClassData; 42 } 43 44 void VulkanRuntime::setResourceData( 45 const DescriptorSetIndex desIndex, const BindingIndex bindIndex, 46 const VulkanHostMemoryBuffer &hostMemBuffer) { 47 resourceData[desIndex][bindIndex] = hostMemBuffer; 48 resourceStorageClassData[desIndex][bindIndex] = 49 SPIRVStorageClass::StorageBuffer; 50 } 51 52 void VulkanRuntime::setEntryPoint(const char *entryPointName) { 53 entryPoint = entryPointName; 54 } 55 56 void VulkanRuntime::setResourceData(const ResourceData &resData) { 57 resourceData = resData; 58 } 59 60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) { 61 binary = shader; 62 binarySize = size; 63 } 64 65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType( 66 SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) { 67 switch (storageClass) { 68 case SPIRVStorageClass::StorageBuffer: 69 descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; 70 break; 71 case SPIRVStorageClass::Uniform: 72 descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; 73 break; 74 } 75 return success(); 76 } 77 78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag( 79 SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) { 80 switch (storageClass) { 81 case SPIRVStorageClass::StorageBuffer: 82 bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; 83 break; 84 case SPIRVStorageClass::Uniform: 85 bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; 86 break; 87 } 88 return success(); 89 } 90 91 LogicalResult VulkanRuntime::countDeviceMemorySize() { 92 for (const auto &resourceDataMapPair : resourceData) { 93 const auto &resourceDataMap = resourceDataMapPair.second; 94 for (const auto &resourceDataBindingPair : resourceDataMap) { 95 if (resourceDataBindingPair.second.size) { 96 memorySize += resourceDataBindingPair.second.size; 97 } else { 98 std::cerr << "expected buffer size greater than zero for resource data"; 99 return failure(); 100 } 101 } 102 } 103 return success(); 104 } 105 106 LogicalResult VulkanRuntime::initRuntime() { 107 if (!resourceData.size()) { 108 std::cerr << "Vulkan runtime needs at least one resource"; 109 return failure(); 110 } 111 if (!binarySize || !binary) { 112 std::cerr << "binary shader size must be greater than zero"; 113 return failure(); 114 } 115 if (failed(countDeviceMemorySize())) { 116 return failure(); 117 } 118 return success(); 119 } 120 121 LogicalResult VulkanRuntime::destroy() { 122 // According to Vulkan spec: 123 // "To ensure that no work is active on the device, vkDeviceWaitIdle can be 124 // used to gate the destruction of the device. Prior to destroying a device, 125 // an application is responsible for destroying/freeing any Vulkan objects 126 // that were created using that device as the first parameter of the 127 // corresponding vkCreate* or vkAllocate* command." 128 RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle"); 129 130 // Free and destroy. 131 vkFreeCommandBuffers(device, commandPool, commandBuffers.size(), 132 commandBuffers.data()); 133 vkDestroyQueryPool(device, queryPool, nullptr); 134 vkDestroyCommandPool(device, commandPool, nullptr); 135 vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(), 136 descriptorSets.data()); 137 vkDestroyDescriptorPool(device, descriptorPool, nullptr); 138 vkDestroyPipeline(device, pipeline, nullptr); 139 vkDestroyPipelineLayout(device, pipelineLayout, nullptr); 140 for (auto &descriptorSetLayout : descriptorSetLayouts) { 141 vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr); 142 } 143 vkDestroyShaderModule(device, shaderModule, nullptr); 144 145 // For each descriptor set. 146 for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 147 auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 148 // For each descriptor binding. 149 for (auto &memoryBuffer : deviceMemoryBuffers) { 150 vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr); 151 vkDestroyBuffer(device, memoryBuffer.buffer, nullptr); 152 } 153 } 154 155 vkDestroyDevice(device, nullptr); 156 vkDestroyInstance(instance, nullptr); 157 return success(); 158 } 159 160 LogicalResult VulkanRuntime::run() { 161 // Create logical device, shader module and memory buffers. 162 if (failed(createInstance()) || failed(createDevice()) || 163 failed(createMemoryBuffers()) || failed(createShaderModule())) { 164 return failure(); 165 } 166 167 // Descriptor bindings divided into sets. Each descriptor binding 168 // must have a layout binding attached into a descriptor set layout. 169 // Each layout set must be binded into a pipeline layout. 170 initDescriptorSetLayoutBindingMap(); 171 if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) || 172 // Each descriptor set must be allocated from a descriptor pool. 173 failed(createComputePipeline()) || failed(createDescriptorPool()) || 174 failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) || 175 // Create command buffer. 176 failed(createCommandPool()) || failed(createQueryPool()) || 177 failed(createComputeCommandBuffer())) { 178 return failure(); 179 } 180 181 // Get working queue. 182 vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue); 183 184 auto submitStart = std::chrono::high_resolution_clock::now(); 185 // Submit command buffer into the queue. 186 if (failed(submitCommandBuffersToQueue())) 187 return failure(); 188 auto submitEnd = std::chrono::high_resolution_clock::now(); 189 190 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle"); 191 auto execEnd = std::chrono::high_resolution_clock::now(); 192 193 auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>( 194 submitEnd - submitStart); 195 auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>( 196 execEnd - submitEnd); 197 198 if (queryPool != VK_NULL_HANDLE) { 199 uint64_t timestamps[2]; 200 RETURN_ON_VULKAN_ERROR( 201 vkGetQueryPoolResults( 202 device, queryPool, /*firstQuery=*/0, /*queryCount=*/2, 203 /*dataSize=*/sizeof(timestamps), 204 /*pData=*/reinterpret_cast<void *>(timestamps), 205 /*stride=*/sizeof(uint64_t), 206 VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT), 207 "vkGetQueryPoolResults"); 208 float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000; 209 std::cout << "Compute shader execution time: " << std::setprecision(3) 210 << microsec << "us\n"; 211 } 212 213 std::cout << "Command buffer submit time: " << submitDuration.count() 214 << "us\nWait idle time: " << execDuration.count() << "us\n"; 215 216 return success(); 217 } 218 219 LogicalResult VulkanRuntime::createInstance() { 220 VkApplicationInfo applicationInfo = {}; 221 applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; 222 applicationInfo.pNext = nullptr; 223 applicationInfo.pApplicationName = "MLIR Vulkan runtime"; 224 applicationInfo.applicationVersion = 0; 225 applicationInfo.pEngineName = "mlir"; 226 applicationInfo.engineVersion = 0; 227 applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0); 228 229 VkInstanceCreateInfo instanceCreateInfo = {}; 230 instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; 231 instanceCreateInfo.pNext = nullptr; 232 instanceCreateInfo.flags = 0; 233 instanceCreateInfo.pApplicationInfo = &applicationInfo; 234 instanceCreateInfo.enabledLayerCount = 0; 235 instanceCreateInfo.ppEnabledLayerNames = 0; 236 instanceCreateInfo.enabledExtensionCount = 0; 237 instanceCreateInfo.ppEnabledExtensionNames = 0; 238 239 RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance), 240 "vkCreateInstance"); 241 return success(); 242 } 243 244 LogicalResult VulkanRuntime::createDevice() { 245 uint32_t physicalDeviceCount = 0; 246 RETURN_ON_VULKAN_ERROR( 247 vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0), 248 "vkEnumeratePhysicalDevices"); 249 250 std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount); 251 RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance, 252 &physicalDeviceCount, 253 physicalDevices.data()), 254 "vkEnumeratePhysicalDevices"); 255 256 RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE, 257 "physicalDeviceCount"); 258 259 // TODO(denis0x0D): find the best device. 260 physicalDevice = physicalDevices.front(); 261 if (failed(getBestComputeQueue())) 262 return failure(); 263 264 const float queuePriority = 1.0f; 265 VkDeviceQueueCreateInfo deviceQueueCreateInfo = {}; 266 deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; 267 deviceQueueCreateInfo.pNext = nullptr; 268 deviceQueueCreateInfo.flags = 0; 269 deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex; 270 deviceQueueCreateInfo.queueCount = 1; 271 deviceQueueCreateInfo.pQueuePriorities = &queuePriority; 272 273 // Structure specifying parameters of a newly created device. 274 VkDeviceCreateInfo deviceCreateInfo = {}; 275 deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; 276 deviceCreateInfo.pNext = nullptr; 277 deviceCreateInfo.flags = 0; 278 deviceCreateInfo.queueCreateInfoCount = 1; 279 deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo; 280 deviceCreateInfo.enabledLayerCount = 0; 281 deviceCreateInfo.ppEnabledLayerNames = nullptr; 282 deviceCreateInfo.enabledExtensionCount = 0; 283 deviceCreateInfo.ppEnabledExtensionNames = nullptr; 284 deviceCreateInfo.pEnabledFeatures = nullptr; 285 286 RETURN_ON_VULKAN_ERROR( 287 vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device), 288 "vkCreateDevice"); 289 290 VkPhysicalDeviceMemoryProperties properties = {}; 291 vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties); 292 293 // Try to find memory type with following properties: 294 // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated 295 // with this type can be mapped for host access using vkMapMemory; 296 // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache 297 // management commands vkFlushMappedMemoryRanges and 298 // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the 299 // device or make device writes visible to the host, respectively. 300 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) { 301 if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT & 302 properties.memoryTypes[i].propertyFlags) && 303 (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT & 304 properties.memoryTypes[i].propertyFlags) && 305 (memorySize <= 306 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) { 307 memoryTypeIndex = i; 308 break; 309 } 310 } 311 312 RETURN_ON_VULKAN_ERROR(memoryTypeIndex == VK_MAX_MEMORY_TYPES ? VK_INCOMPLETE 313 : VK_SUCCESS, 314 "invalid memoryTypeIndex"); 315 return success(); 316 } 317 318 LogicalResult VulkanRuntime::getBestComputeQueue() { 319 uint32_t queueFamilyPropertiesCount = 0; 320 vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, 321 &queueFamilyPropertiesCount, 0); 322 323 std::vector<VkQueueFamilyProperties> familyProperties( 324 queueFamilyPropertiesCount); 325 vkGetPhysicalDeviceQueueFamilyProperties( 326 physicalDevice, &queueFamilyPropertiesCount, familyProperties.data()); 327 328 // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support 329 // compute operations. Try to find a compute-only queue first if possible. 330 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { 331 auto flags = familyProperties[i].queueFlags; 332 if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) { 333 queueFamilyIndex = i; 334 queueFamilyProperties = familyProperties[i]; 335 return success(); 336 } 337 } 338 339 // Otherwise use a queue that can also support graphics. 340 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { 341 auto flags = familyProperties[i].queueFlags; 342 if ((flags & VK_QUEUE_COMPUTE_BIT)) { 343 queueFamilyIndex = i; 344 queueFamilyProperties = familyProperties[i]; 345 return success(); 346 } 347 } 348 349 std::cerr << "cannot find valid queue"; 350 return failure(); 351 } 352 353 LogicalResult VulkanRuntime::createMemoryBuffers() { 354 // For each descriptor set. 355 for (const auto &resourceDataMapPair : resourceData) { 356 std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers; 357 const auto descriptorSetIndex = resourceDataMapPair.first; 358 const auto &resourceDataMap = resourceDataMapPair.second; 359 360 // For each descriptor binding. 361 for (const auto &resourceDataBindingPair : resourceDataMap) { 362 // Create device memory buffer. 363 VulkanDeviceMemoryBuffer memoryBuffer; 364 memoryBuffer.bindingIndex = resourceDataBindingPair.first; 365 VkDescriptorType descriptorType = {}; 366 VkBufferUsageFlagBits bufferUsage = {}; 367 368 // Check that descriptor set has storage class map. 369 const auto resourceStorageClassMapIt = 370 resourceStorageClassData.find(descriptorSetIndex); 371 if (resourceStorageClassMapIt == resourceStorageClassData.end()) { 372 std::cerr 373 << "cannot find storage class for resource in descriptor set: " 374 << descriptorSetIndex; 375 return failure(); 376 } 377 378 // Check that specific descriptor binding has storage class. 379 const auto &resourceStorageClassMap = resourceStorageClassMapIt->second; 380 const auto resourceStorageClassIt = 381 resourceStorageClassMap.find(resourceDataBindingPair.first); 382 if (resourceStorageClassIt == resourceStorageClassMap.end()) { 383 std::cerr 384 << "cannot find storage class for resource with descriptor index: " 385 << resourceDataBindingPair.first; 386 return failure(); 387 } 388 389 const auto resourceStorageClassBinding = resourceStorageClassIt->second; 390 if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding, 391 descriptorType)) || 392 failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding, 393 bufferUsage))) { 394 std::cerr << "storage class for resource with descriptor binding: " 395 << resourceDataBindingPair.first 396 << " in the descriptor set: " << descriptorSetIndex 397 << " is not supported "; 398 return failure(); 399 } 400 401 // Set descriptor type for the specific device memory buffer. 402 memoryBuffer.descriptorType = descriptorType; 403 const auto bufferSize = resourceDataBindingPair.second.size; 404 405 // Specify memory allocation info. 406 VkMemoryAllocateInfo memoryAllocateInfo = {}; 407 memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; 408 memoryAllocateInfo.pNext = nullptr; 409 memoryAllocateInfo.allocationSize = bufferSize; 410 memoryAllocateInfo.memoryTypeIndex = memoryTypeIndex; 411 412 // Allocate device memory. 413 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0, 414 &memoryBuffer.deviceMemory), 415 "vkAllocateMemory"); 416 void *payload; 417 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.deviceMemory, 0, 418 bufferSize, 0, 419 reinterpret_cast<void **>(&payload)), 420 "vkMapMemory"); 421 422 // Copy host memory into the mapped area. 423 std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize); 424 vkUnmapMemory(device, memoryBuffer.deviceMemory); 425 426 VkBufferCreateInfo bufferCreateInfo = {}; 427 bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; 428 bufferCreateInfo.pNext = nullptr; 429 bufferCreateInfo.flags = 0; 430 bufferCreateInfo.size = bufferSize; 431 bufferCreateInfo.usage = bufferUsage; 432 bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; 433 bufferCreateInfo.queueFamilyIndexCount = 1; 434 bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex; 435 RETURN_ON_VULKAN_ERROR( 436 vkCreateBuffer(device, &bufferCreateInfo, 0, &memoryBuffer.buffer), 437 "vkCreateBuffer"); 438 439 // Bind buffer and device memory. 440 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.buffer, 441 memoryBuffer.deviceMemory, 0), 442 "vkBindBufferMemory"); 443 444 // Update buffer info. 445 memoryBuffer.bufferInfo.buffer = memoryBuffer.buffer; 446 memoryBuffer.bufferInfo.offset = 0; 447 memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE; 448 deviceMemoryBuffers.push_back(memoryBuffer); 449 } 450 451 // Associate device memory buffers with a descriptor set. 452 deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers; 453 } 454 return success(); 455 } 456 457 LogicalResult VulkanRuntime::createShaderModule() { 458 VkShaderModuleCreateInfo shaderModuleCreateInfo = {}; 459 shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; 460 shaderModuleCreateInfo.pNext = nullptr; 461 shaderModuleCreateInfo.flags = 0; 462 // Set size in bytes. 463 shaderModuleCreateInfo.codeSize = binarySize; 464 // Set pointer to the binary shader. 465 shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary); 466 RETURN_ON_VULKAN_ERROR( 467 vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule), 468 "vkCreateShaderModule"); 469 return success(); 470 } 471 472 void VulkanRuntime::initDescriptorSetLayoutBindingMap() { 473 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 474 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings; 475 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 476 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; 477 478 // Create a layout binding for each descriptor. 479 for (const auto &memBuffer : deviceMemoryBuffers) { 480 VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {}; 481 descriptorSetLayoutBinding.binding = memBuffer.bindingIndex; 482 descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType; 483 descriptorSetLayoutBinding.descriptorCount = 1; 484 descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; 485 descriptorSetLayoutBinding.pImmutableSamplers = 0; 486 descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding); 487 } 488 descriptorSetLayoutBindingMap[descriptorSetIndex] = 489 descriptorSetLayoutBindings; 490 } 491 } 492 493 LogicalResult VulkanRuntime::createDescriptorSetLayout() { 494 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 495 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; 496 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 497 // Each descriptor in a descriptor set must be the same type. 498 VkDescriptorType descriptorType = 499 deviceMemoryBuffers.front().descriptorType; 500 const uint32_t descriptorSize = deviceMemoryBuffers.size(); 501 const auto descriptorSetLayoutBindingIt = 502 descriptorSetLayoutBindingMap.find(descriptorSetIndex); 503 504 if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) { 505 std::cerr << "cannot find layout bindings for the set with number: " 506 << descriptorSetIndex; 507 return failure(); 508 } 509 510 const auto &descriptorSetLayoutBindings = 511 descriptorSetLayoutBindingIt->second; 512 // Create descriptor set layout. 513 VkDescriptorSetLayout descriptorSetLayout = {}; 514 VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {}; 515 516 descriptorSetLayoutCreateInfo.sType = 517 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; 518 descriptorSetLayoutCreateInfo.pNext = nullptr; 519 descriptorSetLayoutCreateInfo.flags = 0; 520 // Amount of descriptor bindings in a layout set. 521 descriptorSetLayoutCreateInfo.bindingCount = 522 descriptorSetLayoutBindings.size(); 523 descriptorSetLayoutCreateInfo.pBindings = 524 descriptorSetLayoutBindings.data(); 525 RETURN_ON_VULKAN_ERROR( 526 vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0, 527 &descriptorSetLayout), 528 "vkCreateDescriptorSetLayout"); 529 530 descriptorSetLayouts.push_back(descriptorSetLayout); 531 descriptorSetInfoPool.push_back( 532 {descriptorSetIndex, descriptorSize, descriptorType}); 533 } 534 return success(); 535 } 536 537 LogicalResult VulkanRuntime::createPipelineLayout() { 538 // Associate descriptor sets with a pipeline layout. 539 VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {}; 540 pipelineLayoutCreateInfo.sType = 541 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; 542 pipelineLayoutCreateInfo.pNext = nullptr; 543 pipelineLayoutCreateInfo.flags = 0; 544 pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size(); 545 pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data(); 546 pipelineLayoutCreateInfo.pushConstantRangeCount = 0; 547 pipelineLayoutCreateInfo.pPushConstantRanges = 0; 548 RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device, 549 &pipelineLayoutCreateInfo, 0, 550 &pipelineLayout), 551 "vkCreatePipelineLayout"); 552 return success(); 553 } 554 555 LogicalResult VulkanRuntime::createComputePipeline() { 556 VkPipelineShaderStageCreateInfo stageInfo = {}; 557 stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; 558 stageInfo.pNext = nullptr; 559 stageInfo.flags = 0; 560 stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT; 561 stageInfo.module = shaderModule; 562 // Set entry point. 563 stageInfo.pName = entryPoint; 564 stageInfo.pSpecializationInfo = 0; 565 566 VkComputePipelineCreateInfo computePipelineCreateInfo = {}; 567 computePipelineCreateInfo.sType = 568 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; 569 computePipelineCreateInfo.pNext = nullptr; 570 computePipelineCreateInfo.flags = 0; 571 computePipelineCreateInfo.stage = stageInfo; 572 computePipelineCreateInfo.layout = pipelineLayout; 573 computePipelineCreateInfo.basePipelineHandle = 0; 574 computePipelineCreateInfo.basePipelineIndex = 0; 575 RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1, 576 &computePipelineCreateInfo, 0, 577 &pipeline), 578 "vkCreateComputePipelines"); 579 return success(); 580 } 581 582 LogicalResult VulkanRuntime::createDescriptorPool() { 583 std::vector<VkDescriptorPoolSize> descriptorPoolSizes; 584 for (const auto &descriptorSetInfo : descriptorSetInfoPool) { 585 // For each descriptor set populate descriptor pool size. 586 VkDescriptorPoolSize descriptorPoolSize = {}; 587 descriptorPoolSize.type = descriptorSetInfo.descriptorType; 588 descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize; 589 descriptorPoolSizes.push_back(descriptorPoolSize); 590 } 591 592 VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {}; 593 descriptorPoolCreateInfo.sType = 594 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; 595 descriptorPoolCreateInfo.pNext = nullptr; 596 descriptorPoolCreateInfo.flags = 0; 597 descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size(); 598 descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size(); 599 descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data(); 600 RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device, 601 &descriptorPoolCreateInfo, 0, 602 &descriptorPool), 603 "vkCreateDescriptorPool"); 604 return success(); 605 } 606 607 LogicalResult VulkanRuntime::allocateDescriptorSets() { 608 VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {}; 609 // Size of descriptor sets and descriptor layout sets is the same. 610 descriptorSets.resize(descriptorSetLayouts.size()); 611 descriptorSetAllocateInfo.sType = 612 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; 613 descriptorSetAllocateInfo.pNext = nullptr; 614 descriptorSetAllocateInfo.descriptorPool = descriptorPool; 615 descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size(); 616 descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data(); 617 RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device, 618 &descriptorSetAllocateInfo, 619 descriptorSets.data()), 620 "vkAllocateDescriptorSets"); 621 return success(); 622 } 623 624 LogicalResult VulkanRuntime::setWriteDescriptors() { 625 if (descriptorSets.size() != descriptorSetInfoPool.size()) { 626 std::cerr << "Each descriptor set must have descriptor set information"; 627 return failure(); 628 } 629 // For each descriptor set. 630 auto descriptorSetIt = descriptorSets.begin(); 631 // Each descriptor set is associated with descriptor set info. 632 for (const auto &descriptorSetInfo : descriptorSetInfoPool) { 633 // For each device memory buffer in the descriptor set. 634 const auto &deviceMemoryBuffers = 635 deviceMemoryBufferMap[descriptorSetInfo.descriptorSet]; 636 for (const auto &memoryBuffer : deviceMemoryBuffers) { 637 // Structure describing descriptor sets to write to. 638 VkWriteDescriptorSet wSet = {}; 639 wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; 640 wSet.pNext = nullptr; 641 // Descriptor set. 642 wSet.dstSet = *descriptorSetIt; 643 wSet.dstBinding = memoryBuffer.bindingIndex; 644 wSet.dstArrayElement = 0; 645 wSet.descriptorCount = 1; 646 wSet.descriptorType = memoryBuffer.descriptorType; 647 wSet.pImageInfo = nullptr; 648 wSet.pBufferInfo = &memoryBuffer.bufferInfo; 649 wSet.pTexelBufferView = nullptr; 650 vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr); 651 } 652 // Increment descriptor set iterator. 653 ++descriptorSetIt; 654 } 655 return success(); 656 } 657 658 LogicalResult VulkanRuntime::createCommandPool() { 659 VkCommandPoolCreateInfo commandPoolCreateInfo = {}; 660 commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; 661 commandPoolCreateInfo.pNext = nullptr; 662 commandPoolCreateInfo.flags = 0; 663 commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex; 664 RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo, 665 /*pAllocator=*/nullptr, 666 &commandPool), 667 "vkCreateCommandPool"); 668 return success(); 669 } 670 671 LogicalResult VulkanRuntime::createQueryPool() { 672 // Return directly if timestamp query is not supported. 673 if (queueFamilyProperties.timestampValidBits == 0) 674 return success(); 675 676 // Get timestamp period for this physical device. 677 VkPhysicalDeviceProperties deviceProperties = {}; 678 vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties); 679 timestampPeriod = deviceProperties.limits.timestampPeriod; 680 681 // Create query pool. 682 VkQueryPoolCreateInfo queryPoolCreateInfo = {}; 683 queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO; 684 queryPoolCreateInfo.pNext = nullptr; 685 queryPoolCreateInfo.flags = 0; 686 queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP; 687 queryPoolCreateInfo.queryCount = 2; 688 queryPoolCreateInfo.pipelineStatistics = 0; 689 RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo, 690 /*pAllocator=*/nullptr, &queryPool), 691 "vkCreateQueryPool"); 692 693 return success(); 694 } 695 696 LogicalResult VulkanRuntime::createComputeCommandBuffer() { 697 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {}; 698 commandBufferAllocateInfo.sType = 699 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; 700 commandBufferAllocateInfo.pNext = nullptr; 701 commandBufferAllocateInfo.commandPool = commandPool; 702 commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; 703 commandBufferAllocateInfo.commandBufferCount = 1; 704 705 VkCommandBuffer commandBuffer; 706 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device, 707 &commandBufferAllocateInfo, 708 &commandBuffer), 709 "vkAllocateCommandBuffers"); 710 711 VkCommandBufferBeginInfo commandBufferBeginInfo = {}; 712 commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; 713 commandBufferBeginInfo.pNext = nullptr; 714 commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; 715 commandBufferBeginInfo.pInheritanceInfo = nullptr; 716 717 // Commands begin. 718 RETURN_ON_VULKAN_ERROR( 719 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo), 720 "vkBeginCommandBuffer"); 721 722 if (queryPool != VK_NULL_HANDLE) 723 vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2); 724 725 vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); 726 vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, 727 pipelineLayout, 0, descriptorSets.size(), 728 descriptorSets.data(), 0, 0); 729 // Get a timestamp before invoking the compute shader. 730 if (queryPool != VK_NULL_HANDLE) 731 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, 732 queryPool, 0); 733 vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y, 734 numWorkGroups.z); 735 // Get another timestamp after invoking the compute shader. 736 if (queryPool != VK_NULL_HANDLE) 737 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, 738 queryPool, 1); 739 740 // Commands end. 741 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer), 742 "vkEndCommandBuffer"); 743 744 commandBuffers.push_back(commandBuffer); 745 return success(); 746 } 747 748 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() { 749 VkSubmitInfo submitInfo = {}; 750 submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; 751 submitInfo.pNext = nullptr; 752 submitInfo.waitSemaphoreCount = 0; 753 submitInfo.pWaitSemaphores = 0; 754 submitInfo.pWaitDstStageMask = 0; 755 submitInfo.commandBufferCount = commandBuffers.size(); 756 submitInfo.pCommandBuffers = commandBuffers.data(); 757 submitInfo.signalSemaphoreCount = 0; 758 submitInfo.pSignalSemaphores = nullptr; 759 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0), 760 "vkQueueSubmit"); 761 return success(); 762 } 763 764 LogicalResult VulkanRuntime::updateHostMemoryBuffers() { 765 // For each descriptor set. 766 for (auto &resourceDataMapPair : resourceData) { 767 auto &resourceDataMap = resourceDataMapPair.second; 768 auto &deviceMemoryBuffers = 769 deviceMemoryBufferMap[resourceDataMapPair.first]; 770 // For each device memory buffer in the set. 771 for (auto &deviceMemoryBuffer : deviceMemoryBuffers) { 772 if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) { 773 void *payload; 774 auto &hostMemoryBuffer = 775 resourceDataMap[deviceMemoryBuffer.bindingIndex]; 776 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, 777 deviceMemoryBuffer.deviceMemory, 0, 778 hostMemoryBuffer.size, 0, 779 reinterpret_cast<void **>(&payload)), 780 "vkMapMemory"); 781 std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size); 782 vkUnmapMemory(device, deviceMemoryBuffer.deviceMemory); 783 } 784 } 785 } 786 return success(); 787 } 788