1 //===-- LayoutUtils.cpp - Decorate composite type with layout information -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Utilities used to get alignment and layout information
10 // for types in SPIR-V dialect.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
17
18 using namespace mlir;
19
20 spirv::StructType
decorateType(spirv::StructType structType)21 VulkanLayoutUtils::decorateType(spirv::StructType structType) {
22 Size size = 0;
23 Size alignment = 1;
24 return decorateType(structType, size, alignment);
25 }
26
27 spirv::StructType
decorateType(spirv::StructType structType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)28 VulkanLayoutUtils::decorateType(spirv::StructType structType,
29 VulkanLayoutUtils::Size &size,
30 VulkanLayoutUtils::Size &alignment) {
31 if (structType.getNumElements() == 0) {
32 return structType;
33 }
34
35 SmallVector<Type, 4> memberTypes;
36 SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
37 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
38
39 Size structMemberOffset = 0;
40 Size maxMemberAlignment = 1;
41
42 for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) {
43 Size memberSize = 0;
44 Size memberAlignment = 1;
45
46 auto memberType =
47 decorateType(structType.getElementType(i), memberSize, memberAlignment);
48 structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
49 memberTypes.push_back(memberType);
50 offsetInfo.push_back(
51 static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
52 // If the member's size is the max value, it must be the last member and it
53 // must be a runtime array.
54 assert(memberSize != std::numeric_limits<Size>().max() ||
55 (i + 1 == e &&
56 structType.getElementType(i).isa<spirv::RuntimeArrayType>()));
57 // According to the Vulkan spec:
58 // "A structure has a base alignment equal to the largest base alignment of
59 // any of its members."
60 structMemberOffset += memberSize;
61 maxMemberAlignment = std::max(maxMemberAlignment, memberAlignment);
62 }
63
64 // According to the Vulkan spec:
65 // "The Offset decoration of a member must not place it between the end of a
66 // structure or an array and the next multiple of the alignment of that
67 // structure or array."
68 size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
69 alignment = maxMemberAlignment;
70 structType.getMemberDecorations(memberDecorations);
71
72 if (!structType.isIdentified())
73 return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
74
75 // Identified structs are uniqued by identifier so it is not possible
76 // to create 2 structs with the same name but different decorations.
77 return nullptr;
78 }
79
decorateType(Type type,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)80 Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
81 VulkanLayoutUtils::Size &alignment) {
82 if (type.isa<spirv::ScalarType>()) {
83 alignment = getScalarTypeAlignment(type);
84 // Vulkan spec does not specify any padding for a scalar type.
85 size = alignment;
86 return type;
87 }
88 if (auto structType = type.dyn_cast<spirv::StructType>())
89 return decorateType(structType, size, alignment);
90 if (auto arrayType = type.dyn_cast<spirv::ArrayType>())
91 return decorateType(arrayType, size, alignment);
92 if (auto vectorType = type.dyn_cast<VectorType>())
93 return decorateType(vectorType, size, alignment);
94 if (auto arrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
95 size = std::numeric_limits<Size>().max();
96 return decorateType(arrayType, alignment);
97 }
98 llvm_unreachable("unhandled SPIR-V type");
99 }
100
decorateType(VectorType vectorType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)101 Type VulkanLayoutUtils::decorateType(VectorType vectorType,
102 VulkanLayoutUtils::Size &size,
103 VulkanLayoutUtils::Size &alignment) {
104 const auto numElements = vectorType.getNumElements();
105 auto elementType = vectorType.getElementType();
106 Size elementSize = 0;
107 Size elementAlignment = 1;
108
109 auto memberType = decorateType(elementType, elementSize, elementAlignment);
110 // According to the Vulkan spec:
111 // 1. "A two-component vector has a base alignment equal to twice its scalar
112 // alignment."
113 // 2. "A three- or four-component vector has a base alignment equal to four
114 // times its scalar alignment."
115 size = elementSize * numElements;
116 alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4;
117 return VectorType::get(numElements, memberType);
118 }
119
decorateType(spirv::ArrayType arrayType,VulkanLayoutUtils::Size & size,VulkanLayoutUtils::Size & alignment)120 Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
121 VulkanLayoutUtils::Size &size,
122 VulkanLayoutUtils::Size &alignment) {
123 const auto numElements = arrayType.getNumElements();
124 auto elementType = arrayType.getElementType();
125 Size elementSize = 0;
126 Size elementAlignment = 1;
127
128 auto memberType = decorateType(elementType, elementSize, elementAlignment);
129 // According to the Vulkan spec:
130 // "An array has a base alignment equal to the base alignment of its element
131 // type."
132 size = elementSize * numElements;
133 alignment = elementAlignment;
134 return spirv::ArrayType::get(memberType, numElements, elementSize);
135 }
136
decorateType(spirv::RuntimeArrayType arrayType,VulkanLayoutUtils::Size & alignment)137 Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
138 VulkanLayoutUtils::Size &alignment) {
139 auto elementType = arrayType.getElementType();
140 Size elementSize = 0;
141
142 auto memberType = decorateType(elementType, elementSize, alignment);
143 return spirv::RuntimeArrayType::get(memberType, elementSize);
144 }
145
146 VulkanLayoutUtils::Size
getScalarTypeAlignment(Type scalarType)147 VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
148 // According to the Vulkan spec:
149 // 1. "A scalar of size N has a scalar alignment of N."
150 // 2. "A scalar has a base alignment equal to its scalar alignment."
151 // 3. "A scalar, vector or matrix type has an extended alignment equal to its
152 // base alignment."
153 auto bitWidth = scalarType.getIntOrFloatBitWidth();
154 if (bitWidth == 1)
155 return 1;
156 return bitWidth / 8;
157 }
158
isLegalType(Type type)159 bool VulkanLayoutUtils::isLegalType(Type type) {
160 auto ptrType = type.dyn_cast<spirv::PointerType>();
161 if (!ptrType) {
162 return true;
163 }
164
165 auto storageClass = ptrType.getStorageClass();
166 auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
167 if (!structType) {
168 return true;
169 }
170
171 switch (storageClass) {
172 case spirv::StorageClass::Uniform:
173 case spirv::StorageClass::StorageBuffer:
174 case spirv::StorageClass::PushConstant:
175 case spirv::StorageClass::PhysicalStorageBuffer:
176 return structType.hasOffset() || !structType.getNumElements();
177 default:
178 return true;
179 }
180 }
181