1 //===-- LayoutUtils.cpp - Decorate composite type with layout information -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Utilities used to get alignment and layout information 10 // for types in SPIR-V dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 17 18 using namespace mlir; 19 20 spirv::StructType 21 VulkanLayoutUtils::decorateType(spirv::StructType structType) { 22 Size size = 0; 23 Size alignment = 1; 24 return decorateType(structType, size, alignment); 25 } 26 27 spirv::StructType 28 VulkanLayoutUtils::decorateType(spirv::StructType structType, 29 VulkanLayoutUtils::Size &size, 30 VulkanLayoutUtils::Size &alignment) { 31 if (structType.getNumElements() == 0) { 32 return structType; 33 } 34 35 SmallVector<Type, 4> memberTypes; 36 SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo; 37 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 38 39 Size structMemberOffset = 0; 40 Size maxMemberAlignment = 1; 41 42 for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) { 43 Size memberSize = 0; 44 Size memberAlignment = 1; 45 46 auto memberType = 47 decorateType(structType.getElementType(i), memberSize, memberAlignment); 48 structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment); 49 memberTypes.push_back(memberType); 50 offsetInfo.push_back( 51 static_cast<spirv::StructType::OffsetInfo>(structMemberOffset)); 52 // If the member's size is the max value, it must be the last member and it 53 // must be a runtime array. 54 assert(memberSize != std::numeric_limits<Size>().max() || 55 (i + 1 == e && 56 structType.getElementType(i).isa<spirv::RuntimeArrayType>())); 57 // According to the Vulkan spec: 58 // "A structure has a base alignment equal to the largest base alignment of 59 // any of its members." 60 structMemberOffset += memberSize; 61 maxMemberAlignment = std::max(maxMemberAlignment, memberAlignment); 62 } 63 64 // According to the Vulkan spec: 65 // "The Offset decoration of a member must not place it between the end of a 66 // structure or an array and the next multiple of the alignment of that 67 // structure or array." 68 size = llvm::alignTo(structMemberOffset, maxMemberAlignment); 69 alignment = maxMemberAlignment; 70 structType.getMemberDecorations(memberDecorations); 71 72 if (!structType.isIdentified()) 73 return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations); 74 75 // Identified structs are uniqued by identifier so it is not possible 76 // to create 2 structs with the same name but different decorations. 77 return nullptr; 78 } 79 80 Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, 81 VulkanLayoutUtils::Size &alignment) { 82 if (type.isa<spirv::ScalarType>()) { 83 alignment = getScalarTypeAlignment(type); 84 // Vulkan spec does not specify any padding for a scalar type. 85 size = alignment; 86 return type; 87 } 88 if (auto structType = type.dyn_cast<spirv::StructType>()) 89 return decorateType(structType, size, alignment); 90 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) 91 return decorateType(arrayType, size, alignment); 92 if (auto vectorType = type.dyn_cast<VectorType>()) 93 return decorateType(vectorType, size, alignment); 94 if (auto arrayType = type.dyn_cast<spirv::RuntimeArrayType>()) { 95 size = std::numeric_limits<Size>().max(); 96 return decorateType(arrayType, alignment); 97 } 98 llvm_unreachable("unhandled SPIR-V type"); 99 } 100 101 Type VulkanLayoutUtils::decorateType(VectorType vectorType, 102 VulkanLayoutUtils::Size &size, 103 VulkanLayoutUtils::Size &alignment) { 104 const auto numElements = vectorType.getNumElements(); 105 auto elementType = vectorType.getElementType(); 106 Size elementSize = 0; 107 Size elementAlignment = 1; 108 109 auto memberType = decorateType(elementType, elementSize, elementAlignment); 110 // According to the Vulkan spec: 111 // 1. "A two-component vector has a base alignment equal to twice its scalar 112 // alignment." 113 // 2. "A three- or four-component vector has a base alignment equal to four 114 // times its scalar alignment." 115 size = elementSize * numElements; 116 alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4; 117 return VectorType::get(numElements, memberType); 118 } 119 120 Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType, 121 VulkanLayoutUtils::Size &size, 122 VulkanLayoutUtils::Size &alignment) { 123 const auto numElements = arrayType.getNumElements(); 124 auto elementType = arrayType.getElementType(); 125 Size elementSize = 0; 126 Size elementAlignment = 1; 127 128 auto memberType = decorateType(elementType, elementSize, elementAlignment); 129 // According to the Vulkan spec: 130 // "An array has a base alignment equal to the base alignment of its element 131 // type." 132 size = elementSize * numElements; 133 alignment = elementAlignment; 134 return spirv::ArrayType::get(memberType, numElements, elementSize); 135 } 136 137 Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType, 138 VulkanLayoutUtils::Size &alignment) { 139 auto elementType = arrayType.getElementType(); 140 Size elementSize = 0; 141 142 auto memberType = decorateType(elementType, elementSize, alignment); 143 return spirv::RuntimeArrayType::get(memberType, elementSize); 144 } 145 146 VulkanLayoutUtils::Size 147 VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) { 148 // According to the Vulkan spec: 149 // 1. "A scalar of size N has a scalar alignment of N." 150 // 2. "A scalar has a base alignment equal to its scalar alignment." 151 // 3. "A scalar, vector or matrix type has an extended alignment equal to its 152 // base alignment." 153 auto bitWidth = scalarType.getIntOrFloatBitWidth(); 154 if (bitWidth == 1) 155 return 1; 156 return bitWidth / 8; 157 } 158 159 bool VulkanLayoutUtils::isLegalType(Type type) { 160 auto ptrType = type.dyn_cast<spirv::PointerType>(); 161 if (!ptrType) { 162 return true; 163 } 164 165 auto storageClass = ptrType.getStorageClass(); 166 auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 167 if (!structType) { 168 return true; 169 } 170 171 switch (storageClass) { 172 case spirv::StorageClass::Uniform: 173 case spirv::StorageClass::StorageBuffer: 174 case spirv::StorageClass::PushConstant: 175 case spirv::StorageClass::PhysicalStorageBuffer: 176 return structType.hasOffset() || !structType.getNumElements(); 177 default: 178 return true; 179 } 180 } 181