1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/BitVector.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 //===----------------------------------------------------------------------===//
28 /// Tablegen Type Definitions
29 //===----------------------------------------------------------------------===//
30 
31 #define GET_TYPEDEF_CLASSES
32 #include "mlir/IR/BuiltinTypes.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // BuiltinDialect
36 //===----------------------------------------------------------------------===//
37 
38 void BuiltinDialect::registerTypes() {
39   addTypes<
40 #define GET_TYPEDEF_LIST
41 #include "mlir/IR/BuiltinTypes.cpp.inc"
42       >();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 /// ComplexType
47 //===----------------------------------------------------------------------===//
48 
49 /// Verify the construction of an integer type.
50 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
51                                   Type elementType) {
52   if (!elementType.isIntOrFloat())
53     return emitError() << "invalid element type for complex";
54   return success();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // Integer Type
59 //===----------------------------------------------------------------------===//
60 
61 // static constexpr must have a definition (until in C++17 and inline variable).
62 constexpr unsigned IntegerType::kMaxWidth;
63 
64 /// Verify the construction of an integer type.
65 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
66                                   unsigned width,
67                                   SignednessSemantics signedness) {
68   if (width > IntegerType::kMaxWidth) {
69     return emitError() << "integer bitwidth is limited to "
70                        << IntegerType::kMaxWidth << " bits";
71   }
72   return success();
73 }
74 
75 unsigned IntegerType::getWidth() const { return getImpl()->width; }
76 
77 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
78   return getImpl()->signedness;
79 }
80 
81 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
82   if (!scale)
83     return IntegerType();
84   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // Float Type
89 //===----------------------------------------------------------------------===//
90 
91 unsigned FloatType::getWidth() {
92   if (isa<Float16Type, BFloat16Type>())
93     return 16;
94   if (isa<Float32Type>())
95     return 32;
96   if (isa<Float64Type>())
97     return 64;
98   if (isa<Float80Type>())
99     return 80;
100   if (isa<Float128Type>())
101     return 128;
102   llvm_unreachable("unexpected float type");
103 }
104 
105 /// Returns the floating semantics for the given type.
106 const llvm::fltSemantics &FloatType::getFloatSemantics() {
107   if (isa<BFloat16Type>())
108     return APFloat::BFloat();
109   if (isa<Float16Type>())
110     return APFloat::IEEEhalf();
111   if (isa<Float32Type>())
112     return APFloat::IEEEsingle();
113   if (isa<Float64Type>())
114     return APFloat::IEEEdouble();
115   if (isa<Float80Type>())
116     return APFloat::x87DoubleExtended();
117   if (isa<Float128Type>())
118     return APFloat::IEEEquad();
119   llvm_unreachable("non-floating point type used");
120 }
121 
122 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
123   if (!scale)
124     return FloatType();
125   MLIRContext *ctx = getContext();
126   if (isF16() || isBF16()) {
127     if (scale == 2)
128       return FloatType::getF32(ctx);
129     if (scale == 4)
130       return FloatType::getF64(ctx);
131   }
132   if (isF32())
133     if (scale == 2)
134       return FloatType::getF64(ctx);
135   return FloatType();
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // FunctionType
140 //===----------------------------------------------------------------------===//
141 
142 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
143 
144 ArrayRef<Type> FunctionType::getInputs() const {
145   return getImpl()->getInputs();
146 }
147 
148 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
149 
150 ArrayRef<Type> FunctionType::getResults() const {
151   return getImpl()->getResults();
152 }
153 
154 /// Helper to call a callback once on each index in the range
155 /// [0, `totalIndices`), *except* for the indices given in `indices`.
156 /// `indices` is allowed to have duplicates and can be in any order.
157 inline void iterateIndicesExcept(unsigned totalIndices,
158                                  ArrayRef<unsigned> indices,
159                                  function_ref<void(unsigned)> callback) {
160   llvm::BitVector skipIndices(totalIndices);
161   for (unsigned i : indices)
162     skipIndices.set(i);
163 
164   for (unsigned i = 0; i < totalIndices; ++i)
165     if (!skipIndices.test(i))
166       callback(i);
167 }
168 
169 /// Returns a new function type without the specified arguments and results.
170 FunctionType
171 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
172                                        ArrayRef<unsigned> resultIndices) {
173   ArrayRef<Type> newInputTypes = getInputs();
174   SmallVector<Type, 4> newInputTypesBuffer;
175   if (!argIndices.empty()) {
176     unsigned originalNumArgs = getNumInputs();
177     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
178       newInputTypesBuffer.emplace_back(getInput(i));
179     });
180     newInputTypes = newInputTypesBuffer;
181   }
182 
183   ArrayRef<Type> newResultTypes = getResults();
184   SmallVector<Type, 4> newResultTypesBuffer;
185   if (!resultIndices.empty()) {
186     unsigned originalNumResults = getNumResults();
187     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
188       newResultTypesBuffer.emplace_back(getResult(i));
189     });
190     newResultTypes = newResultTypesBuffer;
191   }
192 
193   return get(getContext(), newInputTypes, newResultTypes);
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // OpaqueType
198 //===----------------------------------------------------------------------===//
199 
200 /// Verify the construction of an opaque type.
201 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
202                                  Identifier dialect, StringRef typeData) {
203   if (!Dialect::isValidNamespace(dialect.strref()))
204     return emitError() << "invalid dialect namespace '" << dialect << "'";
205 
206   // Check that the dialect is actually registered.
207   MLIRContext *context = dialect.getContext();
208   if (!context->allowsUnregisteredDialects() &&
209       !context->getLoadedDialect(dialect.strref())) {
210     return emitError()
211            << "`!" << dialect << "<\"" << typeData << "\">"
212            << "` type created with unregistered dialect. If this is "
213               "intended, please call allowUnregisteredDialects() on the "
214               "MLIRContext, or use -allow-unregistered-dialect with "
215               "mlir-opt";
216   }
217 
218   return success();
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // ShapedType
223 //===----------------------------------------------------------------------===//
224 constexpr int64_t ShapedType::kDynamicSize;
225 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
226 
227 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
228   if (auto other = dyn_cast<MemRefType>()) {
229     MemRefType::Builder b(other);
230     b.setShape(shape);
231     b.setElementType(elementType);
232     return b;
233   }
234 
235   if (auto other = dyn_cast<UnrankedMemRefType>()) {
236     MemRefType::Builder b(shape, elementType);
237     b.setMemorySpace(other.getMemorySpace());
238     return b;
239   }
240 
241   if (isa<TensorType>())
242     return RankedTensorType::get(shape, elementType);
243 
244   if (isa<VectorType>())
245     return VectorType::get(shape, elementType);
246 
247   llvm_unreachable("Unhandled ShapedType clone case");
248 }
249 
250 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
251   if (auto other = dyn_cast<MemRefType>()) {
252     MemRefType::Builder b(other);
253     b.setShape(shape);
254     return b;
255   }
256 
257   if (auto other = dyn_cast<UnrankedMemRefType>()) {
258     MemRefType::Builder b(shape, other.getElementType());
259     b.setShape(shape);
260     b.setMemorySpace(other.getMemorySpace());
261     return b;
262   }
263 
264   if (isa<TensorType>())
265     return RankedTensorType::get(shape, getElementType());
266 
267   if (isa<VectorType>())
268     return VectorType::get(shape, getElementType());
269 
270   llvm_unreachable("Unhandled ShapedType clone case");
271 }
272 
273 ShapedType ShapedType::clone(Type elementType) {
274   if (auto other = dyn_cast<MemRefType>()) {
275     MemRefType::Builder b(other);
276     b.setElementType(elementType);
277     return b;
278   }
279 
280   if (auto other = dyn_cast<UnrankedMemRefType>()) {
281     return UnrankedMemRefType::get(elementType, other.getMemorySpace());
282   }
283 
284   if (isa<TensorType>()) {
285     if (hasRank())
286       return RankedTensorType::get(getShape(), elementType);
287     return UnrankedTensorType::get(elementType);
288   }
289 
290   if (isa<VectorType>())
291     return VectorType::get(getShape(), elementType);
292 
293   llvm_unreachable("Unhandled ShapedType clone hit");
294 }
295 
296 Type ShapedType::getElementType() const {
297   return TypeSwitch<Type, Type>(*this)
298       .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
299             UnrankedMemRefType>([](auto ty) { return ty.getElementType(); });
300 }
301 
302 unsigned ShapedType::getElementTypeBitWidth() const {
303   return getElementType().getIntOrFloatBitWidth();
304 }
305 
306 int64_t ShapedType::getNumElements() const {
307   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
308   auto shape = getShape();
309   int64_t num = 1;
310   for (auto dim : shape) {
311     num *= dim;
312     assert(num >= 0 && "integer overflow in element count computation");
313   }
314   return num;
315 }
316 
317 int64_t ShapedType::getRank() const {
318   assert(hasRank() && "cannot query rank of unranked shaped type");
319   return getShape().size();
320 }
321 
322 bool ShapedType::hasRank() const {
323   return !isa<UnrankedMemRefType, UnrankedTensorType>();
324 }
325 
326 int64_t ShapedType::getDimSize(unsigned idx) const {
327   assert(idx < getRank() && "invalid index for shaped type");
328   return getShape()[idx];
329 }
330 
331 bool ShapedType::isDynamicDim(unsigned idx) const {
332   assert(idx < getRank() && "invalid index for shaped type");
333   return isDynamic(getShape()[idx]);
334 }
335 
336 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
337   assert(index < getRank() && "invalid index");
338   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
339   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
340 }
341 
342 /// Get the number of bits require to store a value of the given shaped type.
343 /// Compute the value recursively since tensors are allowed to have vectors as
344 /// elements.
345 int64_t ShapedType::getSizeInBits() const {
346   assert(hasStaticShape() &&
347          "cannot get the bit size of an aggregate with a dynamic shape");
348 
349   auto elementType = getElementType();
350   if (elementType.isIntOrFloat())
351     return elementType.getIntOrFloatBitWidth() * getNumElements();
352 
353   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
354     elementType = complexType.getElementType();
355     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
356   }
357 
358   // Tensors can have vectors and other tensors as elements, other shaped types
359   // cannot.
360   assert(isa<TensorType>() && "unsupported element type");
361   assert((elementType.isa<VectorType, TensorType>()) &&
362          "unsupported tensor element type");
363   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
364 }
365 
366 ArrayRef<int64_t> ShapedType::getShape() const {
367   if (auto vectorType = dyn_cast<VectorType>())
368     return vectorType.getShape();
369   if (auto tensorType = dyn_cast<RankedTensorType>())
370     return tensorType.getShape();
371   return cast<MemRefType>().getShape();
372 }
373 
374 int64_t ShapedType::getNumDynamicDims() const {
375   return llvm::count_if(getShape(), isDynamic);
376 }
377 
378 bool ShapedType::hasStaticShape() const {
379   return hasRank() && llvm::none_of(getShape(), isDynamic);
380 }
381 
382 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
383   return hasStaticShape() && getShape() == shape;
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // VectorType
388 //===----------------------------------------------------------------------===//
389 
390 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
391                                  ArrayRef<int64_t> shape, Type elementType) {
392   if (shape.empty())
393     return emitError() << "vector types must have at least one dimension";
394 
395   if (!isValidElementType(elementType))
396     return emitError() << "vector elements must be int/index/float type";
397 
398   if (any_of(shape, [](int64_t i) { return i <= 0; }))
399     return emitError() << "vector types must have positive constant sizes";
400 
401   return success();
402 }
403 
404 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
405   if (!scale)
406     return VectorType();
407   if (auto et = getElementType().dyn_cast<IntegerType>())
408     if (auto scaledEt = et.scaleElementBitwidth(scale))
409       return VectorType::get(getShape(), scaledEt);
410   if (auto et = getElementType().dyn_cast<FloatType>())
411     if (auto scaledEt = et.scaleElementBitwidth(scale))
412       return VectorType::get(getShape(), scaledEt);
413   return VectorType();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // TensorType
418 //===----------------------------------------------------------------------===//
419 
420 // Check if "elementType" can be an element type of a tensor.
421 static LogicalResult
422 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
423                        Type elementType) {
424   if (!TensorType::isValidElementType(elementType))
425     return emitError() << "invalid tensor element type: " << elementType;
426   return success();
427 }
428 
429 /// Return true if the specified element type is ok in a tensor.
430 bool TensorType::isValidElementType(Type type) {
431   // Note: Non standard/builtin types are allowed to exist within tensor
432   // types. Dialects are expected to verify that tensor types have a valid
433   // element type within that dialect.
434   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
435                   IndexType>() ||
436          !type.getDialect().getNamespace().empty();
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // RankedTensorType
441 //===----------------------------------------------------------------------===//
442 
443 LogicalResult
444 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
445                          ArrayRef<int64_t> shape, Type elementType,
446                          Attribute encoding) {
447   for (int64_t s : shape)
448     if (s < -1)
449       return emitError() << "invalid tensor dimension size";
450   if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
451     if (failed(v.verifyEncoding(shape, elementType, emitError)))
452       return failure();
453   return checkTensorElementType(emitError, elementType);
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // UnrankedTensorType
458 //===----------------------------------------------------------------------===//
459 
460 LogicalResult
461 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
462                            Type elementType) {
463   return checkTensorElementType(emitError, elementType);
464 }
465 
466 //===----------------------------------------------------------------------===//
467 // BaseMemRefType
468 //===----------------------------------------------------------------------===//
469 
470 Attribute BaseMemRefType::getMemorySpace() const {
471   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
472     return rankedMemRefTy.getMemorySpace();
473   return cast<UnrankedMemRefType>().getMemorySpace();
474 }
475 
476 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
477   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
478     return rankedMemRefTy.getMemorySpaceAsInt();
479   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
480 }
481 
482 //===----------------------------------------------------------------------===//
483 // MemRefType
484 //===----------------------------------------------------------------------===//
485 
486 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
487 /// `originalShape` with some `1` entries erased, return the set of indices
488 /// that specifies which of the entries of `originalShape` are dropped to obtain
489 /// `reducedShape`. The returned mask can be applied as a projection to
490 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
491 /// which dimensions must be kept when e.g. compute MemRef strides under
492 /// rank-reducing operations. Return None if reducedShape cannot be obtained
493 /// by dropping only `1` entries in `originalShape`.
494 llvm::Optional<llvm::SmallDenseSet<unsigned>>
495 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
496                                ArrayRef<int64_t> reducedShape) {
497   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
498   llvm::SmallDenseSet<unsigned> unusedDims;
499   unsigned reducedIdx = 0;
500   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
501     // Greedily insert `originalIdx` if no match.
502     if (reducedIdx < reducedRank &&
503         originalShape[originalIdx] == reducedShape[reducedIdx]) {
504       reducedIdx++;
505       continue;
506     }
507 
508     unusedDims.insert(originalIdx);
509     // If no match on `originalIdx`, the `originalShape` at this dimension
510     // must be 1, otherwise we bail.
511     if (originalShape[originalIdx] != 1)
512       return llvm::None;
513   }
514   // The whole reducedShape must be scanned, otherwise we bail.
515   if (reducedIdx != reducedRank)
516     return llvm::None;
517   return unusedDims;
518 }
519 
520 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
521   // Empty attribute is allowed as default memory space.
522   if (!memorySpace)
523     return true;
524 
525   // Supported built-in attributes.
526   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
527     return true;
528 
529   // Allow custom dialect attributes.
530   if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect()))
531     return true;
532 
533   return false;
534 }
535 
536 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
537                                                MLIRContext *ctx) {
538   if (memorySpace == 0)
539     return nullptr;
540 
541   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
542 }
543 
544 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
545   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
546   if (intMemorySpace && intMemorySpace.getValue() == 0)
547     return nullptr;
548 
549   return memorySpace;
550 }
551 
552 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
553   if (!memorySpace)
554     return 0;
555 
556   assert(memorySpace.isa<IntegerAttr>() &&
557          "Using `getMemorySpaceInteger` with non-Integer attribute");
558 
559   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
560 }
561 
562 MemRefType::Builder &
563 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
564   memorySpace =
565       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
566   return *this;
567 }
568 
569 unsigned MemRefType::getMemorySpaceAsInt() const {
570   return detail::getMemorySpaceAsInt(getMemorySpace());
571 }
572 
573 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
574                                  ArrayRef<int64_t> shape, Type elementType,
575                                  ArrayRef<AffineMap> affineMapComposition,
576                                  Attribute memorySpace) {
577   if (!BaseMemRefType::isValidElementType(elementType))
578     return emitError() << "invalid memref element type";
579 
580   // Negative sizes are not allowed except for `-1` that means dynamic size.
581   for (int64_t s : shape)
582     if (s < -1)
583       return emitError() << "invalid memref size";
584 
585   // Check that the structure of the composition is valid, i.e. that each
586   // subsequent affine map has as many inputs as the previous map has results.
587   // Take the dimensionality of the MemRef for the first map.
588   size_t dim = shape.size();
589   for (auto it : llvm::enumerate(affineMapComposition)) {
590     AffineMap map = it.value();
591     if (map.getNumDims() == dim) {
592       dim = map.getNumResults();
593       continue;
594     }
595     return emitError() << "memref affine map dimension mismatch between "
596                        << (it.index() == 0 ? Twine("memref rank")
597                                            : "affine map " + Twine(it.index()))
598                        << " and affine map" << it.index() + 1 << ": " << dim
599                        << " != " << map.getNumDims();
600   }
601 
602   if (!isSupportedMemorySpace(memorySpace)) {
603     return emitError() << "unsupported memory space Attribute";
604   }
605 
606   return success();
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // UnrankedMemRefType
611 //===----------------------------------------------------------------------===//
612 
613 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
614   return detail::getMemorySpaceAsInt(getMemorySpace());
615 }
616 
617 LogicalResult
618 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
619                            Type elementType, Attribute memorySpace) {
620   if (!BaseMemRefType::isValidElementType(elementType))
621     return emitError() << "invalid memref element type";
622 
623   if (!isSupportedMemorySpace(memorySpace))
624     return emitError() << "unsupported memory space Attribute";
625 
626   return success();
627 }
628 
629 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
630 // i.e. single term). Accumulate the AffineExpr into the existing one.
631 static void extractStridesFromTerm(AffineExpr e,
632                                    AffineExpr multiplicativeFactor,
633                                    MutableArrayRef<AffineExpr> strides,
634                                    AffineExpr &offset) {
635   if (auto dim = e.dyn_cast<AffineDimExpr>())
636     strides[dim.getPosition()] =
637         strides[dim.getPosition()] + multiplicativeFactor;
638   else
639     offset = offset + e * multiplicativeFactor;
640 }
641 
642 /// Takes a single AffineExpr `e` and populates the `strides` array with the
643 /// strides expressions for each dim position.
644 /// The convention is that the strides for dimensions d0, .. dn appear in
645 /// order to make indexing intuitive into the result.
646 static LogicalResult extractStrides(AffineExpr e,
647                                     AffineExpr multiplicativeFactor,
648                                     MutableArrayRef<AffineExpr> strides,
649                                     AffineExpr &offset) {
650   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
651   if (!bin) {
652     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
653     return success();
654   }
655 
656   if (bin.getKind() == AffineExprKind::CeilDiv ||
657       bin.getKind() == AffineExprKind::FloorDiv ||
658       bin.getKind() == AffineExprKind::Mod)
659     return failure();
660 
661   if (bin.getKind() == AffineExprKind::Mul) {
662     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
663     if (dim) {
664       strides[dim.getPosition()] =
665           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
666       return success();
667     }
668     // LHS and RHS may both contain complex expressions of dims. Try one path
669     // and if it fails try the other. This is guaranteed to succeed because
670     // only one path may have a `dim`, otherwise this is not an AffineExpr in
671     // the first place.
672     if (bin.getLHS().isSymbolicOrConstant())
673       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
674                             strides, offset);
675     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
676                           strides, offset);
677   }
678 
679   if (bin.getKind() == AffineExprKind::Add) {
680     auto res1 =
681         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
682     auto res2 =
683         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
684     return success(succeeded(res1) && succeeded(res2));
685   }
686 
687   llvm_unreachable("unexpected binary operation");
688 }
689 
690 LogicalResult mlir::getStridesAndOffset(MemRefType t,
691                                         SmallVectorImpl<AffineExpr> &strides,
692                                         AffineExpr &offset) {
693   auto affineMaps = t.getAffineMaps();
694 
695   if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1)
696     return failure();
697 
698   AffineMap m;
699   if (!affineMaps.empty()) {
700     m = affineMaps.back();
701     for (size_t i = affineMaps.size() - 1; i > 0; --i)
702       m = m.compose(affineMaps[i - 1]);
703     assert(!m.isIdentity() && "unexpected identity map");
704   }
705 
706   auto zero = getAffineConstantExpr(0, t.getContext());
707   auto one = getAffineConstantExpr(1, t.getContext());
708   offset = zero;
709   strides.assign(t.getRank(), zero);
710 
711   // Canonical case for empty map.
712   if (!m) {
713     // 0-D corner case, offset is already 0.
714     if (t.getRank() == 0)
715       return success();
716     auto stridedExpr =
717         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
718     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
719       return success();
720     assert(false && "unexpected failure: extract strides in canonical layout");
721   }
722 
723   // Non-canonical case requires more work.
724   auto stridedExpr =
725       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
726   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
727     offset = AffineExpr();
728     strides.clear();
729     return failure();
730   }
731 
732   // Simplify results to allow folding to constants and simple checks.
733   unsigned numDims = m.getNumDims();
734   unsigned numSymbols = m.getNumSymbols();
735   offset = simplifyAffineExpr(offset, numDims, numSymbols);
736   for (auto &stride : strides)
737     stride = simplifyAffineExpr(stride, numDims, numSymbols);
738 
739   /// In practice, a strided memref must be internally non-aliasing. Test
740   /// against 0 as a proxy.
741   /// TODO: static cases can have more advanced checks.
742   /// TODO: dynamic cases would require a way to compare symbolic
743   /// expressions and would probably need an affine set context propagated
744   /// everywhere.
745   if (llvm::any_of(strides, [](AffineExpr e) {
746         return e == getAffineConstantExpr(0, e.getContext());
747       })) {
748     offset = AffineExpr();
749     strides.clear();
750     return failure();
751   }
752 
753   return success();
754 }
755 
756 LogicalResult mlir::getStridesAndOffset(MemRefType t,
757                                         SmallVectorImpl<int64_t> &strides,
758                                         int64_t &offset) {
759   AffineExpr offsetExpr;
760   SmallVector<AffineExpr, 4> strideExprs;
761   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
762     return failure();
763   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
764     offset = cst.getValue();
765   else
766     offset = ShapedType::kDynamicStrideOrOffset;
767   for (auto e : strideExprs) {
768     if (auto c = e.dyn_cast<AffineConstantExpr>())
769       strides.push_back(c.getValue());
770     else
771       strides.push_back(ShapedType::kDynamicStrideOrOffset);
772   }
773   return success();
774 }
775 
776 //===----------------------------------------------------------------------===//
777 /// TupleType
778 //===----------------------------------------------------------------------===//
779 
780 /// Return the elements types for this tuple.
781 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
782 
783 /// Accumulate the types contained in this tuple and tuples nested within it.
784 /// Note that this only flattens nested tuples, not any other container type,
785 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
786 /// (i32, tensor<i32>, f32, i64)
787 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
788   for (Type type : getTypes()) {
789     if (auto nestedTuple = type.dyn_cast<TupleType>())
790       nestedTuple.getFlattenedTypes(types);
791     else
792       types.push_back(type);
793   }
794 }
795 
796 /// Return the number of element types.
797 size_t TupleType::size() const { return getImpl()->size(); }
798 
799 //===----------------------------------------------------------------------===//
800 // Type Utilities
801 //===----------------------------------------------------------------------===//
802 
803 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
804                                            int64_t offset,
805                                            MLIRContext *context) {
806   AffineExpr expr;
807   unsigned nSymbols = 0;
808 
809   // AffineExpr for offset.
810   // Static case.
811   if (offset != MemRefType::getDynamicStrideOrOffset()) {
812     auto cst = getAffineConstantExpr(offset, context);
813     expr = cst;
814   } else {
815     // Dynamic case, new symbol for the offset.
816     auto sym = getAffineSymbolExpr(nSymbols++, context);
817     expr = sym;
818   }
819 
820   // AffineExpr for strides.
821   for (auto en : llvm::enumerate(strides)) {
822     auto dim = en.index();
823     auto stride = en.value();
824     assert(stride != 0 && "Invalid stride specification");
825     auto d = getAffineDimExpr(dim, context);
826     AffineExpr mult;
827     // Static case.
828     if (stride != MemRefType::getDynamicStrideOrOffset())
829       mult = getAffineConstantExpr(stride, context);
830     else
831       // Dynamic case, new symbol for each new stride.
832       mult = getAffineSymbolExpr(nSymbols++, context);
833     expr = expr + d * mult;
834   }
835 
836   return AffineMap::get(strides.size(), nSymbols, expr);
837 }
838 
839 /// Return a version of `t` with identity layout if it can be determined
840 /// statically that the layout is the canonical contiguous strided layout.
841 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
842 /// `t` with simplified layout.
843 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
844 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
845   auto affineMaps = t.getAffineMaps();
846   // Already in canonical form.
847   if (affineMaps.empty())
848     return t;
849 
850   // Can't reduce to canonical identity form, return in canonical form.
851   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
852     return t;
853 
854   // Corner-case for 0-D affine maps.
855   auto m = affineMaps[0];
856   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
857     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
858       if (cst.getValue() == 0)
859         return MemRefType::Builder(t).setAffineMaps({});
860     return t;
861   }
862 
863   // 0-D corner case for empty shape that still have an affine map. Example:
864   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
865   // offset needs to remain, just return t.
866   if (t.getShape().empty())
867     return t;
868 
869   // If the canonical strided layout for the sizes of `t` is equal to the
870   // simplified layout of `t` we can just return an empty layout. Otherwise,
871   // just simplify the existing layout.
872   AffineExpr expr =
873       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
874   auto simplifiedLayoutExpr =
875       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
876   if (expr != simplifiedLayoutExpr)
877     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
878         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
879   return MemRefType::Builder(t).setAffineMaps({});
880 }
881 
882 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
883                                                 ArrayRef<AffineExpr> exprs,
884                                                 MLIRContext *context) {
885   assert(!sizes.empty() && !exprs.empty() &&
886          "expected non-empty sizes and exprs");
887 
888   // Size 0 corner case is useful for canonicalizations.
889   if (llvm::is_contained(sizes, 0))
890     return getAffineConstantExpr(0, context);
891 
892   auto maps = AffineMap::inferFromExprList(exprs);
893   assert(!maps.empty() && "Expected one non-empty map");
894   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
895 
896   AffineExpr expr;
897   bool dynamicPoisonBit = false;
898   int64_t runningSize = 1;
899   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
900     int64_t size = std::get<1>(en);
901     // Degenerate case, no size =-> no stride
902     if (size == 0)
903       continue;
904     AffineExpr dimExpr = std::get<0>(en);
905     AffineExpr stride = dynamicPoisonBit
906                             ? getAffineSymbolExpr(nSymbols++, context)
907                             : getAffineConstantExpr(runningSize, context);
908     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
909     if (size > 0) {
910       runningSize *= size;
911       assert(runningSize > 0 && "integer overflow in size computation");
912     } else {
913       dynamicPoisonBit = true;
914     }
915   }
916   return simplifyAffineExpr(expr, numDims, nSymbols);
917 }
918 
919 /// Return a version of `t` with a layout that has all dynamic offset and
920 /// strides. This is used to erase the static layout.
921 MemRefType mlir::eraseStridedLayout(MemRefType t) {
922   auto val = ShapedType::kDynamicStrideOrOffset;
923   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
924       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
925 }
926 
927 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
928                                                 MLIRContext *context) {
929   SmallVector<AffineExpr, 4> exprs;
930   exprs.reserve(sizes.size());
931   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
932     exprs.push_back(getAffineDimExpr(dim, context));
933   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
934 }
935 
936 /// Return true if the layout for `t` is compatible with strided semantics.
937 bool mlir::isStrided(MemRefType t) {
938   int64_t offset;
939   SmallVector<int64_t, 4> strides;
940   auto res = getStridesAndOffset(t, strides, offset);
941   return succeeded(res);
942 }
943 
944 /// Return the layout map in strided linear layout AffineMap form.
945 /// Return null if the layout is not compatible with a strided layout.
946 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
947   int64_t offset;
948   SmallVector<int64_t, 4> strides;
949   if (failed(getStridesAndOffset(t, strides, offset)))
950     return AffineMap();
951   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
952 }
953