1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides a library for running a module on a Vulkan device. 10 // Implements a Vulkan runtime. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "VulkanRuntime.h" 15 16 #include "llvm/Support/Format.h" 17 #include <chrono> 18 19 using namespace mlir; 20 21 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) { 22 numWorkGroups = numberWorkGroups; 23 } 24 25 void VulkanRuntime::setResourceStorageClassBindingMap( 26 const ResourceStorageClassBindingMap &stClassData) { 27 resourceStorageClassData = stClassData; 28 } 29 30 void VulkanRuntime::setResourceData( 31 const DescriptorSetIndex desIndex, const BindingIndex bindIndex, 32 const VulkanHostMemoryBuffer &hostMemBuffer) { 33 resourceData[desIndex][bindIndex] = hostMemBuffer; 34 resourceStorageClassData[desIndex][bindIndex] = 35 spirv::StorageClass::StorageBuffer; 36 } 37 38 void VulkanRuntime::setEntryPoint(const char *entryPointName) { 39 entryPoint = entryPointName; 40 } 41 42 void VulkanRuntime::setResourceData(const ResourceData &resData) { 43 resourceData = resData; 44 } 45 46 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) { 47 binary = shader; 48 binarySize = size; 49 } 50 51 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType( 52 spirv::StorageClass storageClass, VkDescriptorType &descriptorType) { 53 switch (storageClass) { 54 case spirv::StorageClass::StorageBuffer: 55 descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; 56 break; 57 case spirv::StorageClass::Uniform: 58 descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; 59 break; 60 default: 61 llvm::errs() << "unsupported storage class"; 62 return failure(); 63 } 64 return success(); 65 } 66 67 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag( 68 spirv::StorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) { 69 switch (storageClass) { 70 case spirv::StorageClass::StorageBuffer: 71 bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; 72 break; 73 case spirv::StorageClass::Uniform: 74 bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; 75 break; 76 default: 77 llvm::errs() << "unsupported storage class"; 78 return failure(); 79 } 80 return success(); 81 } 82 83 LogicalResult VulkanRuntime::countDeviceMemorySize() { 84 for (const auto &resourceDataMapPair : resourceData) { 85 const auto &resourceDataMap = resourceDataMapPair.second; 86 for (const auto &resourceDataBindingPair : resourceDataMap) { 87 if (resourceDataBindingPair.second.size) { 88 memorySize += resourceDataBindingPair.second.size; 89 } else { 90 llvm::errs() 91 << "expected buffer size greater than zero for resource data"; 92 return failure(); 93 } 94 } 95 } 96 return success(); 97 } 98 99 LogicalResult VulkanRuntime::initRuntime() { 100 if (!resourceData.size()) { 101 llvm::errs() << "Vulkan runtime needs at least one resource"; 102 return failure(); 103 } 104 if (!binarySize || !binary) { 105 llvm::errs() << "binary shader size must be greater than zero"; 106 return failure(); 107 } 108 if (failed(countDeviceMemorySize())) { 109 return failure(); 110 } 111 return success(); 112 } 113 114 LogicalResult VulkanRuntime::destroy() { 115 // According to Vulkan spec: 116 // "To ensure that no work is active on the device, vkDeviceWaitIdle can be 117 // used to gate the destruction of the device. Prior to destroying a device, 118 // an application is responsible for destroying/freeing any Vulkan objects 119 // that were created using that device as the first parameter of the 120 // corresponding vkCreate* or vkAllocate* command." 121 RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle"); 122 123 // Free and destroy. 124 vkFreeCommandBuffers(device, commandPool, commandBuffers.size(), 125 commandBuffers.data()); 126 vkDestroyQueryPool(device, queryPool, nullptr); 127 vkDestroyCommandPool(device, commandPool, nullptr); 128 vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(), 129 descriptorSets.data()); 130 vkDestroyDescriptorPool(device, descriptorPool, nullptr); 131 vkDestroyPipeline(device, pipeline, nullptr); 132 vkDestroyPipelineLayout(device, pipelineLayout, nullptr); 133 for (auto &descriptorSetLayout: descriptorSetLayouts) { 134 vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr); 135 } 136 vkDestroyShaderModule(device, shaderModule, nullptr); 137 138 // For each descriptor set. 139 for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 140 auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 141 // For each descriptor binding. 142 for (auto &memoryBuffer : deviceMemoryBuffers) { 143 vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr); 144 vkDestroyBuffer(device, memoryBuffer.buffer, nullptr); 145 } 146 } 147 148 vkDestroyDevice(device, nullptr); 149 vkDestroyInstance(instance, nullptr); 150 return success(); 151 } 152 153 LogicalResult VulkanRuntime::run() { 154 // Create logical device, shader module and memory buffers. 155 if (failed(createInstance()) || failed(createDevice()) || 156 failed(createMemoryBuffers()) || failed(createShaderModule())) { 157 return failure(); 158 } 159 160 // Descriptor bindings divided into sets. Each descriptor binding 161 // must have a layout binding attached into a descriptor set layout. 162 // Each layout set must be binded into a pipeline layout. 163 initDescriptorSetLayoutBindingMap(); 164 if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) || 165 // Each descriptor set must be allocated from a descriptor pool. 166 failed(createComputePipeline()) || failed(createDescriptorPool()) || 167 failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) || 168 // Create command buffer. 169 failed(createCommandPool()) || failed(createQueryPool()) || 170 failed(createComputeCommandBuffer())) { 171 return failure(); 172 } 173 174 // Get working queue. 175 vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue); 176 177 auto submitStart = std::chrono::high_resolution_clock::now(); 178 // Submit command buffer into the queue. 179 if (failed(submitCommandBuffersToQueue())) 180 return failure(); 181 auto submitEnd = std::chrono::high_resolution_clock::now(); 182 183 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle"); 184 auto execEnd = std::chrono::high_resolution_clock::now(); 185 186 auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>( 187 submitEnd - submitStart); 188 auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>( 189 execEnd - submitEnd); 190 191 if (queryPool != VK_NULL_HANDLE) { 192 uint64_t timestamps[2]; 193 RETURN_ON_VULKAN_ERROR( 194 vkGetQueryPoolResults( 195 device, queryPool, /*firstQuery=*/0, /*queryCount=*/2, 196 /*dataSize=*/sizeof(timestamps), 197 /*pData=*/reinterpret_cast<void *>(timestamps), 198 /*stride=*/sizeof(uint64_t), 199 VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT), 200 "vkGetQueryPoolResults"); 201 float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000; 202 llvm::outs() << "Compute shader execution time: " 203 << llvm::format("%0.3fus\n", microsec); 204 } 205 206 llvm::outs() << "Command buffer submit time: " << submitDuration.count() 207 << "us\nWait idle time: " << execDuration.count() << "us\n"; 208 209 return success(); 210 } 211 212 LogicalResult VulkanRuntime::createInstance() { 213 VkApplicationInfo applicationInfo = {}; 214 applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; 215 applicationInfo.pNext = nullptr; 216 applicationInfo.pApplicationName = "MLIR Vulkan runtime"; 217 applicationInfo.applicationVersion = 0; 218 applicationInfo.pEngineName = "mlir"; 219 applicationInfo.engineVersion = 0; 220 applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0); 221 222 VkInstanceCreateInfo instanceCreateInfo = {}; 223 instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; 224 instanceCreateInfo.pNext = nullptr; 225 instanceCreateInfo.flags = 0; 226 instanceCreateInfo.pApplicationInfo = &applicationInfo; 227 instanceCreateInfo.enabledLayerCount = 0; 228 instanceCreateInfo.ppEnabledLayerNames = 0; 229 instanceCreateInfo.enabledExtensionCount = 0; 230 instanceCreateInfo.ppEnabledExtensionNames = 0; 231 232 RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance), 233 "vkCreateInstance"); 234 return success(); 235 } 236 237 LogicalResult VulkanRuntime::createDevice() { 238 uint32_t physicalDeviceCount = 0; 239 RETURN_ON_VULKAN_ERROR( 240 vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0), 241 "vkEnumeratePhysicalDevices"); 242 243 llvm::SmallVector<VkPhysicalDevice, 1> physicalDevices(physicalDeviceCount); 244 RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance, 245 &physicalDeviceCount, 246 physicalDevices.data()), 247 "vkEnumeratePhysicalDevices"); 248 249 RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE, 250 "physicalDeviceCount"); 251 252 // TODO(denis0x0D): find the best device. 253 physicalDevice = physicalDevices.front(); 254 if (failed(getBestComputeQueue())) 255 return failure(); 256 257 const float queuePriority = 1.0f; 258 VkDeviceQueueCreateInfo deviceQueueCreateInfo = {}; 259 deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; 260 deviceQueueCreateInfo.pNext = nullptr; 261 deviceQueueCreateInfo.flags = 0; 262 deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex; 263 deviceQueueCreateInfo.queueCount = 1; 264 deviceQueueCreateInfo.pQueuePriorities = &queuePriority; 265 266 // Structure specifying parameters of a newly created device. 267 VkDeviceCreateInfo deviceCreateInfo = {}; 268 deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; 269 deviceCreateInfo.pNext = nullptr; 270 deviceCreateInfo.flags = 0; 271 deviceCreateInfo.queueCreateInfoCount = 1; 272 deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo; 273 deviceCreateInfo.enabledLayerCount = 0; 274 deviceCreateInfo.ppEnabledLayerNames = nullptr; 275 deviceCreateInfo.enabledExtensionCount = 0; 276 deviceCreateInfo.ppEnabledExtensionNames = nullptr; 277 deviceCreateInfo.pEnabledFeatures = nullptr; 278 279 RETURN_ON_VULKAN_ERROR( 280 vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device), 281 "vkCreateDevice"); 282 283 VkPhysicalDeviceMemoryProperties properties = {}; 284 vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties); 285 286 // Try to find memory type with following properties: 287 // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated 288 // with this type can be mapped for host access using vkMapMemory; 289 // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache 290 // management commands vkFlushMappedMemoryRanges and 291 // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the 292 // device or make device writes visible to the host, respectively. 293 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) { 294 if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT & 295 properties.memoryTypes[i].propertyFlags) && 296 (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT & 297 properties.memoryTypes[i].propertyFlags) && 298 (memorySize <= 299 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) { 300 memoryTypeIndex = i; 301 break; 302 } 303 } 304 305 RETURN_ON_VULKAN_ERROR(memoryTypeIndex == VK_MAX_MEMORY_TYPES ? VK_INCOMPLETE 306 : VK_SUCCESS, 307 "invalid memoryTypeIndex"); 308 return success(); 309 } 310 311 LogicalResult VulkanRuntime::getBestComputeQueue() { 312 uint32_t queueFamilyPropertiesCount = 0; 313 vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, 314 &queueFamilyPropertiesCount, 0); 315 316 SmallVector<VkQueueFamilyProperties, 1> familyProperties( 317 queueFamilyPropertiesCount); 318 vkGetPhysicalDeviceQueueFamilyProperties( 319 physicalDevice, &queueFamilyPropertiesCount, familyProperties.data()); 320 321 // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support 322 // compute operations. Try to find a compute-only queue first if possible. 323 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { 324 auto flags = familyProperties[i].queueFlags; 325 if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) { 326 queueFamilyIndex = i; 327 queueFamilyProperties = familyProperties[i]; 328 return success(); 329 } 330 } 331 332 // Otherwise use a queue that can also support graphics. 333 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { 334 auto flags = familyProperties[i].queueFlags; 335 if ((flags & VK_QUEUE_COMPUTE_BIT)) { 336 queueFamilyIndex = i; 337 queueFamilyProperties = familyProperties[i]; 338 return success(); 339 } 340 } 341 342 llvm::errs() << "cannot find valid queue"; 343 return failure(); 344 } 345 346 LogicalResult VulkanRuntime::createMemoryBuffers() { 347 // For each descriptor set. 348 for (const auto &resourceDataMapPair : resourceData) { 349 llvm::SmallVector<VulkanDeviceMemoryBuffer, 1> deviceMemoryBuffers; 350 const auto descriptorSetIndex = resourceDataMapPair.first; 351 const auto &resourceDataMap = resourceDataMapPair.second; 352 353 // For each descriptor binding. 354 for (const auto &resourceDataBindingPair : resourceDataMap) { 355 // Create device memory buffer. 356 VulkanDeviceMemoryBuffer memoryBuffer; 357 memoryBuffer.bindingIndex = resourceDataBindingPair.first; 358 VkDescriptorType descriptorType = {}; 359 VkBufferUsageFlagBits bufferUsage = {}; 360 361 // Check that descriptor set has storage class map. 362 const auto resourceStorageClassMapIt = 363 resourceStorageClassData.find(descriptorSetIndex); 364 if (resourceStorageClassMapIt == resourceStorageClassData.end()) { 365 llvm::errs() 366 << "cannot find storage class for resource in descriptor set: " 367 << descriptorSetIndex; 368 return failure(); 369 } 370 371 // Check that specific descriptor binding has storage class. 372 const auto &resourceStorageClassMap = resourceStorageClassMapIt->second; 373 const auto resourceStorageClassIt = 374 resourceStorageClassMap.find(resourceDataBindingPair.first); 375 if (resourceStorageClassIt == resourceStorageClassMap.end()) { 376 llvm::errs() 377 << "cannot find storage class for resource with descriptor index: " 378 << resourceDataBindingPair.first; 379 return failure(); 380 } 381 382 const auto resourceStorageClassBinding = resourceStorageClassIt->second; 383 if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding, 384 descriptorType)) || 385 failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding, 386 bufferUsage))) { 387 llvm::errs() << "storage class for resource with descriptor binding: " 388 << resourceDataBindingPair.first 389 << " in the descriptor set: " << descriptorSetIndex 390 << " is not supported "; 391 return failure(); 392 } 393 394 // Set descriptor type for the specific device memory buffer. 395 memoryBuffer.descriptorType = descriptorType; 396 const auto bufferSize = resourceDataBindingPair.second.size; 397 398 // Specify memory allocation info. 399 VkMemoryAllocateInfo memoryAllocateInfo = {}; 400 memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; 401 memoryAllocateInfo.pNext = nullptr; 402 memoryAllocateInfo.allocationSize = bufferSize; 403 memoryAllocateInfo.memoryTypeIndex = memoryTypeIndex; 404 405 // Allocate device memory. 406 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0, 407 &memoryBuffer.deviceMemory), 408 "vkAllocateMemory"); 409 void *payload; 410 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.deviceMemory, 0, 411 bufferSize, 0, 412 reinterpret_cast<void **>(&payload)), 413 "vkMapMemory"); 414 415 // Copy host memory into the mapped area. 416 std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize); 417 vkUnmapMemory(device, memoryBuffer.deviceMemory); 418 419 VkBufferCreateInfo bufferCreateInfo = {}; 420 bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; 421 bufferCreateInfo.pNext = nullptr; 422 bufferCreateInfo.flags = 0; 423 bufferCreateInfo.size = bufferSize; 424 bufferCreateInfo.usage = bufferUsage; 425 bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; 426 bufferCreateInfo.queueFamilyIndexCount = 1; 427 bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex; 428 RETURN_ON_VULKAN_ERROR( 429 vkCreateBuffer(device, &bufferCreateInfo, 0, &memoryBuffer.buffer), 430 "vkCreateBuffer"); 431 432 // Bind buffer and device memory. 433 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.buffer, 434 memoryBuffer.deviceMemory, 0), 435 "vkBindBufferMemory"); 436 437 // Update buffer info. 438 memoryBuffer.bufferInfo.buffer = memoryBuffer.buffer; 439 memoryBuffer.bufferInfo.offset = 0; 440 memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE; 441 deviceMemoryBuffers.push_back(memoryBuffer); 442 } 443 444 // Associate device memory buffers with a descriptor set. 445 deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers; 446 } 447 return success(); 448 } 449 450 LogicalResult VulkanRuntime::createShaderModule() { 451 VkShaderModuleCreateInfo shaderModuleCreateInfo = {}; 452 shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; 453 shaderModuleCreateInfo.pNext = nullptr; 454 shaderModuleCreateInfo.flags = 0; 455 // Set size in bytes. 456 shaderModuleCreateInfo.codeSize = binarySize; 457 // Set pointer to the binary shader. 458 shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary); 459 RETURN_ON_VULKAN_ERROR( 460 vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule), 461 "vkCreateShaderModule"); 462 return success(); 463 } 464 465 void VulkanRuntime::initDescriptorSetLayoutBindingMap() { 466 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 467 SmallVector<VkDescriptorSetLayoutBinding, 1> descriptorSetLayoutBindings; 468 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 469 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; 470 471 // Create a layout binding for each descriptor. 472 for (const auto &memBuffer : deviceMemoryBuffers) { 473 VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {}; 474 descriptorSetLayoutBinding.binding = memBuffer.bindingIndex; 475 descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType; 476 descriptorSetLayoutBinding.descriptorCount = 1; 477 descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; 478 descriptorSetLayoutBinding.pImmutableSamplers = 0; 479 descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding); 480 } 481 descriptorSetLayoutBindingMap[descriptorSetIndex] = 482 descriptorSetLayoutBindings; 483 } 484 } 485 486 LogicalResult VulkanRuntime::createDescriptorSetLayout() { 487 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 488 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; 489 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 490 // Each descriptor in a descriptor set must be the same type. 491 VkDescriptorType descriptorType = 492 deviceMemoryBuffers.front().descriptorType; 493 const uint32_t descriptorSize = deviceMemoryBuffers.size(); 494 const auto descriptorSetLayoutBindingIt = 495 descriptorSetLayoutBindingMap.find(descriptorSetIndex); 496 497 if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) { 498 llvm::errs() << "cannot find layout bindings for the set with number: " 499 << descriptorSetIndex; 500 return failure(); 501 } 502 503 const auto &descriptorSetLayoutBindings = 504 descriptorSetLayoutBindingIt->second; 505 // Create descriptor set layout. 506 VkDescriptorSetLayout descriptorSetLayout = {}; 507 VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {}; 508 509 descriptorSetLayoutCreateInfo.sType = 510 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; 511 descriptorSetLayoutCreateInfo.pNext = nullptr; 512 descriptorSetLayoutCreateInfo.flags = 0; 513 // Amount of descriptor bindings in a layout set. 514 descriptorSetLayoutCreateInfo.bindingCount = 515 descriptorSetLayoutBindings.size(); 516 descriptorSetLayoutCreateInfo.pBindings = 517 descriptorSetLayoutBindings.data(); 518 RETURN_ON_VULKAN_ERROR( 519 vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0, 520 &descriptorSetLayout), 521 "vkCreateDescriptorSetLayout"); 522 523 descriptorSetLayouts.push_back(descriptorSetLayout); 524 descriptorSetInfoPool.push_back( 525 {descriptorSetIndex, descriptorSize, descriptorType}); 526 } 527 return success(); 528 } 529 530 LogicalResult VulkanRuntime::createPipelineLayout() { 531 // Associate descriptor sets with a pipeline layout. 532 VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {}; 533 pipelineLayoutCreateInfo.sType = 534 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; 535 pipelineLayoutCreateInfo.pNext = nullptr; 536 pipelineLayoutCreateInfo.flags = 0; 537 pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size(); 538 pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data(); 539 pipelineLayoutCreateInfo.pushConstantRangeCount = 0; 540 pipelineLayoutCreateInfo.pPushConstantRanges = 0; 541 RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device, 542 &pipelineLayoutCreateInfo, 0, 543 &pipelineLayout), 544 "vkCreatePipelineLayout"); 545 return success(); 546 } 547 548 LogicalResult VulkanRuntime::createComputePipeline() { 549 VkPipelineShaderStageCreateInfo stageInfo = {}; 550 stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; 551 stageInfo.pNext = nullptr; 552 stageInfo.flags = 0; 553 stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT; 554 stageInfo.module = shaderModule; 555 // Set entry point. 556 stageInfo.pName = entryPoint; 557 stageInfo.pSpecializationInfo = 0; 558 559 VkComputePipelineCreateInfo computePipelineCreateInfo = {}; 560 computePipelineCreateInfo.sType = 561 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; 562 computePipelineCreateInfo.pNext = nullptr; 563 computePipelineCreateInfo.flags = 0; 564 computePipelineCreateInfo.stage = stageInfo; 565 computePipelineCreateInfo.layout = pipelineLayout; 566 computePipelineCreateInfo.basePipelineHandle = 0; 567 computePipelineCreateInfo.basePipelineIndex = 0; 568 RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1, 569 &computePipelineCreateInfo, 0, 570 &pipeline), 571 "vkCreateComputePipelines"); 572 return success(); 573 } 574 575 LogicalResult VulkanRuntime::createDescriptorPool() { 576 llvm::SmallVector<VkDescriptorPoolSize, 1> descriptorPoolSizes; 577 for (const auto &descriptorSetInfo : descriptorSetInfoPool) { 578 // For each descriptor set populate descriptor pool size. 579 VkDescriptorPoolSize descriptorPoolSize = {}; 580 descriptorPoolSize.type = descriptorSetInfo.descriptorType; 581 descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize; 582 descriptorPoolSizes.push_back(descriptorPoolSize); 583 } 584 585 VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {}; 586 descriptorPoolCreateInfo.sType = 587 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; 588 descriptorPoolCreateInfo.pNext = nullptr; 589 descriptorPoolCreateInfo.flags = 0; 590 descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size(); 591 descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size(); 592 descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data(); 593 RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device, 594 &descriptorPoolCreateInfo, 0, 595 &descriptorPool), 596 "vkCreateDescriptorPool"); 597 return success(); 598 } 599 600 LogicalResult VulkanRuntime::allocateDescriptorSets() { 601 VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {}; 602 // Size of descriptor sets and descriptor layout sets is the same. 603 descriptorSets.resize(descriptorSetLayouts.size()); 604 descriptorSetAllocateInfo.sType = 605 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; 606 descriptorSetAllocateInfo.pNext = nullptr; 607 descriptorSetAllocateInfo.descriptorPool = descriptorPool; 608 descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size(); 609 descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data(); 610 RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device, 611 &descriptorSetAllocateInfo, 612 descriptorSets.data()), 613 "vkAllocateDescriptorSets"); 614 return success(); 615 } 616 617 LogicalResult VulkanRuntime::setWriteDescriptors() { 618 if (descriptorSets.size() != descriptorSetInfoPool.size()) { 619 llvm::errs() << "Each descriptor set must have descriptor set information"; 620 return failure(); 621 } 622 // For each descriptor set. 623 auto descriptorSetIt = descriptorSets.begin(); 624 // Each descriptor set is associated with descriptor set info. 625 for (const auto &descriptorSetInfo : descriptorSetInfoPool) { 626 // For each device memory buffer in the descriptor set. 627 const auto &deviceMemoryBuffers = 628 deviceMemoryBufferMap[descriptorSetInfo.descriptorSet]; 629 for (const auto &memoryBuffer : deviceMemoryBuffers) { 630 // Structure describing descriptor sets to write to. 631 VkWriteDescriptorSet wSet = {}; 632 wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; 633 wSet.pNext = nullptr; 634 // Descriptor set. 635 wSet.dstSet = *descriptorSetIt; 636 wSet.dstBinding = memoryBuffer.bindingIndex; 637 wSet.dstArrayElement = 0; 638 wSet.descriptorCount = 1; 639 wSet.descriptorType = memoryBuffer.descriptorType; 640 wSet.pImageInfo = nullptr; 641 wSet.pBufferInfo = &memoryBuffer.bufferInfo; 642 wSet.pTexelBufferView = nullptr; 643 vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr); 644 } 645 // Increment descriptor set iterator. 646 ++descriptorSetIt; 647 } 648 return success(); 649 } 650 651 LogicalResult VulkanRuntime::createCommandPool() { 652 VkCommandPoolCreateInfo commandPoolCreateInfo = {}; 653 commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; 654 commandPoolCreateInfo.pNext = nullptr; 655 commandPoolCreateInfo.flags = 0; 656 commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex; 657 RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo, 658 /*pAllocator=*/nullptr, 659 &commandPool), 660 "vkCreateCommandPool"); 661 return success(); 662 } 663 664 LogicalResult VulkanRuntime::createQueryPool() { 665 // Return directly if timestamp query is not supported. 666 if (queueFamilyProperties.timestampValidBits == 0) 667 return success(); 668 669 // Get timestamp period for this physical device. 670 VkPhysicalDeviceProperties deviceProperties = {}; 671 vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties); 672 timestampPeriod = deviceProperties.limits.timestampPeriod; 673 674 // Create query pool. 675 VkQueryPoolCreateInfo queryPoolCreateInfo = {}; 676 queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO; 677 queryPoolCreateInfo.pNext = nullptr; 678 queryPoolCreateInfo.flags = 0; 679 queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP; 680 queryPoolCreateInfo.queryCount = 2; 681 queryPoolCreateInfo.pipelineStatistics = 0; 682 RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo, 683 /*pAllocator=*/nullptr, &queryPool), 684 "vkCreateQueryPool"); 685 686 return success(); 687 } 688 689 LogicalResult VulkanRuntime::createComputeCommandBuffer() { 690 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {}; 691 commandBufferAllocateInfo.sType = 692 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; 693 commandBufferAllocateInfo.pNext = nullptr; 694 commandBufferAllocateInfo.commandPool = commandPool; 695 commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; 696 commandBufferAllocateInfo.commandBufferCount = 1; 697 698 VkCommandBuffer commandBuffer; 699 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device, 700 &commandBufferAllocateInfo, 701 &commandBuffer), 702 "vkAllocateCommandBuffers"); 703 704 VkCommandBufferBeginInfo commandBufferBeginInfo = {}; 705 commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; 706 commandBufferBeginInfo.pNext = nullptr; 707 commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; 708 commandBufferBeginInfo.pInheritanceInfo = nullptr; 709 710 // Commands begin. 711 RETURN_ON_VULKAN_ERROR( 712 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo), 713 "vkBeginCommandBuffer"); 714 715 if (queryPool != VK_NULL_HANDLE) 716 vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2); 717 718 vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); 719 vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, 720 pipelineLayout, 0, descriptorSets.size(), 721 descriptorSets.data(), 0, 0); 722 // Get a timestamp before invoking the compute shader. 723 if (queryPool != VK_NULL_HANDLE) 724 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 725 queryPool, 0); 726 vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y, 727 numWorkGroups.z); 728 // Get another timestamp after invoking the compute shader. 729 if (queryPool != VK_NULL_HANDLE) 730 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, 731 queryPool, 1); 732 733 // Commands end. 734 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer), 735 "vkEndCommandBuffer"); 736 737 commandBuffers.push_back(commandBuffer); 738 return success(); 739 } 740 741 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() { 742 VkSubmitInfo submitInfo = {}; 743 submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; 744 submitInfo.pNext = nullptr; 745 submitInfo.waitSemaphoreCount = 0; 746 submitInfo.pWaitSemaphores = 0; 747 submitInfo.pWaitDstStageMask = 0; 748 submitInfo.commandBufferCount = commandBuffers.size(); 749 submitInfo.pCommandBuffers = commandBuffers.data(); 750 submitInfo.signalSemaphoreCount = 0; 751 submitInfo.pSignalSemaphores = nullptr; 752 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0), 753 "vkQueueSubmit"); 754 return success(); 755 } 756 757 LogicalResult VulkanRuntime::updateHostMemoryBuffers() { 758 // For each descriptor set. 759 for (auto &resourceDataMapPair : resourceData) { 760 auto &resourceDataMap = resourceDataMapPair.second; 761 auto &deviceMemoryBuffers = 762 deviceMemoryBufferMap[resourceDataMapPair.first]; 763 // For each device memory buffer in the set. 764 for (auto &deviceMemoryBuffer : deviceMemoryBuffers) { 765 if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) { 766 void *payload; 767 auto &hostMemoryBuffer = 768 resourceDataMap[deviceMemoryBuffer.bindingIndex]; 769 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, 770 deviceMemoryBuffer.deviceMemory, 0, 771 hostMemoryBuffer.size, 0, 772 reinterpret_cast<void **>(&payload)), 773 "vkMapMemory"); 774 std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size); 775 vkUnmapMemory(device, deviceMemoryBuffer.deviceMemory); 776 } 777 } 778 } 779 return success(); 780 } 781