1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 //===----------------------------------------------------------------------===//
24 /// Tablegen Type Definitions
25 //===----------------------------------------------------------------------===//
26 
27 #define GET_TYPEDEF_CLASSES
28 #include "mlir/IR/BuiltinTypes.cpp.inc"
29 
30 //===----------------------------------------------------------------------===//
31 /// ComplexType
32 //===----------------------------------------------------------------------===//
33 
34 /// Verify the construction of an integer type.
35 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
36                                                         Type elementType) {
37   if (!elementType.isIntOrFloat())
38     return emitError(loc, "invalid element type for complex");
39   return success();
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Integer Type
44 //===----------------------------------------------------------------------===//
45 
46 // static constexpr must have a definition (until in C++17 and inline variable).
47 constexpr unsigned IntegerType::kMaxWidth;
48 
49 /// Verify the construction of an integer type.
50 LogicalResult
51 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
52                                           SignednessSemantics signedness) {
53   if (width > IntegerType::kMaxWidth) {
54     return emitError(loc) << "integer bitwidth is limited to "
55                           << IntegerType::kMaxWidth << " bits";
56   }
57   return success();
58 }
59 
60 unsigned IntegerType::getWidth() const { return getImpl()->width; }
61 
62 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
63   return getImpl()->signedness;
64 }
65 
66 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
67   if (!scale)
68     return IntegerType();
69   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // Float Type
74 //===----------------------------------------------------------------------===//
75 
76 unsigned FloatType::getWidth() {
77   if (isa<Float16Type, BFloat16Type>())
78     return 16;
79   if (isa<Float32Type>())
80     return 32;
81   if (isa<Float64Type>())
82     return 64;
83   llvm_unreachable("unexpected float type");
84 }
85 
86 /// Returns the floating semantics for the given type.
87 const llvm::fltSemantics &FloatType::getFloatSemantics() {
88   if (isa<BFloat16Type>())
89     return APFloat::BFloat();
90   if (isa<Float16Type>())
91     return APFloat::IEEEhalf();
92   if (isa<Float32Type>())
93     return APFloat::IEEEsingle();
94   if (isa<Float64Type>())
95     return APFloat::IEEEdouble();
96   llvm_unreachable("non-floating point type used");
97 }
98 
99 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
100   if (!scale)
101     return FloatType();
102   MLIRContext *ctx = getContext();
103   if (isF16() || isBF16()) {
104     if (scale == 2)
105       return FloatType::getF32(ctx);
106     if (scale == 4)
107       return FloatType::getF64(ctx);
108   }
109   if (isF32())
110     if (scale == 2)
111       return FloatType::getF64(ctx);
112   return FloatType();
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // FunctionType
117 //===----------------------------------------------------------------------===//
118 
119 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
120 
121 ArrayRef<Type> FunctionType::getInputs() const {
122   return getImpl()->getInputs();
123 }
124 
125 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
126 
127 ArrayRef<Type> FunctionType::getResults() const {
128   return getImpl()->getResults();
129 }
130 
131 /// Helper to call a callback once on each index in the range
132 /// [0, `totalIndices`), *except* for the indices given in `indices`.
133 /// `indices` is allowed to have duplicates and can be in any order.
134 inline void iterateIndicesExcept(unsigned totalIndices,
135                                  ArrayRef<unsigned> indices,
136                                  function_ref<void(unsigned)> callback) {
137   llvm::BitVector skipIndices(totalIndices);
138   for (unsigned i : indices)
139     skipIndices.set(i);
140 
141   for (unsigned i = 0; i < totalIndices; ++i)
142     if (!skipIndices.test(i))
143       callback(i);
144 }
145 
146 /// Returns a new function type without the specified arguments and results.
147 FunctionType
148 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
149                                        ArrayRef<unsigned> resultIndices) {
150   ArrayRef<Type> newInputTypes = getInputs();
151   SmallVector<Type, 4> newInputTypesBuffer;
152   if (!argIndices.empty()) {
153     unsigned originalNumArgs = getNumInputs();
154     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
155       newInputTypesBuffer.emplace_back(getInput(i));
156     });
157     newInputTypes = newInputTypesBuffer;
158   }
159 
160   ArrayRef<Type> newResultTypes = getResults();
161   SmallVector<Type, 4> newResultTypesBuffer;
162   if (!resultIndices.empty()) {
163     unsigned originalNumResults = getNumResults();
164     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
165       newResultTypesBuffer.emplace_back(getResult(i));
166     });
167     newResultTypes = newResultTypesBuffer;
168   }
169 
170   return get(getContext(), newInputTypes, newResultTypes);
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // OpaqueType
175 //===----------------------------------------------------------------------===//
176 
177 /// Verify the construction of an opaque type.
178 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
179                                                        Identifier dialect,
180                                                        StringRef typeData) {
181   if (!Dialect::isValidNamespace(dialect.strref()))
182     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
183   return success();
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // ShapedType
188 //===----------------------------------------------------------------------===//
189 constexpr int64_t ShapedType::kDynamicSize;
190 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
191 
192 Type ShapedType::getElementType() const {
193   return static_cast<ImplType *>(impl)->elementType;
194 }
195 
196 unsigned ShapedType::getElementTypeBitWidth() const {
197   return getElementType().getIntOrFloatBitWidth();
198 }
199 
200 int64_t ShapedType::getNumElements() const {
201   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
202   auto shape = getShape();
203   int64_t num = 1;
204   for (auto dim : shape)
205     num *= dim;
206   return num;
207 }
208 
209 int64_t ShapedType::getRank() const { return getShape().size(); }
210 
211 bool ShapedType::hasRank() const {
212   return !isa<UnrankedMemRefType, UnrankedTensorType>();
213 }
214 
215 int64_t ShapedType::getDimSize(unsigned idx) const {
216   assert(idx < getRank() && "invalid index for shaped type");
217   return getShape()[idx];
218 }
219 
220 bool ShapedType::isDynamicDim(unsigned idx) const {
221   assert(idx < getRank() && "invalid index for shaped type");
222   return isDynamic(getShape()[idx]);
223 }
224 
225 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
226   assert(index < getRank() && "invalid index");
227   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
228   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
229 }
230 
231 /// Get the number of bits require to store a value of the given shaped type.
232 /// Compute the value recursively since tensors are allowed to have vectors as
233 /// elements.
234 int64_t ShapedType::getSizeInBits() const {
235   assert(hasStaticShape() &&
236          "cannot get the bit size of an aggregate with a dynamic shape");
237 
238   auto elementType = getElementType();
239   if (elementType.isIntOrFloat())
240     return elementType.getIntOrFloatBitWidth() * getNumElements();
241 
242   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
243     elementType = complexType.getElementType();
244     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
245   }
246 
247   // Tensors can have vectors and other tensors as elements, other shaped types
248   // cannot.
249   assert(isa<TensorType>() && "unsupported element type");
250   assert((elementType.isa<VectorType, TensorType>()) &&
251          "unsupported tensor element type");
252   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
253 }
254 
255 ArrayRef<int64_t> ShapedType::getShape() const {
256   if (auto vectorType = dyn_cast<VectorType>())
257     return vectorType.getShape();
258   if (auto tensorType = dyn_cast<RankedTensorType>())
259     return tensorType.getShape();
260   return cast<MemRefType>().getShape();
261 }
262 
263 int64_t ShapedType::getNumDynamicDims() const {
264   return llvm::count_if(getShape(), isDynamic);
265 }
266 
267 bool ShapedType::hasStaticShape() const {
268   return hasRank() && llvm::none_of(getShape(), isDynamic);
269 }
270 
271 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
272   return hasStaticShape() && getShape() == shape;
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // VectorType
277 //===----------------------------------------------------------------------===//
278 
279 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
280   return Base::get(elementType.getContext(), shape, elementType);
281 }
282 
283 VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape,
284                                   Type elementType) {
285   return Base::getChecked(location, shape, elementType);
286 }
287 
288 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
289                                                        ArrayRef<int64_t> shape,
290                                                        Type elementType) {
291   if (shape.empty())
292     return emitError(loc, "vector types must have at least one dimension");
293 
294   if (!isValidElementType(elementType))
295     return emitError(loc, "vector elements must be int or float type");
296 
297   if (any_of(shape, [](int64_t i) { return i <= 0; }))
298     return emitError(loc, "vector types must have positive constant sizes");
299 
300   return success();
301 }
302 
303 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
304 
305 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
306   if (!scale)
307     return VectorType();
308   if (auto et = getElementType().dyn_cast<IntegerType>())
309     if (auto scaledEt = et.scaleElementBitwidth(scale))
310       return VectorType::get(getShape(), scaledEt);
311   if (auto et = getElementType().dyn_cast<FloatType>())
312     if (auto scaledEt = et.scaleElementBitwidth(scale))
313       return VectorType::get(getShape(), scaledEt);
314   return VectorType();
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // TensorType
319 //===----------------------------------------------------------------------===//
320 
321 // Check if "elementType" can be an element type of a tensor. Emit errors if
322 // location is not nullptr.  Returns failure if check failed.
323 static LogicalResult checkTensorElementType(Location location,
324                                             Type elementType) {
325   if (!TensorType::isValidElementType(elementType))
326     return emitError(location, "invalid tensor element type: ") << elementType;
327   return success();
328 }
329 
330 /// Return true if the specified element type is ok in a tensor.
331 bool TensorType::isValidElementType(Type type) {
332   // Note: Non standard/builtin types are allowed to exist within tensor
333   // types. Dialects are expected to verify that tensor types have a valid
334   // element type within that dialect.
335   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
336                   IndexType>() ||
337          !type.getDialect().getNamespace().empty();
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // RankedTensorType
342 //===----------------------------------------------------------------------===//
343 
344 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
345                                        Type elementType) {
346   return Base::get(elementType.getContext(), shape, elementType);
347 }
348 
349 RankedTensorType RankedTensorType::getChecked(Location location,
350                                               ArrayRef<int64_t> shape,
351                                               Type elementType) {
352   return Base::getChecked(location, shape, elementType);
353 }
354 
355 LogicalResult RankedTensorType::verifyConstructionInvariants(
356     Location loc, ArrayRef<int64_t> shape, Type elementType) {
357   for (int64_t s : shape) {
358     if (s < -1)
359       return emitError(loc, "invalid tensor dimension size");
360   }
361   return checkTensorElementType(loc, elementType);
362 }
363 
364 ArrayRef<int64_t> RankedTensorType::getShape() const {
365   return getImpl()->getShape();
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // UnrankedTensorType
370 //===----------------------------------------------------------------------===//
371 
372 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
373   return Base::get(elementType.getContext(), elementType);
374 }
375 
376 UnrankedTensorType UnrankedTensorType::getChecked(Location location,
377                                                   Type elementType) {
378   return Base::getChecked(location, elementType);
379 }
380 
381 LogicalResult
382 UnrankedTensorType::verifyConstructionInvariants(Location loc,
383                                                  Type elementType) {
384   return checkTensorElementType(loc, elementType);
385 }
386 
387 //===----------------------------------------------------------------------===//
388 // BaseMemRefType
389 //===----------------------------------------------------------------------===//
390 
391 unsigned BaseMemRefType::getMemorySpace() const {
392   return static_cast<ImplType *>(impl)->memorySpace;
393 }
394 
395 //===----------------------------------------------------------------------===//
396 // MemRefType
397 //===----------------------------------------------------------------------===//
398 
399 /// Get or create a new MemRefType based on shape, element type, affine
400 /// map composition, and memory space.  Assumes the arguments define a
401 /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
402 /// construction failures.
403 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
404                            ArrayRef<AffineMap> affineMapComposition,
405                            unsigned memorySpace) {
406   auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
407                         /*location=*/llvm::None);
408   assert(result && "Failed to construct instance of MemRefType.");
409   return result;
410 }
411 
412 /// Get or create a new MemRefType based on shape, element type, affine
413 /// map composition, and memory space declared at the given location.
414 /// If the location is unknown, the last argument should be an instance of
415 /// UnknownLoc.  If the MemRefType defined by the arguments would be
416 /// ill-formed, emits errors (to the handler registered with the context or to
417 /// the error stream) and returns nullptr.
418 MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
419                                   Type elementType,
420                                   ArrayRef<AffineMap> affineMapComposition,
421                                   unsigned memorySpace) {
422   return getImpl(shape, elementType, affineMapComposition, memorySpace,
423                  location);
424 }
425 
426 /// Get or create a new MemRefType defined by the arguments.  If the resulting
427 /// type would be ill-formed, return nullptr.  If the location is provided,
428 /// emit detailed error messages.  To emit errors when the location is unknown,
429 /// pass in an instance of UnknownLoc.
430 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
431                                ArrayRef<AffineMap> affineMapComposition,
432                                unsigned memorySpace,
433                                Optional<Location> location) {
434   auto *context = elementType.getContext();
435 
436   if (!BaseMemRefType::isValidElementType(elementType))
437     return emitOptionalError(location, "invalid memref element type"),
438            MemRefType();
439 
440   for (int64_t s : shape) {
441     // Negative sizes are not allowed except for `-1` that means dynamic size.
442     if (s < -1)
443       return emitOptionalError(location, "invalid memref size"), MemRefType();
444   }
445 
446   // Check that the structure of the composition is valid, i.e. that each
447   // subsequent affine map has as many inputs as the previous map has results.
448   // Take the dimensionality of the MemRef for the first map.
449   auto dim = shape.size();
450   unsigned i = 0;
451   for (const auto &affineMap : affineMapComposition) {
452     if (affineMap.getNumDims() != dim) {
453       if (location)
454         emitError(*location)
455             << "memref affine map dimension mismatch between "
456             << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
457             << " and affine map" << i + 1 << ": " << dim
458             << " != " << affineMap.getNumDims();
459       return nullptr;
460     }
461 
462     dim = affineMap.getNumResults();
463     ++i;
464   }
465 
466   // Drop identity maps from the composition.
467   // This may lead to the composition becoming empty, which is interpreted as an
468   // implicit identity.
469   SmallVector<AffineMap, 2> cleanedAffineMapComposition;
470   for (const auto &map : affineMapComposition) {
471     if (map.isIdentity())
472       continue;
473     cleanedAffineMapComposition.push_back(map);
474   }
475 
476   return Base::get(context, shape, elementType, cleanedAffineMapComposition,
477                    memorySpace);
478 }
479 
480 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
481 
482 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
483   return getImpl()->getAffineMaps();
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // UnrankedMemRefType
488 //===----------------------------------------------------------------------===//
489 
490 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
491                                            unsigned memorySpace) {
492   return Base::get(elementType.getContext(), elementType, memorySpace);
493 }
494 
495 UnrankedMemRefType UnrankedMemRefType::getChecked(Location location,
496                                                   Type elementType,
497                                                   unsigned memorySpace) {
498   return Base::getChecked(location, elementType, memorySpace);
499 }
500 
501 LogicalResult
502 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
503                                                  unsigned memorySpace) {
504   if (!BaseMemRefType::isValidElementType(elementType))
505     return emitError(loc, "invalid memref element type");
506   return success();
507 }
508 
509 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
510 // i.e. single term). Accumulate the AffineExpr into the existing one.
511 static void extractStridesFromTerm(AffineExpr e,
512                                    AffineExpr multiplicativeFactor,
513                                    MutableArrayRef<AffineExpr> strides,
514                                    AffineExpr &offset) {
515   if (auto dim = e.dyn_cast<AffineDimExpr>())
516     strides[dim.getPosition()] =
517         strides[dim.getPosition()] + multiplicativeFactor;
518   else
519     offset = offset + e * multiplicativeFactor;
520 }
521 
522 /// Takes a single AffineExpr `e` and populates the `strides` array with the
523 /// strides expressions for each dim position.
524 /// The convention is that the strides for dimensions d0, .. dn appear in
525 /// order to make indexing intuitive into the result.
526 static LogicalResult extractStrides(AffineExpr e,
527                                     AffineExpr multiplicativeFactor,
528                                     MutableArrayRef<AffineExpr> strides,
529                                     AffineExpr &offset) {
530   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
531   if (!bin) {
532     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
533     return success();
534   }
535 
536   if (bin.getKind() == AffineExprKind::CeilDiv ||
537       bin.getKind() == AffineExprKind::FloorDiv ||
538       bin.getKind() == AffineExprKind::Mod)
539     return failure();
540 
541   if (bin.getKind() == AffineExprKind::Mul) {
542     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
543     if (dim) {
544       strides[dim.getPosition()] =
545           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
546       return success();
547     }
548     // LHS and RHS may both contain complex expressions of dims. Try one path
549     // and if it fails try the other. This is guaranteed to succeed because
550     // only one path may have a `dim`, otherwise this is not an AffineExpr in
551     // the first place.
552     if (bin.getLHS().isSymbolicOrConstant())
553       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
554                             strides, offset);
555     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
556                           strides, offset);
557   }
558 
559   if (bin.getKind() == AffineExprKind::Add) {
560     auto res1 =
561         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
562     auto res2 =
563         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
564     return success(succeeded(res1) && succeeded(res2));
565   }
566 
567   llvm_unreachable("unexpected binary operation");
568 }
569 
570 LogicalResult mlir::getStridesAndOffset(MemRefType t,
571                                         SmallVectorImpl<AffineExpr> &strides,
572                                         AffineExpr &offset) {
573   auto affineMaps = t.getAffineMaps();
574   // For now strides are only computed on a single affine map with a single
575   // result (i.e. the closed subset of linearization maps that are compatible
576   // with striding semantics).
577   // TODO: support more forms on a per-need basis.
578   if (affineMaps.size() > 1)
579     return failure();
580   if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
581     return failure();
582 
583   auto zero = getAffineConstantExpr(0, t.getContext());
584   auto one = getAffineConstantExpr(1, t.getContext());
585   offset = zero;
586   strides.assign(t.getRank(), zero);
587 
588   AffineMap m;
589   if (!affineMaps.empty()) {
590     m = affineMaps.front();
591     assert(!m.isIdentity() && "unexpected identity map");
592   }
593 
594   // Canonical case for empty map.
595   if (!m) {
596     // 0-D corner case, offset is already 0.
597     if (t.getRank() == 0)
598       return success();
599     auto stridedExpr =
600         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
601     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
602       return success();
603     assert(false && "unexpected failure: extract strides in canonical layout");
604   }
605 
606   // Non-canonical case requires more work.
607   auto stridedExpr =
608       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
609   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
610     offset = AffineExpr();
611     strides.clear();
612     return failure();
613   }
614 
615   // Simplify results to allow folding to constants and simple checks.
616   unsigned numDims = m.getNumDims();
617   unsigned numSymbols = m.getNumSymbols();
618   offset = simplifyAffineExpr(offset, numDims, numSymbols);
619   for (auto &stride : strides)
620     stride = simplifyAffineExpr(stride, numDims, numSymbols);
621 
622   /// In practice, a strided memref must be internally non-aliasing. Test
623   /// against 0 as a proxy.
624   /// TODO: static cases can have more advanced checks.
625   /// TODO: dynamic cases would require a way to compare symbolic
626   /// expressions and would probably need an affine set context propagated
627   /// everywhere.
628   if (llvm::any_of(strides, [](AffineExpr e) {
629         return e == getAffineConstantExpr(0, e.getContext());
630       })) {
631     offset = AffineExpr();
632     strides.clear();
633     return failure();
634   }
635 
636   return success();
637 }
638 
639 LogicalResult mlir::getStridesAndOffset(MemRefType t,
640                                         SmallVectorImpl<int64_t> &strides,
641                                         int64_t &offset) {
642   AffineExpr offsetExpr;
643   SmallVector<AffineExpr, 4> strideExprs;
644   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
645     return failure();
646   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
647     offset = cst.getValue();
648   else
649     offset = ShapedType::kDynamicStrideOrOffset;
650   for (auto e : strideExprs) {
651     if (auto c = e.dyn_cast<AffineConstantExpr>())
652       strides.push_back(c.getValue());
653     else
654       strides.push_back(ShapedType::kDynamicStrideOrOffset);
655   }
656   return success();
657 }
658 
659 //===----------------------------------------------------------------------===//
660 /// TupleType
661 //===----------------------------------------------------------------------===//
662 
663 /// Return the elements types for this tuple.
664 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
665 
666 /// Accumulate the types contained in this tuple and tuples nested within it.
667 /// Note that this only flattens nested tuples, not any other container type,
668 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
669 /// (i32, tensor<i32>, f32, i64)
670 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
671   for (Type type : getTypes()) {
672     if (auto nestedTuple = type.dyn_cast<TupleType>())
673       nestedTuple.getFlattenedTypes(types);
674     else
675       types.push_back(type);
676   }
677 }
678 
679 /// Return the number of element types.
680 size_t TupleType::size() const { return getImpl()->size(); }
681 
682 //===----------------------------------------------------------------------===//
683 // Type Utilities
684 //===----------------------------------------------------------------------===//
685 
686 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
687                                            int64_t offset,
688                                            MLIRContext *context) {
689   AffineExpr expr;
690   unsigned nSymbols = 0;
691 
692   // AffineExpr for offset.
693   // Static case.
694   if (offset != MemRefType::getDynamicStrideOrOffset()) {
695     auto cst = getAffineConstantExpr(offset, context);
696     expr = cst;
697   } else {
698     // Dynamic case, new symbol for the offset.
699     auto sym = getAffineSymbolExpr(nSymbols++, context);
700     expr = sym;
701   }
702 
703   // AffineExpr for strides.
704   for (auto en : llvm::enumerate(strides)) {
705     auto dim = en.index();
706     auto stride = en.value();
707     assert(stride != 0 && "Invalid stride specification");
708     auto d = getAffineDimExpr(dim, context);
709     AffineExpr mult;
710     // Static case.
711     if (stride != MemRefType::getDynamicStrideOrOffset())
712       mult = getAffineConstantExpr(stride, context);
713     else
714       // Dynamic case, new symbol for each new stride.
715       mult = getAffineSymbolExpr(nSymbols++, context);
716     expr = expr + d * mult;
717   }
718 
719   return AffineMap::get(strides.size(), nSymbols, expr);
720 }
721 
722 /// Return a version of `t` with identity layout if it can be determined
723 /// statically that the layout is the canonical contiguous strided layout.
724 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
725 /// `t` with simplified layout.
726 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
727 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
728   auto affineMaps = t.getAffineMaps();
729   // Already in canonical form.
730   if (affineMaps.empty())
731     return t;
732 
733   // Can't reduce to canonical identity form, return in canonical form.
734   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
735     return t;
736 
737   // If the canonical strided layout for the sizes of `t` is equal to the
738   // simplified layout of `t` we can just return an empty layout. Otherwise,
739   // just simplify the existing layout.
740   AffineExpr expr =
741       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
742   auto m = affineMaps[0];
743   auto simplifiedLayoutExpr =
744       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
745   if (expr != simplifiedLayoutExpr)
746     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
747         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
748   return MemRefType::Builder(t).setAffineMaps({});
749 }
750 
751 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
752                                                 ArrayRef<AffineExpr> exprs,
753                                                 MLIRContext *context) {
754   // Size 0 corner case is useful for canonicalizations.
755   if (llvm::is_contained(sizes, 0))
756     return getAffineConstantExpr(0, context);
757 
758   auto maps = AffineMap::inferFromExprList(exprs);
759   assert(!maps.empty() && "Expected one non-empty map");
760   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
761 
762   AffineExpr expr;
763   bool dynamicPoisonBit = false;
764   int64_t runningSize = 1;
765   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
766     int64_t size = std::get<1>(en);
767     // Degenerate case, no size =-> no stride
768     if (size == 0)
769       continue;
770     AffineExpr dimExpr = std::get<0>(en);
771     AffineExpr stride = dynamicPoisonBit
772                             ? getAffineSymbolExpr(nSymbols++, context)
773                             : getAffineConstantExpr(runningSize, context);
774     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
775     if (size > 0)
776       runningSize *= size;
777     else
778       dynamicPoisonBit = true;
779   }
780   return simplifyAffineExpr(expr, numDims, nSymbols);
781 }
782 
783 /// Return a version of `t` with a layout that has all dynamic offset and
784 /// strides. This is used to erase the static layout.
785 MemRefType mlir::eraseStridedLayout(MemRefType t) {
786   auto val = ShapedType::kDynamicStrideOrOffset;
787   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
788       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
789 }
790 
791 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
792                                                 MLIRContext *context) {
793   SmallVector<AffineExpr, 4> exprs;
794   exprs.reserve(sizes.size());
795   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
796     exprs.push_back(getAffineDimExpr(dim, context));
797   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
798 }
799 
800 /// Return true if the layout for `t` is compatible with strided semantics.
801 bool mlir::isStrided(MemRefType t) {
802   int64_t offset;
803   SmallVector<int64_t, 4> stridesAndOffset;
804   auto res = getStridesAndOffset(t, stridesAndOffset, offset);
805   return succeeded(res);
806 }
807