1 //===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_BUILTINTYPES_H
10 #define MLIR_IR_BUILTINTYPES_H
11 
12 #include "BuiltinAttributeInterfaces.h"
13 #include "SubElementInterfaces.h"
14 
15 namespace llvm {
16 class BitVector;
17 struct fltSemantics;
18 } // namespace llvm
19 
20 //===----------------------------------------------------------------------===//
21 // Tablegen Interface Declarations
22 //===----------------------------------------------------------------------===//
23 
24 #include "mlir/IR/BuiltinTypeInterfaces.h.inc"
25 
26 namespace mlir {
27 class AffineExpr;
28 class AffineMap;
29 class FloatType;
30 class IndexType;
31 class IntegerType;
32 class StringAttr;
33 class TypeRange;
34 
35 //===----------------------------------------------------------------------===//
36 // FloatType
37 //===----------------------------------------------------------------------===//
38 
39 class FloatType : public Type {
40 public:
41   using Type::Type;
42 
43   // Convenience factories.
44   static FloatType getBF16(MLIRContext *ctx);
45   static FloatType getF16(MLIRContext *ctx);
46   static FloatType getF32(MLIRContext *ctx);
47   static FloatType getF64(MLIRContext *ctx);
48   static FloatType getF80(MLIRContext *ctx);
49   static FloatType getF128(MLIRContext *ctx);
50 
51   /// Methods for support type inquiry through isa, cast, and dyn_cast.
52   static bool classof(Type type);
53 
54   /// Return the bitwidth of this float type.
55   unsigned getWidth();
56 
57   /// Return the width of the mantissa of this type.
58   unsigned getFPMantissaWidth();
59 
60   /// Get or create a new FloatType with bitwidth scaled by `scale`.
61   /// Return null if the scaled element type cannot be represented.
62   FloatType scaleElementBitwidth(unsigned scale);
63 
64   /// Return the floating semantics of this float type.
65   const llvm::fltSemantics &getFloatSemantics();
66 };
67 
68 //===----------------------------------------------------------------------===//
69 // TensorType
70 //===----------------------------------------------------------------------===//
71 
72 /// Tensor types represent multi-dimensional arrays, and have two variants:
73 /// RankedTensorType and UnrankedTensorType.
74 /// Note: This class attaches the ShapedType trait to act as a mixin to
75 ///       provide many useful utility functions. This inheritance has no effect
76 ///       on derived tensor types.
77 class TensorType : public Type, public ShapedType::Trait<TensorType> {
78 public:
79   using Type::Type;
80 
81   /// Returns the element type of this tensor type.
82   Type getElementType() const;
83 
84   /// Returns if this type is ranked, i.e. it has a known number of dimensions.
85   bool hasRank() const;
86 
87   /// Returns the shape of this tensor type.
88   ArrayRef<int64_t> getShape() const;
89 
90   /// Clone this type with the given shape and element type. If the
91   /// provided shape is `None`, the current shape of the type is used.
92   TensorType cloneWith(Optional<ArrayRef<int64_t>> shape,
93                        Type elementType) const;
94 
95   /// Return true if the specified element type is ok in a tensor.
96   static bool isValidElementType(Type type);
97 
98   /// Methods for support type inquiry through isa, cast, and dyn_cast.
99   static bool classof(Type type);
100 
101   /// Allow implicit conversion to ShapedType.
ShapedType()102   operator ShapedType() const { return cast<ShapedType>(); }
103 };
104 
105 //===----------------------------------------------------------------------===//
106 // BaseMemRefType
107 //===----------------------------------------------------------------------===//
108 
109 /// This class provides a shared interface for ranked and unranked memref types.
110 /// Note: This class attaches the ShapedType trait to act as a mixin to
111 ///       provide many useful utility functions. This inheritance has no effect
112 ///       on derived memref types.
113 class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
114 public:
115   using Type::Type;
116 
117   /// Returns the element type of this memref type.
118   Type getElementType() const;
119 
120   /// Returns if this type is ranked, i.e. it has a known number of dimensions.
121   bool hasRank() const;
122 
123   /// Returns the shape of this memref type.
124   ArrayRef<int64_t> getShape() const;
125 
126   /// Clone this type with the given shape and element type. If the
127   /// provided shape is `None`, the current shape of the type is used.
128   BaseMemRefType cloneWith(Optional<ArrayRef<int64_t>> shape,
129                            Type elementType) const;
130 
131   /// Return true if the specified element type is ok in a memref.
132   static bool isValidElementType(Type type);
133 
134   /// Methods for support type inquiry through isa, cast, and dyn_cast.
135   static bool classof(Type type);
136 
137   /// Returns the memory space in which data referred to by this memref resides.
138   Attribute getMemorySpace() const;
139 
140   /// [deprecated] Returns the memory space in old raw integer representation.
141   /// New `Attribute getMemorySpace()` method should be used instead.
142   unsigned getMemorySpaceAsInt() const;
143 
144   /// Allow implicit conversion to ShapedType.
ShapedType()145   operator ShapedType() const { return cast<ShapedType>(); }
146 };
147 
148 } // namespace mlir
149 
150 //===----------------------------------------------------------------------===//
151 // Tablegen Type Declarations
152 //===----------------------------------------------------------------------===//
153 
154 #define GET_TYPEDEF_CLASSES
155 #include "mlir/IR/BuiltinTypes.h.inc"
156 
157 namespace mlir {
158 
159 //===----------------------------------------------------------------------===//
160 // MemRefType
161 //===----------------------------------------------------------------------===//
162 
163 /// This is a builder type that keeps local references to arguments. Arguments
164 /// that are passed into the builder must outlive the builder.
165 class MemRefType::Builder {
166 public:
167   // Build from another MemRefType.
Builder(MemRefType other)168   explicit Builder(MemRefType other)
169       : shape(other.getShape()), elementType(other.getElementType()),
170         layout(other.getLayout()), memorySpace(other.getMemorySpace()) {}
171 
172   // Build from scratch.
Builder(ArrayRef<int64_t> shape,Type elementType)173   Builder(ArrayRef<int64_t> shape, Type elementType)
174       : shape(shape), elementType(elementType) {}
175 
setShape(ArrayRef<int64_t> newShape)176   Builder &setShape(ArrayRef<int64_t> newShape) {
177     shape = newShape;
178     return *this;
179   }
180 
setElementType(Type newElementType)181   Builder &setElementType(Type newElementType) {
182     elementType = newElementType;
183     return *this;
184   }
185 
setLayout(MemRefLayoutAttrInterface newLayout)186   Builder &setLayout(MemRefLayoutAttrInterface newLayout) {
187     layout = newLayout;
188     return *this;
189   }
190 
setMemorySpace(Attribute newMemorySpace)191   Builder &setMemorySpace(Attribute newMemorySpace) {
192     memorySpace = newMemorySpace;
193     return *this;
194   }
195 
196   // [deprecated] `setMemorySpace(Attribute)` should be used instead.
197   Builder &setMemorySpace(unsigned newMemorySpace);
198 
MemRefType()199   operator MemRefType() {
200     return MemRefType::get(shape, elementType, layout, memorySpace);
201   }
202 
203 private:
204   ArrayRef<int64_t> shape;
205   Type elementType;
206   MemRefLayoutAttrInterface layout;
207   Attribute memorySpace;
208 };
209 
210 //===----------------------------------------------------------------------===//
211 // RankedTensorType
212 //===----------------------------------------------------------------------===//
213 
214 /// This is a builder type that keeps local references to arguments. Arguments
215 /// that are passed into the builder must outlive the builder.
216 class RankedTensorType::Builder {
217 public:
218   /// Build from another RankedTensorType.
Builder(RankedTensorType other)219   explicit Builder(RankedTensorType other)
220       : shape(other.getShape()), elementType(other.getElementType()),
221         encoding(other.getEncoding()) {}
222 
223   /// Build from scratch.
Builder(ArrayRef<int64_t> shape,Type elementType,Attribute encoding)224   Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
225       : shape(shape), elementType(elementType), encoding(encoding) {}
226 
setShape(ArrayRef<int64_t> newShape)227   Builder &setShape(ArrayRef<int64_t> newShape) {
228     shape = newShape;
229     return *this;
230   }
231 
setElementType(Type newElementType)232   Builder &setElementType(Type newElementType) {
233     elementType = newElementType;
234     return *this;
235   }
236 
setEncoding(Attribute newEncoding)237   Builder &setEncoding(Attribute newEncoding) {
238     encoding = newEncoding;
239     return *this;
240   }
241 
242   /// Erase a dim from shape @pos.
dropDim(unsigned pos)243   Builder &dropDim(unsigned pos) {
244     assert(pos < shape.size() && "overflow");
245     if (storage.empty())
246       storage.append(shape.begin(), shape.end());
247     storage.erase(storage.begin() + pos);
248     shape = {storage.data(), storage.size()};
249     return *this;
250   }
251 
252   /// Insert a val into shape @pos.
insertDim(int64_t val,unsigned pos)253   Builder &insertDim(int64_t val, unsigned pos) {
254     assert(pos <= shape.size() && "overflow");
255     if (storage.empty())
256       storage.append(shape.begin(), shape.end());
257     storage.insert(storage.begin() + pos, val);
258     shape = {storage.data(), storage.size()};
259     return *this;
260   }
261 
RankedTensorType()262   operator RankedTensorType() {
263     return RankedTensorType::get(shape, elementType, encoding);
264   }
265 
266 private:
267   ArrayRef<int64_t> shape;
268   // Owning shape data for copy-on-write operations.
269   SmallVector<int64_t> storage;
270   Type elementType;
271   Attribute encoding;
272 };
273 
274 //===----------------------------------------------------------------------===//
275 // VectorType
276 //===----------------------------------------------------------------------===//
277 
278 /// This is a builder type that keeps local references to arguments. Arguments
279 /// that are passed into the builder must outlive the builder.
280 class VectorType::Builder {
281 public:
282   /// Build from another VectorType.
Builder(VectorType other)283   explicit Builder(VectorType other)
284       : shape(other.getShape()), elementType(other.getElementType()),
285         numScalableDims(other.getNumScalableDims()) {}
286 
287   /// Build from scratch.
288   Builder(ArrayRef<int64_t> shape, Type elementType,
289           unsigned numScalableDims = 0)
shape(shape)290       : shape(shape), elementType(elementType),
291         numScalableDims(numScalableDims) {}
292 
293   Builder &setShape(ArrayRef<int64_t> newShape,
294                     unsigned newNumScalableDims = 0) {
295     numScalableDims = newNumScalableDims;
296     shape = newShape;
297     return *this;
298   }
299 
setElementType(Type newElementType)300   Builder &setElementType(Type newElementType) {
301     elementType = newElementType;
302     return *this;
303   }
304 
305   /// Erase a dim from shape @pos.
dropDim(unsigned pos)306   Builder &dropDim(unsigned pos) {
307     assert(pos < shape.size() && "overflow");
308     if (pos >= shape.size() - numScalableDims)
309       numScalableDims--;
310     if (storage.empty())
311       storage.append(shape.begin(), shape.end());
312     storage.erase(storage.begin() + pos);
313     shape = {storage.data(), storage.size()};
314     return *this;
315   }
316 
317   /// In the particular case where the vector has a single dimension that we
318   /// drop, return the scalar element type.
319   // TODO: unify once we have a VectorType that supports 0-D.
Type()320   operator Type() {
321     if (shape.empty())
322       return elementType;
323     return VectorType::get(shape, elementType, numScalableDims);
324   }
325 
326 private:
327   ArrayRef<int64_t> shape;
328   // Owning shape data for copy-on-write operations.
329   SmallVector<int64_t> storage;
330   Type elementType;
331   unsigned numScalableDims;
332 };
333 
334 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
335 /// `originalShape` with some `1` entries erased, return the set of indices
336 /// that specifies which of the entries of `originalShape` are dropped to obtain
337 /// `reducedShape`. The returned mask can be applied as a projection to
338 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
339 /// which dimensions must be kept when e.g. compute MemRef strides under
340 /// rank-reducing operations. Return None if reducedShape cannot be obtained
341 /// by dropping only `1` entries in `originalShape`.
342 llvm::Optional<llvm::SmallDenseSet<unsigned>>
343 computeRankReductionMask(ArrayRef<int64_t> originalShape,
344                          ArrayRef<int64_t> reducedShape);
345 
346 /// Enum that captures information related to verifier error conditions on
347 /// slice insert/extract type of ops.
348 enum class SliceVerificationResult {
349   Success,
350   RankTooLarge,
351   SizeMismatch,
352   ElemTypeMismatch,
353   // Error codes to ops with a memory space and a layout annotation.
354   MemSpaceMismatch,
355   LayoutMismatch
356 };
357 
358 /// Check if `originalType` can be rank reduced to `candidateReducedType` type
359 /// by dropping some dimensions with static size `1`.
360 /// Return `SliceVerificationResult::Success` on success or an appropriate error
361 /// code.
362 SliceVerificationResult isRankReducedType(ShapedType originalType,
363                                           ShapedType candidateReducedType);
364 
365 //===----------------------------------------------------------------------===//
366 // Deferred Method Definitions
367 //===----------------------------------------------------------------------===//
368 
classof(Type type)369 inline bool BaseMemRefType::classof(Type type) {
370   return type.isa<MemRefType, UnrankedMemRefType>();
371 }
372 
isValidElementType(Type type)373 inline bool BaseMemRefType::isValidElementType(Type type) {
374   return type.isIntOrIndexOrFloat() ||
375          type.isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>() ||
376          type.isa<MemRefElementTypeInterface>();
377 }
378 
classof(Type type)379 inline bool FloatType::classof(Type type) {
380   return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
381                   Float80Type, Float128Type>();
382 }
383 
getBF16(MLIRContext * ctx)384 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
385   return BFloat16Type::get(ctx);
386 }
387 
getF16(MLIRContext * ctx)388 inline FloatType FloatType::getF16(MLIRContext *ctx) {
389   return Float16Type::get(ctx);
390 }
391 
getF32(MLIRContext * ctx)392 inline FloatType FloatType::getF32(MLIRContext *ctx) {
393   return Float32Type::get(ctx);
394 }
395 
getF64(MLIRContext * ctx)396 inline FloatType FloatType::getF64(MLIRContext *ctx) {
397   return Float64Type::get(ctx);
398 }
399 
getF80(MLIRContext * ctx)400 inline FloatType FloatType::getF80(MLIRContext *ctx) {
401   return Float80Type::get(ctx);
402 }
403 
getF128(MLIRContext * ctx)404 inline FloatType FloatType::getF128(MLIRContext *ctx) {
405   return Float128Type::get(ctx);
406 }
407 
classof(Type type)408 inline bool TensorType::classof(Type type) {
409   return type.isa<RankedTensorType, UnrankedTensorType>();
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // Type Utilities
414 //===----------------------------------------------------------------------===//
415 
416 /// Returns the strides of the MemRef if the layout map is in strided form.
417 /// MemRefs with a layout map in strided form include:
418 ///   1. empty or identity layout map, in which case the stride information is
419 ///      the canonical form computed from sizes;
420 ///   2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
421 ///      where K and ki's are constants or symbols.
422 ///
423 /// A stride specification is a list of integer values that are either static
424 /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the
425 /// distance in the number of elements between successive entries along a
426 /// particular dimension.
427 ///
428 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
429 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
430 /// distance between two consecutive elements along the outer dimension is `1`
431 /// and the distance between two consecutive elements along the inner dimension
432 /// is `64`.
433 ///
434 /// The convention is that the strides for dimensions d0, .. dn appear in
435 /// order to make indexing intuitive into the result.
436 LogicalResult getStridesAndOffset(MemRefType t,
437                                   SmallVectorImpl<int64_t> &strides,
438                                   int64_t &offset);
439 LogicalResult getStridesAndOffset(MemRefType t,
440                                   SmallVectorImpl<AffineExpr> &strides,
441                                   AffineExpr &offset);
442 
443 /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset()
444 /// represents a dynamic value), return the single result AffineMap which
445 /// represents the linearized strided layout map. Dimensions correspond to the
446 /// offset followed by the strides in order. Symbols are inserted for each
447 /// dynamic dimension in order. A stride cannot take value `0`.
448 ///
449 /// Examples:
450 /// =========
451 ///
452 ///   1. For offset: 0 strides: ?, ?, 1 return
453 ///         (i, j, k)[M, N]->(M * i + N * j + k)
454 ///
455 ///   2. For offset: 3 strides: 32, ?, 16 return
456 ///         (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k)
457 ///
458 ///   3. For offset: ? strides: ?, ?, ? return
459 ///         (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k)
460 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
461                                      MLIRContext *context);
462 
463 /// Return a version of `t` with identity layout if it can be determined
464 /// statically that the layout is the canonical contiguous strided layout.
465 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
466 /// `t` with simplified layout.
467 MemRefType canonicalizeStridedLayout(MemRefType t);
468 
469 /// Return a version of `t` with a layout that has all dynamic offset and
470 /// strides. This is used to erase the static layout.
471 MemRefType eraseStridedLayout(MemRefType t);
472 
473 /// Given MemRef `sizes` that are either static or dynamic, returns the
474 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
475 /// once a dynamic dimension is encountered, all canonical strides become
476 /// dynamic and need to be encoded with a different symbol.
477 /// For canonical strides expressions, the offset is always 0 and and fastest
478 /// varying stride is always `1`.
479 ///
480 /// Examples:
481 ///   - memref<3x4x5xf32> has canonical stride expression
482 ///         `20*exprs[0] + 5*exprs[1] + exprs[2]`.
483 ///   - memref<3x?x5xf32> has canonical stride expression
484 ///         `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
485 ///   - memref<3x4x?xf32> has canonical stride expression
486 ///         `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
487 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
488                                           ArrayRef<AffineExpr> exprs,
489                                           MLIRContext *context);
490 
491 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case
492 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
493 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
494                                           MLIRContext *context);
495 
496 /// Return true if the layout for `t` is compatible with strided semantics.
497 bool isStrided(MemRefType t);
498 
499 /// Return the layout map in strided linear layout AffineMap form.
500 /// Return null if the layout is not compatible with a strided layout.
501 AffineMap getStridedLinearLayoutMap(MemRefType t);
502 
503 /// Helper determining if a memref is static-shape and contiguous-row-major
504 /// layout, while still allowing for an arbitrary offset (any static or
505 /// dynamic value).
506 bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType);
507 
508 } // namespace mlir
509 
510 #endif // MLIR_IR_BUILTINTYPES_H
511