1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a library for running a module on a Vulkan device.
10 // Implements a Vulkan runtime.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "VulkanRuntime.h"
15 
16 #include "llvm/Support/Format.h"
17 #include <chrono>
18 
19 using namespace mlir;
20 
21 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
22   numWorkGroups = numberWorkGroups;
23 }
24 
25 void VulkanRuntime::setResourceStorageClassBindingMap(
26     const ResourceStorageClassBindingMap &stClassData) {
27   resourceStorageClassData = stClassData;
28 }
29 
30 void VulkanRuntime::setResourceData(
31     const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
32     const VulkanHostMemoryBuffer &hostMemBuffer) {
33   resourceData[desIndex][bindIndex] = hostMemBuffer;
34   resourceStorageClassData[desIndex][bindIndex] =
35       spirv::StorageClass::StorageBuffer;
36 }
37 
38 void VulkanRuntime::setEntryPoint(const char *entryPointName) {
39   entryPoint = entryPointName;
40 }
41 
42 void VulkanRuntime::setResourceData(const ResourceData &resData) {
43   resourceData = resData;
44 }
45 
46 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
47   binary = shader;
48   binarySize = size;
49 }
50 
51 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
52     spirv::StorageClass storageClass, VkDescriptorType &descriptorType) {
53   switch (storageClass) {
54   case spirv::StorageClass::StorageBuffer:
55     descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
56     break;
57   case spirv::StorageClass::Uniform:
58     descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
59     break;
60   default:
61     llvm::errs() << "unsupported storage class";
62     return failure();
63   }
64   return success();
65 }
66 
67 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
68     spirv::StorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
69   switch (storageClass) {
70   case spirv::StorageClass::StorageBuffer:
71     bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
72     break;
73   case spirv::StorageClass::Uniform:
74     bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
75     break;
76   default:
77     llvm::errs() << "unsupported storage class";
78     return failure();
79   }
80   return success();
81 }
82 
83 LogicalResult VulkanRuntime::countDeviceMemorySize() {
84   for (const auto &resourceDataMapPair : resourceData) {
85     const auto &resourceDataMap = resourceDataMapPair.second;
86     for (const auto &resourceDataBindingPair : resourceDataMap) {
87       if (resourceDataBindingPair.second.size) {
88         memorySize += resourceDataBindingPair.second.size;
89       } else {
90         llvm::errs()
91             << "expected buffer size greater than zero for resource data";
92         return failure();
93       }
94     }
95   }
96   return success();
97 }
98 
99 LogicalResult VulkanRuntime::initRuntime() {
100   if (!resourceData.size()) {
101     llvm::errs() << "Vulkan runtime needs at least one resource";
102     return failure();
103   }
104   if (!binarySize || !binary) {
105     llvm::errs() << "binary shader size must be greater than zero";
106     return failure();
107   }
108   if (failed(countDeviceMemorySize())) {
109     return failure();
110   }
111   return success();
112 }
113 
114 LogicalResult VulkanRuntime::destroy() {
115   // According to Vulkan spec:
116   // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
117   // used to gate the destruction of the device. Prior to destroying a device,
118   // an application is responsible for destroying/freeing any Vulkan objects
119   // that were created using that device as the first parameter of the
120   // corresponding vkCreate* or vkAllocate* command."
121   RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
122 
123   // Free and destroy.
124   vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
125                        commandBuffers.data());
126   vkDestroyQueryPool(device, queryPool, nullptr);
127   vkDestroyCommandPool(device, commandPool, nullptr);
128   vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
129                        descriptorSets.data());
130   vkDestroyDescriptorPool(device, descriptorPool, nullptr);
131   vkDestroyPipeline(device, pipeline, nullptr);
132   vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
133   for (auto &descriptorSetLayout: descriptorSetLayouts) {
134     vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
135   }
136   vkDestroyShaderModule(device, shaderModule, nullptr);
137 
138   // For each descriptor set.
139   for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
140     auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
141     // For each descriptor binding.
142     for (auto &memoryBuffer : deviceMemoryBuffers) {
143       vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
144       vkDestroyBuffer(device, memoryBuffer.buffer, nullptr);
145     }
146   }
147 
148   vkDestroyDevice(device, nullptr);
149   vkDestroyInstance(instance, nullptr);
150   return success();
151 }
152 
153 LogicalResult VulkanRuntime::run() {
154   // Create logical device, shader module and memory buffers.
155   if (failed(createInstance()) || failed(createDevice()) ||
156       failed(createMemoryBuffers()) || failed(createShaderModule())) {
157     return failure();
158   }
159 
160   // Descriptor bindings divided into sets. Each descriptor binding
161   // must have a layout binding attached into a descriptor set layout.
162   // Each layout set must be binded into a pipeline layout.
163   initDescriptorSetLayoutBindingMap();
164   if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
165       // Each descriptor set must be allocated from a descriptor pool.
166       failed(createComputePipeline()) || failed(createDescriptorPool()) ||
167       failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
168       // Create command buffer.
169       failed(createCommandPool()) || failed(createQueryPool()) ||
170       failed(createComputeCommandBuffer())) {
171     return failure();
172   }
173 
174   // Get working queue.
175   vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
176 
177   auto submitStart = std::chrono::high_resolution_clock::now();
178   // Submit command buffer into the queue.
179   if (failed(submitCommandBuffersToQueue()))
180     return failure();
181   auto submitEnd = std::chrono::high_resolution_clock::now();
182 
183   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
184   auto execEnd = std::chrono::high_resolution_clock::now();
185 
186   auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
187       submitEnd - submitStart);
188   auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
189       execEnd - submitEnd);
190 
191   if (queryPool != VK_NULL_HANDLE) {
192     uint64_t timestamps[2];
193     RETURN_ON_VULKAN_ERROR(
194         vkGetQueryPoolResults(
195             device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
196             /*dataSize=*/sizeof(timestamps),
197             /*pData=*/reinterpret_cast<void *>(timestamps),
198             /*stride=*/sizeof(uint64_t),
199             VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
200         "vkGetQueryPoolResults");
201     float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
202     llvm::outs() << "Compute shader execution time: "
203                  << llvm::format("%0.3fus\n", microsec);
204   }
205 
206   llvm::outs() << "Command buffer submit time: " << submitDuration.count()
207                << "us\nWait idle time: " << execDuration.count() << "us\n";
208 
209   return success();
210 }
211 
212 LogicalResult VulkanRuntime::createInstance() {
213   VkApplicationInfo applicationInfo = {};
214   applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
215   applicationInfo.pNext = nullptr;
216   applicationInfo.pApplicationName = "MLIR Vulkan runtime";
217   applicationInfo.applicationVersion = 0;
218   applicationInfo.pEngineName = "mlir";
219   applicationInfo.engineVersion = 0;
220   applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
221 
222   VkInstanceCreateInfo instanceCreateInfo = {};
223   instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
224   instanceCreateInfo.pNext = nullptr;
225   instanceCreateInfo.flags = 0;
226   instanceCreateInfo.pApplicationInfo = &applicationInfo;
227   instanceCreateInfo.enabledLayerCount = 0;
228   instanceCreateInfo.ppEnabledLayerNames = 0;
229   instanceCreateInfo.enabledExtensionCount = 0;
230   instanceCreateInfo.ppEnabledExtensionNames = 0;
231 
232   RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance),
233                          "vkCreateInstance");
234   return success();
235 }
236 
237 LogicalResult VulkanRuntime::createDevice() {
238   uint32_t physicalDeviceCount = 0;
239   RETURN_ON_VULKAN_ERROR(
240       vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0),
241       "vkEnumeratePhysicalDevices");
242 
243   llvm::SmallVector<VkPhysicalDevice, 1> physicalDevices(physicalDeviceCount);
244   RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
245                                                     &physicalDeviceCount,
246                                                     physicalDevices.data()),
247                          "vkEnumeratePhysicalDevices");
248 
249   RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
250                          "physicalDeviceCount");
251 
252   // TODO(denis0x0D): find the best device.
253   physicalDevice = physicalDevices.front();
254   if (failed(getBestComputeQueue()))
255     return failure();
256 
257   const float queuePriority = 1.0f;
258   VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
259   deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
260   deviceQueueCreateInfo.pNext = nullptr;
261   deviceQueueCreateInfo.flags = 0;
262   deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
263   deviceQueueCreateInfo.queueCount = 1;
264   deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
265 
266   // Structure specifying parameters of a newly created device.
267   VkDeviceCreateInfo deviceCreateInfo = {};
268   deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
269   deviceCreateInfo.pNext = nullptr;
270   deviceCreateInfo.flags = 0;
271   deviceCreateInfo.queueCreateInfoCount = 1;
272   deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
273   deviceCreateInfo.enabledLayerCount = 0;
274   deviceCreateInfo.ppEnabledLayerNames = nullptr;
275   deviceCreateInfo.enabledExtensionCount = 0;
276   deviceCreateInfo.ppEnabledExtensionNames = nullptr;
277   deviceCreateInfo.pEnabledFeatures = nullptr;
278 
279   RETURN_ON_VULKAN_ERROR(
280       vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device),
281       "vkCreateDevice");
282 
283   VkPhysicalDeviceMemoryProperties properties = {};
284   vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
285 
286   // Try to find memory type with following properties:
287   // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
288   // with this type can be mapped for host access using vkMapMemory;
289   // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
290   // management commands vkFlushMappedMemoryRanges and
291   // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
292   // device or make device writes visible to the host, respectively.
293   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
294     if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
295          properties.memoryTypes[i].propertyFlags) &&
296         (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
297          properties.memoryTypes[i].propertyFlags) &&
298         (memorySize <=
299          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
300       memoryTypeIndex = i;
301       break;
302     }
303   }
304 
305   RETURN_ON_VULKAN_ERROR(memoryTypeIndex == VK_MAX_MEMORY_TYPES ? VK_INCOMPLETE
306                                                                 : VK_SUCCESS,
307                          "invalid memoryTypeIndex");
308   return success();
309 }
310 
311 LogicalResult VulkanRuntime::getBestComputeQueue() {
312   uint32_t queueFamilyPropertiesCount = 0;
313   vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice,
314                                            &queueFamilyPropertiesCount, 0);
315 
316   SmallVector<VkQueueFamilyProperties, 1> familyProperties(
317       queueFamilyPropertiesCount);
318   vkGetPhysicalDeviceQueueFamilyProperties(
319       physicalDevice, &queueFamilyPropertiesCount, familyProperties.data());
320 
321   // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
322   // compute operations. Try to find a compute-only queue first if possible.
323   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
324     auto flags = familyProperties[i].queueFlags;
325     if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
326       queueFamilyIndex = i;
327       queueFamilyProperties = familyProperties[i];
328       return success();
329     }
330   }
331 
332   // Otherwise use a queue that can also support graphics.
333   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
334     auto flags = familyProperties[i].queueFlags;
335     if ((flags & VK_QUEUE_COMPUTE_BIT)) {
336       queueFamilyIndex = i;
337       queueFamilyProperties = familyProperties[i];
338       return success();
339     }
340   }
341 
342   llvm::errs() << "cannot find valid queue";
343   return failure();
344 }
345 
346 LogicalResult VulkanRuntime::createMemoryBuffers() {
347   // For each descriptor set.
348   for (const auto &resourceDataMapPair : resourceData) {
349     llvm::SmallVector<VulkanDeviceMemoryBuffer, 1> deviceMemoryBuffers;
350     const auto descriptorSetIndex = resourceDataMapPair.first;
351     const auto &resourceDataMap = resourceDataMapPair.second;
352 
353     // For each descriptor binding.
354     for (const auto &resourceDataBindingPair : resourceDataMap) {
355       // Create device memory buffer.
356       VulkanDeviceMemoryBuffer memoryBuffer;
357       memoryBuffer.bindingIndex = resourceDataBindingPair.first;
358       VkDescriptorType descriptorType = {};
359       VkBufferUsageFlagBits bufferUsage = {};
360 
361       // Check that descriptor set has storage class map.
362       const auto resourceStorageClassMapIt =
363           resourceStorageClassData.find(descriptorSetIndex);
364       if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
365         llvm::errs()
366             << "cannot find storage class for resource in descriptor set: "
367             << descriptorSetIndex;
368         return failure();
369       }
370 
371       // Check that specific descriptor binding has storage class.
372       const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
373       const auto resourceStorageClassIt =
374           resourceStorageClassMap.find(resourceDataBindingPair.first);
375       if (resourceStorageClassIt == resourceStorageClassMap.end()) {
376         llvm::errs()
377             << "cannot find storage class for resource with descriptor index: "
378             << resourceDataBindingPair.first;
379         return failure();
380       }
381 
382       const auto resourceStorageClassBinding = resourceStorageClassIt->second;
383       if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
384                                                  descriptorType)) ||
385           failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
386                                                   bufferUsage))) {
387         llvm::errs() << "storage class for resource with descriptor binding: "
388                      << resourceDataBindingPair.first
389                      << " in the descriptor set: " << descriptorSetIndex
390                      << " is not supported ";
391         return failure();
392       }
393 
394       // Set descriptor type for the specific device memory buffer.
395       memoryBuffer.descriptorType = descriptorType;
396       const auto bufferSize = resourceDataBindingPair.second.size;
397 
398       // Specify memory allocation info.
399       VkMemoryAllocateInfo memoryAllocateInfo = {};
400       memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
401       memoryAllocateInfo.pNext = nullptr;
402       memoryAllocateInfo.allocationSize = bufferSize;
403       memoryAllocateInfo.memoryTypeIndex = memoryTypeIndex;
404 
405       // Allocate device memory.
406       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
407                                               &memoryBuffer.deviceMemory),
408                              "vkAllocateMemory");
409       void *payload;
410       RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.deviceMemory, 0,
411                                          bufferSize, 0,
412                                          reinterpret_cast<void **>(&payload)),
413                              "vkMapMemory");
414 
415       // Copy host memory into the mapped area.
416       std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
417       vkUnmapMemory(device, memoryBuffer.deviceMemory);
418 
419       VkBufferCreateInfo bufferCreateInfo = {};
420       bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
421       bufferCreateInfo.pNext = nullptr;
422       bufferCreateInfo.flags = 0;
423       bufferCreateInfo.size = bufferSize;
424       bufferCreateInfo.usage = bufferUsage;
425       bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
426       bufferCreateInfo.queueFamilyIndexCount = 1;
427       bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
428       RETURN_ON_VULKAN_ERROR(
429           vkCreateBuffer(device, &bufferCreateInfo, 0, &memoryBuffer.buffer),
430           "vkCreateBuffer");
431 
432       // Bind buffer and device memory.
433       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.buffer,
434                                                 memoryBuffer.deviceMemory, 0),
435                              "vkBindBufferMemory");
436 
437       // Update buffer info.
438       memoryBuffer.bufferInfo.buffer = memoryBuffer.buffer;
439       memoryBuffer.bufferInfo.offset = 0;
440       memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
441       deviceMemoryBuffers.push_back(memoryBuffer);
442     }
443 
444     // Associate device memory buffers with a descriptor set.
445     deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
446   }
447   return success();
448 }
449 
450 LogicalResult VulkanRuntime::createShaderModule() {
451   VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
452   shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
453   shaderModuleCreateInfo.pNext = nullptr;
454   shaderModuleCreateInfo.flags = 0;
455   // Set size in bytes.
456   shaderModuleCreateInfo.codeSize = binarySize;
457   // Set pointer to the binary shader.
458   shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
459   RETURN_ON_VULKAN_ERROR(
460       vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule),
461       "vkCreateShaderModule");
462   return success();
463 }
464 
465 void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
466   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
467     SmallVector<VkDescriptorSetLayoutBinding, 1> descriptorSetLayoutBindings;
468     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
469     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
470 
471     // Create a layout binding for each descriptor.
472     for (const auto &memBuffer : deviceMemoryBuffers) {
473       VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
474       descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
475       descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
476       descriptorSetLayoutBinding.descriptorCount = 1;
477       descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
478       descriptorSetLayoutBinding.pImmutableSamplers = 0;
479       descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
480     }
481     descriptorSetLayoutBindingMap[descriptorSetIndex] =
482         descriptorSetLayoutBindings;
483   }
484 }
485 
486 LogicalResult VulkanRuntime::createDescriptorSetLayout() {
487   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
488     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
489     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
490     // Each descriptor in a descriptor set must be the same type.
491     VkDescriptorType descriptorType =
492         deviceMemoryBuffers.front().descriptorType;
493     const uint32_t descriptorSize = deviceMemoryBuffers.size();
494     const auto descriptorSetLayoutBindingIt =
495         descriptorSetLayoutBindingMap.find(descriptorSetIndex);
496 
497     if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
498       llvm::errs() << "cannot find layout bindings for the set with number: "
499                    << descriptorSetIndex;
500       return failure();
501     }
502 
503     const auto &descriptorSetLayoutBindings =
504         descriptorSetLayoutBindingIt->second;
505     // Create descriptor set layout.
506     VkDescriptorSetLayout descriptorSetLayout = {};
507     VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
508 
509     descriptorSetLayoutCreateInfo.sType =
510         VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
511     descriptorSetLayoutCreateInfo.pNext = nullptr;
512     descriptorSetLayoutCreateInfo.flags = 0;
513     // Amount of descriptor bindings in a layout set.
514     descriptorSetLayoutCreateInfo.bindingCount =
515         descriptorSetLayoutBindings.size();
516     descriptorSetLayoutCreateInfo.pBindings =
517         descriptorSetLayoutBindings.data();
518     RETURN_ON_VULKAN_ERROR(
519         vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0,
520                                     &descriptorSetLayout),
521         "vkCreateDescriptorSetLayout");
522 
523     descriptorSetLayouts.push_back(descriptorSetLayout);
524     descriptorSetInfoPool.push_back(
525         {descriptorSetIndex, descriptorSize, descriptorType});
526   }
527   return success();
528 }
529 
530 LogicalResult VulkanRuntime::createPipelineLayout() {
531   // Associate descriptor sets with a pipeline layout.
532   VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
533   pipelineLayoutCreateInfo.sType =
534       VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
535   pipelineLayoutCreateInfo.pNext = nullptr;
536   pipelineLayoutCreateInfo.flags = 0;
537   pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
538   pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
539   pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
540   pipelineLayoutCreateInfo.pPushConstantRanges = 0;
541   RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
542                                                 &pipelineLayoutCreateInfo, 0,
543                                                 &pipelineLayout),
544                          "vkCreatePipelineLayout");
545   return success();
546 }
547 
548 LogicalResult VulkanRuntime::createComputePipeline() {
549   VkPipelineShaderStageCreateInfo stageInfo = {};
550   stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
551   stageInfo.pNext = nullptr;
552   stageInfo.flags = 0;
553   stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
554   stageInfo.module = shaderModule;
555   // Set entry point.
556   stageInfo.pName = entryPoint;
557   stageInfo.pSpecializationInfo = 0;
558 
559   VkComputePipelineCreateInfo computePipelineCreateInfo = {};
560   computePipelineCreateInfo.sType =
561       VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
562   computePipelineCreateInfo.pNext = nullptr;
563   computePipelineCreateInfo.flags = 0;
564   computePipelineCreateInfo.stage = stageInfo;
565   computePipelineCreateInfo.layout = pipelineLayout;
566   computePipelineCreateInfo.basePipelineHandle = 0;
567   computePipelineCreateInfo.basePipelineIndex = 0;
568   RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1,
569                                                   &computePipelineCreateInfo, 0,
570                                                   &pipeline),
571                          "vkCreateComputePipelines");
572   return success();
573 }
574 
575 LogicalResult VulkanRuntime::createDescriptorPool() {
576   llvm::SmallVector<VkDescriptorPoolSize, 1> descriptorPoolSizes;
577   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
578     // For each descriptor set populate descriptor pool size.
579     VkDescriptorPoolSize descriptorPoolSize = {};
580     descriptorPoolSize.type = descriptorSetInfo.descriptorType;
581     descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
582     descriptorPoolSizes.push_back(descriptorPoolSize);
583   }
584 
585   VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
586   descriptorPoolCreateInfo.sType =
587       VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
588   descriptorPoolCreateInfo.pNext = nullptr;
589   descriptorPoolCreateInfo.flags = 0;
590   descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
591   descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
592   descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
593   RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
594                                                 &descriptorPoolCreateInfo, 0,
595                                                 &descriptorPool),
596                          "vkCreateDescriptorPool");
597   return success();
598 }
599 
600 LogicalResult VulkanRuntime::allocateDescriptorSets() {
601   VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
602   // Size of descriptor sets and descriptor layout sets is the same.
603   descriptorSets.resize(descriptorSetLayouts.size());
604   descriptorSetAllocateInfo.sType =
605       VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
606   descriptorSetAllocateInfo.pNext = nullptr;
607   descriptorSetAllocateInfo.descriptorPool = descriptorPool;
608   descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
609   descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
610   RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
611                                                   &descriptorSetAllocateInfo,
612                                                   descriptorSets.data()),
613                          "vkAllocateDescriptorSets");
614   return success();
615 }
616 
617 LogicalResult VulkanRuntime::setWriteDescriptors() {
618   if (descriptorSets.size() != descriptorSetInfoPool.size()) {
619     llvm::errs() << "Each descriptor set must have descriptor set information";
620     return failure();
621   }
622   // For each descriptor set.
623   auto descriptorSetIt = descriptorSets.begin();
624   // Each descriptor set is associated with descriptor set info.
625   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
626     // For each device memory buffer in the descriptor set.
627     const auto &deviceMemoryBuffers =
628         deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
629     for (const auto &memoryBuffer : deviceMemoryBuffers) {
630       // Structure describing descriptor sets to write to.
631       VkWriteDescriptorSet wSet = {};
632       wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
633       wSet.pNext = nullptr;
634       // Descriptor set.
635       wSet.dstSet = *descriptorSetIt;
636       wSet.dstBinding = memoryBuffer.bindingIndex;
637       wSet.dstArrayElement = 0;
638       wSet.descriptorCount = 1;
639       wSet.descriptorType = memoryBuffer.descriptorType;
640       wSet.pImageInfo = nullptr;
641       wSet.pBufferInfo = &memoryBuffer.bufferInfo;
642       wSet.pTexelBufferView = nullptr;
643       vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
644     }
645     // Increment descriptor set iterator.
646     ++descriptorSetIt;
647   }
648   return success();
649 }
650 
651 LogicalResult VulkanRuntime::createCommandPool() {
652   VkCommandPoolCreateInfo commandPoolCreateInfo = {};
653   commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
654   commandPoolCreateInfo.pNext = nullptr;
655   commandPoolCreateInfo.flags = 0;
656   commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
657   RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
658                                              /*pAllocator=*/nullptr,
659                                              &commandPool),
660                          "vkCreateCommandPool");
661   return success();
662 }
663 
664 LogicalResult VulkanRuntime::createQueryPool() {
665   // Return directly if timestamp query is not supported.
666   if (queueFamilyProperties.timestampValidBits == 0)
667     return success();
668 
669   // Get timestamp period for this physical device.
670   VkPhysicalDeviceProperties deviceProperties = {};
671   vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties);
672   timestampPeriod = deviceProperties.limits.timestampPeriod;
673 
674   // Create query pool.
675   VkQueryPoolCreateInfo queryPoolCreateInfo = {};
676   queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
677   queryPoolCreateInfo.pNext = nullptr;
678   queryPoolCreateInfo.flags = 0;
679   queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
680   queryPoolCreateInfo.queryCount = 2;
681   queryPoolCreateInfo.pipelineStatistics = 0;
682   RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
683                                            /*pAllocator=*/nullptr, &queryPool),
684                          "vkCreateQueryPool");
685 
686   return success();
687 }
688 
689 LogicalResult VulkanRuntime::createComputeCommandBuffer() {
690   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
691   commandBufferAllocateInfo.sType =
692       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
693   commandBufferAllocateInfo.pNext = nullptr;
694   commandBufferAllocateInfo.commandPool = commandPool;
695   commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
696   commandBufferAllocateInfo.commandBufferCount = 1;
697 
698   VkCommandBuffer commandBuffer;
699   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
700                                                   &commandBufferAllocateInfo,
701                                                   &commandBuffer),
702                          "vkAllocateCommandBuffers");
703 
704   VkCommandBufferBeginInfo commandBufferBeginInfo = {};
705   commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
706   commandBufferBeginInfo.pNext = nullptr;
707   commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
708   commandBufferBeginInfo.pInheritanceInfo = nullptr;
709 
710   // Commands begin.
711   RETURN_ON_VULKAN_ERROR(
712       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
713       "vkBeginCommandBuffer");
714 
715   if (queryPool != VK_NULL_HANDLE)
716     vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2);
717 
718   vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
719   vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
720                           pipelineLayout, 0, descriptorSets.size(),
721                           descriptorSets.data(), 0, 0);
722   // Get a timestamp before invoking the compute shader.
723   if (queryPool != VK_NULL_HANDLE)
724     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
725                         queryPool, 0);
726   vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
727                 numWorkGroups.z);
728   // Get another timestamp after invoking the compute shader.
729   if (queryPool != VK_NULL_HANDLE)
730     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
731                         queryPool, 1);
732 
733   // Commands end.
734   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
735                          "vkEndCommandBuffer");
736 
737   commandBuffers.push_back(commandBuffer);
738   return success();
739 }
740 
741 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
742   VkSubmitInfo submitInfo = {};
743   submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
744   submitInfo.pNext = nullptr;
745   submitInfo.waitSemaphoreCount = 0;
746   submitInfo.pWaitSemaphores = 0;
747   submitInfo.pWaitDstStageMask = 0;
748   submitInfo.commandBufferCount = commandBuffers.size();
749   submitInfo.pCommandBuffers = commandBuffers.data();
750   submitInfo.signalSemaphoreCount = 0;
751   submitInfo.pSignalSemaphores = nullptr;
752   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0),
753                          "vkQueueSubmit");
754   return success();
755 }
756 
757 LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
758   // For each descriptor set.
759   for (auto &resourceDataMapPair : resourceData) {
760     auto &resourceDataMap = resourceDataMapPair.second;
761     auto &deviceMemoryBuffers =
762         deviceMemoryBufferMap[resourceDataMapPair.first];
763     // For each device memory buffer in the set.
764     for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
765       if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
766         void *payload;
767         auto &hostMemoryBuffer =
768             resourceDataMap[deviceMemoryBuffer.bindingIndex];
769         RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
770                                            deviceMemoryBuffer.deviceMemory, 0,
771                                            hostMemoryBuffer.size, 0,
772                                            reinterpret_cast<void **>(&payload)),
773                                "vkMapMemory");
774         std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
775         vkUnmapMemory(device, deviceMemoryBuffer.deviceMemory);
776       }
777     }
778   }
779   return success();
780 }
781