1 //===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares the types in the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_ 14 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_ 15 16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Diagnostics.h" 19 #include "mlir/IR/Location.h" 20 #include "mlir/IR/TypeSupport.h" 21 #include "mlir/IR/Types.h" 22 23 #include <tuple> 24 25 namespace mlir { 26 namespace spirv { 27 28 namespace detail { 29 struct ArrayTypeStorage; 30 struct CooperativeMatrixTypeStorage; 31 struct ImageTypeStorage; 32 struct MatrixTypeStorage; 33 struct PointerTypeStorage; 34 struct RuntimeArrayTypeStorage; 35 struct SampledImageTypeStorage; 36 struct StructTypeStorage; 37 38 } // namespace detail 39 40 // Base SPIR-V type for providing availability queries. 41 class SPIRVType : public Type { 42 public: 43 using Type::Type; 44 45 static bool classof(Type type); 46 47 bool isScalarOrVector(); 48 49 /// The extension requirements for each type are following the 50 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 51 /// convention. 52 using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<Extension>>; 53 54 /// Appends to `extensions` the extensions needed for this type to appear in 55 /// the given `storage` class. This method does not guarantee the uniqueness 56 /// of extensions; the same extension may be appended multiple times. 57 void getExtensions(ExtensionArrayRefVector &extensions, 58 Optional<StorageClass> storage = llvm::None); 59 60 /// The capability requirements for each type are following the 61 /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D)) 62 /// convention. 63 using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<Capability>>; 64 65 /// Appends to `capabilities` the capabilities needed for this type to appear 66 /// in the given `storage` class. This method does not guarantee the 67 /// uniqueness of capabilities; the same capability may be appended multiple 68 /// times. 69 void getCapabilities(CapabilityArrayRefVector &capabilities, 70 Optional<StorageClass> storage = llvm::None); 71 72 /// Returns the size in bytes for each type. If no size can be calculated, 73 /// returns `llvm::None`. Note that if the type has explicit layout, it is 74 /// also taken into account in calculation. 75 Optional<int64_t> getSizeInBytes(); 76 }; 77 78 // SPIR-V scalar type: bool type, integer type, floating point type. 79 class ScalarType : public SPIRVType { 80 public: 81 using SPIRVType::SPIRVType; 82 83 static bool classof(Type type); 84 85 /// Returns true if the given integer type is valid for the SPIR-V dialect. 86 static bool isValid(FloatType); 87 /// Returns true if the given float type is valid for the SPIR-V dialect. 88 static bool isValid(IntegerType); 89 90 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 91 Optional<StorageClass> storage = llvm::None); 92 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 93 Optional<StorageClass> storage = llvm::None); 94 95 Optional<int64_t> getSizeInBytes(); 96 }; 97 98 // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. 99 class CompositeType : public SPIRVType { 100 public: 101 using SPIRVType::SPIRVType; 102 103 static bool classof(Type type); 104 105 /// Returns true if the given vector type is valid for the SPIR-V dialect. 106 static bool isValid(VectorType); 107 108 /// Return the number of elements of the type. This should only be called if 109 /// hasCompileTimeKnownNumElements is true. 110 unsigned getNumElements() const; 111 112 Type getElementType(unsigned) const; 113 114 /// Return true if the number of elements is known at compile time and is not 115 /// implementation dependent. 116 bool hasCompileTimeKnownNumElements() const; 117 118 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 119 Optional<StorageClass> storage = llvm::None); 120 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 121 Optional<StorageClass> storage = llvm::None); 122 123 Optional<int64_t> getSizeInBytes(); 124 }; 125 126 // SPIR-V array type 127 class ArrayType : public Type::TypeBase<ArrayType, CompositeType, 128 detail::ArrayTypeStorage> { 129 public: 130 using Base::Base; 131 132 static ArrayType get(Type elementType, unsigned elementCount); 133 134 /// Returns an array type with the given stride in bytes. 135 static ArrayType get(Type elementType, unsigned elementCount, 136 unsigned stride); 137 138 unsigned getNumElements() const; 139 140 Type getElementType() const; 141 142 /// Returns the array stride in bytes. 0 means no stride decorated on this 143 /// type. 144 unsigned getArrayStride() const; 145 146 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 147 Optional<StorageClass> storage = llvm::None); 148 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 149 Optional<StorageClass> storage = llvm::None); 150 151 /// Returns the array size in bytes. Since array type may have an explicit 152 /// stride declaration (in bytes), we also include it in the calculation. 153 Optional<int64_t> getSizeInBytes(); 154 }; 155 156 // SPIR-V image type 157 class ImageType 158 : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> { 159 public: 160 using Base::Base; 161 162 static ImageType 163 get(Type elementType, Dim dim, 164 ImageDepthInfo depth = ImageDepthInfo::DepthUnknown, 165 ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed, 166 ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled, 167 ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown, 168 ImageFormat format = ImageFormat::Unknown) { 169 return ImageType::get( 170 std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 171 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>( 172 elementType, dim, depth, arrayed, samplingInfo, samplerUse, 173 format)); 174 } 175 176 static ImageType 177 get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 178 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>); 179 180 Type getElementType() const; 181 Dim getDim() const; 182 ImageDepthInfo getDepthInfo() const; 183 ImageArrayedInfo getArrayedInfo() const; 184 ImageSamplingInfo getSamplingInfo() const; 185 ImageSamplerUseInfo getSamplerUseInfo() const; 186 ImageFormat getImageFormat() const; 187 // TODO: Add support for Access qualifier 188 189 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 190 Optional<StorageClass> storage = llvm::None); 191 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 192 Optional<StorageClass> storage = llvm::None); 193 }; 194 195 // SPIR-V pointer type 196 class PointerType : public Type::TypeBase<PointerType, SPIRVType, 197 detail::PointerTypeStorage> { 198 public: 199 using Base::Base; 200 201 static PointerType get(Type pointeeType, StorageClass storageClass); 202 203 Type getPointeeType() const; 204 205 StorageClass getStorageClass() const; 206 207 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 208 Optional<StorageClass> storage = llvm::None); 209 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 210 Optional<StorageClass> storage = llvm::None); 211 }; 212 213 // SPIR-V run-time array type 214 class RuntimeArrayType 215 : public Type::TypeBase<RuntimeArrayType, SPIRVType, 216 detail::RuntimeArrayTypeStorage> { 217 public: 218 using Base::Base; 219 220 static RuntimeArrayType get(Type elementType); 221 222 /// Returns a runtime array type with the given stride in bytes. 223 static RuntimeArrayType get(Type elementType, unsigned stride); 224 225 Type getElementType() const; 226 227 /// Returns the array stride in bytes. 0 means no stride decorated on this 228 /// type. 229 unsigned getArrayStride() const; 230 231 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 232 Optional<StorageClass> storage = llvm::None); 233 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 234 Optional<StorageClass> storage = llvm::None); 235 }; 236 237 // SPIR-V sampled image type 238 class SampledImageType 239 : public Type::TypeBase<SampledImageType, SPIRVType, 240 detail::SampledImageTypeStorage> { 241 public: 242 using Base::Base; 243 244 static SampledImageType get(Type imageType); 245 246 static SampledImageType 247 getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType); 248 249 static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError, 250 Type imageType); 251 252 Type getImageType() const; 253 254 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 255 Optional<spirv::StorageClass> storage = llvm::None); 256 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 257 Optional<spirv::StorageClass> storage = llvm::None); 258 }; 259 260 /// SPIR-V struct type. Two kinds of struct types are supported: 261 /// - Literal: a literal struct type is uniqued by its fields (types + offset 262 /// info + decoration info). 263 /// - Identified: an indentified struct type is uniqued by its string identifier 264 /// (name). This is useful in representing recursive structs. For example, the 265 /// following C struct: 266 /// 267 /// struct A { 268 /// A* next; 269 /// }; 270 /// 271 /// would be represented in MLIR as: 272 /// 273 /// !spv.struct<A, (!spv.ptr<!spv.struct<A>, Generic>)> 274 /// 275 /// In the above, expressing recursive struct types is accomplished by giving a 276 /// recursive struct a unique identified and using that identifier in the struct 277 /// definition for recursive references. 278 class StructType 279 : public Type::TypeBase<StructType, CompositeType, 280 detail::StructTypeStorage, TypeTrait::IsMutable> { 281 public: 282 using Base::Base; 283 284 // Type for specifying the offset of the struct members 285 using OffsetInfo = uint32_t; 286 287 // Type for specifying the decoration(s) on struct members 288 struct MemberDecorationInfo { 289 uint32_t memberIndex : 31; 290 uint32_t hasValue : 1; 291 Decoration decoration; 292 uint32_t decorationValue; 293 MemberDecorationInfoMemberDecorationInfo294 MemberDecorationInfo(uint32_t index, uint32_t hasValue, 295 Decoration decoration, uint32_t decorationValue) 296 : memberIndex(index), hasValue(hasValue), decoration(decoration), 297 decorationValue(decorationValue) {} 298 299 bool operator==(const MemberDecorationInfo &other) const { 300 return (this->memberIndex == other.memberIndex) && 301 (this->decoration == other.decoration) && 302 (this->decorationValue == other.decorationValue); 303 } 304 305 bool operator<(const MemberDecorationInfo &other) const { 306 return this->memberIndex < other.memberIndex || 307 (this->memberIndex == other.memberIndex && 308 static_cast<uint32_t>(this->decoration) < 309 static_cast<uint32_t>(other.decoration)); 310 } 311 }; 312 313 /// Construct a literal StructType with at least one member. 314 static StructType get(ArrayRef<Type> memberTypes, 315 ArrayRef<OffsetInfo> offsetInfo = {}, 316 ArrayRef<MemberDecorationInfo> memberDecorations = {}); 317 318 /// Construct an identified StructType. This creates a StructType whose body 319 /// (member types, offset info, and decorations) is not set yet. A call to 320 /// StructType::trySetBody(...) must follow when the StructType contents are 321 /// available (e.g. parsed or deserialized). 322 /// 323 /// Note: If another thread creates (or had already created) a struct with the 324 /// same identifier, that struct will be returned as a result. 325 static StructType getIdentified(MLIRContext *context, StringRef identifier); 326 327 /// Construct a (possibly identified) StructType with no members. 328 /// 329 /// Note: this method might fail in a multi-threaded setup if another thread 330 /// created an identified struct with the same identifier but with different 331 /// contents before returning. In which case, an empty (default-constructed) 332 /// StructType is returned. 333 static StructType getEmpty(MLIRContext *context, StringRef identifier = ""); 334 335 /// For literal structs, return an empty string. 336 /// For identified structs, return the struct's identifier. 337 StringRef getIdentifier() const; 338 339 /// Returns true if the StructType is identified. 340 bool isIdentified() const; 341 342 unsigned getNumElements() const; 343 344 Type getElementType(unsigned) const; 345 346 /// Range class for element types. 347 class ElementTypeRange 348 : public ::llvm::detail::indexed_accessor_range_base< 349 ElementTypeRange, const Type *, Type, Type, Type> { 350 private: 351 using RangeBaseT::RangeBaseT; 352 353 /// See `llvm::detail::indexed_accessor_range_base` for details. offset_base(const Type * object,ptrdiff_t index)354 static const Type *offset_base(const Type *object, ptrdiff_t index) { 355 return object + index; 356 } 357 /// See `llvm::detail::indexed_accessor_range_base` for details. dereference_iterator(const Type * object,ptrdiff_t index)358 static Type dereference_iterator(const Type *object, ptrdiff_t index) { 359 return object[index]; 360 } 361 362 /// Allow base class access to `offset_base` and `dereference_iterator`. 363 friend RangeBaseT; 364 }; 365 366 ElementTypeRange getElementTypes() const; 367 368 bool hasOffset() const; 369 370 uint64_t getMemberOffset(unsigned) const; 371 372 // Returns in `memberDecorations` the Decorations (apart from Offset) 373 // associated with all members of the StructType. 374 void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo> 375 &memberDecorations) const; 376 377 // Returns in `decorationsInfo` all the Decorations (apart from Offset) 378 // associated with the `i`-th member of the StructType. 379 void getMemberDecorations( 380 unsigned i, 381 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const; 382 383 /// Sets the contents of an incomplete identified StructType. This method must 384 /// be called only for identified StructTypes and it must be called only once 385 /// per instance. Otherwise, failure() is returned. 386 LogicalResult 387 trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {}, 388 ArrayRef<MemberDecorationInfo> memberDecorations = {}); 389 390 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 391 Optional<StorageClass> storage = llvm::None); 392 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 393 Optional<StorageClass> storage = llvm::None); 394 }; 395 396 llvm::hash_code 397 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); 398 399 // SPIR-V cooperative matrix type 400 class CooperativeMatrixNVType 401 : public Type::TypeBase<CooperativeMatrixNVType, CompositeType, 402 detail::CooperativeMatrixTypeStorage> { 403 public: 404 using Base::Base; 405 406 static CooperativeMatrixNVType get(Type elementType, Scope scope, 407 unsigned rows, unsigned columns); 408 Type getElementType() const; 409 410 /// Return the scope of the cooperative matrix. 411 Scope getScope() const; 412 /// return the number of rows of the matrix. 413 unsigned getRows() const; 414 /// return the number of columns of the matrix. 415 unsigned getColumns() const; 416 417 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 418 Optional<StorageClass> storage = llvm::None); 419 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 420 Optional<StorageClass> storage = llvm::None); 421 }; 422 423 // SPIR-V matrix type 424 class MatrixType : public Type::TypeBase<MatrixType, CompositeType, 425 detail::MatrixTypeStorage> { 426 public: 427 using Base::Base; 428 429 static MatrixType get(Type columnType, uint32_t columnCount); 430 431 static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError, 432 Type columnType, uint32_t columnCount); 433 434 static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError, 435 Type columnType, uint32_t columnCount); 436 437 /// Returns true if the matrix elements are vectors of float elements. 438 static bool isValidColumnType(Type columnType); 439 440 Type getColumnType() const; 441 442 /// Returns the number of rows. 443 unsigned getNumRows() const; 444 445 /// Returns the number of columns. 446 unsigned getNumColumns() const; 447 448 /// Returns total number of elements (rows*columns). 449 unsigned getNumElements() const; 450 451 /// Returns the elements' type (i.e, single element type). 452 Type getElementType() const; 453 454 void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, 455 Optional<StorageClass> storage = llvm::None); 456 void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, 457 Optional<StorageClass> storage = llvm::None); 458 }; 459 460 } // namespace spirv 461 } // namespace mlir 462 463 #endif // MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_ 464