1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 //===----------------------------------------------------------------------===//
25 /// Tablegen Type Definitions
26 //===----------------------------------------------------------------------===//
27 
28 #define GET_TYPEDEF_CLASSES
29 #include "mlir/IR/BuiltinTypes.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 /// ComplexType
33 //===----------------------------------------------------------------------===//
34 
35 /// Verify the construction of an integer type.
36 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
37                                   Type elementType) {
38   if (!elementType.isIntOrFloat())
39     return emitError() << "invalid element type for complex";
40   return success();
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // Integer Type
45 //===----------------------------------------------------------------------===//
46 
47 // static constexpr must have a definition (until in C++17 and inline variable).
48 constexpr unsigned IntegerType::kMaxWidth;
49 
50 /// Verify the construction of an integer type.
51 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
52                                   unsigned width,
53                                   SignednessSemantics signedness) {
54   if (width > IntegerType::kMaxWidth) {
55     return emitError() << "integer bitwidth is limited to "
56                        << IntegerType::kMaxWidth << " bits";
57   }
58   return success();
59 }
60 
61 unsigned IntegerType::getWidth() const { return getImpl()->width; }
62 
63 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
64   return getImpl()->signedness;
65 }
66 
67 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
68   if (!scale)
69     return IntegerType();
70   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // Float Type
75 //===----------------------------------------------------------------------===//
76 
77 unsigned FloatType::getWidth() {
78   if (isa<Float16Type, BFloat16Type>())
79     return 16;
80   if (isa<Float32Type>())
81     return 32;
82   if (isa<Float64Type>())
83     return 64;
84   if (isa<Float80Type>())
85     return 80;
86   if (isa<Float128Type>())
87     return 128;
88   llvm_unreachable("unexpected float type");
89 }
90 
91 /// Returns the floating semantics for the given type.
92 const llvm::fltSemantics &FloatType::getFloatSemantics() {
93   if (isa<BFloat16Type>())
94     return APFloat::BFloat();
95   if (isa<Float16Type>())
96     return APFloat::IEEEhalf();
97   if (isa<Float32Type>())
98     return APFloat::IEEEsingle();
99   if (isa<Float64Type>())
100     return APFloat::IEEEdouble();
101   if (isa<Float80Type>())
102     return APFloat::x87DoubleExtended();
103   if (isa<Float128Type>())
104     return APFloat::IEEEquad();
105   llvm_unreachable("non-floating point type used");
106 }
107 
108 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
109   if (!scale)
110     return FloatType();
111   MLIRContext *ctx = getContext();
112   if (isF16() || isBF16()) {
113     if (scale == 2)
114       return FloatType::getF32(ctx);
115     if (scale == 4)
116       return FloatType::getF64(ctx);
117   }
118   if (isF32())
119     if (scale == 2)
120       return FloatType::getF64(ctx);
121   return FloatType();
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // FunctionType
126 //===----------------------------------------------------------------------===//
127 
128 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
129 
130 ArrayRef<Type> FunctionType::getInputs() const {
131   return getImpl()->getInputs();
132 }
133 
134 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
135 
136 ArrayRef<Type> FunctionType::getResults() const {
137   return getImpl()->getResults();
138 }
139 
140 /// Helper to call a callback once on each index in the range
141 /// [0, `totalIndices`), *except* for the indices given in `indices`.
142 /// `indices` is allowed to have duplicates and can be in any order.
143 inline void iterateIndicesExcept(unsigned totalIndices,
144                                  ArrayRef<unsigned> indices,
145                                  function_ref<void(unsigned)> callback) {
146   llvm::BitVector skipIndices(totalIndices);
147   for (unsigned i : indices)
148     skipIndices.set(i);
149 
150   for (unsigned i = 0; i < totalIndices; ++i)
151     if (!skipIndices.test(i))
152       callback(i);
153 }
154 
155 /// Returns a new function type without the specified arguments and results.
156 FunctionType
157 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
158                                        ArrayRef<unsigned> resultIndices) {
159   ArrayRef<Type> newInputTypes = getInputs();
160   SmallVector<Type, 4> newInputTypesBuffer;
161   if (!argIndices.empty()) {
162     unsigned originalNumArgs = getNumInputs();
163     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
164       newInputTypesBuffer.emplace_back(getInput(i));
165     });
166     newInputTypes = newInputTypesBuffer;
167   }
168 
169   ArrayRef<Type> newResultTypes = getResults();
170   SmallVector<Type, 4> newResultTypesBuffer;
171   if (!resultIndices.empty()) {
172     unsigned originalNumResults = getNumResults();
173     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
174       newResultTypesBuffer.emplace_back(getResult(i));
175     });
176     newResultTypes = newResultTypesBuffer;
177   }
178 
179   return get(getContext(), newInputTypes, newResultTypes);
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // OpaqueType
184 //===----------------------------------------------------------------------===//
185 
186 /// Verify the construction of an opaque type.
187 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
188                                  Identifier dialect, StringRef typeData) {
189   if (!Dialect::isValidNamespace(dialect.strref()))
190     return emitError() << "invalid dialect namespace '" << dialect << "'";
191   return success();
192 }
193 
194 //===----------------------------------------------------------------------===//
195 // ShapedType
196 //===----------------------------------------------------------------------===//
197 constexpr int64_t ShapedType::kDynamicSize;
198 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
199 
200 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
201   if (auto other = dyn_cast<MemRefType>()) {
202     MemRefType::Builder b(other);
203     b.setShape(shape);
204     b.setElementType(elementType);
205     return b;
206   }
207 
208   if (auto other = dyn_cast<UnrankedMemRefType>()) {
209     MemRefType::Builder b(shape, elementType);
210     b.setMemorySpace(other.getMemorySpaceAsInt());
211     return b;
212   }
213 
214   if (isa<TensorType>())
215     return RankedTensorType::get(shape, elementType);
216 
217   if (isa<VectorType>())
218     return VectorType::get(shape, elementType);
219 
220   llvm_unreachable("Unhandled ShapedType clone case");
221 }
222 
223 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
224   if (auto other = dyn_cast<MemRefType>()) {
225     MemRefType::Builder b(other);
226     b.setShape(shape);
227     return b;
228   }
229 
230   if (auto other = dyn_cast<UnrankedMemRefType>()) {
231     MemRefType::Builder b(shape, other.getElementType());
232     b.setShape(shape);
233     b.setMemorySpace(other.getMemorySpaceAsInt());
234     return b;
235   }
236 
237   if (isa<TensorType>())
238     return RankedTensorType::get(shape, getElementType());
239 
240   if (isa<VectorType>())
241     return VectorType::get(shape, getElementType());
242 
243   llvm_unreachable("Unhandled ShapedType clone case");
244 }
245 
246 ShapedType ShapedType::clone(Type elementType) {
247   if (auto other = dyn_cast<MemRefType>()) {
248     MemRefType::Builder b(other);
249     b.setElementType(elementType);
250     return b;
251   }
252 
253   if (auto other = dyn_cast<UnrankedMemRefType>()) {
254     return UnrankedMemRefType::get(elementType, other.getMemorySpaceAsInt());
255   }
256 
257   if (isa<TensorType>()) {
258     if (hasRank())
259       return RankedTensorType::get(getShape(), elementType);
260     return UnrankedTensorType::get(elementType);
261   }
262 
263   if (isa<VectorType>())
264     return VectorType::get(getShape(), elementType);
265 
266   llvm_unreachable("Unhandled ShapedType clone hit");
267 }
268 
269 Type ShapedType::getElementType() const {
270   return TypeSwitch<Type, Type>(*this)
271       .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
272             UnrankedMemRefType>([](auto ty) { return ty.getElementType(); });
273 }
274 
275 unsigned ShapedType::getElementTypeBitWidth() const {
276   return getElementType().getIntOrFloatBitWidth();
277 }
278 
279 int64_t ShapedType::getNumElements() const {
280   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
281   auto shape = getShape();
282   int64_t num = 1;
283   for (auto dim : shape) {
284     num *= dim;
285     assert(num >= 0 && "integer overflow in element count computation");
286   }
287   return num;
288 }
289 
290 int64_t ShapedType::getRank() const {
291   assert(hasRank() && "cannot query rank of unranked shaped type");
292   return getShape().size();
293 }
294 
295 bool ShapedType::hasRank() const {
296   return !isa<UnrankedMemRefType, UnrankedTensorType>();
297 }
298 
299 int64_t ShapedType::getDimSize(unsigned idx) const {
300   assert(idx < getRank() && "invalid index for shaped type");
301   return getShape()[idx];
302 }
303 
304 bool ShapedType::isDynamicDim(unsigned idx) const {
305   assert(idx < getRank() && "invalid index for shaped type");
306   return isDynamic(getShape()[idx]);
307 }
308 
309 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
310   assert(index < getRank() && "invalid index");
311   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
312   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
313 }
314 
315 /// Get the number of bits require to store a value of the given shaped type.
316 /// Compute the value recursively since tensors are allowed to have vectors as
317 /// elements.
318 int64_t ShapedType::getSizeInBits() const {
319   assert(hasStaticShape() &&
320          "cannot get the bit size of an aggregate with a dynamic shape");
321 
322   auto elementType = getElementType();
323   if (elementType.isIntOrFloat())
324     return elementType.getIntOrFloatBitWidth() * getNumElements();
325 
326   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
327     elementType = complexType.getElementType();
328     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
329   }
330 
331   // Tensors can have vectors and other tensors as elements, other shaped types
332   // cannot.
333   assert(isa<TensorType>() && "unsupported element type");
334   assert((elementType.isa<VectorType, TensorType>()) &&
335          "unsupported tensor element type");
336   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
337 }
338 
339 ArrayRef<int64_t> ShapedType::getShape() const {
340   if (auto vectorType = dyn_cast<VectorType>())
341     return vectorType.getShape();
342   if (auto tensorType = dyn_cast<RankedTensorType>())
343     return tensorType.getShape();
344   return cast<MemRefType>().getShape();
345 }
346 
347 int64_t ShapedType::getNumDynamicDims() const {
348   return llvm::count_if(getShape(), isDynamic);
349 }
350 
351 bool ShapedType::hasStaticShape() const {
352   return hasRank() && llvm::none_of(getShape(), isDynamic);
353 }
354 
355 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
356   return hasStaticShape() && getShape() == shape;
357 }
358 
359 //===----------------------------------------------------------------------===//
360 // VectorType
361 //===----------------------------------------------------------------------===//
362 
363 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
364                                  ArrayRef<int64_t> shape, Type elementType) {
365   if (shape.empty())
366     return emitError() << "vector types must have at least one dimension";
367 
368   if (!isValidElementType(elementType))
369     return emitError() << "vector elements must be int or float type";
370 
371   if (any_of(shape, [](int64_t i) { return i <= 0; }))
372     return emitError() << "vector types must have positive constant sizes";
373 
374   return success();
375 }
376 
377 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
378   if (!scale)
379     return VectorType();
380   if (auto et = getElementType().dyn_cast<IntegerType>())
381     if (auto scaledEt = et.scaleElementBitwidth(scale))
382       return VectorType::get(getShape(), scaledEt);
383   if (auto et = getElementType().dyn_cast<FloatType>())
384     if (auto scaledEt = et.scaleElementBitwidth(scale))
385       return VectorType::get(getShape(), scaledEt);
386   return VectorType();
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // TensorType
391 //===----------------------------------------------------------------------===//
392 
393 // Check if "elementType" can be an element type of a tensor.
394 static LogicalResult
395 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
396                        Type elementType) {
397   if (!TensorType::isValidElementType(elementType))
398     return emitError() << "invalid tensor element type: " << elementType;
399   return success();
400 }
401 
402 /// Return true if the specified element type is ok in a tensor.
403 bool TensorType::isValidElementType(Type type) {
404   // Note: Non standard/builtin types are allowed to exist within tensor
405   // types. Dialects are expected to verify that tensor types have a valid
406   // element type within that dialect.
407   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
408                   IndexType>() ||
409          !type.getDialect().getNamespace().empty();
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // RankedTensorType
414 //===----------------------------------------------------------------------===//
415 
416 LogicalResult
417 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
418                          ArrayRef<int64_t> shape, Type elementType) {
419   for (int64_t s : shape)
420     if (s < -1)
421       return emitError() << "invalid tensor dimension size";
422   return checkTensorElementType(emitError, elementType);
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // UnrankedTensorType
427 //===----------------------------------------------------------------------===//
428 
429 LogicalResult
430 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
431                            Type elementType) {
432   return checkTensorElementType(emitError, elementType);
433 }
434 
435 //===----------------------------------------------------------------------===//
436 // BaseMemRefType
437 //===----------------------------------------------------------------------===//
438 
439 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
440   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
441     return rankedMemRefTy.getMemorySpaceAsInt();
442   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // MemRefType
447 //===----------------------------------------------------------------------===//
448 
449 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
450                                  ArrayRef<int64_t> shape, Type elementType,
451                                  ArrayRef<AffineMap> affineMapComposition,
452                                  unsigned memorySpace) {
453   if (!BaseMemRefType::isValidElementType(elementType))
454     return emitError() << "invalid memref element type";
455 
456     // Negative sizes are not allowed except for `-1` that means dynamic size.
457   for (int64_t s : shape)
458     if (s < -1)
459       return emitError() << "invalid memref size";
460 
461   // Check that the structure of the composition is valid, i.e. that each
462   // subsequent affine map has as many inputs as the previous map has results.
463   // Take the dimensionality of the MemRef for the first map.
464   size_t dim = shape.size();
465   for (auto it : llvm::enumerate(affineMapComposition)) {
466     AffineMap map = it.value();
467     if (map.getNumDims() == dim) {
468       dim = map.getNumResults();
469       continue;
470     }
471     return emitError() << "memref affine map dimension mismatch between "
472                        << (it.index() == 0 ? Twine("memref rank")
473                                            : "affine map " + Twine(it.index()))
474                        << " and affine map" << it.index() + 1 << ": " << dim
475                        << " != " << map.getNumDims();
476   }
477   return success();
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // UnrankedMemRefType
482 //===----------------------------------------------------------------------===//
483 
484 LogicalResult
485 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
486                            Type elementType, unsigned memorySpace) {
487   if (!BaseMemRefType::isValidElementType(elementType))
488     return emitError() << "invalid memref element type";
489   return success();
490 }
491 
492 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
493 // i.e. single term). Accumulate the AffineExpr into the existing one.
494 static void extractStridesFromTerm(AffineExpr e,
495                                    AffineExpr multiplicativeFactor,
496                                    MutableArrayRef<AffineExpr> strides,
497                                    AffineExpr &offset) {
498   if (auto dim = e.dyn_cast<AffineDimExpr>())
499     strides[dim.getPosition()] =
500         strides[dim.getPosition()] + multiplicativeFactor;
501   else
502     offset = offset + e * multiplicativeFactor;
503 }
504 
505 /// Takes a single AffineExpr `e` and populates the `strides` array with the
506 /// strides expressions for each dim position.
507 /// The convention is that the strides for dimensions d0, .. dn appear in
508 /// order to make indexing intuitive into the result.
509 static LogicalResult extractStrides(AffineExpr e,
510                                     AffineExpr multiplicativeFactor,
511                                     MutableArrayRef<AffineExpr> strides,
512                                     AffineExpr &offset) {
513   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
514   if (!bin) {
515     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
516     return success();
517   }
518 
519   if (bin.getKind() == AffineExprKind::CeilDiv ||
520       bin.getKind() == AffineExprKind::FloorDiv ||
521       bin.getKind() == AffineExprKind::Mod)
522     return failure();
523 
524   if (bin.getKind() == AffineExprKind::Mul) {
525     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
526     if (dim) {
527       strides[dim.getPosition()] =
528           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
529       return success();
530     }
531     // LHS and RHS may both contain complex expressions of dims. Try one path
532     // and if it fails try the other. This is guaranteed to succeed because
533     // only one path may have a `dim`, otherwise this is not an AffineExpr in
534     // the first place.
535     if (bin.getLHS().isSymbolicOrConstant())
536       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
537                             strides, offset);
538     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
539                           strides, offset);
540   }
541 
542   if (bin.getKind() == AffineExprKind::Add) {
543     auto res1 =
544         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
545     auto res2 =
546         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
547     return success(succeeded(res1) && succeeded(res2));
548   }
549 
550   llvm_unreachable("unexpected binary operation");
551 }
552 
553 LogicalResult mlir::getStridesAndOffset(MemRefType t,
554                                         SmallVectorImpl<AffineExpr> &strides,
555                                         AffineExpr &offset) {
556   auto affineMaps = t.getAffineMaps();
557   // For now strides are only computed on a single affine map with a single
558   // result (i.e. the closed subset of linearization maps that are compatible
559   // with striding semantics).
560   // TODO: support more forms on a per-need basis.
561   if (affineMaps.size() > 1)
562     return failure();
563   if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
564     return failure();
565 
566   auto zero = getAffineConstantExpr(0, t.getContext());
567   auto one = getAffineConstantExpr(1, t.getContext());
568   offset = zero;
569   strides.assign(t.getRank(), zero);
570 
571   AffineMap m;
572   if (!affineMaps.empty()) {
573     m = affineMaps.front();
574     assert(!m.isIdentity() && "unexpected identity map");
575   }
576 
577   // Canonical case for empty map.
578   if (!m) {
579     // 0-D corner case, offset is already 0.
580     if (t.getRank() == 0)
581       return success();
582     auto stridedExpr =
583         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
584     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
585       return success();
586     assert(false && "unexpected failure: extract strides in canonical layout");
587   }
588 
589   // Non-canonical case requires more work.
590   auto stridedExpr =
591       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
592   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
593     offset = AffineExpr();
594     strides.clear();
595     return failure();
596   }
597 
598   // Simplify results to allow folding to constants and simple checks.
599   unsigned numDims = m.getNumDims();
600   unsigned numSymbols = m.getNumSymbols();
601   offset = simplifyAffineExpr(offset, numDims, numSymbols);
602   for (auto &stride : strides)
603     stride = simplifyAffineExpr(stride, numDims, numSymbols);
604 
605   /// In practice, a strided memref must be internally non-aliasing. Test
606   /// against 0 as a proxy.
607   /// TODO: static cases can have more advanced checks.
608   /// TODO: dynamic cases would require a way to compare symbolic
609   /// expressions and would probably need an affine set context propagated
610   /// everywhere.
611   if (llvm::any_of(strides, [](AffineExpr e) {
612         return e == getAffineConstantExpr(0, e.getContext());
613       })) {
614     offset = AffineExpr();
615     strides.clear();
616     return failure();
617   }
618 
619   return success();
620 }
621 
622 LogicalResult mlir::getStridesAndOffset(MemRefType t,
623                                         SmallVectorImpl<int64_t> &strides,
624                                         int64_t &offset) {
625   AffineExpr offsetExpr;
626   SmallVector<AffineExpr, 4> strideExprs;
627   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
628     return failure();
629   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
630     offset = cst.getValue();
631   else
632     offset = ShapedType::kDynamicStrideOrOffset;
633   for (auto e : strideExprs) {
634     if (auto c = e.dyn_cast<AffineConstantExpr>())
635       strides.push_back(c.getValue());
636     else
637       strides.push_back(ShapedType::kDynamicStrideOrOffset);
638   }
639   return success();
640 }
641 
642 //===----------------------------------------------------------------------===//
643 /// TupleType
644 //===----------------------------------------------------------------------===//
645 
646 /// Return the elements types for this tuple.
647 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
648 
649 /// Accumulate the types contained in this tuple and tuples nested within it.
650 /// Note that this only flattens nested tuples, not any other container type,
651 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
652 /// (i32, tensor<i32>, f32, i64)
653 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
654   for (Type type : getTypes()) {
655     if (auto nestedTuple = type.dyn_cast<TupleType>())
656       nestedTuple.getFlattenedTypes(types);
657     else
658       types.push_back(type);
659   }
660 }
661 
662 /// Return the number of element types.
663 size_t TupleType::size() const { return getImpl()->size(); }
664 
665 //===----------------------------------------------------------------------===//
666 // Type Utilities
667 //===----------------------------------------------------------------------===//
668 
669 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
670                                            int64_t offset,
671                                            MLIRContext *context) {
672   AffineExpr expr;
673   unsigned nSymbols = 0;
674 
675   // AffineExpr for offset.
676   // Static case.
677   if (offset != MemRefType::getDynamicStrideOrOffset()) {
678     auto cst = getAffineConstantExpr(offset, context);
679     expr = cst;
680   } else {
681     // Dynamic case, new symbol for the offset.
682     auto sym = getAffineSymbolExpr(nSymbols++, context);
683     expr = sym;
684   }
685 
686   // AffineExpr for strides.
687   for (auto en : llvm::enumerate(strides)) {
688     auto dim = en.index();
689     auto stride = en.value();
690     assert(stride != 0 && "Invalid stride specification");
691     auto d = getAffineDimExpr(dim, context);
692     AffineExpr mult;
693     // Static case.
694     if (stride != MemRefType::getDynamicStrideOrOffset())
695       mult = getAffineConstantExpr(stride, context);
696     else
697       // Dynamic case, new symbol for each new stride.
698       mult = getAffineSymbolExpr(nSymbols++, context);
699     expr = expr + d * mult;
700   }
701 
702   return AffineMap::get(strides.size(), nSymbols, expr);
703 }
704 
705 /// Return a version of `t` with identity layout if it can be determined
706 /// statically that the layout is the canonical contiguous strided layout.
707 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
708 /// `t` with simplified layout.
709 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
710 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
711   auto affineMaps = t.getAffineMaps();
712   // Already in canonical form.
713   if (affineMaps.empty())
714     return t;
715 
716   // Can't reduce to canonical identity form, return in canonical form.
717   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
718     return t;
719 
720   // Corner-case for 0-D affine maps.
721   auto m = affineMaps[0];
722   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
723     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
724       if (cst.getValue() == 0)
725         return MemRefType::Builder(t).setAffineMaps({});
726     return t;
727   }
728 
729   // 0-D corner case for empty shape that still have an affine map. Example:
730   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
731   // offset needs to remain, just return t.
732   if (t.getShape().empty())
733     return t;
734 
735   // If the canonical strided layout for the sizes of `t` is equal to the
736   // simplified layout of `t` we can just return an empty layout. Otherwise,
737   // just simplify the existing layout.
738   AffineExpr expr =
739       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
740   auto simplifiedLayoutExpr =
741       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
742   if (expr != simplifiedLayoutExpr)
743     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
744         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
745   return MemRefType::Builder(t).setAffineMaps({});
746 }
747 
748 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
749                                                 ArrayRef<AffineExpr> exprs,
750                                                 MLIRContext *context) {
751   assert(!sizes.empty() && !exprs.empty() &&
752          "expected non-empty sizes and exprs");
753 
754   // Size 0 corner case is useful for canonicalizations.
755   if (llvm::is_contained(sizes, 0))
756     return getAffineConstantExpr(0, context);
757 
758   auto maps = AffineMap::inferFromExprList(exprs);
759   assert(!maps.empty() && "Expected one non-empty map");
760   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
761 
762   AffineExpr expr;
763   bool dynamicPoisonBit = false;
764   int64_t runningSize = 1;
765   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
766     int64_t size = std::get<1>(en);
767     // Degenerate case, no size =-> no stride
768     if (size == 0)
769       continue;
770     AffineExpr dimExpr = std::get<0>(en);
771     AffineExpr stride = dynamicPoisonBit
772                             ? getAffineSymbolExpr(nSymbols++, context)
773                             : getAffineConstantExpr(runningSize, context);
774     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
775     if (size > 0) {
776       runningSize *= size;
777       assert(runningSize > 0 && "integer overflow in size computation");
778     } else {
779       dynamicPoisonBit = true;
780     }
781   }
782   return simplifyAffineExpr(expr, numDims, nSymbols);
783 }
784 
785 /// Return a version of `t` with a layout that has all dynamic offset and
786 /// strides. This is used to erase the static layout.
787 MemRefType mlir::eraseStridedLayout(MemRefType t) {
788   auto val = ShapedType::kDynamicStrideOrOffset;
789   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
790       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
791 }
792 
793 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
794                                                 MLIRContext *context) {
795   SmallVector<AffineExpr, 4> exprs;
796   exprs.reserve(sizes.size());
797   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
798     exprs.push_back(getAffineDimExpr(dim, context));
799   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
800 }
801 
802 /// Return true if the layout for `t` is compatible with strided semantics.
803 bool mlir::isStrided(MemRefType t) {
804   int64_t offset;
805   SmallVector<int64_t, 4> strides;
806   auto res = getStridesAndOffset(t, strides, offset);
807   return succeeded(res);
808 }
809 
810 /// Return the layout map in strided linear layout AffineMap form.
811 /// Return null if the layout is not compatible with a strided layout.
812 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
813   int64_t offset;
814   SmallVector<int64_t, 4> strides;
815   if (failed(getStridesAndOffset(t, strides, offset)))
816     return AffineMap();
817   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
818 }
819