1 //===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
14 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
15 
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/TypeSupport.h"
21 #include "mlir/IR/Types.h"
22 
23 #include <tuple>
24 
25 namespace mlir {
26 namespace spirv {
27 
28 namespace detail {
29 struct ArrayTypeStorage;
30 struct CooperativeMatrixTypeStorage;
31 struct ImageTypeStorage;
32 struct MatrixTypeStorage;
33 struct PointerTypeStorage;
34 struct RuntimeArrayTypeStorage;
35 struct SampledImageTypeStorage;
36 struct StructTypeStorage;
37 
38 } // namespace detail
39 
40 // Base SPIR-V type for providing availability queries.
41 class SPIRVType : public Type {
42 public:
43   using Type::Type;
44 
45   static bool classof(Type type);
46 
47   bool isScalarOrVector();
48 
49   /// The extension requirements for each type are following the
50   /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
51   /// convention.
52   using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<Extension>>;
53 
54   /// Appends to `extensions` the extensions needed for this type to appear in
55   /// the given `storage` class. This method does not guarantee the uniqueness
56   /// of extensions; the same extension may be appended multiple times.
57   void getExtensions(ExtensionArrayRefVector &extensions,
58                      Optional<StorageClass> storage = llvm::None);
59 
60   /// The capability requirements for each type are following the
61   /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
62   /// convention.
63   using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<Capability>>;
64 
65   /// Appends to `capabilities` the capabilities needed for this type to appear
66   /// in the given `storage` class. This method does not guarantee the
67   /// uniqueness of capabilities; the same capability may be appended multiple
68   /// times.
69   void getCapabilities(CapabilityArrayRefVector &capabilities,
70                        Optional<StorageClass> storage = llvm::None);
71 
72   /// Returns the size in bytes for each type. If no size can be calculated,
73   /// returns `llvm::None`. Note that if the type has explicit layout, it is
74   /// also taken into account in calculation.
75   Optional<int64_t> getSizeInBytes();
76 };
77 
78 // SPIR-V scalar type: bool type, integer type, floating point type.
79 class ScalarType : public SPIRVType {
80 public:
81   using SPIRVType::SPIRVType;
82 
83   static bool classof(Type type);
84 
85   /// Returns true if the given integer type is valid for the SPIR-V dialect.
86   static bool isValid(FloatType);
87   /// Returns true if the given float type is valid for the SPIR-V dialect.
88   static bool isValid(IntegerType);
89 
90   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
91                      Optional<StorageClass> storage = llvm::None);
92   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
93                        Optional<StorageClass> storage = llvm::None);
94 
95   Optional<int64_t> getSizeInBytes();
96 };
97 
98 // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
99 class CompositeType : public SPIRVType {
100 public:
101   using SPIRVType::SPIRVType;
102 
103   static bool classof(Type type);
104 
105   /// Returns true if the given vector type is valid for the SPIR-V dialect.
106   static bool isValid(VectorType);
107 
108   /// Return the number of elements of the type. This should only be called if
109   /// hasCompileTimeKnownNumElements is true.
110   unsigned getNumElements() const;
111 
112   Type getElementType(unsigned) const;
113 
114   /// Return true if the number of elements is known at compile time and is not
115   /// implementation dependent.
116   bool hasCompileTimeKnownNumElements() const;
117 
118   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
119                      Optional<StorageClass> storage = llvm::None);
120   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
121                        Optional<StorageClass> storage = llvm::None);
122 
123   Optional<int64_t> getSizeInBytes();
124 };
125 
126 // SPIR-V array type
127 class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
128                                         detail::ArrayTypeStorage> {
129 public:
130   using Base::Base;
131 
132   static ArrayType get(Type elementType, unsigned elementCount);
133 
134   /// Returns an array type with the given stride in bytes.
135   static ArrayType get(Type elementType, unsigned elementCount,
136                        unsigned stride);
137 
138   unsigned getNumElements() const;
139 
140   Type getElementType() const;
141 
142   /// Returns the array stride in bytes. 0 means no stride decorated on this
143   /// type.
144   unsigned getArrayStride() const;
145 
146   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
147                      Optional<StorageClass> storage = llvm::None);
148   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
149                        Optional<StorageClass> storage = llvm::None);
150 
151   /// Returns the array size in bytes. Since array type may have an explicit
152   /// stride declaration (in bytes), we also include it in the calculation.
153   Optional<int64_t> getSizeInBytes();
154 };
155 
156 // SPIR-V image type
157 class ImageType
158     : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
159 public:
160   using Base::Base;
161 
162   static ImageType
163   get(Type elementType, Dim dim,
164       ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
165       ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
166       ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
167       ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
168       ImageFormat format = ImageFormat::Unknown) {
169     return ImageType::get(
170         std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
171                    ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
172             elementType, dim, depth, arrayed, samplingInfo, samplerUse,
173             format));
174   }
175 
176   static ImageType
177       get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
178                      ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
179 
180   Type getElementType() const;
181   Dim getDim() const;
182   ImageDepthInfo getDepthInfo() const;
183   ImageArrayedInfo getArrayedInfo() const;
184   ImageSamplingInfo getSamplingInfo() const;
185   ImageSamplerUseInfo getSamplerUseInfo() const;
186   ImageFormat getImageFormat() const;
187   // TODO: Add support for Access qualifier
188 
189   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
190                      Optional<StorageClass> storage = llvm::None);
191   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
192                        Optional<StorageClass> storage = llvm::None);
193 };
194 
195 // SPIR-V pointer type
196 class PointerType : public Type::TypeBase<PointerType, SPIRVType,
197                                           detail::PointerTypeStorage> {
198 public:
199   using Base::Base;
200 
201   static PointerType get(Type pointeeType, StorageClass storageClass);
202 
203   Type getPointeeType() const;
204 
205   StorageClass getStorageClass() const;
206 
207   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
208                      Optional<StorageClass> storage = llvm::None);
209   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
210                        Optional<StorageClass> storage = llvm::None);
211 };
212 
213 // SPIR-V run-time array type
214 class RuntimeArrayType
215     : public Type::TypeBase<RuntimeArrayType, SPIRVType,
216                             detail::RuntimeArrayTypeStorage> {
217 public:
218   using Base::Base;
219 
220   static RuntimeArrayType get(Type elementType);
221 
222   /// Returns a runtime array type with the given stride in bytes.
223   static RuntimeArrayType get(Type elementType, unsigned stride);
224 
225   Type getElementType() const;
226 
227   /// Returns the array stride in bytes. 0 means no stride decorated on this
228   /// type.
229   unsigned getArrayStride() const;
230 
231   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
232                      Optional<StorageClass> storage = llvm::None);
233   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
234                        Optional<StorageClass> storage = llvm::None);
235 };
236 
237 // SPIR-V sampled image type
238 class SampledImageType
239     : public Type::TypeBase<SampledImageType, SPIRVType,
240                             detail::SampledImageTypeStorage> {
241 public:
242   using Base::Base;
243 
244   static SampledImageType get(Type imageType);
245 
246   static SampledImageType
247   getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
248 
249   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
250                               Type imageType);
251 
252   Type getImageType() const;
253 
254   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
255                      Optional<spirv::StorageClass> storage = llvm::None);
256   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
257                        Optional<spirv::StorageClass> storage = llvm::None);
258 };
259 
260 /// SPIR-V struct type. Two kinds of struct types are supported:
261 /// - Literal: a literal struct type is uniqued by its fields (types + offset
262 /// info + decoration info).
263 /// - Identified: an indentified struct type is uniqued by its string identifier
264 /// (name). This is useful in representing recursive structs. For example, the
265 /// following C struct:
266 ///
267 /// struct A {
268 ///   A* next;
269 /// };
270 ///
271 /// would be represented in MLIR as:
272 ///
273 /// !spv.struct<A, (!spv.ptr<!spv.struct<A>, Generic>)>
274 ///
275 /// In the above, expressing recursive struct types is accomplished by giving a
276 /// recursive struct a unique identified and using that identifier in the struct
277 /// definition for recursive references.
278 class StructType
279     : public Type::TypeBase<StructType, CompositeType,
280                             detail::StructTypeStorage, TypeTrait::IsMutable> {
281 public:
282   using Base::Base;
283 
284   // Type for specifying the offset of the struct members
285   using OffsetInfo = uint32_t;
286 
287   // Type for specifying the decoration(s) on struct members
288   struct MemberDecorationInfo {
289     uint32_t memberIndex : 31;
290     uint32_t hasValue : 1;
291     Decoration decoration;
292     uint32_t decorationValue;
293 
MemberDecorationInfoMemberDecorationInfo294     MemberDecorationInfo(uint32_t index, uint32_t hasValue,
295                          Decoration decoration, uint32_t decorationValue)
296         : memberIndex(index), hasValue(hasValue), decoration(decoration),
297           decorationValue(decorationValue) {}
298 
299     bool operator==(const MemberDecorationInfo &other) const {
300       return (this->memberIndex == other.memberIndex) &&
301              (this->decoration == other.decoration) &&
302              (this->decorationValue == other.decorationValue);
303     }
304 
305     bool operator<(const MemberDecorationInfo &other) const {
306       return this->memberIndex < other.memberIndex ||
307              (this->memberIndex == other.memberIndex &&
308               static_cast<uint32_t>(this->decoration) <
309                   static_cast<uint32_t>(other.decoration));
310     }
311   };
312 
313   /// Construct a literal StructType with at least one member.
314   static StructType get(ArrayRef<Type> memberTypes,
315                         ArrayRef<OffsetInfo> offsetInfo = {},
316                         ArrayRef<MemberDecorationInfo> memberDecorations = {});
317 
318   /// Construct an identified StructType. This creates a StructType whose body
319   /// (member types, offset info, and decorations) is not set yet. A call to
320   /// StructType::trySetBody(...) must follow when the StructType contents are
321   /// available (e.g. parsed or deserialized).
322   ///
323   /// Note: If another thread creates (or had already created) a struct with the
324   /// same identifier, that struct will be returned as a result.
325   static StructType getIdentified(MLIRContext *context, StringRef identifier);
326 
327   /// Construct a (possibly identified) StructType with no members.
328   ///
329   /// Note: this method might fail in a multi-threaded setup if another thread
330   /// created an identified struct with the same identifier but with different
331   /// contents before returning. In which case, an empty (default-constructed)
332   /// StructType is returned.
333   static StructType getEmpty(MLIRContext *context, StringRef identifier = "");
334 
335   /// For literal structs, return an empty string.
336   /// For identified structs, return the struct's identifier.
337   StringRef getIdentifier() const;
338 
339   /// Returns true if the StructType is identified.
340   bool isIdentified() const;
341 
342   unsigned getNumElements() const;
343 
344   Type getElementType(unsigned) const;
345 
346   /// Range class for element types.
347   class ElementTypeRange
348       : public ::llvm::detail::indexed_accessor_range_base<
349             ElementTypeRange, const Type *, Type, Type, Type> {
350   private:
351     using RangeBaseT::RangeBaseT;
352 
353     /// See `llvm::detail::indexed_accessor_range_base` for details.
offset_base(const Type * object,ptrdiff_t index)354     static const Type *offset_base(const Type *object, ptrdiff_t index) {
355       return object + index;
356     }
357     /// See `llvm::detail::indexed_accessor_range_base` for details.
dereference_iterator(const Type * object,ptrdiff_t index)358     static Type dereference_iterator(const Type *object, ptrdiff_t index) {
359       return object[index];
360     }
361 
362     /// Allow base class access to `offset_base` and `dereference_iterator`.
363     friend RangeBaseT;
364   };
365 
366   ElementTypeRange getElementTypes() const;
367 
368   bool hasOffset() const;
369 
370   uint64_t getMemberOffset(unsigned) const;
371 
372   // Returns in `memberDecorations` the Decorations (apart from Offset)
373   // associated with all members of the StructType.
374   void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo>
375                                 &memberDecorations) const;
376 
377   // Returns in `decorationsInfo` all the Decorations (apart from Offset)
378   // associated with the `i`-th member of the StructType.
379   void getMemberDecorations(
380       unsigned i,
381       SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
382 
383   /// Sets the contents of an incomplete identified StructType. This method must
384   /// be called only for identified StructTypes and it must be called only once
385   /// per instance. Otherwise, failure() is returned.
386   LogicalResult
387   trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
388              ArrayRef<MemberDecorationInfo> memberDecorations = {});
389 
390   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
391                      Optional<StorageClass> storage = llvm::None);
392   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
393                        Optional<StorageClass> storage = llvm::None);
394 };
395 
396 llvm::hash_code
397 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
398 
399 // SPIR-V cooperative matrix type
400 class CooperativeMatrixNVType
401     : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
402                             detail::CooperativeMatrixTypeStorage> {
403 public:
404   using Base::Base;
405 
406   static CooperativeMatrixNVType get(Type elementType, Scope scope,
407                                      unsigned rows, unsigned columns);
408   Type getElementType() const;
409 
410   /// Return the scope of the cooperative matrix.
411   Scope getScope() const;
412   /// return the number of rows of the matrix.
413   unsigned getRows() const;
414   /// return the number of columns of the matrix.
415   unsigned getColumns() const;
416 
417   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
418                      Optional<StorageClass> storage = llvm::None);
419   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
420                        Optional<StorageClass> storage = llvm::None);
421 };
422 
423 // SPIR-V matrix type
424 class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
425                                          detail::MatrixTypeStorage> {
426 public:
427   using Base::Base;
428 
429   static MatrixType get(Type columnType, uint32_t columnCount);
430 
431   static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
432                                Type columnType, uint32_t columnCount);
433 
434   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
435                               Type columnType, uint32_t columnCount);
436 
437   /// Returns true if the matrix elements are vectors of float elements.
438   static bool isValidColumnType(Type columnType);
439 
440   Type getColumnType() const;
441 
442   /// Returns the number of rows.
443   unsigned getNumRows() const;
444 
445   /// Returns the number of columns.
446   unsigned getNumColumns() const;
447 
448   /// Returns total number of elements (rows*columns).
449   unsigned getNumElements() const;
450 
451   /// Returns the elements' type (i.e, single element type).
452   Type getElementType() const;
453 
454   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
455                      Optional<StorageClass> storage = llvm::None);
456   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
457                        Optional<StorageClass> storage = llvm::None);
458 };
459 
460 } // namespace spirv
461 } // namespace mlir
462 
463 #endif // MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
464