1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "llvm/ADT/APFloat.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 //===----------------------------------------------------------------------===//
27 /// Tablegen Type Definitions
28 //===----------------------------------------------------------------------===//
29 
30 #define GET_TYPEDEF_CLASSES
31 #include "mlir/IR/BuiltinTypes.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // BuiltinDialect
35 //===----------------------------------------------------------------------===//
36 
37 void BuiltinDialect::registerTypes() {
38   addTypes<
39 #define GET_TYPEDEF_LIST
40 #include "mlir/IR/BuiltinTypes.cpp.inc"
41       >();
42 }
43 
44 //===----------------------------------------------------------------------===//
45 /// ComplexType
46 //===----------------------------------------------------------------------===//
47 
48 /// Verify the construction of an integer type.
49 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
50                                   Type elementType) {
51   if (!elementType.isIntOrFloat())
52     return emitError() << "invalid element type for complex";
53   return success();
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Integer Type
58 //===----------------------------------------------------------------------===//
59 
60 // static constexpr must have a definition (until in C++17 and inline variable).
61 constexpr unsigned IntegerType::kMaxWidth;
62 
63 /// Verify the construction of an integer type.
64 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
65                                   unsigned width,
66                                   SignednessSemantics signedness) {
67   if (width > IntegerType::kMaxWidth) {
68     return emitError() << "integer bitwidth is limited to "
69                        << IntegerType::kMaxWidth << " bits";
70   }
71   return success();
72 }
73 
74 unsigned IntegerType::getWidth() const { return getImpl()->width; }
75 
76 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
77   return getImpl()->signedness;
78 }
79 
80 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
81   if (!scale)
82     return IntegerType();
83   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Float Type
88 //===----------------------------------------------------------------------===//
89 
90 unsigned FloatType::getWidth() {
91   if (isa<Float16Type, BFloat16Type>())
92     return 16;
93   if (isa<Float32Type>())
94     return 32;
95   if (isa<Float64Type>())
96     return 64;
97   if (isa<Float80Type>())
98     return 80;
99   if (isa<Float128Type>())
100     return 128;
101   llvm_unreachable("unexpected float type");
102 }
103 
104 /// Returns the floating semantics for the given type.
105 const llvm::fltSemantics &FloatType::getFloatSemantics() {
106   if (isa<BFloat16Type>())
107     return APFloat::BFloat();
108   if (isa<Float16Type>())
109     return APFloat::IEEEhalf();
110   if (isa<Float32Type>())
111     return APFloat::IEEEsingle();
112   if (isa<Float64Type>())
113     return APFloat::IEEEdouble();
114   if (isa<Float80Type>())
115     return APFloat::x87DoubleExtended();
116   if (isa<Float128Type>())
117     return APFloat::IEEEquad();
118   llvm_unreachable("non-floating point type used");
119 }
120 
121 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
122   if (!scale)
123     return FloatType();
124   MLIRContext *ctx = getContext();
125   if (isF16() || isBF16()) {
126     if (scale == 2)
127       return FloatType::getF32(ctx);
128     if (scale == 4)
129       return FloatType::getF64(ctx);
130   }
131   if (isF32())
132     if (scale == 2)
133       return FloatType::getF64(ctx);
134   return FloatType();
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // FunctionType
139 //===----------------------------------------------------------------------===//
140 
141 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
142 
143 ArrayRef<Type> FunctionType::getInputs() const {
144   return getImpl()->getInputs();
145 }
146 
147 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
148 
149 ArrayRef<Type> FunctionType::getResults() const {
150   return getImpl()->getResults();
151 }
152 
153 /// Helper to call a callback once on each index in the range
154 /// [0, `totalIndices`), *except* for the indices given in `indices`.
155 /// `indices` is allowed to have duplicates and can be in any order.
156 inline void iterateIndicesExcept(unsigned totalIndices,
157                                  ArrayRef<unsigned> indices,
158                                  function_ref<void(unsigned)> callback) {
159   llvm::BitVector skipIndices(totalIndices);
160   for (unsigned i : indices)
161     skipIndices.set(i);
162 
163   for (unsigned i = 0; i < totalIndices; ++i)
164     if (!skipIndices.test(i))
165       callback(i);
166 }
167 
168 /// Returns a new function type without the specified arguments and results.
169 FunctionType
170 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
171                                        ArrayRef<unsigned> resultIndices) {
172   ArrayRef<Type> newInputTypes = getInputs();
173   SmallVector<Type, 4> newInputTypesBuffer;
174   if (!argIndices.empty()) {
175     unsigned originalNumArgs = getNumInputs();
176     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
177       newInputTypesBuffer.emplace_back(getInput(i));
178     });
179     newInputTypes = newInputTypesBuffer;
180   }
181 
182   ArrayRef<Type> newResultTypes = getResults();
183   SmallVector<Type, 4> newResultTypesBuffer;
184   if (!resultIndices.empty()) {
185     unsigned originalNumResults = getNumResults();
186     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
187       newResultTypesBuffer.emplace_back(getResult(i));
188     });
189     newResultTypes = newResultTypesBuffer;
190   }
191 
192   return get(getContext(), newInputTypes, newResultTypes);
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // OpaqueType
197 //===----------------------------------------------------------------------===//
198 
199 /// Verify the construction of an opaque type.
200 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
201                                  Identifier dialect, StringRef typeData) {
202   if (!Dialect::isValidNamespace(dialect.strref()))
203     return emitError() << "invalid dialect namespace '" << dialect << "'";
204 
205   // Check that the dialect is actually registered.
206   MLIRContext *context = dialect.getContext();
207   if (!context->allowsUnregisteredDialects() &&
208       !context->getLoadedDialect(dialect.strref())) {
209     return emitError()
210            << "`!" << dialect << "<\"" << typeData << "\">"
211            << "` type created with unregistered dialect. If this is "
212               "intended, please call allowUnregisteredDialects() on the "
213               "MLIRContext, or use -allow-unregistered-dialect with "
214               "mlir-opt";
215   }
216 
217   return success();
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // ShapedType
222 //===----------------------------------------------------------------------===//
223 constexpr int64_t ShapedType::kDynamicSize;
224 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
225 
226 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
227   if (auto other = dyn_cast<MemRefType>()) {
228     MemRefType::Builder b(other);
229     b.setShape(shape);
230     b.setElementType(elementType);
231     return b;
232   }
233 
234   if (auto other = dyn_cast<UnrankedMemRefType>()) {
235     MemRefType::Builder b(shape, elementType);
236     b.setMemorySpace(other.getMemorySpace());
237     return b;
238   }
239 
240   if (isa<TensorType>())
241     return RankedTensorType::get(shape, elementType);
242 
243   if (isa<VectorType>())
244     return VectorType::get(shape, elementType);
245 
246   llvm_unreachable("Unhandled ShapedType clone case");
247 }
248 
249 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
250   if (auto other = dyn_cast<MemRefType>()) {
251     MemRefType::Builder b(other);
252     b.setShape(shape);
253     return b;
254   }
255 
256   if (auto other = dyn_cast<UnrankedMemRefType>()) {
257     MemRefType::Builder b(shape, other.getElementType());
258     b.setShape(shape);
259     b.setMemorySpace(other.getMemorySpace());
260     return b;
261   }
262 
263   if (isa<TensorType>())
264     return RankedTensorType::get(shape, getElementType());
265 
266   if (isa<VectorType>())
267     return VectorType::get(shape, getElementType());
268 
269   llvm_unreachable("Unhandled ShapedType clone case");
270 }
271 
272 ShapedType ShapedType::clone(Type elementType) {
273   if (auto other = dyn_cast<MemRefType>()) {
274     MemRefType::Builder b(other);
275     b.setElementType(elementType);
276     return b;
277   }
278 
279   if (auto other = dyn_cast<UnrankedMemRefType>()) {
280     return UnrankedMemRefType::get(elementType, other.getMemorySpace());
281   }
282 
283   if (isa<TensorType>()) {
284     if (hasRank())
285       return RankedTensorType::get(getShape(), elementType);
286     return UnrankedTensorType::get(elementType);
287   }
288 
289   if (isa<VectorType>())
290     return VectorType::get(getShape(), elementType);
291 
292   llvm_unreachable("Unhandled ShapedType clone hit");
293 }
294 
295 Type ShapedType::getElementType() const {
296   return TypeSwitch<Type, Type>(*this)
297       .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
298             UnrankedMemRefType>([](auto ty) { return ty.getElementType(); });
299 }
300 
301 unsigned ShapedType::getElementTypeBitWidth() const {
302   return getElementType().getIntOrFloatBitWidth();
303 }
304 
305 int64_t ShapedType::getNumElements() const {
306   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
307   auto shape = getShape();
308   int64_t num = 1;
309   for (auto dim : shape) {
310     num *= dim;
311     assert(num >= 0 && "integer overflow in element count computation");
312   }
313   return num;
314 }
315 
316 int64_t ShapedType::getRank() const {
317   assert(hasRank() && "cannot query rank of unranked shaped type");
318   return getShape().size();
319 }
320 
321 bool ShapedType::hasRank() const {
322   return !isa<UnrankedMemRefType, UnrankedTensorType>();
323 }
324 
325 int64_t ShapedType::getDimSize(unsigned idx) const {
326   assert(idx < getRank() && "invalid index for shaped type");
327   return getShape()[idx];
328 }
329 
330 bool ShapedType::isDynamicDim(unsigned idx) const {
331   assert(idx < getRank() && "invalid index for shaped type");
332   return isDynamic(getShape()[idx]);
333 }
334 
335 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
336   assert(index < getRank() && "invalid index");
337   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
338   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
339 }
340 
341 /// Get the number of bits require to store a value of the given shaped type.
342 /// Compute the value recursively since tensors are allowed to have vectors as
343 /// elements.
344 int64_t ShapedType::getSizeInBits() const {
345   assert(hasStaticShape() &&
346          "cannot get the bit size of an aggregate with a dynamic shape");
347 
348   auto elementType = getElementType();
349   if (elementType.isIntOrFloat())
350     return elementType.getIntOrFloatBitWidth() * getNumElements();
351 
352   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
353     elementType = complexType.getElementType();
354     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
355   }
356 
357   // Tensors can have vectors and other tensors as elements, other shaped types
358   // cannot.
359   assert(isa<TensorType>() && "unsupported element type");
360   assert((elementType.isa<VectorType, TensorType>()) &&
361          "unsupported tensor element type");
362   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
363 }
364 
365 ArrayRef<int64_t> ShapedType::getShape() const {
366   if (auto vectorType = dyn_cast<VectorType>())
367     return vectorType.getShape();
368   if (auto tensorType = dyn_cast<RankedTensorType>())
369     return tensorType.getShape();
370   return cast<MemRefType>().getShape();
371 }
372 
373 int64_t ShapedType::getNumDynamicDims() const {
374   return llvm::count_if(getShape(), isDynamic);
375 }
376 
377 bool ShapedType::hasStaticShape() const {
378   return hasRank() && llvm::none_of(getShape(), isDynamic);
379 }
380 
381 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
382   return hasStaticShape() && getShape() == shape;
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // VectorType
387 //===----------------------------------------------------------------------===//
388 
389 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
390                                  ArrayRef<int64_t> shape, Type elementType) {
391   if (shape.empty())
392     return emitError() << "vector types must have at least one dimension";
393 
394   if (!isValidElementType(elementType))
395     return emitError() << "vector elements must be int/index/float type";
396 
397   if (any_of(shape, [](int64_t i) { return i <= 0; }))
398     return emitError() << "vector types must have positive constant sizes";
399 
400   return success();
401 }
402 
403 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
404   if (!scale)
405     return VectorType();
406   if (auto et = getElementType().dyn_cast<IntegerType>())
407     if (auto scaledEt = et.scaleElementBitwidth(scale))
408       return VectorType::get(getShape(), scaledEt);
409   if (auto et = getElementType().dyn_cast<FloatType>())
410     if (auto scaledEt = et.scaleElementBitwidth(scale))
411       return VectorType::get(getShape(), scaledEt);
412   return VectorType();
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // TensorType
417 //===----------------------------------------------------------------------===//
418 
419 // Check if "elementType" can be an element type of a tensor.
420 static LogicalResult
421 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
422                        Type elementType) {
423   if (!TensorType::isValidElementType(elementType))
424     return emitError() << "invalid tensor element type: " << elementType;
425   return success();
426 }
427 
428 /// Return true if the specified element type is ok in a tensor.
429 bool TensorType::isValidElementType(Type type) {
430   // Note: Non standard/builtin types are allowed to exist within tensor
431   // types. Dialects are expected to verify that tensor types have a valid
432   // element type within that dialect.
433   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
434                   IndexType>() ||
435          !type.getDialect().getNamespace().empty();
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // RankedTensorType
440 //===----------------------------------------------------------------------===//
441 
442 LogicalResult
443 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
444                          ArrayRef<int64_t> shape, Type elementType,
445                          Attribute encoding) {
446   for (int64_t s : shape)
447     if (s < -1)
448       return emitError() << "invalid tensor dimension size";
449   // TODO: verify contents of encoding attribute.
450   return checkTensorElementType(emitError, elementType);
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // UnrankedTensorType
455 //===----------------------------------------------------------------------===//
456 
457 LogicalResult
458 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
459                            Type elementType) {
460   return checkTensorElementType(emitError, elementType);
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // BaseMemRefType
465 //===----------------------------------------------------------------------===//
466 
467 Attribute BaseMemRefType::getMemorySpace() const {
468   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
469     return rankedMemRefTy.getMemorySpace();
470   return cast<UnrankedMemRefType>().getMemorySpace();
471 }
472 
473 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
474   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
475     return rankedMemRefTy.getMemorySpaceAsInt();
476   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // MemRefType
481 //===----------------------------------------------------------------------===//
482 
483 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
484 /// `originalShape` with some `1` entries erased, return the set of indices
485 /// that specifies which of the entries of `originalShape` are dropped to obtain
486 /// `reducedShape`. The returned mask can be applied as a projection to
487 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
488 /// which dimensions must be kept when e.g. compute MemRef strides under
489 /// rank-reducing operations. Return None if reducedShape cannot be obtained
490 /// by dropping only `1` entries in `originalShape`.
491 llvm::Optional<llvm::SmallDenseSet<unsigned>>
492 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
493                                ArrayRef<int64_t> reducedShape) {
494   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
495   llvm::SmallDenseSet<unsigned> unusedDims;
496   unsigned reducedIdx = 0;
497   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
498     // Greedily insert `originalIdx` if no match.
499     if (reducedIdx < reducedRank &&
500         originalShape[originalIdx] == reducedShape[reducedIdx]) {
501       reducedIdx++;
502       continue;
503     }
504 
505     unusedDims.insert(originalIdx);
506     // If no match on `originalIdx`, the `originalShape` at this dimension
507     // must be 1, otherwise we bail.
508     if (originalShape[originalIdx] != 1)
509       return llvm::None;
510   }
511   // The whole reducedShape must be scanned, otherwise we bail.
512   if (reducedIdx != reducedRank)
513     return llvm::None;
514   return unusedDims;
515 }
516 
517 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
518   // Empty attribute is allowed as default memory space.
519   if (!memorySpace)
520     return true;
521 
522   // Supported built-in attributes.
523   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
524     return true;
525 
526   // Allow custom dialect attributes.
527   if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect()))
528     return true;
529 
530   return false;
531 }
532 
533 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
534                                                MLIRContext *ctx) {
535   if (memorySpace == 0)
536     return nullptr;
537 
538   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
539 }
540 
541 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
542   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
543   if (intMemorySpace && intMemorySpace.getValue() == 0)
544     return nullptr;
545 
546   return memorySpace;
547 }
548 
549 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
550   if (!memorySpace)
551     return 0;
552 
553   assert(memorySpace.isa<IntegerAttr>() &&
554          "Using `getMemorySpaceInteger` with non-Integer attribute");
555 
556   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
557 }
558 
559 MemRefType::Builder &
560 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
561   memorySpace =
562       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
563   return *this;
564 }
565 
566 unsigned MemRefType::getMemorySpaceAsInt() const {
567   return detail::getMemorySpaceAsInt(getMemorySpace());
568 }
569 
570 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
571                                  ArrayRef<int64_t> shape, Type elementType,
572                                  ArrayRef<AffineMap> affineMapComposition,
573                                  Attribute memorySpace) {
574   if (!BaseMemRefType::isValidElementType(elementType))
575     return emitError() << "invalid memref element type";
576 
577   // Negative sizes are not allowed except for `-1` that means dynamic size.
578   for (int64_t s : shape)
579     if (s < -1)
580       return emitError() << "invalid memref size";
581 
582   // Check that the structure of the composition is valid, i.e. that each
583   // subsequent affine map has as many inputs as the previous map has results.
584   // Take the dimensionality of the MemRef for the first map.
585   size_t dim = shape.size();
586   for (auto it : llvm::enumerate(affineMapComposition)) {
587     AffineMap map = it.value();
588     if (map.getNumDims() == dim) {
589       dim = map.getNumResults();
590       continue;
591     }
592     return emitError() << "memref affine map dimension mismatch between "
593                        << (it.index() == 0 ? Twine("memref rank")
594                                            : "affine map " + Twine(it.index()))
595                        << " and affine map" << it.index() + 1 << ": " << dim
596                        << " != " << map.getNumDims();
597   }
598 
599   if (!isSupportedMemorySpace(memorySpace)) {
600     return emitError() << "unsupported memory space Attribute";
601   }
602 
603   return success();
604 }
605 
606 //===----------------------------------------------------------------------===//
607 // UnrankedMemRefType
608 //===----------------------------------------------------------------------===//
609 
610 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
611   return detail::getMemorySpaceAsInt(getMemorySpace());
612 }
613 
614 LogicalResult
615 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
616                            Type elementType, Attribute memorySpace) {
617   if (!BaseMemRefType::isValidElementType(elementType))
618     return emitError() << "invalid memref element type";
619 
620   if (!isSupportedMemorySpace(memorySpace))
621     return emitError() << "unsupported memory space Attribute";
622 
623   return success();
624 }
625 
626 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
627 // i.e. single term). Accumulate the AffineExpr into the existing one.
628 static void extractStridesFromTerm(AffineExpr e,
629                                    AffineExpr multiplicativeFactor,
630                                    MutableArrayRef<AffineExpr> strides,
631                                    AffineExpr &offset) {
632   if (auto dim = e.dyn_cast<AffineDimExpr>())
633     strides[dim.getPosition()] =
634         strides[dim.getPosition()] + multiplicativeFactor;
635   else
636     offset = offset + e * multiplicativeFactor;
637 }
638 
639 /// Takes a single AffineExpr `e` and populates the `strides` array with the
640 /// strides expressions for each dim position.
641 /// The convention is that the strides for dimensions d0, .. dn appear in
642 /// order to make indexing intuitive into the result.
643 static LogicalResult extractStrides(AffineExpr e,
644                                     AffineExpr multiplicativeFactor,
645                                     MutableArrayRef<AffineExpr> strides,
646                                     AffineExpr &offset) {
647   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
648   if (!bin) {
649     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
650     return success();
651   }
652 
653   if (bin.getKind() == AffineExprKind::CeilDiv ||
654       bin.getKind() == AffineExprKind::FloorDiv ||
655       bin.getKind() == AffineExprKind::Mod)
656     return failure();
657 
658   if (bin.getKind() == AffineExprKind::Mul) {
659     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
660     if (dim) {
661       strides[dim.getPosition()] =
662           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
663       return success();
664     }
665     // LHS and RHS may both contain complex expressions of dims. Try one path
666     // and if it fails try the other. This is guaranteed to succeed because
667     // only one path may have a `dim`, otherwise this is not an AffineExpr in
668     // the first place.
669     if (bin.getLHS().isSymbolicOrConstant())
670       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
671                             strides, offset);
672     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
673                           strides, offset);
674   }
675 
676   if (bin.getKind() == AffineExprKind::Add) {
677     auto res1 =
678         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
679     auto res2 =
680         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
681     return success(succeeded(res1) && succeeded(res2));
682   }
683 
684   llvm_unreachable("unexpected binary operation");
685 }
686 
687 LogicalResult mlir::getStridesAndOffset(MemRefType t,
688                                         SmallVectorImpl<AffineExpr> &strides,
689                                         AffineExpr &offset) {
690   auto affineMaps = t.getAffineMaps();
691 
692   if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1)
693     return failure();
694 
695   AffineMap m;
696   if (!affineMaps.empty()) {
697     m = affineMaps.back();
698     for (size_t i = affineMaps.size() - 1; i > 0; --i)
699       m = m.compose(affineMaps[i - 1]);
700     assert(!m.isIdentity() && "unexpected identity map");
701   }
702 
703   auto zero = getAffineConstantExpr(0, t.getContext());
704   auto one = getAffineConstantExpr(1, t.getContext());
705   offset = zero;
706   strides.assign(t.getRank(), zero);
707 
708   // Canonical case for empty map.
709   if (!m) {
710     // 0-D corner case, offset is already 0.
711     if (t.getRank() == 0)
712       return success();
713     auto stridedExpr =
714         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
715     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
716       return success();
717     assert(false && "unexpected failure: extract strides in canonical layout");
718   }
719 
720   // Non-canonical case requires more work.
721   auto stridedExpr =
722       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
723   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
724     offset = AffineExpr();
725     strides.clear();
726     return failure();
727   }
728 
729   // Simplify results to allow folding to constants and simple checks.
730   unsigned numDims = m.getNumDims();
731   unsigned numSymbols = m.getNumSymbols();
732   offset = simplifyAffineExpr(offset, numDims, numSymbols);
733   for (auto &stride : strides)
734     stride = simplifyAffineExpr(stride, numDims, numSymbols);
735 
736   /// In practice, a strided memref must be internally non-aliasing. Test
737   /// against 0 as a proxy.
738   /// TODO: static cases can have more advanced checks.
739   /// TODO: dynamic cases would require a way to compare symbolic
740   /// expressions and would probably need an affine set context propagated
741   /// everywhere.
742   if (llvm::any_of(strides, [](AffineExpr e) {
743         return e == getAffineConstantExpr(0, e.getContext());
744       })) {
745     offset = AffineExpr();
746     strides.clear();
747     return failure();
748   }
749 
750   return success();
751 }
752 
753 LogicalResult mlir::getStridesAndOffset(MemRefType t,
754                                         SmallVectorImpl<int64_t> &strides,
755                                         int64_t &offset) {
756   AffineExpr offsetExpr;
757   SmallVector<AffineExpr, 4> strideExprs;
758   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
759     return failure();
760   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
761     offset = cst.getValue();
762   else
763     offset = ShapedType::kDynamicStrideOrOffset;
764   for (auto e : strideExprs) {
765     if (auto c = e.dyn_cast<AffineConstantExpr>())
766       strides.push_back(c.getValue());
767     else
768       strides.push_back(ShapedType::kDynamicStrideOrOffset);
769   }
770   return success();
771 }
772 
773 //===----------------------------------------------------------------------===//
774 /// TupleType
775 //===----------------------------------------------------------------------===//
776 
777 /// Return the elements types for this tuple.
778 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
779 
780 /// Accumulate the types contained in this tuple and tuples nested within it.
781 /// Note that this only flattens nested tuples, not any other container type,
782 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
783 /// (i32, tensor<i32>, f32, i64)
784 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
785   for (Type type : getTypes()) {
786     if (auto nestedTuple = type.dyn_cast<TupleType>())
787       nestedTuple.getFlattenedTypes(types);
788     else
789       types.push_back(type);
790   }
791 }
792 
793 /// Return the number of element types.
794 size_t TupleType::size() const { return getImpl()->size(); }
795 
796 //===----------------------------------------------------------------------===//
797 // Type Utilities
798 //===----------------------------------------------------------------------===//
799 
800 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
801                                            int64_t offset,
802                                            MLIRContext *context) {
803   AffineExpr expr;
804   unsigned nSymbols = 0;
805 
806   // AffineExpr for offset.
807   // Static case.
808   if (offset != MemRefType::getDynamicStrideOrOffset()) {
809     auto cst = getAffineConstantExpr(offset, context);
810     expr = cst;
811   } else {
812     // Dynamic case, new symbol for the offset.
813     auto sym = getAffineSymbolExpr(nSymbols++, context);
814     expr = sym;
815   }
816 
817   // AffineExpr for strides.
818   for (auto en : llvm::enumerate(strides)) {
819     auto dim = en.index();
820     auto stride = en.value();
821     assert(stride != 0 && "Invalid stride specification");
822     auto d = getAffineDimExpr(dim, context);
823     AffineExpr mult;
824     // Static case.
825     if (stride != MemRefType::getDynamicStrideOrOffset())
826       mult = getAffineConstantExpr(stride, context);
827     else
828       // Dynamic case, new symbol for each new stride.
829       mult = getAffineSymbolExpr(nSymbols++, context);
830     expr = expr + d * mult;
831   }
832 
833   return AffineMap::get(strides.size(), nSymbols, expr);
834 }
835 
836 /// Return a version of `t` with identity layout if it can be determined
837 /// statically that the layout is the canonical contiguous strided layout.
838 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
839 /// `t` with simplified layout.
840 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
841 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
842   auto affineMaps = t.getAffineMaps();
843   // Already in canonical form.
844   if (affineMaps.empty())
845     return t;
846 
847   // Can't reduce to canonical identity form, return in canonical form.
848   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
849     return t;
850 
851   // Corner-case for 0-D affine maps.
852   auto m = affineMaps[0];
853   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
854     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
855       if (cst.getValue() == 0)
856         return MemRefType::Builder(t).setAffineMaps({});
857     return t;
858   }
859 
860   // 0-D corner case for empty shape that still have an affine map. Example:
861   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
862   // offset needs to remain, just return t.
863   if (t.getShape().empty())
864     return t;
865 
866   // If the canonical strided layout for the sizes of `t` is equal to the
867   // simplified layout of `t` we can just return an empty layout. Otherwise,
868   // just simplify the existing layout.
869   AffineExpr expr =
870       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
871   auto simplifiedLayoutExpr =
872       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
873   if (expr != simplifiedLayoutExpr)
874     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
875         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
876   return MemRefType::Builder(t).setAffineMaps({});
877 }
878 
879 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
880                                                 ArrayRef<AffineExpr> exprs,
881                                                 MLIRContext *context) {
882   assert(!sizes.empty() && !exprs.empty() &&
883          "expected non-empty sizes and exprs");
884 
885   // Size 0 corner case is useful for canonicalizations.
886   if (llvm::is_contained(sizes, 0))
887     return getAffineConstantExpr(0, context);
888 
889   auto maps = AffineMap::inferFromExprList(exprs);
890   assert(!maps.empty() && "Expected one non-empty map");
891   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
892 
893   AffineExpr expr;
894   bool dynamicPoisonBit = false;
895   int64_t runningSize = 1;
896   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
897     int64_t size = std::get<1>(en);
898     // Degenerate case, no size =-> no stride
899     if (size == 0)
900       continue;
901     AffineExpr dimExpr = std::get<0>(en);
902     AffineExpr stride = dynamicPoisonBit
903                             ? getAffineSymbolExpr(nSymbols++, context)
904                             : getAffineConstantExpr(runningSize, context);
905     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
906     if (size > 0) {
907       runningSize *= size;
908       assert(runningSize > 0 && "integer overflow in size computation");
909     } else {
910       dynamicPoisonBit = true;
911     }
912   }
913   return simplifyAffineExpr(expr, numDims, nSymbols);
914 }
915 
916 /// Return a version of `t` with a layout that has all dynamic offset and
917 /// strides. This is used to erase the static layout.
918 MemRefType mlir::eraseStridedLayout(MemRefType t) {
919   auto val = ShapedType::kDynamicStrideOrOffset;
920   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
921       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
922 }
923 
924 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
925                                                 MLIRContext *context) {
926   SmallVector<AffineExpr, 4> exprs;
927   exprs.reserve(sizes.size());
928   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
929     exprs.push_back(getAffineDimExpr(dim, context));
930   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
931 }
932 
933 /// Return true if the layout for `t` is compatible with strided semantics.
934 bool mlir::isStrided(MemRefType t) {
935   int64_t offset;
936   SmallVector<int64_t, 4> strides;
937   auto res = getStridesAndOffset(t, strides, offset);
938   return succeeded(res);
939 }
940 
941 /// Return the layout map in strided linear layout AffineMap form.
942 /// Return null if the layout is not compatible with a strided layout.
943 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
944   int64_t offset;
945   SmallVector<int64_t, 4> strides;
946   if (failed(getStridesAndOffset(t, strides, offset)))
947     return AffineMap();
948   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
949 }
950