1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a library for running a module on a Vulkan device.
10 // Implements a Vulkan runtime.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "VulkanRuntime.h"
15 
16 #include <chrono>
17 #include <cstring>
18 // TODO(antiagainst): It's generally bad to access stdout/stderr in a library.
19 // Figure out a better way for error reporting.
20 #include <iomanip>
21 #include <iostream>
22 
23 inline void emitVulkanError(const char *api, VkResult error) {
24   std::cerr << " failed with error code " << error << " when executing " << api;
25 }
26 
27 #define RETURN_ON_VULKAN_ERROR(result, api)                                    \
28   if ((result) != VK_SUCCESS) {                                                \
29     emitVulkanError(api, (result));                                            \
30     return failure();                                                          \
31   }
32 
33 using namespace mlir;
34 
35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
36   numWorkGroups = numberWorkGroups;
37 }
38 
39 void VulkanRuntime::setResourceStorageClassBindingMap(
40     const ResourceStorageClassBindingMap &stClassData) {
41   resourceStorageClassData = stClassData;
42 }
43 
44 void VulkanRuntime::setResourceData(
45     const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
46     const VulkanHostMemoryBuffer &hostMemBuffer) {
47   resourceData[desIndex][bindIndex] = hostMemBuffer;
48   resourceStorageClassData[desIndex][bindIndex] =
49       SPIRVStorageClass::StorageBuffer;
50 }
51 
52 void VulkanRuntime::setEntryPoint(const char *entryPointName) {
53   entryPoint = entryPointName;
54 }
55 
56 void VulkanRuntime::setResourceData(const ResourceData &resData) {
57   resourceData = resData;
58 }
59 
60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
61   binary = shader;
62   binarySize = size;
63 }
64 
65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
66     SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) {
67   switch (storageClass) {
68   case SPIRVStorageClass::StorageBuffer:
69     descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
70     break;
71   case SPIRVStorageClass::Uniform:
72     descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73     break;
74   }
75   return success();
76 }
77 
78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
79     SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
80   switch (storageClass) {
81   case SPIRVStorageClass::StorageBuffer:
82     bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
83     break;
84   case SPIRVStorageClass::Uniform:
85     bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
86     break;
87   }
88   return success();
89 }
90 
91 LogicalResult VulkanRuntime::countDeviceMemorySize() {
92   for (const auto &resourceDataMapPair : resourceData) {
93     const auto &resourceDataMap = resourceDataMapPair.second;
94     for (const auto &resourceDataBindingPair : resourceDataMap) {
95       if (resourceDataBindingPair.second.size) {
96         memorySize += resourceDataBindingPair.second.size;
97       } else {
98         std::cerr << "expected buffer size greater than zero for resource data";
99         return failure();
100       }
101     }
102   }
103   return success();
104 }
105 
106 LogicalResult VulkanRuntime::initRuntime() {
107   if (!resourceData.size()) {
108     std::cerr << "Vulkan runtime needs at least one resource";
109     return failure();
110   }
111   if (!binarySize || !binary) {
112     std::cerr << "binary shader size must be greater than zero";
113     return failure();
114   }
115   if (failed(countDeviceMemorySize())) {
116     return failure();
117   }
118   return success();
119 }
120 
121 LogicalResult VulkanRuntime::destroy() {
122   // According to Vulkan spec:
123   // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124   // used to gate the destruction of the device. Prior to destroying a device,
125   // an application is responsible for destroying/freeing any Vulkan objects
126   // that were created using that device as the first parameter of the
127   // corresponding vkCreate* or vkAllocate* command."
128   RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
129 
130   // Free and destroy.
131   vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
132                        commandBuffers.data());
133   vkDestroyQueryPool(device, queryPool, nullptr);
134   vkDestroyCommandPool(device, commandPool, nullptr);
135   vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
136                        descriptorSets.data());
137   vkDestroyDescriptorPool(device, descriptorPool, nullptr);
138   vkDestroyPipeline(device, pipeline, nullptr);
139   vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
140   for (auto &descriptorSetLayout : descriptorSetLayouts) {
141     vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
142   }
143   vkDestroyShaderModule(device, shaderModule, nullptr);
144 
145   // For each descriptor set.
146   for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
147     auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
148     // For each descriptor binding.
149     for (auto &memoryBuffer : deviceMemoryBuffers) {
150       vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
151       vkDestroyBuffer(device, memoryBuffer.buffer, nullptr);
152     }
153   }
154 
155   vkDestroyDevice(device, nullptr);
156   vkDestroyInstance(instance, nullptr);
157   return success();
158 }
159 
160 LogicalResult VulkanRuntime::run() {
161   // Create logical device, shader module and memory buffers.
162   if (failed(createInstance()) || failed(createDevice()) ||
163       failed(createMemoryBuffers()) || failed(createShaderModule())) {
164     return failure();
165   }
166 
167   // Descriptor bindings divided into sets. Each descriptor binding
168   // must have a layout binding attached into a descriptor set layout.
169   // Each layout set must be binded into a pipeline layout.
170   initDescriptorSetLayoutBindingMap();
171   if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
172       // Each descriptor set must be allocated from a descriptor pool.
173       failed(createComputePipeline()) || failed(createDescriptorPool()) ||
174       failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
175       // Create command buffer.
176       failed(createCommandPool()) || failed(createQueryPool()) ||
177       failed(createComputeCommandBuffer())) {
178     return failure();
179   }
180 
181   // Get working queue.
182   vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
183 
184   auto submitStart = std::chrono::high_resolution_clock::now();
185   // Submit command buffer into the queue.
186   if (failed(submitCommandBuffersToQueue()))
187     return failure();
188   auto submitEnd = std::chrono::high_resolution_clock::now();
189 
190   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
191   auto execEnd = std::chrono::high_resolution_clock::now();
192 
193   auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
194       submitEnd - submitStart);
195   auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
196       execEnd - submitEnd);
197 
198   if (queryPool != VK_NULL_HANDLE) {
199     uint64_t timestamps[2];
200     RETURN_ON_VULKAN_ERROR(
201         vkGetQueryPoolResults(
202             device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
203             /*dataSize=*/sizeof(timestamps),
204             /*pData=*/reinterpret_cast<void *>(timestamps),
205             /*stride=*/sizeof(uint64_t),
206             VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
207         "vkGetQueryPoolResults");
208     float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
209     std::cout << "Compute shader execution time: " << std::setprecision(3)
210               << microsec << "us\n";
211   }
212 
213   std::cout << "Command buffer submit time: " << submitDuration.count()
214             << "us\nWait idle time: " << execDuration.count() << "us\n";
215 
216   return success();
217 }
218 
219 LogicalResult VulkanRuntime::createInstance() {
220   VkApplicationInfo applicationInfo = {};
221   applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
222   applicationInfo.pNext = nullptr;
223   applicationInfo.pApplicationName = "MLIR Vulkan runtime";
224   applicationInfo.applicationVersion = 0;
225   applicationInfo.pEngineName = "mlir";
226   applicationInfo.engineVersion = 0;
227   applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
228 
229   VkInstanceCreateInfo instanceCreateInfo = {};
230   instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
231   instanceCreateInfo.pNext = nullptr;
232   instanceCreateInfo.flags = 0;
233   instanceCreateInfo.pApplicationInfo = &applicationInfo;
234   instanceCreateInfo.enabledLayerCount = 0;
235   instanceCreateInfo.ppEnabledLayerNames = 0;
236   instanceCreateInfo.enabledExtensionCount = 0;
237   instanceCreateInfo.ppEnabledExtensionNames = 0;
238 
239   RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance),
240                          "vkCreateInstance");
241   return success();
242 }
243 
244 LogicalResult VulkanRuntime::createDevice() {
245   uint32_t physicalDeviceCount = 0;
246   RETURN_ON_VULKAN_ERROR(
247       vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0),
248       "vkEnumeratePhysicalDevices");
249 
250   std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount);
251   RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
252                                                     &physicalDeviceCount,
253                                                     physicalDevices.data()),
254                          "vkEnumeratePhysicalDevices");
255 
256   RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
257                          "physicalDeviceCount");
258 
259   // TODO(denis0x0D): find the best device.
260   physicalDevice = physicalDevices.front();
261   if (failed(getBestComputeQueue()))
262     return failure();
263 
264   const float queuePriority = 1.0f;
265   VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
266   deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
267   deviceQueueCreateInfo.pNext = nullptr;
268   deviceQueueCreateInfo.flags = 0;
269   deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
270   deviceQueueCreateInfo.queueCount = 1;
271   deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
272 
273   // Structure specifying parameters of a newly created device.
274   VkDeviceCreateInfo deviceCreateInfo = {};
275   deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
276   deviceCreateInfo.pNext = nullptr;
277   deviceCreateInfo.flags = 0;
278   deviceCreateInfo.queueCreateInfoCount = 1;
279   deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
280   deviceCreateInfo.enabledLayerCount = 0;
281   deviceCreateInfo.ppEnabledLayerNames = nullptr;
282   deviceCreateInfo.enabledExtensionCount = 0;
283   deviceCreateInfo.ppEnabledExtensionNames = nullptr;
284   deviceCreateInfo.pEnabledFeatures = nullptr;
285 
286   RETURN_ON_VULKAN_ERROR(
287       vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device),
288       "vkCreateDevice");
289 
290   VkPhysicalDeviceMemoryProperties properties = {};
291   vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
292 
293   // Try to find memory type with following properties:
294   // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
295   // with this type can be mapped for host access using vkMapMemory;
296   // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
297   // management commands vkFlushMappedMemoryRanges and
298   // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
299   // device or make device writes visible to the host, respectively.
300   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
301     if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
302          properties.memoryTypes[i].propertyFlags) &&
303         (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
304          properties.memoryTypes[i].propertyFlags) &&
305         (memorySize <=
306          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
307       memoryTypeIndex = i;
308       break;
309     }
310   }
311 
312   RETURN_ON_VULKAN_ERROR(memoryTypeIndex == VK_MAX_MEMORY_TYPES ? VK_INCOMPLETE
313                                                                 : VK_SUCCESS,
314                          "invalid memoryTypeIndex");
315   return success();
316 }
317 
318 LogicalResult VulkanRuntime::getBestComputeQueue() {
319   uint32_t queueFamilyPropertiesCount = 0;
320   vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice,
321                                            &queueFamilyPropertiesCount, 0);
322 
323   std::vector<VkQueueFamilyProperties> familyProperties(
324       queueFamilyPropertiesCount);
325   vkGetPhysicalDeviceQueueFamilyProperties(
326       physicalDevice, &queueFamilyPropertiesCount, familyProperties.data());
327 
328   // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
329   // compute operations. Try to find a compute-only queue first if possible.
330   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
331     auto flags = familyProperties[i].queueFlags;
332     if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
333       queueFamilyIndex = i;
334       queueFamilyProperties = familyProperties[i];
335       return success();
336     }
337   }
338 
339   // Otherwise use a queue that can also support graphics.
340   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
341     auto flags = familyProperties[i].queueFlags;
342     if ((flags & VK_QUEUE_COMPUTE_BIT)) {
343       queueFamilyIndex = i;
344       queueFamilyProperties = familyProperties[i];
345       return success();
346     }
347   }
348 
349   std::cerr << "cannot find valid queue";
350   return failure();
351 }
352 
353 LogicalResult VulkanRuntime::createMemoryBuffers() {
354   // For each descriptor set.
355   for (const auto &resourceDataMapPair : resourceData) {
356     std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers;
357     const auto descriptorSetIndex = resourceDataMapPair.first;
358     const auto &resourceDataMap = resourceDataMapPair.second;
359 
360     // For each descriptor binding.
361     for (const auto &resourceDataBindingPair : resourceDataMap) {
362       // Create device memory buffer.
363       VulkanDeviceMemoryBuffer memoryBuffer;
364       memoryBuffer.bindingIndex = resourceDataBindingPair.first;
365       VkDescriptorType descriptorType = {};
366       VkBufferUsageFlagBits bufferUsage = {};
367 
368       // Check that descriptor set has storage class map.
369       const auto resourceStorageClassMapIt =
370           resourceStorageClassData.find(descriptorSetIndex);
371       if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
372         std::cerr
373             << "cannot find storage class for resource in descriptor set: "
374             << descriptorSetIndex;
375         return failure();
376       }
377 
378       // Check that specific descriptor binding has storage class.
379       const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
380       const auto resourceStorageClassIt =
381           resourceStorageClassMap.find(resourceDataBindingPair.first);
382       if (resourceStorageClassIt == resourceStorageClassMap.end()) {
383         std::cerr
384             << "cannot find storage class for resource with descriptor index: "
385             << resourceDataBindingPair.first;
386         return failure();
387       }
388 
389       const auto resourceStorageClassBinding = resourceStorageClassIt->second;
390       if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
391                                                  descriptorType)) ||
392           failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
393                                                   bufferUsage))) {
394         std::cerr << "storage class for resource with descriptor binding: "
395                   << resourceDataBindingPair.first
396                   << " in the descriptor set: " << descriptorSetIndex
397                   << " is not supported ";
398         return failure();
399       }
400 
401       // Set descriptor type for the specific device memory buffer.
402       memoryBuffer.descriptorType = descriptorType;
403       const auto bufferSize = resourceDataBindingPair.second.size;
404 
405       // Specify memory allocation info.
406       VkMemoryAllocateInfo memoryAllocateInfo = {};
407       memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
408       memoryAllocateInfo.pNext = nullptr;
409       memoryAllocateInfo.allocationSize = bufferSize;
410       memoryAllocateInfo.memoryTypeIndex = memoryTypeIndex;
411 
412       // Allocate device memory.
413       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
414                                               &memoryBuffer.deviceMemory),
415                              "vkAllocateMemory");
416       void *payload;
417       RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.deviceMemory, 0,
418                                          bufferSize, 0,
419                                          reinterpret_cast<void **>(&payload)),
420                              "vkMapMemory");
421 
422       // Copy host memory into the mapped area.
423       std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
424       vkUnmapMemory(device, memoryBuffer.deviceMemory);
425 
426       VkBufferCreateInfo bufferCreateInfo = {};
427       bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
428       bufferCreateInfo.pNext = nullptr;
429       bufferCreateInfo.flags = 0;
430       bufferCreateInfo.size = bufferSize;
431       bufferCreateInfo.usage = bufferUsage;
432       bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
433       bufferCreateInfo.queueFamilyIndexCount = 1;
434       bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
435       RETURN_ON_VULKAN_ERROR(
436           vkCreateBuffer(device, &bufferCreateInfo, 0, &memoryBuffer.buffer),
437           "vkCreateBuffer");
438 
439       // Bind buffer and device memory.
440       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.buffer,
441                                                 memoryBuffer.deviceMemory, 0),
442                              "vkBindBufferMemory");
443 
444       // Update buffer info.
445       memoryBuffer.bufferInfo.buffer = memoryBuffer.buffer;
446       memoryBuffer.bufferInfo.offset = 0;
447       memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
448       deviceMemoryBuffers.push_back(memoryBuffer);
449     }
450 
451     // Associate device memory buffers with a descriptor set.
452     deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
453   }
454   return success();
455 }
456 
457 LogicalResult VulkanRuntime::createShaderModule() {
458   VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
459   shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
460   shaderModuleCreateInfo.pNext = nullptr;
461   shaderModuleCreateInfo.flags = 0;
462   // Set size in bytes.
463   shaderModuleCreateInfo.codeSize = binarySize;
464   // Set pointer to the binary shader.
465   shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
466   RETURN_ON_VULKAN_ERROR(
467       vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule),
468       "vkCreateShaderModule");
469   return success();
470 }
471 
472 void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
473   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
474     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
475     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
476     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
477 
478     // Create a layout binding for each descriptor.
479     for (const auto &memBuffer : deviceMemoryBuffers) {
480       VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
481       descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
482       descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
483       descriptorSetLayoutBinding.descriptorCount = 1;
484       descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
485       descriptorSetLayoutBinding.pImmutableSamplers = 0;
486       descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
487     }
488     descriptorSetLayoutBindingMap[descriptorSetIndex] =
489         descriptorSetLayoutBindings;
490   }
491 }
492 
493 LogicalResult VulkanRuntime::createDescriptorSetLayout() {
494   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
495     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
496     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
497     // Each descriptor in a descriptor set must be the same type.
498     VkDescriptorType descriptorType =
499         deviceMemoryBuffers.front().descriptorType;
500     const uint32_t descriptorSize = deviceMemoryBuffers.size();
501     const auto descriptorSetLayoutBindingIt =
502         descriptorSetLayoutBindingMap.find(descriptorSetIndex);
503 
504     if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
505       std::cerr << "cannot find layout bindings for the set with number: "
506                 << descriptorSetIndex;
507       return failure();
508     }
509 
510     const auto &descriptorSetLayoutBindings =
511         descriptorSetLayoutBindingIt->second;
512     // Create descriptor set layout.
513     VkDescriptorSetLayout descriptorSetLayout = {};
514     VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
515 
516     descriptorSetLayoutCreateInfo.sType =
517         VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
518     descriptorSetLayoutCreateInfo.pNext = nullptr;
519     descriptorSetLayoutCreateInfo.flags = 0;
520     // Amount of descriptor bindings in a layout set.
521     descriptorSetLayoutCreateInfo.bindingCount =
522         descriptorSetLayoutBindings.size();
523     descriptorSetLayoutCreateInfo.pBindings =
524         descriptorSetLayoutBindings.data();
525     RETURN_ON_VULKAN_ERROR(
526         vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0,
527                                     &descriptorSetLayout),
528         "vkCreateDescriptorSetLayout");
529 
530     descriptorSetLayouts.push_back(descriptorSetLayout);
531     descriptorSetInfoPool.push_back(
532         {descriptorSetIndex, descriptorSize, descriptorType});
533   }
534   return success();
535 }
536 
537 LogicalResult VulkanRuntime::createPipelineLayout() {
538   // Associate descriptor sets with a pipeline layout.
539   VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
540   pipelineLayoutCreateInfo.sType =
541       VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
542   pipelineLayoutCreateInfo.pNext = nullptr;
543   pipelineLayoutCreateInfo.flags = 0;
544   pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
545   pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
546   pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
547   pipelineLayoutCreateInfo.pPushConstantRanges = 0;
548   RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
549                                                 &pipelineLayoutCreateInfo, 0,
550                                                 &pipelineLayout),
551                          "vkCreatePipelineLayout");
552   return success();
553 }
554 
555 LogicalResult VulkanRuntime::createComputePipeline() {
556   VkPipelineShaderStageCreateInfo stageInfo = {};
557   stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
558   stageInfo.pNext = nullptr;
559   stageInfo.flags = 0;
560   stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
561   stageInfo.module = shaderModule;
562   // Set entry point.
563   stageInfo.pName = entryPoint;
564   stageInfo.pSpecializationInfo = 0;
565 
566   VkComputePipelineCreateInfo computePipelineCreateInfo = {};
567   computePipelineCreateInfo.sType =
568       VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
569   computePipelineCreateInfo.pNext = nullptr;
570   computePipelineCreateInfo.flags = 0;
571   computePipelineCreateInfo.stage = stageInfo;
572   computePipelineCreateInfo.layout = pipelineLayout;
573   computePipelineCreateInfo.basePipelineHandle = 0;
574   computePipelineCreateInfo.basePipelineIndex = 0;
575   RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1,
576                                                   &computePipelineCreateInfo, 0,
577                                                   &pipeline),
578                          "vkCreateComputePipelines");
579   return success();
580 }
581 
582 LogicalResult VulkanRuntime::createDescriptorPool() {
583   std::vector<VkDescriptorPoolSize> descriptorPoolSizes;
584   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
585     // For each descriptor set populate descriptor pool size.
586     VkDescriptorPoolSize descriptorPoolSize = {};
587     descriptorPoolSize.type = descriptorSetInfo.descriptorType;
588     descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
589     descriptorPoolSizes.push_back(descriptorPoolSize);
590   }
591 
592   VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
593   descriptorPoolCreateInfo.sType =
594       VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
595   descriptorPoolCreateInfo.pNext = nullptr;
596   descriptorPoolCreateInfo.flags = 0;
597   descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
598   descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
599   descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
600   RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
601                                                 &descriptorPoolCreateInfo, 0,
602                                                 &descriptorPool),
603                          "vkCreateDescriptorPool");
604   return success();
605 }
606 
607 LogicalResult VulkanRuntime::allocateDescriptorSets() {
608   VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
609   // Size of descriptor sets and descriptor layout sets is the same.
610   descriptorSets.resize(descriptorSetLayouts.size());
611   descriptorSetAllocateInfo.sType =
612       VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
613   descriptorSetAllocateInfo.pNext = nullptr;
614   descriptorSetAllocateInfo.descriptorPool = descriptorPool;
615   descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
616   descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
617   RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
618                                                   &descriptorSetAllocateInfo,
619                                                   descriptorSets.data()),
620                          "vkAllocateDescriptorSets");
621   return success();
622 }
623 
624 LogicalResult VulkanRuntime::setWriteDescriptors() {
625   if (descriptorSets.size() != descriptorSetInfoPool.size()) {
626     std::cerr << "Each descriptor set must have descriptor set information";
627     return failure();
628   }
629   // For each descriptor set.
630   auto descriptorSetIt = descriptorSets.begin();
631   // Each descriptor set is associated with descriptor set info.
632   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
633     // For each device memory buffer in the descriptor set.
634     const auto &deviceMemoryBuffers =
635         deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
636     for (const auto &memoryBuffer : deviceMemoryBuffers) {
637       // Structure describing descriptor sets to write to.
638       VkWriteDescriptorSet wSet = {};
639       wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
640       wSet.pNext = nullptr;
641       // Descriptor set.
642       wSet.dstSet = *descriptorSetIt;
643       wSet.dstBinding = memoryBuffer.bindingIndex;
644       wSet.dstArrayElement = 0;
645       wSet.descriptorCount = 1;
646       wSet.descriptorType = memoryBuffer.descriptorType;
647       wSet.pImageInfo = nullptr;
648       wSet.pBufferInfo = &memoryBuffer.bufferInfo;
649       wSet.pTexelBufferView = nullptr;
650       vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
651     }
652     // Increment descriptor set iterator.
653     ++descriptorSetIt;
654   }
655   return success();
656 }
657 
658 LogicalResult VulkanRuntime::createCommandPool() {
659   VkCommandPoolCreateInfo commandPoolCreateInfo = {};
660   commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
661   commandPoolCreateInfo.pNext = nullptr;
662   commandPoolCreateInfo.flags = 0;
663   commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
664   RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
665                                              /*pAllocator=*/nullptr,
666                                              &commandPool),
667                          "vkCreateCommandPool");
668   return success();
669 }
670 
671 LogicalResult VulkanRuntime::createQueryPool() {
672   // Return directly if timestamp query is not supported.
673   if (queueFamilyProperties.timestampValidBits == 0)
674     return success();
675 
676   // Get timestamp period for this physical device.
677   VkPhysicalDeviceProperties deviceProperties = {};
678   vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties);
679   timestampPeriod = deviceProperties.limits.timestampPeriod;
680 
681   // Create query pool.
682   VkQueryPoolCreateInfo queryPoolCreateInfo = {};
683   queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
684   queryPoolCreateInfo.pNext = nullptr;
685   queryPoolCreateInfo.flags = 0;
686   queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
687   queryPoolCreateInfo.queryCount = 2;
688   queryPoolCreateInfo.pipelineStatistics = 0;
689   RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
690                                            /*pAllocator=*/nullptr, &queryPool),
691                          "vkCreateQueryPool");
692 
693   return success();
694 }
695 
696 LogicalResult VulkanRuntime::createComputeCommandBuffer() {
697   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
698   commandBufferAllocateInfo.sType =
699       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
700   commandBufferAllocateInfo.pNext = nullptr;
701   commandBufferAllocateInfo.commandPool = commandPool;
702   commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
703   commandBufferAllocateInfo.commandBufferCount = 1;
704 
705   VkCommandBuffer commandBuffer;
706   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
707                                                   &commandBufferAllocateInfo,
708                                                   &commandBuffer),
709                          "vkAllocateCommandBuffers");
710 
711   VkCommandBufferBeginInfo commandBufferBeginInfo = {};
712   commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
713   commandBufferBeginInfo.pNext = nullptr;
714   commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
715   commandBufferBeginInfo.pInheritanceInfo = nullptr;
716 
717   // Commands begin.
718   RETURN_ON_VULKAN_ERROR(
719       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
720       "vkBeginCommandBuffer");
721 
722   if (queryPool != VK_NULL_HANDLE)
723     vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2);
724 
725   vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
726   vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
727                           pipelineLayout, 0, descriptorSets.size(),
728                           descriptorSets.data(), 0, 0);
729   // Get a timestamp before invoking the compute shader.
730   if (queryPool != VK_NULL_HANDLE)
731     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
732                         queryPool, 0);
733   vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
734                 numWorkGroups.z);
735   // Get another timestamp after invoking the compute shader.
736   if (queryPool != VK_NULL_HANDLE)
737     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
738                         queryPool, 1);
739 
740   // Commands end.
741   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
742                          "vkEndCommandBuffer");
743 
744   commandBuffers.push_back(commandBuffer);
745   return success();
746 }
747 
748 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
749   VkSubmitInfo submitInfo = {};
750   submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
751   submitInfo.pNext = nullptr;
752   submitInfo.waitSemaphoreCount = 0;
753   submitInfo.pWaitSemaphores = 0;
754   submitInfo.pWaitDstStageMask = 0;
755   submitInfo.commandBufferCount = commandBuffers.size();
756   submitInfo.pCommandBuffers = commandBuffers.data();
757   submitInfo.signalSemaphoreCount = 0;
758   submitInfo.pSignalSemaphores = nullptr;
759   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0),
760                          "vkQueueSubmit");
761   return success();
762 }
763 
764 LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
765   // For each descriptor set.
766   for (auto &resourceDataMapPair : resourceData) {
767     auto &resourceDataMap = resourceDataMapPair.second;
768     auto &deviceMemoryBuffers =
769         deviceMemoryBufferMap[resourceDataMapPair.first];
770     // For each device memory buffer in the set.
771     for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
772       if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
773         void *payload;
774         auto &hostMemoryBuffer =
775             resourceDataMap[deviceMemoryBuffer.bindingIndex];
776         RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
777                                            deviceMemoryBuffer.deviceMemory, 0,
778                                            hostMemoryBuffer.size, 0,
779                                            reinterpret_cast<void **>(&payload)),
780                                "vkMapMemory");
781         std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
782         vkUnmapMemory(device, deviceMemoryBuffer.deviceMemory);
783       }
784     }
785   }
786   return success();
787 }
788