1 //===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_IR_BUILTINTYPES_H
10 #define MLIR_IR_BUILTINTYPES_H
11
12 #include "BuiltinAttributeInterfaces.h"
13 #include "SubElementInterfaces.h"
14
15 namespace llvm {
16 class BitVector;
17 struct fltSemantics;
18 } // namespace llvm
19
20 //===----------------------------------------------------------------------===//
21 // Tablegen Interface Declarations
22 //===----------------------------------------------------------------------===//
23
24 #include "mlir/IR/BuiltinTypeInterfaces.h.inc"
25
26 namespace mlir {
27 class AffineExpr;
28 class AffineMap;
29 class FloatType;
30 class IndexType;
31 class IntegerType;
32 class StringAttr;
33 class TypeRange;
34
35 //===----------------------------------------------------------------------===//
36 // FloatType
37 //===----------------------------------------------------------------------===//
38
39 class FloatType : public Type {
40 public:
41 using Type::Type;
42
43 // Convenience factories.
44 static FloatType getBF16(MLIRContext *ctx);
45 static FloatType getF16(MLIRContext *ctx);
46 static FloatType getF32(MLIRContext *ctx);
47 static FloatType getF64(MLIRContext *ctx);
48 static FloatType getF80(MLIRContext *ctx);
49 static FloatType getF128(MLIRContext *ctx);
50
51 /// Methods for support type inquiry through isa, cast, and dyn_cast.
52 static bool classof(Type type);
53
54 /// Return the bitwidth of this float type.
55 unsigned getWidth();
56
57 /// Return the width of the mantissa of this type.
58 unsigned getFPMantissaWidth();
59
60 /// Get or create a new FloatType with bitwidth scaled by `scale`.
61 /// Return null if the scaled element type cannot be represented.
62 FloatType scaleElementBitwidth(unsigned scale);
63
64 /// Return the floating semantics of this float type.
65 const llvm::fltSemantics &getFloatSemantics();
66 };
67
68 //===----------------------------------------------------------------------===//
69 // TensorType
70 //===----------------------------------------------------------------------===//
71
72 /// Tensor types represent multi-dimensional arrays, and have two variants:
73 /// RankedTensorType and UnrankedTensorType.
74 /// Note: This class attaches the ShapedType trait to act as a mixin to
75 /// provide many useful utility functions. This inheritance has no effect
76 /// on derived tensor types.
77 class TensorType : public Type, public ShapedType::Trait<TensorType> {
78 public:
79 using Type::Type;
80
81 /// Returns the element type of this tensor type.
82 Type getElementType() const;
83
84 /// Returns if this type is ranked, i.e. it has a known number of dimensions.
85 bool hasRank() const;
86
87 /// Returns the shape of this tensor type.
88 ArrayRef<int64_t> getShape() const;
89
90 /// Clone this type with the given shape and element type. If the
91 /// provided shape is `None`, the current shape of the type is used.
92 TensorType cloneWith(Optional<ArrayRef<int64_t>> shape,
93 Type elementType) const;
94
95 /// Return true if the specified element type is ok in a tensor.
96 static bool isValidElementType(Type type);
97
98 /// Methods for support type inquiry through isa, cast, and dyn_cast.
99 static bool classof(Type type);
100
101 /// Allow implicit conversion to ShapedType.
ShapedType()102 operator ShapedType() const { return cast<ShapedType>(); }
103 };
104
105 //===----------------------------------------------------------------------===//
106 // BaseMemRefType
107 //===----------------------------------------------------------------------===//
108
109 /// This class provides a shared interface for ranked and unranked memref types.
110 /// Note: This class attaches the ShapedType trait to act as a mixin to
111 /// provide many useful utility functions. This inheritance has no effect
112 /// on derived memref types.
113 class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
114 public:
115 using Type::Type;
116
117 /// Returns the element type of this memref type.
118 Type getElementType() const;
119
120 /// Returns if this type is ranked, i.e. it has a known number of dimensions.
121 bool hasRank() const;
122
123 /// Returns the shape of this memref type.
124 ArrayRef<int64_t> getShape() const;
125
126 /// Clone this type with the given shape and element type. If the
127 /// provided shape is `None`, the current shape of the type is used.
128 BaseMemRefType cloneWith(Optional<ArrayRef<int64_t>> shape,
129 Type elementType) const;
130
131 /// Return true if the specified element type is ok in a memref.
132 static bool isValidElementType(Type type);
133
134 /// Methods for support type inquiry through isa, cast, and dyn_cast.
135 static bool classof(Type type);
136
137 /// Returns the memory space in which data referred to by this memref resides.
138 Attribute getMemorySpace() const;
139
140 /// [deprecated] Returns the memory space in old raw integer representation.
141 /// New `Attribute getMemorySpace()` method should be used instead.
142 unsigned getMemorySpaceAsInt() const;
143
144 /// Allow implicit conversion to ShapedType.
ShapedType()145 operator ShapedType() const { return cast<ShapedType>(); }
146 };
147
148 } // namespace mlir
149
150 //===----------------------------------------------------------------------===//
151 // Tablegen Type Declarations
152 //===----------------------------------------------------------------------===//
153
154 #define GET_TYPEDEF_CLASSES
155 #include "mlir/IR/BuiltinTypes.h.inc"
156
157 namespace mlir {
158
159 //===----------------------------------------------------------------------===//
160 // MemRefType
161 //===----------------------------------------------------------------------===//
162
163 /// This is a builder type that keeps local references to arguments. Arguments
164 /// that are passed into the builder must outlive the builder.
165 class MemRefType::Builder {
166 public:
167 // Build from another MemRefType.
Builder(MemRefType other)168 explicit Builder(MemRefType other)
169 : shape(other.getShape()), elementType(other.getElementType()),
170 layout(other.getLayout()), memorySpace(other.getMemorySpace()) {}
171
172 // Build from scratch.
Builder(ArrayRef<int64_t> shape,Type elementType)173 Builder(ArrayRef<int64_t> shape, Type elementType)
174 : shape(shape), elementType(elementType) {}
175
setShape(ArrayRef<int64_t> newShape)176 Builder &setShape(ArrayRef<int64_t> newShape) {
177 shape = newShape;
178 return *this;
179 }
180
setElementType(Type newElementType)181 Builder &setElementType(Type newElementType) {
182 elementType = newElementType;
183 return *this;
184 }
185
setLayout(MemRefLayoutAttrInterface newLayout)186 Builder &setLayout(MemRefLayoutAttrInterface newLayout) {
187 layout = newLayout;
188 return *this;
189 }
190
setMemorySpace(Attribute newMemorySpace)191 Builder &setMemorySpace(Attribute newMemorySpace) {
192 memorySpace = newMemorySpace;
193 return *this;
194 }
195
196 // [deprecated] `setMemorySpace(Attribute)` should be used instead.
197 Builder &setMemorySpace(unsigned newMemorySpace);
198
MemRefType()199 operator MemRefType() {
200 return MemRefType::get(shape, elementType, layout, memorySpace);
201 }
202
203 private:
204 ArrayRef<int64_t> shape;
205 Type elementType;
206 MemRefLayoutAttrInterface layout;
207 Attribute memorySpace;
208 };
209
210 //===----------------------------------------------------------------------===//
211 // RankedTensorType
212 //===----------------------------------------------------------------------===//
213
214 /// This is a builder type that keeps local references to arguments. Arguments
215 /// that are passed into the builder must outlive the builder.
216 class RankedTensorType::Builder {
217 public:
218 /// Build from another RankedTensorType.
Builder(RankedTensorType other)219 explicit Builder(RankedTensorType other)
220 : shape(other.getShape()), elementType(other.getElementType()),
221 encoding(other.getEncoding()) {}
222
223 /// Build from scratch.
Builder(ArrayRef<int64_t> shape,Type elementType,Attribute encoding)224 Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
225 : shape(shape), elementType(elementType), encoding(encoding) {}
226
setShape(ArrayRef<int64_t> newShape)227 Builder &setShape(ArrayRef<int64_t> newShape) {
228 shape = newShape;
229 return *this;
230 }
231
setElementType(Type newElementType)232 Builder &setElementType(Type newElementType) {
233 elementType = newElementType;
234 return *this;
235 }
236
setEncoding(Attribute newEncoding)237 Builder &setEncoding(Attribute newEncoding) {
238 encoding = newEncoding;
239 return *this;
240 }
241
242 /// Erase a dim from shape @pos.
dropDim(unsigned pos)243 Builder &dropDim(unsigned pos) {
244 assert(pos < shape.size() && "overflow");
245 if (storage.empty())
246 storage.append(shape.begin(), shape.end());
247 storage.erase(storage.begin() + pos);
248 shape = {storage.data(), storage.size()};
249 return *this;
250 }
251
252 /// Insert a val into shape @pos.
insertDim(int64_t val,unsigned pos)253 Builder &insertDim(int64_t val, unsigned pos) {
254 assert(pos <= shape.size() && "overflow");
255 if (storage.empty())
256 storage.append(shape.begin(), shape.end());
257 storage.insert(storage.begin() + pos, val);
258 shape = {storage.data(), storage.size()};
259 return *this;
260 }
261
RankedTensorType()262 operator RankedTensorType() {
263 return RankedTensorType::get(shape, elementType, encoding);
264 }
265
266 private:
267 ArrayRef<int64_t> shape;
268 // Owning shape data for copy-on-write operations.
269 SmallVector<int64_t> storage;
270 Type elementType;
271 Attribute encoding;
272 };
273
274 //===----------------------------------------------------------------------===//
275 // VectorType
276 //===----------------------------------------------------------------------===//
277
278 /// This is a builder type that keeps local references to arguments. Arguments
279 /// that are passed into the builder must outlive the builder.
280 class VectorType::Builder {
281 public:
282 /// Build from another VectorType.
Builder(VectorType other)283 explicit Builder(VectorType other)
284 : shape(other.getShape()), elementType(other.getElementType()),
285 numScalableDims(other.getNumScalableDims()) {}
286
287 /// Build from scratch.
288 Builder(ArrayRef<int64_t> shape, Type elementType,
289 unsigned numScalableDims = 0)
shape(shape)290 : shape(shape), elementType(elementType),
291 numScalableDims(numScalableDims) {}
292
293 Builder &setShape(ArrayRef<int64_t> newShape,
294 unsigned newNumScalableDims = 0) {
295 numScalableDims = newNumScalableDims;
296 shape = newShape;
297 return *this;
298 }
299
setElementType(Type newElementType)300 Builder &setElementType(Type newElementType) {
301 elementType = newElementType;
302 return *this;
303 }
304
305 /// Erase a dim from shape @pos.
dropDim(unsigned pos)306 Builder &dropDim(unsigned pos) {
307 assert(pos < shape.size() && "overflow");
308 if (pos >= shape.size() - numScalableDims)
309 numScalableDims--;
310 if (storage.empty())
311 storage.append(shape.begin(), shape.end());
312 storage.erase(storage.begin() + pos);
313 shape = {storage.data(), storage.size()};
314 return *this;
315 }
316
317 /// In the particular case where the vector has a single dimension that we
318 /// drop, return the scalar element type.
319 // TODO: unify once we have a VectorType that supports 0-D.
Type()320 operator Type() {
321 if (shape.empty())
322 return elementType;
323 return VectorType::get(shape, elementType, numScalableDims);
324 }
325
326 private:
327 ArrayRef<int64_t> shape;
328 // Owning shape data for copy-on-write operations.
329 SmallVector<int64_t> storage;
330 Type elementType;
331 unsigned numScalableDims;
332 };
333
334 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
335 /// `originalShape` with some `1` entries erased, return the set of indices
336 /// that specifies which of the entries of `originalShape` are dropped to obtain
337 /// `reducedShape`. The returned mask can be applied as a projection to
338 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
339 /// which dimensions must be kept when e.g. compute MemRef strides under
340 /// rank-reducing operations. Return None if reducedShape cannot be obtained
341 /// by dropping only `1` entries in `originalShape`.
342 llvm::Optional<llvm::SmallDenseSet<unsigned>>
343 computeRankReductionMask(ArrayRef<int64_t> originalShape,
344 ArrayRef<int64_t> reducedShape);
345
346 /// Enum that captures information related to verifier error conditions on
347 /// slice insert/extract type of ops.
348 enum class SliceVerificationResult {
349 Success,
350 RankTooLarge,
351 SizeMismatch,
352 ElemTypeMismatch,
353 // Error codes to ops with a memory space and a layout annotation.
354 MemSpaceMismatch,
355 LayoutMismatch
356 };
357
358 /// Check if `originalType` can be rank reduced to `candidateReducedType` type
359 /// by dropping some dimensions with static size `1`.
360 /// Return `SliceVerificationResult::Success` on success or an appropriate error
361 /// code.
362 SliceVerificationResult isRankReducedType(ShapedType originalType,
363 ShapedType candidateReducedType);
364
365 //===----------------------------------------------------------------------===//
366 // Deferred Method Definitions
367 //===----------------------------------------------------------------------===//
368
classof(Type type)369 inline bool BaseMemRefType::classof(Type type) {
370 return type.isa<MemRefType, UnrankedMemRefType>();
371 }
372
isValidElementType(Type type)373 inline bool BaseMemRefType::isValidElementType(Type type) {
374 return type.isIntOrIndexOrFloat() ||
375 type.isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>() ||
376 type.isa<MemRefElementTypeInterface>();
377 }
378
classof(Type type)379 inline bool FloatType::classof(Type type) {
380 return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
381 Float80Type, Float128Type>();
382 }
383
getBF16(MLIRContext * ctx)384 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
385 return BFloat16Type::get(ctx);
386 }
387
getF16(MLIRContext * ctx)388 inline FloatType FloatType::getF16(MLIRContext *ctx) {
389 return Float16Type::get(ctx);
390 }
391
getF32(MLIRContext * ctx)392 inline FloatType FloatType::getF32(MLIRContext *ctx) {
393 return Float32Type::get(ctx);
394 }
395
getF64(MLIRContext * ctx)396 inline FloatType FloatType::getF64(MLIRContext *ctx) {
397 return Float64Type::get(ctx);
398 }
399
getF80(MLIRContext * ctx)400 inline FloatType FloatType::getF80(MLIRContext *ctx) {
401 return Float80Type::get(ctx);
402 }
403
getF128(MLIRContext * ctx)404 inline FloatType FloatType::getF128(MLIRContext *ctx) {
405 return Float128Type::get(ctx);
406 }
407
classof(Type type)408 inline bool TensorType::classof(Type type) {
409 return type.isa<RankedTensorType, UnrankedTensorType>();
410 }
411
412 //===----------------------------------------------------------------------===//
413 // Type Utilities
414 //===----------------------------------------------------------------------===//
415
416 /// Returns the strides of the MemRef if the layout map is in strided form.
417 /// MemRefs with a layout map in strided form include:
418 /// 1. empty or identity layout map, in which case the stride information is
419 /// the canonical form computed from sizes;
420 /// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
421 /// where K and ki's are constants or symbols.
422 ///
423 /// A stride specification is a list of integer values that are either static
424 /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the
425 /// distance in the number of elements between successive entries along a
426 /// particular dimension.
427 ///
428 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
429 /// non-contiguous memory region of `42` by `16` `f32` elements in which the
430 /// distance between two consecutive elements along the outer dimension is `1`
431 /// and the distance between two consecutive elements along the inner dimension
432 /// is `64`.
433 ///
434 /// The convention is that the strides for dimensions d0, .. dn appear in
435 /// order to make indexing intuitive into the result.
436 LogicalResult getStridesAndOffset(MemRefType t,
437 SmallVectorImpl<int64_t> &strides,
438 int64_t &offset);
439 LogicalResult getStridesAndOffset(MemRefType t,
440 SmallVectorImpl<AffineExpr> &strides,
441 AffineExpr &offset);
442
443 /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset()
444 /// represents a dynamic value), return the single result AffineMap which
445 /// represents the linearized strided layout map. Dimensions correspond to the
446 /// offset followed by the strides in order. Symbols are inserted for each
447 /// dynamic dimension in order. A stride cannot take value `0`.
448 ///
449 /// Examples:
450 /// =========
451 ///
452 /// 1. For offset: 0 strides: ?, ?, 1 return
453 /// (i, j, k)[M, N]->(M * i + N * j + k)
454 ///
455 /// 2. For offset: 3 strides: 32, ?, 16 return
456 /// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k)
457 ///
458 /// 3. For offset: ? strides: ?, ?, ? return
459 /// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k)
460 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
461 MLIRContext *context);
462
463 /// Return a version of `t` with identity layout if it can be determined
464 /// statically that the layout is the canonical contiguous strided layout.
465 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
466 /// `t` with simplified layout.
467 MemRefType canonicalizeStridedLayout(MemRefType t);
468
469 /// Return a version of `t` with a layout that has all dynamic offset and
470 /// strides. This is used to erase the static layout.
471 MemRefType eraseStridedLayout(MemRefType t);
472
473 /// Given MemRef `sizes` that are either static or dynamic, returns the
474 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
475 /// once a dynamic dimension is encountered, all canonical strides become
476 /// dynamic and need to be encoded with a different symbol.
477 /// For canonical strides expressions, the offset is always 0 and and fastest
478 /// varying stride is always `1`.
479 ///
480 /// Examples:
481 /// - memref<3x4x5xf32> has canonical stride expression
482 /// `20*exprs[0] + 5*exprs[1] + exprs[2]`.
483 /// - memref<3x?x5xf32> has canonical stride expression
484 /// `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
485 /// - memref<3x4x?xf32> has canonical stride expression
486 /// `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
487 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
488 ArrayRef<AffineExpr> exprs,
489 MLIRContext *context);
490
491 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case
492 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
493 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
494 MLIRContext *context);
495
496 /// Return true if the layout for `t` is compatible with strided semantics.
497 bool isStrided(MemRefType t);
498
499 /// Return the layout map in strided linear layout AffineMap form.
500 /// Return null if the layout is not compatible with a strided layout.
501 AffineMap getStridedLinearLayoutMap(MemRefType t);
502
503 /// Helper determining if a memref is static-shape and contiguous-row-major
504 /// layout, while still allowing for an arbitrary offset (any static or
505 /// dynamic value).
506 bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType);
507
508 } // namespace mlir
509
510 #endif // MLIR_IR_BUILTINTYPES_H
511