1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
19
20 using namespace mlir;
21 using namespace mlir::spirv;
22
23 //===----------------------------------------------------------------------===//
24 // ArrayType
25 //===----------------------------------------------------------------------===//
26
27 struct spirv::detail::ArrayTypeStorage : public TypeStorage {
28 using KeyTy = std::tuple<Type, unsigned, unsigned>;
29
constructspirv::detail::ArrayTypeStorage30 static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
31 const KeyTy &key) {
32 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
33 }
34
operator ==spirv::detail::ArrayTypeStorage35 bool operator==(const KeyTy &key) const {
36 return key == KeyTy(elementType, elementCount, stride);
37 }
38
ArrayTypeStoragespirv::detail::ArrayTypeStorage39 ArrayTypeStorage(const KeyTy &key)
40 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
41 stride(std::get<2>(key)) {}
42
43 Type elementType;
44 unsigned elementCount;
45 unsigned stride;
46 };
47
get(Type elementType,unsigned elementCount)48 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
49 assert(elementCount && "ArrayType needs at least one element");
50 return Base::get(elementType.getContext(), elementType, elementCount,
51 /*stride=*/0);
52 }
53
get(Type elementType,unsigned elementCount,unsigned stride)54 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
55 unsigned stride) {
56 assert(elementCount && "ArrayType needs at least one element");
57 return Base::get(elementType.getContext(), elementType, elementCount, stride);
58 }
59
getNumElements() const60 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
61
getElementType() const62 Type ArrayType::getElementType() const { return getImpl()->elementType; }
63
getArrayStride() const64 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
65
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)66 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
67 Optional<StorageClass> storage) {
68 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
69 }
70
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)71 void ArrayType::getCapabilities(
72 SPIRVType::CapabilityArrayRefVector &capabilities,
73 Optional<StorageClass> storage) {
74 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
75 }
76
getSizeInBytes()77 Optional<int64_t> ArrayType::getSizeInBytes() {
78 auto elementType = getElementType().cast<SPIRVType>();
79 Optional<int64_t> size = elementType.getSizeInBytes();
80 if (!size)
81 return llvm::None;
82 return (*size + getArrayStride()) * getNumElements();
83 }
84
85 //===----------------------------------------------------------------------===//
86 // CompositeType
87 //===----------------------------------------------------------------------===//
88
classof(Type type)89 bool CompositeType::classof(Type type) {
90 if (auto vectorType = type.dyn_cast<VectorType>())
91 return isValid(vectorType);
92 return type
93 .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
94 spirv::RuntimeArrayType, spirv::StructType>();
95 }
96
isValid(VectorType type)97 bool CompositeType::isValid(VectorType type) {
98 switch (type.getNumElements()) {
99 case 2:
100 case 3:
101 case 4:
102 case 8:
103 case 16:
104 break;
105 default:
106 return false;
107 }
108 return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
109 }
110
getElementType(unsigned index) const111 Type CompositeType::getElementType(unsigned index) const {
112 return TypeSwitch<Type, Type>(*this)
113 .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
114 [](auto type) { return type.getElementType(); })
115 .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
116 .Case<StructType>(
117 [index](StructType type) { return type.getElementType(index); })
118 .Default(
119 [](Type) -> Type { llvm_unreachable("invalid composite type"); });
120 }
121
getNumElements() const122 unsigned CompositeType::getNumElements() const {
123 if (auto arrayType = dyn_cast<ArrayType>())
124 return arrayType.getNumElements();
125 if (auto matrixType = dyn_cast<MatrixType>())
126 return matrixType.getNumColumns();
127 if (auto structType = dyn_cast<StructType>())
128 return structType.getNumElements();
129 if (auto vectorType = dyn_cast<VectorType>())
130 return vectorType.getNumElements();
131 if (isa<CooperativeMatrixNVType>()) {
132 llvm_unreachable(
133 "invalid to query number of elements of spirv::CooperativeMatrix type");
134 }
135 if (isa<RuntimeArrayType>()) {
136 llvm_unreachable(
137 "invalid to query number of elements of spirv::RuntimeArray type");
138 }
139 llvm_unreachable("invalid composite type");
140 }
141
hasCompileTimeKnownNumElements() const142 bool CompositeType::hasCompileTimeKnownNumElements() const {
143 return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
144 }
145
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)146 void CompositeType::getExtensions(
147 SPIRVType::ExtensionArrayRefVector &extensions,
148 Optional<StorageClass> storage) {
149 TypeSwitch<Type>(*this)
150 .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
151 StructType>(
152 [&](auto type) { type.getExtensions(extensions, storage); })
153 .Case<VectorType>([&](VectorType type) {
154 return type.getElementType().cast<ScalarType>().getExtensions(
155 extensions, storage);
156 })
157 .Default([](Type) { llvm_unreachable("invalid composite type"); });
158 }
159
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)160 void CompositeType::getCapabilities(
161 SPIRVType::CapabilityArrayRefVector &capabilities,
162 Optional<StorageClass> storage) {
163 TypeSwitch<Type>(*this)
164 .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
165 StructType>(
166 [&](auto type) { type.getCapabilities(capabilities, storage); })
167 .Case<VectorType>([&](VectorType type) {
168 auto vecSize = getNumElements();
169 if (vecSize == 8 || vecSize == 16) {
170 static const Capability caps[] = {Capability::Vector16};
171 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
172 capabilities.push_back(ref);
173 }
174 return type.getElementType().cast<ScalarType>().getCapabilities(
175 capabilities, storage);
176 })
177 .Default([](Type) { llvm_unreachable("invalid composite type"); });
178 }
179
getSizeInBytes()180 Optional<int64_t> CompositeType::getSizeInBytes() {
181 if (auto arrayType = dyn_cast<ArrayType>())
182 return arrayType.getSizeInBytes();
183 if (auto structType = dyn_cast<StructType>())
184 return structType.getSizeInBytes();
185 if (auto vectorType = dyn_cast<VectorType>()) {
186 Optional<int64_t> elementSize =
187 vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
188 if (!elementSize)
189 return llvm::None;
190 return *elementSize * vectorType.getNumElements();
191 }
192 return llvm::None;
193 }
194
195 //===----------------------------------------------------------------------===//
196 // CooperativeMatrixType
197 //===----------------------------------------------------------------------===//
198
199 struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
200 using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
201
202 static CooperativeMatrixTypeStorage *
constructspirv::detail::CooperativeMatrixTypeStorage203 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
204 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
205 CooperativeMatrixTypeStorage(key);
206 }
207
operator ==spirv::detail::CooperativeMatrixTypeStorage208 bool operator==(const KeyTy &key) const {
209 return key == KeyTy(elementType, scope, rows, columns);
210 }
211
CooperativeMatrixTypeStoragespirv::detail::CooperativeMatrixTypeStorage212 CooperativeMatrixTypeStorage(const KeyTy &key)
213 : elementType(std::get<0>(key)), rows(std::get<2>(key)),
214 columns(std::get<3>(key)), scope(std::get<1>(key)) {}
215
216 Type elementType;
217 unsigned rows;
218 unsigned columns;
219 Scope scope;
220 };
221
get(Type elementType,Scope scope,unsigned rows,unsigned columns)222 CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
223 Scope scope, unsigned rows,
224 unsigned columns) {
225 return Base::get(elementType.getContext(), elementType, scope, rows, columns);
226 }
227
getElementType() const228 Type CooperativeMatrixNVType::getElementType() const {
229 return getImpl()->elementType;
230 }
231
getScope() const232 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
233
getRows() const234 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
235
getColumns() const236 unsigned CooperativeMatrixNVType::getColumns() const {
237 return getImpl()->columns;
238 }
239
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)240 void CooperativeMatrixNVType::getExtensions(
241 SPIRVType::ExtensionArrayRefVector &extensions,
242 Optional<StorageClass> storage) {
243 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
244 static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
245 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
246 extensions.push_back(ref);
247 }
248
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)249 void CooperativeMatrixNVType::getCapabilities(
250 SPIRVType::CapabilityArrayRefVector &capabilities,
251 Optional<StorageClass> storage) {
252 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
253 static const Capability caps[] = {Capability::CooperativeMatrixNV};
254 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
255 capabilities.push_back(ref);
256 }
257
258 //===----------------------------------------------------------------------===//
259 // ImageType
260 //===----------------------------------------------------------------------===//
261
262 template <typename T>
getNumBits()263 static constexpr unsigned getNumBits() {
264 return 0;
265 }
266 template <>
getNumBits()267 constexpr unsigned getNumBits<Dim>() {
268 static_assert((1 << 3) > getMaxEnumValForDim(),
269 "Not enough bits to encode Dim value");
270 return 3;
271 }
272 template <>
getNumBits()273 constexpr unsigned getNumBits<ImageDepthInfo>() {
274 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
275 "Not enough bits to encode ImageDepthInfo value");
276 return 2;
277 }
278 template <>
getNumBits()279 constexpr unsigned getNumBits<ImageArrayedInfo>() {
280 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
281 "Not enough bits to encode ImageArrayedInfo value");
282 return 1;
283 }
284 template <>
getNumBits()285 constexpr unsigned getNumBits<ImageSamplingInfo>() {
286 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
287 "Not enough bits to encode ImageSamplingInfo value");
288 return 1;
289 }
290 template <>
getNumBits()291 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
292 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
293 "Not enough bits to encode ImageSamplerUseInfo value");
294 return 2;
295 }
296 template <>
getNumBits()297 constexpr unsigned getNumBits<ImageFormat>() {
298 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
299 "Not enough bits to encode ImageFormat value");
300 return 6;
301 }
302
303 struct spirv::detail::ImageTypeStorage : public TypeStorage {
304 public:
305 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
306 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
307
constructspirv::detail::ImageTypeStorage308 static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
309 const KeyTy &key) {
310 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
311 }
312
operator ==spirv::detail::ImageTypeStorage313 bool operator==(const KeyTy &key) const {
314 return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
315 samplerUseInfo, format);
316 }
317
ImageTypeStoragespirv::detail::ImageTypeStorage318 ImageTypeStorage(const KeyTy &key)
319 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
320 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
321 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
322 format(std::get<6>(key)) {}
323
324 Type elementType;
325 Dim dim : getNumBits<Dim>();
326 ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
327 ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
328 ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
329 ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
330 ImageFormat format : getNumBits<ImageFormat>();
331 };
332
333 ImageType
get(std::tuple<Type,Dim,ImageDepthInfo,ImageArrayedInfo,ImageSamplingInfo,ImageSamplerUseInfo,ImageFormat> value)334 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
335 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
336 value) {
337 return Base::get(std::get<0>(value).getContext(), value);
338 }
339
getElementType() const340 Type ImageType::getElementType() const { return getImpl()->elementType; }
341
getDim() const342 Dim ImageType::getDim() const { return getImpl()->dim; }
343
getDepthInfo() const344 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
345
getArrayedInfo() const346 ImageArrayedInfo ImageType::getArrayedInfo() const {
347 return getImpl()->arrayedInfo;
348 }
349
getSamplingInfo() const350 ImageSamplingInfo ImageType::getSamplingInfo() const {
351 return getImpl()->samplingInfo;
352 }
353
getSamplerUseInfo() const354 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
355 return getImpl()->samplerUseInfo;
356 }
357
getImageFormat() const358 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
359
getExtensions(SPIRVType::ExtensionArrayRefVector &,Optional<StorageClass>)360 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
361 Optional<StorageClass>) {
362 // Image types do not require extra extensions thus far.
363 }
364
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass>)365 void ImageType::getCapabilities(
366 SPIRVType::CapabilityArrayRefVector &capabilities, Optional<StorageClass>) {
367 if (auto dimCaps = spirv::getCapabilities(getDim()))
368 capabilities.push_back(*dimCaps);
369
370 if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
371 capabilities.push_back(*fmtCaps);
372 }
373
374 //===----------------------------------------------------------------------===//
375 // PointerType
376 //===----------------------------------------------------------------------===//
377
378 struct spirv::detail::PointerTypeStorage : public TypeStorage {
379 // (Type, StorageClass) as the key: Type stored in this struct, and
380 // StorageClass stored as TypeStorage's subclass data.
381 using KeyTy = std::pair<Type, StorageClass>;
382
constructspirv::detail::PointerTypeStorage383 static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
384 const KeyTy &key) {
385 return new (allocator.allocate<PointerTypeStorage>())
386 PointerTypeStorage(key);
387 }
388
operator ==spirv::detail::PointerTypeStorage389 bool operator==(const KeyTy &key) const {
390 return key == KeyTy(pointeeType, storageClass);
391 }
392
PointerTypeStoragespirv::detail::PointerTypeStorage393 PointerTypeStorage(const KeyTy &key)
394 : pointeeType(key.first), storageClass(key.second) {}
395
396 Type pointeeType;
397 StorageClass storageClass;
398 };
399
get(Type pointeeType,StorageClass storageClass)400 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
401 return Base::get(pointeeType.getContext(), pointeeType, storageClass);
402 }
403
getPointeeType() const404 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
405
getStorageClass() const406 StorageClass PointerType::getStorageClass() const {
407 return getImpl()->storageClass;
408 }
409
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)410 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
411 Optional<StorageClass> storage) {
412 // Use this pointer type's storage class because this pointer indicates we are
413 // using the pointee type in that specific storage class.
414 getPointeeType().cast<SPIRVType>().getExtensions(extensions,
415 getStorageClass());
416
417 if (auto scExts = spirv::getExtensions(getStorageClass()))
418 extensions.push_back(*scExts);
419 }
420
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)421 void PointerType::getCapabilities(
422 SPIRVType::CapabilityArrayRefVector &capabilities,
423 Optional<StorageClass> storage) {
424 // Use this pointer type's storage class because this pointer indicates we are
425 // using the pointee type in that specific storage class.
426 getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
427 getStorageClass());
428
429 if (auto scCaps = spirv::getCapabilities(getStorageClass()))
430 capabilities.push_back(*scCaps);
431 }
432
433 //===----------------------------------------------------------------------===//
434 // RuntimeArrayType
435 //===----------------------------------------------------------------------===//
436
437 struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
438 using KeyTy = std::pair<Type, unsigned>;
439
constructspirv::detail::RuntimeArrayTypeStorage440 static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
441 const KeyTy &key) {
442 return new (allocator.allocate<RuntimeArrayTypeStorage>())
443 RuntimeArrayTypeStorage(key);
444 }
445
operator ==spirv::detail::RuntimeArrayTypeStorage446 bool operator==(const KeyTy &key) const {
447 return key == KeyTy(elementType, stride);
448 }
449
RuntimeArrayTypeStoragespirv::detail::RuntimeArrayTypeStorage450 RuntimeArrayTypeStorage(const KeyTy &key)
451 : elementType(key.first), stride(key.second) {}
452
453 Type elementType;
454 unsigned stride;
455 };
456
get(Type elementType)457 RuntimeArrayType RuntimeArrayType::get(Type elementType) {
458 return Base::get(elementType.getContext(), elementType, /*stride=*/0);
459 }
460
get(Type elementType,unsigned stride)461 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
462 return Base::get(elementType.getContext(), elementType, stride);
463 }
464
getElementType() const465 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
466
getArrayStride() const467 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
468
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)469 void RuntimeArrayType::getExtensions(
470 SPIRVType::ExtensionArrayRefVector &extensions,
471 Optional<StorageClass> storage) {
472 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
473 }
474
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)475 void RuntimeArrayType::getCapabilities(
476 SPIRVType::CapabilityArrayRefVector &capabilities,
477 Optional<StorageClass> storage) {
478 {
479 static const Capability caps[] = {Capability::Shader};
480 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
481 capabilities.push_back(ref);
482 }
483 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
484 }
485
486 //===----------------------------------------------------------------------===//
487 // ScalarType
488 //===----------------------------------------------------------------------===//
489
classof(Type type)490 bool ScalarType::classof(Type type) {
491 if (auto floatType = type.dyn_cast<FloatType>()) {
492 return isValid(floatType);
493 }
494 if (auto intType = type.dyn_cast<IntegerType>()) {
495 return isValid(intType);
496 }
497 return false;
498 }
499
isValid(FloatType type)500 bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
501
isValid(IntegerType type)502 bool ScalarType::isValid(IntegerType type) {
503 switch (type.getWidth()) {
504 case 1:
505 case 8:
506 case 16:
507 case 32:
508 case 64:
509 return true;
510 default:
511 return false;
512 }
513 }
514
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)515 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
516 Optional<StorageClass> storage) {
517 // 8- or 16-bit integer/floating-point numbers will require extra extensions
518 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
519 // SPV_KHR_8bit_storage for more details.
520 if (!storage)
521 return;
522
523 switch (*storage) {
524 case StorageClass::PushConstant:
525 case StorageClass::StorageBuffer:
526 case StorageClass::Uniform:
527 if (getIntOrFloatBitWidth() == 8) {
528 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
529 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
530 extensions.push_back(ref);
531 }
532 LLVM_FALLTHROUGH;
533 case StorageClass::Input:
534 case StorageClass::Output:
535 if (getIntOrFloatBitWidth() == 16) {
536 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
537 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
538 extensions.push_back(ref);
539 }
540 break;
541 default:
542 break;
543 }
544 }
545
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)546 void ScalarType::getCapabilities(
547 SPIRVType::CapabilityArrayRefVector &capabilities,
548 Optional<StorageClass> storage) {
549 unsigned bitwidth = getIntOrFloatBitWidth();
550
551 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
552 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
553 // SPV_KHR_8bit_storage for more details.
554
555 #define STORAGE_CASE(storage, cap8, cap16) \
556 case StorageClass::storage: { \
557 if (bitwidth == 8) { \
558 static const Capability caps[] = {Capability::cap8}; \
559 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
560 capabilities.push_back(ref); \
561 return; \
562 } \
563 if (bitwidth == 16) { \
564 static const Capability caps[] = {Capability::cap16}; \
565 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
566 capabilities.push_back(ref); \
567 return; \
568 } \
569 /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
570 /* storage classes. Fall through to the next section. */ \
571 } break
572
573 // This part only handles the cases where special bitwidths appearing in
574 // interface storage classes.
575 if (storage) {
576 switch (*storage) {
577 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
578 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
579 StorageBuffer16BitAccess);
580 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
581 StorageUniform16);
582 case StorageClass::Input:
583 case StorageClass::Output: {
584 if (bitwidth == 16) {
585 static const Capability caps[] = {Capability::StorageInputOutput16};
586 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
587 capabilities.push_back(ref);
588 return;
589 }
590 break;
591 }
592 default:
593 break;
594 }
595 }
596 #undef STORAGE_CASE
597
598 // For other non-interface storage classes, require a different set of
599 // capabilities for special bitwidths.
600
601 #define WIDTH_CASE(type, width) \
602 case width: { \
603 static const Capability caps[] = {Capability::type##width}; \
604 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
605 capabilities.push_back(ref); \
606 } break
607
608 if (auto intType = dyn_cast<IntegerType>()) {
609 switch (bitwidth) {
610 WIDTH_CASE(Int, 8);
611 WIDTH_CASE(Int, 16);
612 WIDTH_CASE(Int, 64);
613 case 1:
614 case 32:
615 break;
616 default:
617 llvm_unreachable("invalid bitwidth to getCapabilities");
618 }
619 } else {
620 assert(isa<FloatType>());
621 switch (bitwidth) {
622 WIDTH_CASE(Float, 16);
623 WIDTH_CASE(Float, 64);
624 case 32:
625 break;
626 default:
627 llvm_unreachable("invalid bitwidth to getCapabilities");
628 }
629 }
630
631 #undef WIDTH_CASE
632 }
633
getSizeInBytes()634 Optional<int64_t> ScalarType::getSizeInBytes() {
635 auto bitWidth = getIntOrFloatBitWidth();
636 // According to the SPIR-V spec:
637 // "There is no physical size or bit pattern defined for values with boolean
638 // type. If they are stored (in conjunction with OpVariable), they can only
639 // be used with logical addressing operations, not physical, and only with
640 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
641 // Private, Function, Input, and Output."
642 if (bitWidth == 1)
643 return llvm::None;
644 return bitWidth / 8;
645 }
646
647 //===----------------------------------------------------------------------===//
648 // SPIRVType
649 //===----------------------------------------------------------------------===//
650
classof(Type type)651 bool SPIRVType::classof(Type type) {
652 // Allow SPIR-V dialect types
653 if (llvm::isa<SPIRVDialect>(type.getDialect()))
654 return true;
655 if (type.isa<ScalarType>())
656 return true;
657 if (auto vectorType = type.dyn_cast<VectorType>())
658 return CompositeType::isValid(vectorType);
659 return false;
660 }
661
isScalarOrVector()662 bool SPIRVType::isScalarOrVector() {
663 return isIntOrFloat() || isa<VectorType>();
664 }
665
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)666 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
667 Optional<StorageClass> storage) {
668 if (auto scalarType = dyn_cast<ScalarType>()) {
669 scalarType.getExtensions(extensions, storage);
670 } else if (auto compositeType = dyn_cast<CompositeType>()) {
671 compositeType.getExtensions(extensions, storage);
672 } else if (auto imageType = dyn_cast<ImageType>()) {
673 imageType.getExtensions(extensions, storage);
674 } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
675 sampledImageType.getExtensions(extensions, storage);
676 } else if (auto matrixType = dyn_cast<MatrixType>()) {
677 matrixType.getExtensions(extensions, storage);
678 } else if (auto ptrType = dyn_cast<PointerType>()) {
679 ptrType.getExtensions(extensions, storage);
680 } else {
681 llvm_unreachable("invalid SPIR-V Type to getExtensions");
682 }
683 }
684
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)685 void SPIRVType::getCapabilities(
686 SPIRVType::CapabilityArrayRefVector &capabilities,
687 Optional<StorageClass> storage) {
688 if (auto scalarType = dyn_cast<ScalarType>()) {
689 scalarType.getCapabilities(capabilities, storage);
690 } else if (auto compositeType = dyn_cast<CompositeType>()) {
691 compositeType.getCapabilities(capabilities, storage);
692 } else if (auto imageType = dyn_cast<ImageType>()) {
693 imageType.getCapabilities(capabilities, storage);
694 } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
695 sampledImageType.getCapabilities(capabilities, storage);
696 } else if (auto matrixType = dyn_cast<MatrixType>()) {
697 matrixType.getCapabilities(capabilities, storage);
698 } else if (auto ptrType = dyn_cast<PointerType>()) {
699 ptrType.getCapabilities(capabilities, storage);
700 } else {
701 llvm_unreachable("invalid SPIR-V Type to getCapabilities");
702 }
703 }
704
getSizeInBytes()705 Optional<int64_t> SPIRVType::getSizeInBytes() {
706 if (auto scalarType = dyn_cast<ScalarType>())
707 return scalarType.getSizeInBytes();
708 if (auto compositeType = dyn_cast<CompositeType>())
709 return compositeType.getSizeInBytes();
710 return llvm::None;
711 }
712
713 //===----------------------------------------------------------------------===//
714 // SampledImageType
715 //===----------------------------------------------------------------------===//
716 struct spirv::detail::SampledImageTypeStorage : public TypeStorage {
717 using KeyTy = Type;
718
SampledImageTypeStoragespirv::detail::SampledImageTypeStorage719 SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
720
operator ==spirv::detail::SampledImageTypeStorage721 bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
722
constructspirv::detail::SampledImageTypeStorage723 static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator,
724 const KeyTy &key) {
725 return new (allocator.allocate<SampledImageTypeStorage>())
726 SampledImageTypeStorage(key);
727 }
728
729 Type imageType;
730 };
731
get(Type imageType)732 SampledImageType SampledImageType::get(Type imageType) {
733 return Base::get(imageType.getContext(), imageType);
734 }
735
736 SampledImageType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type imageType)737 SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
738 Type imageType) {
739 return Base::getChecked(emitError, imageType.getContext(), imageType);
740 }
741
getImageType() const742 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
743
744 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type imageType)745 SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
746 Type imageType) {
747 if (!imageType.isa<ImageType>())
748 return emitError() << "expected image type";
749
750 return success();
751 }
752
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)753 void SampledImageType::getExtensions(
754 SPIRVType::ExtensionArrayRefVector &extensions,
755 Optional<StorageClass> storage) {
756 getImageType().cast<ImageType>().getExtensions(extensions, storage);
757 }
758
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)759 void SampledImageType::getCapabilities(
760 SPIRVType::CapabilityArrayRefVector &capabilities,
761 Optional<StorageClass> storage) {
762 getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
763 }
764
765 //===----------------------------------------------------------------------===//
766 // StructType
767 //===----------------------------------------------------------------------===//
768
769 /// Type storage for SPIR-V structure types:
770 ///
771 /// Structures are uniqued using:
772 /// - for identified structs:
773 /// - a string identifier;
774 /// - for literal structs:
775 /// - a list of member types;
776 /// - a list of member offset info;
777 /// - a list of member decoration info.
778 ///
779 /// Identified structures only have a mutable component consisting of:
780 /// - a list of member types;
781 /// - a list of member offset info;
782 /// - a list of member decoration info.
783 struct spirv::detail::StructTypeStorage : public TypeStorage {
784 /// Construct a storage object for an identified struct type. A struct type
785 /// associated with such storage must call StructType::trySetBody(...) later
786 /// in order to mutate the storage object providing the actual content.
StructTypeStoragespirv::detail::StructTypeStorage787 StructTypeStorage(StringRef identifier)
788 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
789 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
790 identifier(identifier) {}
791
792 /// Construct a storage object for a literal struct type. A struct type
793 /// associated with such storage is immutable.
StructTypeStoragespirv::detail::StructTypeStorage794 StructTypeStorage(
795 unsigned numMembers, Type const *memberTypes,
796 StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
797 StructType::MemberDecorationInfo const *memberDecorationsInfo)
798 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
799 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
800 memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()) {}
801
802 /// A storage key is divided into 2 parts:
803 /// - for identified structs:
804 /// - a StringRef representing the struct identifier;
805 /// - for literal structs:
806 /// - an ArrayRef<Type> for member types;
807 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
808 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
809 /// info.
810 ///
811 /// An identified struct type is uniqued only by the first part (field 0)
812 /// of the key.
813 ///
814 /// A literal struct type is uniqued only by the second part (fields 1, 2, and
815 /// 3) of the key. The identifier field (field 0) must be empty.
816 using KeyTy =
817 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
818 ArrayRef<StructType::MemberDecorationInfo>>;
819
820 /// For identified structs, return true if the given key contains the same
821 /// identifier.
822 ///
823 /// For literal structs, return true if the given key contains a matching list
824 /// of member types + offset info + decoration info.
operator ==spirv::detail::StructTypeStorage825 bool operator==(const KeyTy &key) const {
826 if (isIdentified()) {
827 // Identified types are uniqued by their identifier.
828 return getIdentifier() == std::get<0>(key);
829 }
830
831 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
832 getMemberDecorationsInfo());
833 }
834
835 /// If the given key contains a non-empty identifier, this method constructs
836 /// an identified struct and leaves the rest of the struct type data to be set
837 /// through a later call to StructType::trySetBody(...).
838 ///
839 /// If, on the other hand, the key contains an empty identifier, a literal
840 /// struct is constructed using the other fields of the key.
constructspirv::detail::StructTypeStorage841 static StructTypeStorage *construct(TypeStorageAllocator &allocator,
842 const KeyTy &key) {
843 StringRef keyIdentifier = std::get<0>(key);
844
845 if (!keyIdentifier.empty()) {
846 StringRef identifier = allocator.copyInto(keyIdentifier);
847
848 // Identified StructType body/members will be set through trySetBody(...)
849 // later.
850 return new (allocator.allocate<StructTypeStorage>())
851 StructTypeStorage(identifier);
852 }
853
854 ArrayRef<Type> keyTypes = std::get<1>(key);
855
856 // Copy the member type and layout information into the bump pointer
857 const Type *typesList = nullptr;
858 if (!keyTypes.empty()) {
859 typesList = allocator.copyInto(keyTypes).data();
860 }
861
862 const StructType::OffsetInfo *offsetInfoList = nullptr;
863 if (!std::get<2>(key).empty()) {
864 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
865 assert(keyOffsetInfo.size() == keyTypes.size() &&
866 "size of offset information must be same as the size of number of "
867 "elements");
868 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
869 }
870
871 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
872 unsigned numMemberDecorations = 0;
873 if (!std::get<3>(key).empty()) {
874 auto keyMemberDecorations = std::get<3>(key);
875 numMemberDecorations = keyMemberDecorations.size();
876 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
877 }
878
879 return new (allocator.allocate<StructTypeStorage>())
880 StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
881 numMemberDecorations, memberDecorationList);
882 }
883
getMemberTypesspirv::detail::StructTypeStorage884 ArrayRef<Type> getMemberTypes() const {
885 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
886 }
887
getOffsetInfospirv::detail::StructTypeStorage888 ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
889 if (offsetInfo) {
890 return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
891 }
892 return {};
893 }
894
getMemberDecorationsInfospirv::detail::StructTypeStorage895 ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
896 if (memberDecorationsInfo) {
897 return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
898 numMemberDecorations);
899 }
900 return {};
901 }
902
getIdentifierspirv::detail::StructTypeStorage903 StringRef getIdentifier() const { return identifier; }
904
isIdentifiedspirv::detail::StructTypeStorage905 bool isIdentified() const { return !identifier.empty(); }
906
907 /// Sets the struct type content for identified structs. Calling this method
908 /// is only valid for identified structs.
909 ///
910 /// Fails under the following conditions:
911 /// - If called for a literal struct;
912 /// - If called for an identified struct whose body was set before (through a
913 /// call to this method) but with different contents from the passed
914 /// arguments.
mutatespirv::detail::StructTypeStorage915 LogicalResult mutate(
916 TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
917 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
918 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
919 if (!isIdentified())
920 return failure();
921
922 if (memberTypesAndIsBodySet.getInt() &&
923 (getMemberTypes() != structMemberTypes ||
924 getOffsetInfo() != structOffsetInfo ||
925 getMemberDecorationsInfo() != structMemberDecorationInfo))
926 return failure();
927
928 memberTypesAndIsBodySet.setInt(true);
929 numMembers = structMemberTypes.size();
930
931 // Copy the member type and layout information into the bump pointer.
932 if (!structMemberTypes.empty())
933 memberTypesAndIsBodySet.setPointer(
934 allocator.copyInto(structMemberTypes).data());
935
936 if (!structOffsetInfo.empty()) {
937 assert(structOffsetInfo.size() == structMemberTypes.size() &&
938 "size of offset information must be same as the size of number of "
939 "elements");
940 offsetInfo = allocator.copyInto(structOffsetInfo).data();
941 }
942
943 if (!structMemberDecorationInfo.empty()) {
944 numMemberDecorations = structMemberDecorationInfo.size();
945 memberDecorationsInfo =
946 allocator.copyInto(structMemberDecorationInfo).data();
947 }
948
949 return success();
950 }
951
952 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
953 StructType::OffsetInfo const *offsetInfo;
954 unsigned numMembers;
955 unsigned numMemberDecorations;
956 StructType::MemberDecorationInfo const *memberDecorationsInfo;
957 StringRef identifier;
958 };
959
960 StructType
get(ArrayRef<Type> memberTypes,ArrayRef<StructType::OffsetInfo> offsetInfo,ArrayRef<StructType::MemberDecorationInfo> memberDecorations)961 StructType::get(ArrayRef<Type> memberTypes,
962 ArrayRef<StructType::OffsetInfo> offsetInfo,
963 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
964 assert(!memberTypes.empty() && "Struct needs at least one member type");
965 // Sort the decorations.
966 SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
967 memberDecorations.begin(), memberDecorations.end());
968 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
969 return Base::get(memberTypes.vec().front().getContext(),
970 /*identifier=*/StringRef(), memberTypes, offsetInfo,
971 sortedDecorations);
972 }
973
getIdentified(MLIRContext * context,StringRef identifier)974 StructType StructType::getIdentified(MLIRContext *context,
975 StringRef identifier) {
976 assert(!identifier.empty() &&
977 "StructType identifier must be non-empty string");
978
979 return Base::get(context, identifier, ArrayRef<Type>(),
980 ArrayRef<StructType::OffsetInfo>(),
981 ArrayRef<StructType::MemberDecorationInfo>());
982 }
983
getEmpty(MLIRContext * context,StringRef identifier)984 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
985 StructType newStructType = Base::get(
986 context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
987 ArrayRef<StructType::MemberDecorationInfo>());
988 // Set an empty body in case this is a identified struct.
989 if (newStructType.isIdentified() &&
990 failed(newStructType.trySetBody(
991 ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
992 ArrayRef<StructType::MemberDecorationInfo>())))
993 return StructType();
994
995 return newStructType;
996 }
997
getIdentifier() const998 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
999
isIdentified() const1000 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1001
getNumElements() const1002 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1003
getElementType(unsigned index) const1004 Type StructType::getElementType(unsigned index) const {
1005 assert(getNumElements() > index && "member index out of range");
1006 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1007 }
1008
getElementTypes() const1009 StructType::ElementTypeRange StructType::getElementTypes() const {
1010 return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1011 getNumElements());
1012 }
1013
hasOffset() const1014 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1015
getMemberOffset(unsigned index) const1016 uint64_t StructType::getMemberOffset(unsigned index) const {
1017 assert(getNumElements() > index && "member index out of range");
1018 return getImpl()->offsetInfo[index];
1019 }
1020
getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo> & memberDecorations) const1021 void StructType::getMemberDecorations(
1022 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
1023 const {
1024 memberDecorations.clear();
1025 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1026 memberDecorations.append(implMemberDecorations.begin(),
1027 implMemberDecorations.end());
1028 }
1029
getMemberDecorations(unsigned index,SmallVectorImpl<StructType::MemberDecorationInfo> & decorationsInfo) const1030 void StructType::getMemberDecorations(
1031 unsigned index,
1032 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1033 assert(getNumElements() > index && "member index out of range");
1034 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1035 decorationsInfo.clear();
1036 for (const auto &memberDecoration : memberDecorations) {
1037 if (memberDecoration.memberIndex == index) {
1038 decorationsInfo.push_back(memberDecoration);
1039 }
1040 if (memberDecoration.memberIndex > index) {
1041 // Early exit since the decorations are stored sorted.
1042 return;
1043 }
1044 }
1045 }
1046
1047 LogicalResult
trySetBody(ArrayRef<Type> memberTypes,ArrayRef<OffsetInfo> offsetInfo,ArrayRef<MemberDecorationInfo> memberDecorations)1048 StructType::trySetBody(ArrayRef<Type> memberTypes,
1049 ArrayRef<OffsetInfo> offsetInfo,
1050 ArrayRef<MemberDecorationInfo> memberDecorations) {
1051 return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1052 }
1053
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1054 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1055 Optional<StorageClass> storage) {
1056 for (Type elementType : getElementTypes())
1057 elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1058 }
1059
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1060 void StructType::getCapabilities(
1061 SPIRVType::CapabilityArrayRefVector &capabilities,
1062 Optional<StorageClass> storage) {
1063 for (Type elementType : getElementTypes())
1064 elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1065 }
1066
hash_value(const StructType::MemberDecorationInfo & memberDecorationInfo)1067 llvm::hash_code spirv::hash_value(
1068 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1069 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1070 memberDecorationInfo.decoration);
1071 }
1072
1073 //===----------------------------------------------------------------------===//
1074 // MatrixType
1075 //===----------------------------------------------------------------------===//
1076
1077 struct spirv::detail::MatrixTypeStorage : public TypeStorage {
MatrixTypeStoragespirv::detail::MatrixTypeStorage1078 MatrixTypeStorage(Type columnType, uint32_t columnCount)
1079 : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
1080
1081 using KeyTy = std::tuple<Type, uint32_t>;
1082
constructspirv::detail::MatrixTypeStorage1083 static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
1084 const KeyTy &key) {
1085
1086 // Initialize the memory using placement new.
1087 return new (allocator.allocate<MatrixTypeStorage>())
1088 MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1089 }
1090
operator ==spirv::detail::MatrixTypeStorage1091 bool operator==(const KeyTy &key) const {
1092 return key == KeyTy(columnType, columnCount);
1093 }
1094
1095 Type columnType;
1096 const uint32_t columnCount;
1097 };
1098
get(Type columnType,uint32_t columnCount)1099 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1100 return Base::get(columnType.getContext(), columnType, columnCount);
1101 }
1102
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type columnType,uint32_t columnCount)1103 MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1104 Type columnType, uint32_t columnCount) {
1105 return Base::getChecked(emitError, columnType.getContext(), columnType,
1106 columnCount);
1107 }
1108
verify(function_ref<InFlightDiagnostic ()> emitError,Type columnType,uint32_t columnCount)1109 LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
1110 Type columnType, uint32_t columnCount) {
1111 if (columnCount < 2 || columnCount > 4)
1112 return emitError() << "matrix can have 2, 3, or 4 columns only";
1113
1114 if (!isValidColumnType(columnType))
1115 return emitError() << "matrix columns must be vectors of floats";
1116
1117 /// The underlying vectors (columns) must be of size 2, 3, or 4
1118 ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1119 if (columnShape.size() != 1)
1120 return emitError() << "matrix columns must be 1D vectors";
1121
1122 if (columnShape[0] < 2 || columnShape[0] > 4)
1123 return emitError() << "matrix columns must be of size 2, 3, or 4";
1124
1125 return success();
1126 }
1127
1128 /// Returns true if the matrix elements are vectors of float elements
isValidColumnType(Type columnType)1129 bool MatrixType::isValidColumnType(Type columnType) {
1130 if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1131 if (vectorType.getElementType().isa<FloatType>())
1132 return true;
1133 }
1134 return false;
1135 }
1136
getColumnType() const1137 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1138
getElementType() const1139 Type MatrixType::getElementType() const {
1140 return getImpl()->columnType.cast<VectorType>().getElementType();
1141 }
1142
getNumColumns() const1143 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1144
getNumRows() const1145 unsigned MatrixType::getNumRows() const {
1146 return getImpl()->columnType.cast<VectorType>().getShape()[0];
1147 }
1148
getNumElements() const1149 unsigned MatrixType::getNumElements() const {
1150 return (getImpl()->columnCount) * getNumRows();
1151 }
1152
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1153 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1154 Optional<StorageClass> storage) {
1155 getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1156 }
1157
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1158 void MatrixType::getCapabilities(
1159 SPIRVType::CapabilityArrayRefVector &capabilities,
1160 Optional<StorageClass> storage) {
1161 {
1162 static const Capability caps[] = {Capability::Matrix};
1163 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
1164 capabilities.push_back(ref);
1165 }
1166 // Add any capabilities associated with the underlying vectors (i.e., columns)
1167 getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
1168 }
1169
1170 //===----------------------------------------------------------------------===//
1171 // SPIR-V Dialect
1172 //===----------------------------------------------------------------------===//
1173
registerTypes()1174 void SPIRVDialect::registerTypes() {
1175 addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
1176 PointerType, RuntimeArrayType, SampledImageType, StructType>();
1177 }
1178