1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 using namespace mlir::spirv;
22 
23 //===----------------------------------------------------------------------===//
24 // ArrayType
25 //===----------------------------------------------------------------------===//
26 
27 struct spirv::detail::ArrayTypeStorage : public TypeStorage {
28   using KeyTy = std::tuple<Type, unsigned, unsigned>;
29 
constructspirv::detail::ArrayTypeStorage30   static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
31                                      const KeyTy &key) {
32     return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
33   }
34 
operator ==spirv::detail::ArrayTypeStorage35   bool operator==(const KeyTy &key) const {
36     return key == KeyTy(elementType, elementCount, stride);
37   }
38 
ArrayTypeStoragespirv::detail::ArrayTypeStorage39   ArrayTypeStorage(const KeyTy &key)
40       : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
41         stride(std::get<2>(key)) {}
42 
43   Type elementType;
44   unsigned elementCount;
45   unsigned stride;
46 };
47 
get(Type elementType,unsigned elementCount)48 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
49   assert(elementCount && "ArrayType needs at least one element");
50   return Base::get(elementType.getContext(), elementType, elementCount,
51                    /*stride=*/0);
52 }
53 
get(Type elementType,unsigned elementCount,unsigned stride)54 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
55                          unsigned stride) {
56   assert(elementCount && "ArrayType needs at least one element");
57   return Base::get(elementType.getContext(), elementType, elementCount, stride);
58 }
59 
getNumElements() const60 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
61 
getElementType() const62 Type ArrayType::getElementType() const { return getImpl()->elementType; }
63 
getArrayStride() const64 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
65 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)66 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
67                               Optional<StorageClass> storage) {
68   getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
69 }
70 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)71 void ArrayType::getCapabilities(
72     SPIRVType::CapabilityArrayRefVector &capabilities,
73     Optional<StorageClass> storage) {
74   getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
75 }
76 
getSizeInBytes()77 Optional<int64_t> ArrayType::getSizeInBytes() {
78   auto elementType = getElementType().cast<SPIRVType>();
79   Optional<int64_t> size = elementType.getSizeInBytes();
80   if (!size)
81     return llvm::None;
82   return (*size + getArrayStride()) * getNumElements();
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // CompositeType
87 //===----------------------------------------------------------------------===//
88 
classof(Type type)89 bool CompositeType::classof(Type type) {
90   if (auto vectorType = type.dyn_cast<VectorType>())
91     return isValid(vectorType);
92   return type
93       .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
94            spirv::RuntimeArrayType, spirv::StructType>();
95 }
96 
isValid(VectorType type)97 bool CompositeType::isValid(VectorType type) {
98   switch (type.getNumElements()) {
99   case 2:
100   case 3:
101   case 4:
102   case 8:
103   case 16:
104     break;
105   default:
106     return false;
107   }
108   return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
109 }
110 
getElementType(unsigned index) const111 Type CompositeType::getElementType(unsigned index) const {
112   return TypeSwitch<Type, Type>(*this)
113       .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
114           [](auto type) { return type.getElementType(); })
115       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
116       .Case<StructType>(
117           [index](StructType type) { return type.getElementType(index); })
118       .Default(
119           [](Type) -> Type { llvm_unreachable("invalid composite type"); });
120 }
121 
getNumElements() const122 unsigned CompositeType::getNumElements() const {
123   if (auto arrayType = dyn_cast<ArrayType>())
124     return arrayType.getNumElements();
125   if (auto matrixType = dyn_cast<MatrixType>())
126     return matrixType.getNumColumns();
127   if (auto structType = dyn_cast<StructType>())
128     return structType.getNumElements();
129   if (auto vectorType = dyn_cast<VectorType>())
130     return vectorType.getNumElements();
131   if (isa<CooperativeMatrixNVType>()) {
132     llvm_unreachable(
133         "invalid to query number of elements of spirv::CooperativeMatrix type");
134   }
135   if (isa<RuntimeArrayType>()) {
136     llvm_unreachable(
137         "invalid to query number of elements of spirv::RuntimeArray type");
138   }
139   llvm_unreachable("invalid composite type");
140 }
141 
hasCompileTimeKnownNumElements() const142 bool CompositeType::hasCompileTimeKnownNumElements() const {
143   return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
144 }
145 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)146 void CompositeType::getExtensions(
147     SPIRVType::ExtensionArrayRefVector &extensions,
148     Optional<StorageClass> storage) {
149   TypeSwitch<Type>(*this)
150       .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
151             StructType>(
152           [&](auto type) { type.getExtensions(extensions, storage); })
153       .Case<VectorType>([&](VectorType type) {
154         return type.getElementType().cast<ScalarType>().getExtensions(
155             extensions, storage);
156       })
157       .Default([](Type) { llvm_unreachable("invalid composite type"); });
158 }
159 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)160 void CompositeType::getCapabilities(
161     SPIRVType::CapabilityArrayRefVector &capabilities,
162     Optional<StorageClass> storage) {
163   TypeSwitch<Type>(*this)
164       .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
165             StructType>(
166           [&](auto type) { type.getCapabilities(capabilities, storage); })
167       .Case<VectorType>([&](VectorType type) {
168         auto vecSize = getNumElements();
169         if (vecSize == 8 || vecSize == 16) {
170           static const Capability caps[] = {Capability::Vector16};
171           ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
172           capabilities.push_back(ref);
173         }
174         return type.getElementType().cast<ScalarType>().getCapabilities(
175             capabilities, storage);
176       })
177       .Default([](Type) { llvm_unreachable("invalid composite type"); });
178 }
179 
getSizeInBytes()180 Optional<int64_t> CompositeType::getSizeInBytes() {
181   if (auto arrayType = dyn_cast<ArrayType>())
182     return arrayType.getSizeInBytes();
183   if (auto structType = dyn_cast<StructType>())
184     return structType.getSizeInBytes();
185   if (auto vectorType = dyn_cast<VectorType>()) {
186     Optional<int64_t> elementSize =
187         vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
188     if (!elementSize)
189       return llvm::None;
190     return *elementSize * vectorType.getNumElements();
191   }
192   return llvm::None;
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // CooperativeMatrixType
197 //===----------------------------------------------------------------------===//
198 
199 struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
200   using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
201 
202   static CooperativeMatrixTypeStorage *
constructspirv::detail::CooperativeMatrixTypeStorage203   construct(TypeStorageAllocator &allocator, const KeyTy &key) {
204     return new (allocator.allocate<CooperativeMatrixTypeStorage>())
205         CooperativeMatrixTypeStorage(key);
206   }
207 
operator ==spirv::detail::CooperativeMatrixTypeStorage208   bool operator==(const KeyTy &key) const {
209     return key == KeyTy(elementType, scope, rows, columns);
210   }
211 
CooperativeMatrixTypeStoragespirv::detail::CooperativeMatrixTypeStorage212   CooperativeMatrixTypeStorage(const KeyTy &key)
213       : elementType(std::get<0>(key)), rows(std::get<2>(key)),
214         columns(std::get<3>(key)), scope(std::get<1>(key)) {}
215 
216   Type elementType;
217   unsigned rows;
218   unsigned columns;
219   Scope scope;
220 };
221 
get(Type elementType,Scope scope,unsigned rows,unsigned columns)222 CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
223                                                      Scope scope, unsigned rows,
224                                                      unsigned columns) {
225   return Base::get(elementType.getContext(), elementType, scope, rows, columns);
226 }
227 
getElementType() const228 Type CooperativeMatrixNVType::getElementType() const {
229   return getImpl()->elementType;
230 }
231 
getScope() const232 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
233 
getRows() const234 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
235 
getColumns() const236 unsigned CooperativeMatrixNVType::getColumns() const {
237   return getImpl()->columns;
238 }
239 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)240 void CooperativeMatrixNVType::getExtensions(
241     SPIRVType::ExtensionArrayRefVector &extensions,
242     Optional<StorageClass> storage) {
243   getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
244   static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
245   ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
246   extensions.push_back(ref);
247 }
248 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)249 void CooperativeMatrixNVType::getCapabilities(
250     SPIRVType::CapabilityArrayRefVector &capabilities,
251     Optional<StorageClass> storage) {
252   getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
253   static const Capability caps[] = {Capability::CooperativeMatrixNV};
254   ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
255   capabilities.push_back(ref);
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // ImageType
260 //===----------------------------------------------------------------------===//
261 
262 template <typename T>
getNumBits()263 static constexpr unsigned getNumBits() {
264   return 0;
265 }
266 template <>
getNumBits()267 constexpr unsigned getNumBits<Dim>() {
268   static_assert((1 << 3) > getMaxEnumValForDim(),
269                 "Not enough bits to encode Dim value");
270   return 3;
271 }
272 template <>
getNumBits()273 constexpr unsigned getNumBits<ImageDepthInfo>() {
274   static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
275                 "Not enough bits to encode ImageDepthInfo value");
276   return 2;
277 }
278 template <>
getNumBits()279 constexpr unsigned getNumBits<ImageArrayedInfo>() {
280   static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
281                 "Not enough bits to encode ImageArrayedInfo value");
282   return 1;
283 }
284 template <>
getNumBits()285 constexpr unsigned getNumBits<ImageSamplingInfo>() {
286   static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
287                 "Not enough bits to encode ImageSamplingInfo value");
288   return 1;
289 }
290 template <>
getNumBits()291 constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
292   static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
293                 "Not enough bits to encode ImageSamplerUseInfo value");
294   return 2;
295 }
296 template <>
getNumBits()297 constexpr unsigned getNumBits<ImageFormat>() {
298   static_assert((1 << 6) > getMaxEnumValForImageFormat(),
299                 "Not enough bits to encode ImageFormat value");
300   return 6;
301 }
302 
303 struct spirv::detail::ImageTypeStorage : public TypeStorage {
304 public:
305   using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
306                            ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
307 
constructspirv::detail::ImageTypeStorage308   static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
309                                      const KeyTy &key) {
310     return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
311   }
312 
operator ==spirv::detail::ImageTypeStorage313   bool operator==(const KeyTy &key) const {
314     return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
315                         samplerUseInfo, format);
316   }
317 
ImageTypeStoragespirv::detail::ImageTypeStorage318   ImageTypeStorage(const KeyTy &key)
319       : elementType(std::get<0>(key)), dim(std::get<1>(key)),
320         depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
321         samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
322         format(std::get<6>(key)) {}
323 
324   Type elementType;
325   Dim dim : getNumBits<Dim>();
326   ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
327   ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
328   ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
329   ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
330   ImageFormat format : getNumBits<ImageFormat>();
331 };
332 
333 ImageType
get(std::tuple<Type,Dim,ImageDepthInfo,ImageArrayedInfo,ImageSamplingInfo,ImageSamplerUseInfo,ImageFormat> value)334 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
335                           ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
336                    value) {
337   return Base::get(std::get<0>(value).getContext(), value);
338 }
339 
getElementType() const340 Type ImageType::getElementType() const { return getImpl()->elementType; }
341 
getDim() const342 Dim ImageType::getDim() const { return getImpl()->dim; }
343 
getDepthInfo() const344 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
345 
getArrayedInfo() const346 ImageArrayedInfo ImageType::getArrayedInfo() const {
347   return getImpl()->arrayedInfo;
348 }
349 
getSamplingInfo() const350 ImageSamplingInfo ImageType::getSamplingInfo() const {
351   return getImpl()->samplingInfo;
352 }
353 
getSamplerUseInfo() const354 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
355   return getImpl()->samplerUseInfo;
356 }
357 
getImageFormat() const358 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
359 
getExtensions(SPIRVType::ExtensionArrayRefVector &,Optional<StorageClass>)360 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
361                               Optional<StorageClass>) {
362   // Image types do not require extra extensions thus far.
363 }
364 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass>)365 void ImageType::getCapabilities(
366     SPIRVType::CapabilityArrayRefVector &capabilities, Optional<StorageClass>) {
367   if (auto dimCaps = spirv::getCapabilities(getDim()))
368     capabilities.push_back(*dimCaps);
369 
370   if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
371     capabilities.push_back(*fmtCaps);
372 }
373 
374 //===----------------------------------------------------------------------===//
375 // PointerType
376 //===----------------------------------------------------------------------===//
377 
378 struct spirv::detail::PointerTypeStorage : public TypeStorage {
379   // (Type, StorageClass) as the key: Type stored in this struct, and
380   // StorageClass stored as TypeStorage's subclass data.
381   using KeyTy = std::pair<Type, StorageClass>;
382 
constructspirv::detail::PointerTypeStorage383   static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
384                                        const KeyTy &key) {
385     return new (allocator.allocate<PointerTypeStorage>())
386         PointerTypeStorage(key);
387   }
388 
operator ==spirv::detail::PointerTypeStorage389   bool operator==(const KeyTy &key) const {
390     return key == KeyTy(pointeeType, storageClass);
391   }
392 
PointerTypeStoragespirv::detail::PointerTypeStorage393   PointerTypeStorage(const KeyTy &key)
394       : pointeeType(key.first), storageClass(key.second) {}
395 
396   Type pointeeType;
397   StorageClass storageClass;
398 };
399 
get(Type pointeeType,StorageClass storageClass)400 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
401   return Base::get(pointeeType.getContext(), pointeeType, storageClass);
402 }
403 
getPointeeType() const404 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
405 
getStorageClass() const406 StorageClass PointerType::getStorageClass() const {
407   return getImpl()->storageClass;
408 }
409 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)410 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
411                                 Optional<StorageClass> storage) {
412   // Use this pointer type's storage class because this pointer indicates we are
413   // using the pointee type in that specific storage class.
414   getPointeeType().cast<SPIRVType>().getExtensions(extensions,
415                                                    getStorageClass());
416 
417   if (auto scExts = spirv::getExtensions(getStorageClass()))
418     extensions.push_back(*scExts);
419 }
420 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)421 void PointerType::getCapabilities(
422     SPIRVType::CapabilityArrayRefVector &capabilities,
423     Optional<StorageClass> storage) {
424   // Use this pointer type's storage class because this pointer indicates we are
425   // using the pointee type in that specific storage class.
426   getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
427                                                      getStorageClass());
428 
429   if (auto scCaps = spirv::getCapabilities(getStorageClass()))
430     capabilities.push_back(*scCaps);
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // RuntimeArrayType
435 //===----------------------------------------------------------------------===//
436 
437 struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
438   using KeyTy = std::pair<Type, unsigned>;
439 
constructspirv::detail::RuntimeArrayTypeStorage440   static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
441                                             const KeyTy &key) {
442     return new (allocator.allocate<RuntimeArrayTypeStorage>())
443         RuntimeArrayTypeStorage(key);
444   }
445 
operator ==spirv::detail::RuntimeArrayTypeStorage446   bool operator==(const KeyTy &key) const {
447     return key == KeyTy(elementType, stride);
448   }
449 
RuntimeArrayTypeStoragespirv::detail::RuntimeArrayTypeStorage450   RuntimeArrayTypeStorage(const KeyTy &key)
451       : elementType(key.first), stride(key.second) {}
452 
453   Type elementType;
454   unsigned stride;
455 };
456 
get(Type elementType)457 RuntimeArrayType RuntimeArrayType::get(Type elementType) {
458   return Base::get(elementType.getContext(), elementType, /*stride=*/0);
459 }
460 
get(Type elementType,unsigned stride)461 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
462   return Base::get(elementType.getContext(), elementType, stride);
463 }
464 
getElementType() const465 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
466 
getArrayStride() const467 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
468 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)469 void RuntimeArrayType::getExtensions(
470     SPIRVType::ExtensionArrayRefVector &extensions,
471     Optional<StorageClass> storage) {
472   getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
473 }
474 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)475 void RuntimeArrayType::getCapabilities(
476     SPIRVType::CapabilityArrayRefVector &capabilities,
477     Optional<StorageClass> storage) {
478   {
479     static const Capability caps[] = {Capability::Shader};
480     ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
481     capabilities.push_back(ref);
482   }
483   getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // ScalarType
488 //===----------------------------------------------------------------------===//
489 
classof(Type type)490 bool ScalarType::classof(Type type) {
491   if (auto floatType = type.dyn_cast<FloatType>()) {
492     return isValid(floatType);
493   }
494   if (auto intType = type.dyn_cast<IntegerType>()) {
495     return isValid(intType);
496   }
497   return false;
498 }
499 
isValid(FloatType type)500 bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
501 
isValid(IntegerType type)502 bool ScalarType::isValid(IntegerType type) {
503   switch (type.getWidth()) {
504   case 1:
505   case 8:
506   case 16:
507   case 32:
508   case 64:
509     return true;
510   default:
511     return false;
512   }
513 }
514 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)515 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
516                                Optional<StorageClass> storage) {
517   // 8- or 16-bit integer/floating-point numbers will require extra extensions
518   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
519   // SPV_KHR_8bit_storage for more details.
520   if (!storage)
521     return;
522 
523   switch (*storage) {
524   case StorageClass::PushConstant:
525   case StorageClass::StorageBuffer:
526   case StorageClass::Uniform:
527     if (getIntOrFloatBitWidth() == 8) {
528       static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
529       ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
530       extensions.push_back(ref);
531     }
532     LLVM_FALLTHROUGH;
533   case StorageClass::Input:
534   case StorageClass::Output:
535     if (getIntOrFloatBitWidth() == 16) {
536       static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
537       ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
538       extensions.push_back(ref);
539     }
540     break;
541   default:
542     break;
543   }
544 }
545 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)546 void ScalarType::getCapabilities(
547     SPIRVType::CapabilityArrayRefVector &capabilities,
548     Optional<StorageClass> storage) {
549   unsigned bitwidth = getIntOrFloatBitWidth();
550 
551   // 8- or 16-bit integer/floating-point numbers will require extra capabilities
552   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
553   // SPV_KHR_8bit_storage for more details.
554 
555 #define STORAGE_CASE(storage, cap8, cap16)                                     \
556   case StorageClass::storage: {                                                \
557     if (bitwidth == 8) {                                                       \
558       static const Capability caps[] = {Capability::cap8};                     \
559       ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
560       capabilities.push_back(ref);                                             \
561       return;                                                                  \
562     }                                                                          \
563     if (bitwidth == 16) {                                                      \
564       static const Capability caps[] = {Capability::cap16};                    \
565       ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
566       capabilities.push_back(ref);                                             \
567       return;                                                                  \
568     }                                                                          \
569     /* For 64-bit integers/floats, Int64/Float64 enables support for all */    \
570     /* storage classes. Fall through to the next section. */                   \
571   } break
572 
573   // This part only handles the cases where special bitwidths appearing in
574   // interface storage classes.
575   if (storage) {
576     switch (*storage) {
577       STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
578       STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
579                    StorageBuffer16BitAccess);
580       STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
581                    StorageUniform16);
582     case StorageClass::Input:
583     case StorageClass::Output: {
584       if (bitwidth == 16) {
585         static const Capability caps[] = {Capability::StorageInputOutput16};
586         ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
587         capabilities.push_back(ref);
588         return;
589       }
590       break;
591     }
592     default:
593       break;
594     }
595   }
596 #undef STORAGE_CASE
597 
598   // For other non-interface storage classes, require a different set of
599   // capabilities for special bitwidths.
600 
601 #define WIDTH_CASE(type, width)                                                \
602   case width: {                                                                \
603     static const Capability caps[] = {Capability::type##width};                \
604     ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));                \
605     capabilities.push_back(ref);                                               \
606   } break
607 
608   if (auto intType = dyn_cast<IntegerType>()) {
609     switch (bitwidth) {
610       WIDTH_CASE(Int, 8);
611       WIDTH_CASE(Int, 16);
612       WIDTH_CASE(Int, 64);
613     case 1:
614     case 32:
615       break;
616     default:
617       llvm_unreachable("invalid bitwidth to getCapabilities");
618     }
619   } else {
620     assert(isa<FloatType>());
621     switch (bitwidth) {
622       WIDTH_CASE(Float, 16);
623       WIDTH_CASE(Float, 64);
624     case 32:
625       break;
626     default:
627       llvm_unreachable("invalid bitwidth to getCapabilities");
628     }
629   }
630 
631 #undef WIDTH_CASE
632 }
633 
getSizeInBytes()634 Optional<int64_t> ScalarType::getSizeInBytes() {
635   auto bitWidth = getIntOrFloatBitWidth();
636   // According to the SPIR-V spec:
637   // "There is no physical size or bit pattern defined for values with boolean
638   // type. If they are stored (in conjunction with OpVariable), they can only
639   // be used with logical addressing operations, not physical, and only with
640   // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
641   // Private, Function, Input, and Output."
642   if (bitWidth == 1)
643     return llvm::None;
644   return bitWidth / 8;
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // SPIRVType
649 //===----------------------------------------------------------------------===//
650 
classof(Type type)651 bool SPIRVType::classof(Type type) {
652   // Allow SPIR-V dialect types
653   if (llvm::isa<SPIRVDialect>(type.getDialect()))
654     return true;
655   if (type.isa<ScalarType>())
656     return true;
657   if (auto vectorType = type.dyn_cast<VectorType>())
658     return CompositeType::isValid(vectorType);
659   return false;
660 }
661 
isScalarOrVector()662 bool SPIRVType::isScalarOrVector() {
663   return isIntOrFloat() || isa<VectorType>();
664 }
665 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)666 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
667                               Optional<StorageClass> storage) {
668   if (auto scalarType = dyn_cast<ScalarType>()) {
669     scalarType.getExtensions(extensions, storage);
670   } else if (auto compositeType = dyn_cast<CompositeType>()) {
671     compositeType.getExtensions(extensions, storage);
672   } else if (auto imageType = dyn_cast<ImageType>()) {
673     imageType.getExtensions(extensions, storage);
674   } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
675     sampledImageType.getExtensions(extensions, storage);
676   } else if (auto matrixType = dyn_cast<MatrixType>()) {
677     matrixType.getExtensions(extensions, storage);
678   } else if (auto ptrType = dyn_cast<PointerType>()) {
679     ptrType.getExtensions(extensions, storage);
680   } else {
681     llvm_unreachable("invalid SPIR-V Type to getExtensions");
682   }
683 }
684 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)685 void SPIRVType::getCapabilities(
686     SPIRVType::CapabilityArrayRefVector &capabilities,
687     Optional<StorageClass> storage) {
688   if (auto scalarType = dyn_cast<ScalarType>()) {
689     scalarType.getCapabilities(capabilities, storage);
690   } else if (auto compositeType = dyn_cast<CompositeType>()) {
691     compositeType.getCapabilities(capabilities, storage);
692   } else if (auto imageType = dyn_cast<ImageType>()) {
693     imageType.getCapabilities(capabilities, storage);
694   } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
695     sampledImageType.getCapabilities(capabilities, storage);
696   } else if (auto matrixType = dyn_cast<MatrixType>()) {
697     matrixType.getCapabilities(capabilities, storage);
698   } else if (auto ptrType = dyn_cast<PointerType>()) {
699     ptrType.getCapabilities(capabilities, storage);
700   } else {
701     llvm_unreachable("invalid SPIR-V Type to getCapabilities");
702   }
703 }
704 
getSizeInBytes()705 Optional<int64_t> SPIRVType::getSizeInBytes() {
706   if (auto scalarType = dyn_cast<ScalarType>())
707     return scalarType.getSizeInBytes();
708   if (auto compositeType = dyn_cast<CompositeType>())
709     return compositeType.getSizeInBytes();
710   return llvm::None;
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // SampledImageType
715 //===----------------------------------------------------------------------===//
716 struct spirv::detail::SampledImageTypeStorage : public TypeStorage {
717   using KeyTy = Type;
718 
SampledImageTypeStoragespirv::detail::SampledImageTypeStorage719   SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
720 
operator ==spirv::detail::SampledImageTypeStorage721   bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
722 
constructspirv::detail::SampledImageTypeStorage723   static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator,
724                                             const KeyTy &key) {
725     return new (allocator.allocate<SampledImageTypeStorage>())
726         SampledImageTypeStorage(key);
727   }
728 
729   Type imageType;
730 };
731 
get(Type imageType)732 SampledImageType SampledImageType::get(Type imageType) {
733   return Base::get(imageType.getContext(), imageType);
734 }
735 
736 SampledImageType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type imageType)737 SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
738                              Type imageType) {
739   return Base::getChecked(emitError, imageType.getContext(), imageType);
740 }
741 
getImageType() const742 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
743 
744 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type imageType)745 SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
746                          Type imageType) {
747   if (!imageType.isa<ImageType>())
748     return emitError() << "expected image type";
749 
750   return success();
751 }
752 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)753 void SampledImageType::getExtensions(
754     SPIRVType::ExtensionArrayRefVector &extensions,
755     Optional<StorageClass> storage) {
756   getImageType().cast<ImageType>().getExtensions(extensions, storage);
757 }
758 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)759 void SampledImageType::getCapabilities(
760     SPIRVType::CapabilityArrayRefVector &capabilities,
761     Optional<StorageClass> storage) {
762   getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // StructType
767 //===----------------------------------------------------------------------===//
768 
769 /// Type storage for SPIR-V structure types:
770 ///
771 /// Structures are uniqued using:
772 /// - for identified structs:
773 ///   - a string identifier;
774 /// - for literal structs:
775 ///   - a list of member types;
776 ///   - a list of member offset info;
777 ///   - a list of member decoration info.
778 ///
779 /// Identified structures only have a mutable component consisting of:
780 /// - a list of member types;
781 /// - a list of member offset info;
782 /// - a list of member decoration info.
783 struct spirv::detail::StructTypeStorage : public TypeStorage {
784   /// Construct a storage object for an identified struct type. A struct type
785   /// associated with such storage must call StructType::trySetBody(...) later
786   /// in order to mutate the storage object providing the actual content.
StructTypeStoragespirv::detail::StructTypeStorage787   StructTypeStorage(StringRef identifier)
788       : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
789         numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
790         identifier(identifier) {}
791 
792   /// Construct a storage object for a literal struct type. A struct type
793   /// associated with such storage is immutable.
StructTypeStoragespirv::detail::StructTypeStorage794   StructTypeStorage(
795       unsigned numMembers, Type const *memberTypes,
796       StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
797       StructType::MemberDecorationInfo const *memberDecorationsInfo)
798       : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
799         numMembers(numMembers), numMemberDecorations(numMemberDecorations),
800         memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()) {}
801 
802   /// A storage key is divided into 2 parts:
803   /// - for identified structs:
804   ///   - a StringRef representing the struct identifier;
805   /// - for literal structs:
806   ///   - an ArrayRef<Type> for member types;
807   ///   - an ArrayRef<StructType::OffsetInfo> for member offset info;
808   ///   - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
809   ///     info.
810   ///
811   /// An identified struct type is uniqued only by the first part (field 0)
812   /// of the key.
813   ///
814   /// A literal struct type is uniqued only by the second part (fields 1, 2, and
815   /// 3) of the key. The identifier field (field 0) must be empty.
816   using KeyTy =
817       std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
818                  ArrayRef<StructType::MemberDecorationInfo>>;
819 
820   /// For identified structs, return true if the given key contains the same
821   /// identifier.
822   ///
823   /// For literal structs, return true if the given key contains a matching list
824   /// of member types + offset info + decoration info.
operator ==spirv::detail::StructTypeStorage825   bool operator==(const KeyTy &key) const {
826     if (isIdentified()) {
827       // Identified types are uniqued by their identifier.
828       return getIdentifier() == std::get<0>(key);
829     }
830 
831     return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
832                         getMemberDecorationsInfo());
833   }
834 
835   /// If the given key contains a non-empty identifier, this method constructs
836   /// an identified struct and leaves the rest of the struct type data to be set
837   /// through a later call to StructType::trySetBody(...).
838   ///
839   /// If, on the other hand, the key contains an empty identifier, a literal
840   /// struct is constructed using the other fields of the key.
constructspirv::detail::StructTypeStorage841   static StructTypeStorage *construct(TypeStorageAllocator &allocator,
842                                       const KeyTy &key) {
843     StringRef keyIdentifier = std::get<0>(key);
844 
845     if (!keyIdentifier.empty()) {
846       StringRef identifier = allocator.copyInto(keyIdentifier);
847 
848       // Identified StructType body/members will be set through trySetBody(...)
849       // later.
850       return new (allocator.allocate<StructTypeStorage>())
851           StructTypeStorage(identifier);
852     }
853 
854     ArrayRef<Type> keyTypes = std::get<1>(key);
855 
856     // Copy the member type and layout information into the bump pointer
857     const Type *typesList = nullptr;
858     if (!keyTypes.empty()) {
859       typesList = allocator.copyInto(keyTypes).data();
860     }
861 
862     const StructType::OffsetInfo *offsetInfoList = nullptr;
863     if (!std::get<2>(key).empty()) {
864       ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
865       assert(keyOffsetInfo.size() == keyTypes.size() &&
866              "size of offset information must be same as the size of number of "
867              "elements");
868       offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
869     }
870 
871     const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
872     unsigned numMemberDecorations = 0;
873     if (!std::get<3>(key).empty()) {
874       auto keyMemberDecorations = std::get<3>(key);
875       numMemberDecorations = keyMemberDecorations.size();
876       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
877     }
878 
879     return new (allocator.allocate<StructTypeStorage>())
880         StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
881                           numMemberDecorations, memberDecorationList);
882   }
883 
getMemberTypesspirv::detail::StructTypeStorage884   ArrayRef<Type> getMemberTypes() const {
885     return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
886   }
887 
getOffsetInfospirv::detail::StructTypeStorage888   ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
889     if (offsetInfo) {
890       return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
891     }
892     return {};
893   }
894 
getMemberDecorationsInfospirv::detail::StructTypeStorage895   ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
896     if (memberDecorationsInfo) {
897       return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
898                                                         numMemberDecorations);
899     }
900     return {};
901   }
902 
getIdentifierspirv::detail::StructTypeStorage903   StringRef getIdentifier() const { return identifier; }
904 
isIdentifiedspirv::detail::StructTypeStorage905   bool isIdentified() const { return !identifier.empty(); }
906 
907   /// Sets the struct type content for identified structs. Calling this method
908   /// is only valid for identified structs.
909   ///
910   /// Fails under the following conditions:
911   /// - If called for a literal struct;
912   /// - If called for an identified struct whose body was set before (through a
913   /// call to this method) but with different contents from the passed
914   /// arguments.
mutatespirv::detail::StructTypeStorage915   LogicalResult mutate(
916       TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
917       ArrayRef<StructType::OffsetInfo> structOffsetInfo,
918       ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
919     if (!isIdentified())
920       return failure();
921 
922     if (memberTypesAndIsBodySet.getInt() &&
923         (getMemberTypes() != structMemberTypes ||
924          getOffsetInfo() != structOffsetInfo ||
925          getMemberDecorationsInfo() != structMemberDecorationInfo))
926       return failure();
927 
928     memberTypesAndIsBodySet.setInt(true);
929     numMembers = structMemberTypes.size();
930 
931     // Copy the member type and layout information into the bump pointer.
932     if (!structMemberTypes.empty())
933       memberTypesAndIsBodySet.setPointer(
934           allocator.copyInto(structMemberTypes).data());
935 
936     if (!structOffsetInfo.empty()) {
937       assert(structOffsetInfo.size() == structMemberTypes.size() &&
938              "size of offset information must be same as the size of number of "
939              "elements");
940       offsetInfo = allocator.copyInto(structOffsetInfo).data();
941     }
942 
943     if (!structMemberDecorationInfo.empty()) {
944       numMemberDecorations = structMemberDecorationInfo.size();
945       memberDecorationsInfo =
946           allocator.copyInto(structMemberDecorationInfo).data();
947     }
948 
949     return success();
950   }
951 
952   llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
953   StructType::OffsetInfo const *offsetInfo;
954   unsigned numMembers;
955   unsigned numMemberDecorations;
956   StructType::MemberDecorationInfo const *memberDecorationsInfo;
957   StringRef identifier;
958 };
959 
960 StructType
get(ArrayRef<Type> memberTypes,ArrayRef<StructType::OffsetInfo> offsetInfo,ArrayRef<StructType::MemberDecorationInfo> memberDecorations)961 StructType::get(ArrayRef<Type> memberTypes,
962                 ArrayRef<StructType::OffsetInfo> offsetInfo,
963                 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
964   assert(!memberTypes.empty() && "Struct needs at least one member type");
965   // Sort the decorations.
966   SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
967       memberDecorations.begin(), memberDecorations.end());
968   llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
969   return Base::get(memberTypes.vec().front().getContext(),
970                    /*identifier=*/StringRef(), memberTypes, offsetInfo,
971                    sortedDecorations);
972 }
973 
getIdentified(MLIRContext * context,StringRef identifier)974 StructType StructType::getIdentified(MLIRContext *context,
975                                      StringRef identifier) {
976   assert(!identifier.empty() &&
977          "StructType identifier must be non-empty string");
978 
979   return Base::get(context, identifier, ArrayRef<Type>(),
980                    ArrayRef<StructType::OffsetInfo>(),
981                    ArrayRef<StructType::MemberDecorationInfo>());
982 }
983 
getEmpty(MLIRContext * context,StringRef identifier)984 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
985   StructType newStructType = Base::get(
986       context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
987       ArrayRef<StructType::MemberDecorationInfo>());
988   // Set an empty body in case this is a identified struct.
989   if (newStructType.isIdentified() &&
990       failed(newStructType.trySetBody(
991           ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
992           ArrayRef<StructType::MemberDecorationInfo>())))
993     return StructType();
994 
995   return newStructType;
996 }
997 
getIdentifier() const998 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
999 
isIdentified() const1000 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1001 
getNumElements() const1002 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1003 
getElementType(unsigned index) const1004 Type StructType::getElementType(unsigned index) const {
1005   assert(getNumElements() > index && "member index out of range");
1006   return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1007 }
1008 
getElementTypes() const1009 StructType::ElementTypeRange StructType::getElementTypes() const {
1010   return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1011                           getNumElements());
1012 }
1013 
hasOffset() const1014 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1015 
getMemberOffset(unsigned index) const1016 uint64_t StructType::getMemberOffset(unsigned index) const {
1017   assert(getNumElements() > index && "member index out of range");
1018   return getImpl()->offsetInfo[index];
1019 }
1020 
getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo> & memberDecorations) const1021 void StructType::getMemberDecorations(
1022     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
1023     const {
1024   memberDecorations.clear();
1025   auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1026   memberDecorations.append(implMemberDecorations.begin(),
1027                            implMemberDecorations.end());
1028 }
1029 
getMemberDecorations(unsigned index,SmallVectorImpl<StructType::MemberDecorationInfo> & decorationsInfo) const1030 void StructType::getMemberDecorations(
1031     unsigned index,
1032     SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1033   assert(getNumElements() > index && "member index out of range");
1034   auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1035   decorationsInfo.clear();
1036   for (const auto &memberDecoration : memberDecorations) {
1037     if (memberDecoration.memberIndex == index) {
1038       decorationsInfo.push_back(memberDecoration);
1039     }
1040     if (memberDecoration.memberIndex > index) {
1041       // Early exit since the decorations are stored sorted.
1042       return;
1043     }
1044   }
1045 }
1046 
1047 LogicalResult
trySetBody(ArrayRef<Type> memberTypes,ArrayRef<OffsetInfo> offsetInfo,ArrayRef<MemberDecorationInfo> memberDecorations)1048 StructType::trySetBody(ArrayRef<Type> memberTypes,
1049                        ArrayRef<OffsetInfo> offsetInfo,
1050                        ArrayRef<MemberDecorationInfo> memberDecorations) {
1051   return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1052 }
1053 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1054 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1055                                Optional<StorageClass> storage) {
1056   for (Type elementType : getElementTypes())
1057     elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1058 }
1059 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1060 void StructType::getCapabilities(
1061     SPIRVType::CapabilityArrayRefVector &capabilities,
1062     Optional<StorageClass> storage) {
1063   for (Type elementType : getElementTypes())
1064     elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1065 }
1066 
hash_value(const StructType::MemberDecorationInfo & memberDecorationInfo)1067 llvm::hash_code spirv::hash_value(
1068     const StructType::MemberDecorationInfo &memberDecorationInfo) {
1069   return llvm::hash_combine(memberDecorationInfo.memberIndex,
1070                             memberDecorationInfo.decoration);
1071 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // MatrixType
1075 //===----------------------------------------------------------------------===//
1076 
1077 struct spirv::detail::MatrixTypeStorage : public TypeStorage {
MatrixTypeStoragespirv::detail::MatrixTypeStorage1078   MatrixTypeStorage(Type columnType, uint32_t columnCount)
1079       : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
1080 
1081   using KeyTy = std::tuple<Type, uint32_t>;
1082 
constructspirv::detail::MatrixTypeStorage1083   static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
1084                                       const KeyTy &key) {
1085 
1086     // Initialize the memory using placement new.
1087     return new (allocator.allocate<MatrixTypeStorage>())
1088         MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1089   }
1090 
operator ==spirv::detail::MatrixTypeStorage1091   bool operator==(const KeyTy &key) const {
1092     return key == KeyTy(columnType, columnCount);
1093   }
1094 
1095   Type columnType;
1096   const uint32_t columnCount;
1097 };
1098 
get(Type columnType,uint32_t columnCount)1099 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1100   return Base::get(columnType.getContext(), columnType, columnCount);
1101 }
1102 
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type columnType,uint32_t columnCount)1103 MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1104                                   Type columnType, uint32_t columnCount) {
1105   return Base::getChecked(emitError, columnType.getContext(), columnType,
1106                           columnCount);
1107 }
1108 
verify(function_ref<InFlightDiagnostic ()> emitError,Type columnType,uint32_t columnCount)1109 LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
1110                                  Type columnType, uint32_t columnCount) {
1111   if (columnCount < 2 || columnCount > 4)
1112     return emitError() << "matrix can have 2, 3, or 4 columns only";
1113 
1114   if (!isValidColumnType(columnType))
1115     return emitError() << "matrix columns must be vectors of floats";
1116 
1117   /// The underlying vectors (columns) must be of size 2, 3, or 4
1118   ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1119   if (columnShape.size() != 1)
1120     return emitError() << "matrix columns must be 1D vectors";
1121 
1122   if (columnShape[0] < 2 || columnShape[0] > 4)
1123     return emitError() << "matrix columns must be of size 2, 3, or 4";
1124 
1125   return success();
1126 }
1127 
1128 /// Returns true if the matrix elements are vectors of float elements
isValidColumnType(Type columnType)1129 bool MatrixType::isValidColumnType(Type columnType) {
1130   if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1131     if (vectorType.getElementType().isa<FloatType>())
1132       return true;
1133   }
1134   return false;
1135 }
1136 
getColumnType() const1137 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1138 
getElementType() const1139 Type MatrixType::getElementType() const {
1140   return getImpl()->columnType.cast<VectorType>().getElementType();
1141 }
1142 
getNumColumns() const1143 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1144 
getNumRows() const1145 unsigned MatrixType::getNumRows() const {
1146   return getImpl()->columnType.cast<VectorType>().getShape()[0];
1147 }
1148 
getNumElements() const1149 unsigned MatrixType::getNumElements() const {
1150   return (getImpl()->columnCount) * getNumRows();
1151 }
1152 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1153 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1154                                Optional<StorageClass> storage) {
1155   getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1156 }
1157 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1158 void MatrixType::getCapabilities(
1159     SPIRVType::CapabilityArrayRefVector &capabilities,
1160     Optional<StorageClass> storage) {
1161   {
1162     static const Capability caps[] = {Capability::Matrix};
1163     ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
1164     capabilities.push_back(ref);
1165   }
1166   // Add any capabilities associated with the underlying vectors (i.e., columns)
1167   getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
1168 }
1169 
1170 //===----------------------------------------------------------------------===//
1171 // SPIR-V Dialect
1172 //===----------------------------------------------------------------------===//
1173 
registerTypes()1174 void SPIRVDialect::registerTypes() {
1175   addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
1176            PointerType, RuntimeArrayType, SampledImageType, StructType>();
1177 }
1178