1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 //===----------------------------------------------------------------------===//
24 /// Tablegen Type Definitions
25 //===----------------------------------------------------------------------===//
26 
27 #define GET_TYPEDEF_CLASSES
28 #include "mlir/IR/BuiltinTypes.cpp.inc"
29 
30 //===----------------------------------------------------------------------===//
31 /// ComplexType
32 //===----------------------------------------------------------------------===//
33 
34 /// Verify the construction of an integer type.
35 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
36                                   Type elementType) {
37   if (!elementType.isIntOrFloat())
38     return emitError() << "invalid element type for complex";
39   return success();
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Integer Type
44 //===----------------------------------------------------------------------===//
45 
46 // static constexpr must have a definition (until in C++17 and inline variable).
47 constexpr unsigned IntegerType::kMaxWidth;
48 
49 /// Verify the construction of an integer type.
50 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
51                                   unsigned width,
52                                   SignednessSemantics signedness) {
53   if (width > IntegerType::kMaxWidth) {
54     return emitError() << "integer bitwidth is limited to "
55                        << IntegerType::kMaxWidth << " bits";
56   }
57   return success();
58 }
59 
60 unsigned IntegerType::getWidth() const { return getImpl()->width; }
61 
62 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
63   return getImpl()->signedness;
64 }
65 
66 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
67   if (!scale)
68     return IntegerType();
69   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // Float Type
74 //===----------------------------------------------------------------------===//
75 
76 unsigned FloatType::getWidth() {
77   if (isa<Float16Type, BFloat16Type>())
78     return 16;
79   if (isa<Float32Type>())
80     return 32;
81   if (isa<Float64Type>())
82     return 64;
83   if (isa<Float80Type>())
84     return 80;
85   if (isa<Float128Type>())
86     return 128;
87   llvm_unreachable("unexpected float type");
88 }
89 
90 /// Returns the floating semantics for the given type.
91 const llvm::fltSemantics &FloatType::getFloatSemantics() {
92   if (isa<BFloat16Type>())
93     return APFloat::BFloat();
94   if (isa<Float16Type>())
95     return APFloat::IEEEhalf();
96   if (isa<Float32Type>())
97     return APFloat::IEEEsingle();
98   if (isa<Float64Type>())
99     return APFloat::IEEEdouble();
100   if (isa<Float80Type>())
101     return APFloat::x87DoubleExtended();
102   if (isa<Float128Type>())
103     return APFloat::IEEEquad();
104   llvm_unreachable("non-floating point type used");
105 }
106 
107 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
108   if (!scale)
109     return FloatType();
110   MLIRContext *ctx = getContext();
111   if (isF16() || isBF16()) {
112     if (scale == 2)
113       return FloatType::getF32(ctx);
114     if (scale == 4)
115       return FloatType::getF64(ctx);
116   }
117   if (isF32())
118     if (scale == 2)
119       return FloatType::getF64(ctx);
120   return FloatType();
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // FunctionType
125 //===----------------------------------------------------------------------===//
126 
127 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
128 
129 ArrayRef<Type> FunctionType::getInputs() const {
130   return getImpl()->getInputs();
131 }
132 
133 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
134 
135 ArrayRef<Type> FunctionType::getResults() const {
136   return getImpl()->getResults();
137 }
138 
139 /// Helper to call a callback once on each index in the range
140 /// [0, `totalIndices`), *except* for the indices given in `indices`.
141 /// `indices` is allowed to have duplicates and can be in any order.
142 inline void iterateIndicesExcept(unsigned totalIndices,
143                                  ArrayRef<unsigned> indices,
144                                  function_ref<void(unsigned)> callback) {
145   llvm::BitVector skipIndices(totalIndices);
146   for (unsigned i : indices)
147     skipIndices.set(i);
148 
149   for (unsigned i = 0; i < totalIndices; ++i)
150     if (!skipIndices.test(i))
151       callback(i);
152 }
153 
154 /// Returns a new function type without the specified arguments and results.
155 FunctionType
156 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
157                                        ArrayRef<unsigned> resultIndices) {
158   ArrayRef<Type> newInputTypes = getInputs();
159   SmallVector<Type, 4> newInputTypesBuffer;
160   if (!argIndices.empty()) {
161     unsigned originalNumArgs = getNumInputs();
162     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
163       newInputTypesBuffer.emplace_back(getInput(i));
164     });
165     newInputTypes = newInputTypesBuffer;
166   }
167 
168   ArrayRef<Type> newResultTypes = getResults();
169   SmallVector<Type, 4> newResultTypesBuffer;
170   if (!resultIndices.empty()) {
171     unsigned originalNumResults = getNumResults();
172     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
173       newResultTypesBuffer.emplace_back(getResult(i));
174     });
175     newResultTypes = newResultTypesBuffer;
176   }
177 
178   return get(getContext(), newInputTypes, newResultTypes);
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // OpaqueType
183 //===----------------------------------------------------------------------===//
184 
185 /// Verify the construction of an opaque type.
186 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
187                                  Identifier dialect, StringRef typeData) {
188   if (!Dialect::isValidNamespace(dialect.strref()))
189     return emitError() << "invalid dialect namespace '" << dialect << "'";
190   return success();
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // ShapedType
195 //===----------------------------------------------------------------------===//
196 constexpr int64_t ShapedType::kDynamicSize;
197 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
198 
199 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
200   if (auto other = dyn_cast<MemRefType>()) {
201     MemRefType::Builder b(other);
202     b.setShape(shape);
203     b.setElementType(elementType);
204     return b;
205   }
206 
207   if (auto other = dyn_cast<UnrankedMemRefType>()) {
208     MemRefType::Builder b(shape, elementType);
209     b.setMemorySpace(other.getMemorySpace());
210     return b;
211   }
212 
213   if (isa<TensorType>())
214     return RankedTensorType::get(shape, elementType);
215 
216   if (isa<VectorType>())
217     return VectorType::get(shape, elementType);
218 
219   llvm_unreachable("Unhandled ShapedType clone case");
220 }
221 
222 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
223   if (auto other = dyn_cast<MemRefType>()) {
224     MemRefType::Builder b(other);
225     b.setShape(shape);
226     return b;
227   }
228 
229   if (auto other = dyn_cast<UnrankedMemRefType>()) {
230     MemRefType::Builder b(shape, other.getElementType());
231     b.setShape(shape);
232     b.setMemorySpace(other.getMemorySpace());
233     return b;
234   }
235 
236   if (isa<TensorType>())
237     return RankedTensorType::get(shape, getElementType());
238 
239   if (isa<VectorType>())
240     return VectorType::get(shape, getElementType());
241 
242   llvm_unreachable("Unhandled ShapedType clone case");
243 }
244 
245 ShapedType ShapedType::clone(Type elementType) {
246   if (auto other = dyn_cast<MemRefType>()) {
247     MemRefType::Builder b(other);
248     b.setElementType(elementType);
249     return b;
250   }
251 
252   if (auto other = dyn_cast<UnrankedMemRefType>()) {
253     return UnrankedMemRefType::get(elementType, other.getMemorySpace());
254   }
255 
256   if (isa<TensorType>()) {
257     if (hasRank())
258       return RankedTensorType::get(getShape(), elementType);
259     return UnrankedTensorType::get(elementType);
260   }
261 
262   if (isa<VectorType>())
263     return VectorType::get(getShape(), elementType);
264 
265   llvm_unreachable("Unhandled ShapedType clone hit");
266 }
267 
268 Type ShapedType::getElementType() const {
269   return static_cast<ImplType *>(impl)->elementType;
270 }
271 
272 unsigned ShapedType::getElementTypeBitWidth() const {
273   return getElementType().getIntOrFloatBitWidth();
274 }
275 
276 int64_t ShapedType::getNumElements() const {
277   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
278   auto shape = getShape();
279   int64_t num = 1;
280   for (auto dim : shape) {
281     num *= dim;
282     assert(num >= 0 && "integer overflow in element count computation");
283   }
284   return num;
285 }
286 
287 int64_t ShapedType::getRank() const {
288   assert(hasRank() && "cannot query rank of unranked shaped type");
289   return getShape().size();
290 }
291 
292 bool ShapedType::hasRank() const {
293   return !isa<UnrankedMemRefType, UnrankedTensorType>();
294 }
295 
296 int64_t ShapedType::getDimSize(unsigned idx) const {
297   assert(idx < getRank() && "invalid index for shaped type");
298   return getShape()[idx];
299 }
300 
301 bool ShapedType::isDynamicDim(unsigned idx) const {
302   assert(idx < getRank() && "invalid index for shaped type");
303   return isDynamic(getShape()[idx]);
304 }
305 
306 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
307   assert(index < getRank() && "invalid index");
308   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
309   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
310 }
311 
312 /// Get the number of bits require to store a value of the given shaped type.
313 /// Compute the value recursively since tensors are allowed to have vectors as
314 /// elements.
315 int64_t ShapedType::getSizeInBits() const {
316   assert(hasStaticShape() &&
317          "cannot get the bit size of an aggregate with a dynamic shape");
318 
319   auto elementType = getElementType();
320   if (elementType.isIntOrFloat())
321     return elementType.getIntOrFloatBitWidth() * getNumElements();
322 
323   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
324     elementType = complexType.getElementType();
325     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
326   }
327 
328   // Tensors can have vectors and other tensors as elements, other shaped types
329   // cannot.
330   assert(isa<TensorType>() && "unsupported element type");
331   assert((elementType.isa<VectorType, TensorType>()) &&
332          "unsupported tensor element type");
333   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
334 }
335 
336 ArrayRef<int64_t> ShapedType::getShape() const {
337   if (auto vectorType = dyn_cast<VectorType>())
338     return vectorType.getShape();
339   if (auto tensorType = dyn_cast<RankedTensorType>())
340     return tensorType.getShape();
341   return cast<MemRefType>().getShape();
342 }
343 
344 int64_t ShapedType::getNumDynamicDims() const {
345   return llvm::count_if(getShape(), isDynamic);
346 }
347 
348 bool ShapedType::hasStaticShape() const {
349   return hasRank() && llvm::none_of(getShape(), isDynamic);
350 }
351 
352 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
353   return hasStaticShape() && getShape() == shape;
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // VectorType
358 //===----------------------------------------------------------------------===//
359 
360 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
361   return Base::get(elementType.getContext(), shape, elementType);
362 }
363 
364 VectorType VectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
365                                   ArrayRef<int64_t> shape, Type elementType) {
366   return Base::getChecked(emitError, elementType.getContext(), shape,
367                           elementType);
368 }
369 
370 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
371                                  ArrayRef<int64_t> shape, Type elementType) {
372   if (shape.empty())
373     return emitError() << "vector types must have at least one dimension";
374 
375   if (!isValidElementType(elementType))
376     return emitError() << "vector elements must be int or float type";
377 
378   if (any_of(shape, [](int64_t i) { return i <= 0; }))
379     return emitError() << "vector types must have positive constant sizes";
380 
381   return success();
382 }
383 
384 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
385 
386 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
387   if (!scale)
388     return VectorType();
389   if (auto et = getElementType().dyn_cast<IntegerType>())
390     if (auto scaledEt = et.scaleElementBitwidth(scale))
391       return VectorType::get(getShape(), scaledEt);
392   if (auto et = getElementType().dyn_cast<FloatType>())
393     if (auto scaledEt = et.scaleElementBitwidth(scale))
394       return VectorType::get(getShape(), scaledEt);
395   return VectorType();
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // TensorType
400 //===----------------------------------------------------------------------===//
401 
402 // Check if "elementType" can be an element type of a tensor.
403 static LogicalResult
404 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
405                        Type elementType) {
406   if (!TensorType::isValidElementType(elementType))
407     return emitError() << "invalid tensor element type: " << elementType;
408   return success();
409 }
410 
411 /// Return true if the specified element type is ok in a tensor.
412 bool TensorType::isValidElementType(Type type) {
413   // Note: Non standard/builtin types are allowed to exist within tensor
414   // types. Dialects are expected to verify that tensor types have a valid
415   // element type within that dialect.
416   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
417                   IndexType>() ||
418          !type.getDialect().getNamespace().empty();
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // RankedTensorType
423 //===----------------------------------------------------------------------===//
424 
425 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
426                                        Type elementType) {
427   return Base::get(elementType.getContext(), shape, elementType);
428 }
429 
430 RankedTensorType
431 RankedTensorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
432                              ArrayRef<int64_t> shape, Type elementType) {
433   return Base::getChecked(emitError, elementType.getContext(), shape,
434                           elementType);
435 }
436 
437 LogicalResult
438 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
439                          ArrayRef<int64_t> shape, Type elementType) {
440   for (int64_t s : shape) {
441     if (s < -1)
442       return emitError() << "invalid tensor dimension size";
443   }
444   return checkTensorElementType(emitError, elementType);
445 }
446 
447 ArrayRef<int64_t> RankedTensorType::getShape() const {
448   return getImpl()->getShape();
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // UnrankedTensorType
453 //===----------------------------------------------------------------------===//
454 
455 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
456   return Base::get(elementType.getContext(), elementType);
457 }
458 
459 UnrankedTensorType
460 UnrankedTensorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
461                                Type elementType) {
462   return Base::getChecked(emitError, elementType.getContext(), elementType);
463 }
464 
465 LogicalResult
466 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
467                            Type elementType) {
468   return checkTensorElementType(emitError, elementType);
469 }
470 
471 //===----------------------------------------------------------------------===//
472 // BaseMemRefType
473 //===----------------------------------------------------------------------===//
474 
475 unsigned BaseMemRefType::getMemorySpace() const {
476   return static_cast<ImplType *>(impl)->memorySpace;
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // MemRefType
481 //===----------------------------------------------------------------------===//
482 
483 /// Get or create a new MemRefType based on shape, element type, affine
484 /// map composition, and memory space.  Assumes the arguments define a
485 /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
486 /// construction failures.
487 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
488                            ArrayRef<AffineMap> affineMapComposition,
489                            unsigned memorySpace) {
490   auto result =
491       getImpl(shape, elementType, affineMapComposition, memorySpace, [=] {
492         return emitError(UnknownLoc::get(elementType.getContext()));
493       });
494   assert(result && "Failed to construct instance of MemRefType.");
495   return result;
496 }
497 
498 /// Get or create a new MemRefType based on shape, element type, affine
499 /// map composition, and memory space declared at the given location.
500 /// If the location is unknown, the last argument should be an instance of
501 /// UnknownLoc.  If the MemRefType defined by the arguments would be
502 /// ill-formed, emits errors (to the handler registered with the context or to
503 /// the error stream) and returns nullptr.
504 MemRefType MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitError,
505                                   ArrayRef<int64_t> shape, Type elementType,
506                                   ArrayRef<AffineMap> affineMapComposition,
507                                   unsigned memorySpace) {
508   return getImpl(shape, elementType, affineMapComposition, memorySpace,
509                  emitError);
510 }
511 
512 /// Get or create a new MemRefType defined by the arguments.  If the resulting
513 /// type would be ill-formed, return nullptr.  If the location is provided,
514 /// emit detailed error messages.  To emit errors when the location is unknown,
515 /// pass in an instance of UnknownLoc.
516 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
517                                ArrayRef<AffineMap> affineMapComposition,
518                                unsigned memorySpace,
519                                function_ref<InFlightDiagnostic()> emitError) {
520   auto *context = elementType.getContext();
521 
522   if (!BaseMemRefType::isValidElementType(elementType))
523     return (emitError() << "invalid memref element type", MemRefType());
524 
525   for (int64_t s : shape) {
526     // Negative sizes are not allowed except for `-1` that means dynamic size.
527     if (s < -1)
528       return (emitError() << "invalid memref size", MemRefType());
529   }
530 
531   // Check that the structure of the composition is valid, i.e. that each
532   // subsequent affine map has as many inputs as the previous map has results.
533   // Take the dimensionality of the MemRef for the first map.
534   auto dim = shape.size();
535   unsigned i = 0;
536   for (const auto &affineMap : affineMapComposition) {
537     if (affineMap.getNumDims() != dim) {
538       emitError() << "memref affine map dimension mismatch between "
539                   << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
540                   << " and affine map" << i + 1 << ": " << dim
541                   << " != " << affineMap.getNumDims();
542       return nullptr;
543     }
544 
545     dim = affineMap.getNumResults();
546     ++i;
547   }
548 
549   // Drop identity maps from the composition.
550   // This may lead to the composition becoming empty, which is interpreted as an
551   // implicit identity.
552   SmallVector<AffineMap, 2> cleanedAffineMapComposition;
553   for (const auto &map : affineMapComposition) {
554     if (map.isIdentity())
555       continue;
556     cleanedAffineMapComposition.push_back(map);
557   }
558 
559   return Base::get(context, shape, elementType, cleanedAffineMapComposition,
560                    memorySpace);
561 }
562 
563 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
564 
565 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
566   return getImpl()->getAffineMaps();
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // UnrankedMemRefType
571 //===----------------------------------------------------------------------===//
572 
573 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
574                                            unsigned memorySpace) {
575   return Base::get(elementType.getContext(), elementType, memorySpace);
576 }
577 
578 UnrankedMemRefType
579 UnrankedMemRefType::getChecked(function_ref<InFlightDiagnostic()> emitError,
580                                Type elementType, unsigned memorySpace) {
581   return Base::getChecked(emitError, elementType.getContext(), elementType,
582                           memorySpace);
583 }
584 
585 LogicalResult
586 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
587                            Type elementType, unsigned memorySpace) {
588   if (!BaseMemRefType::isValidElementType(elementType))
589     return emitError() << "invalid memref element type";
590   return success();
591 }
592 
593 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
594 // i.e. single term). Accumulate the AffineExpr into the existing one.
595 static void extractStridesFromTerm(AffineExpr e,
596                                    AffineExpr multiplicativeFactor,
597                                    MutableArrayRef<AffineExpr> strides,
598                                    AffineExpr &offset) {
599   if (auto dim = e.dyn_cast<AffineDimExpr>())
600     strides[dim.getPosition()] =
601         strides[dim.getPosition()] + multiplicativeFactor;
602   else
603     offset = offset + e * multiplicativeFactor;
604 }
605 
606 /// Takes a single AffineExpr `e` and populates the `strides` array with the
607 /// strides expressions for each dim position.
608 /// The convention is that the strides for dimensions d0, .. dn appear in
609 /// order to make indexing intuitive into the result.
610 static LogicalResult extractStrides(AffineExpr e,
611                                     AffineExpr multiplicativeFactor,
612                                     MutableArrayRef<AffineExpr> strides,
613                                     AffineExpr &offset) {
614   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
615   if (!bin) {
616     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
617     return success();
618   }
619 
620   if (bin.getKind() == AffineExprKind::CeilDiv ||
621       bin.getKind() == AffineExprKind::FloorDiv ||
622       bin.getKind() == AffineExprKind::Mod)
623     return failure();
624 
625   if (bin.getKind() == AffineExprKind::Mul) {
626     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
627     if (dim) {
628       strides[dim.getPosition()] =
629           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
630       return success();
631     }
632     // LHS and RHS may both contain complex expressions of dims. Try one path
633     // and if it fails try the other. This is guaranteed to succeed because
634     // only one path may have a `dim`, otherwise this is not an AffineExpr in
635     // the first place.
636     if (bin.getLHS().isSymbolicOrConstant())
637       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
638                             strides, offset);
639     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
640                           strides, offset);
641   }
642 
643   if (bin.getKind() == AffineExprKind::Add) {
644     auto res1 =
645         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
646     auto res2 =
647         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
648     return success(succeeded(res1) && succeeded(res2));
649   }
650 
651   llvm_unreachable("unexpected binary operation");
652 }
653 
654 LogicalResult mlir::getStridesAndOffset(MemRefType t,
655                                         SmallVectorImpl<AffineExpr> &strides,
656                                         AffineExpr &offset) {
657   auto affineMaps = t.getAffineMaps();
658   // For now strides are only computed on a single affine map with a single
659   // result (i.e. the closed subset of linearization maps that are compatible
660   // with striding semantics).
661   // TODO: support more forms on a per-need basis.
662   if (affineMaps.size() > 1)
663     return failure();
664   if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
665     return failure();
666 
667   auto zero = getAffineConstantExpr(0, t.getContext());
668   auto one = getAffineConstantExpr(1, t.getContext());
669   offset = zero;
670   strides.assign(t.getRank(), zero);
671 
672   AffineMap m;
673   if (!affineMaps.empty()) {
674     m = affineMaps.front();
675     assert(!m.isIdentity() && "unexpected identity map");
676   }
677 
678   // Canonical case for empty map.
679   if (!m) {
680     // 0-D corner case, offset is already 0.
681     if (t.getRank() == 0)
682       return success();
683     auto stridedExpr =
684         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
685     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
686       return success();
687     assert(false && "unexpected failure: extract strides in canonical layout");
688   }
689 
690   // Non-canonical case requires more work.
691   auto stridedExpr =
692       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
693   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
694     offset = AffineExpr();
695     strides.clear();
696     return failure();
697   }
698 
699   // Simplify results to allow folding to constants and simple checks.
700   unsigned numDims = m.getNumDims();
701   unsigned numSymbols = m.getNumSymbols();
702   offset = simplifyAffineExpr(offset, numDims, numSymbols);
703   for (auto &stride : strides)
704     stride = simplifyAffineExpr(stride, numDims, numSymbols);
705 
706   /// In practice, a strided memref must be internally non-aliasing. Test
707   /// against 0 as a proxy.
708   /// TODO: static cases can have more advanced checks.
709   /// TODO: dynamic cases would require a way to compare symbolic
710   /// expressions and would probably need an affine set context propagated
711   /// everywhere.
712   if (llvm::any_of(strides, [](AffineExpr e) {
713         return e == getAffineConstantExpr(0, e.getContext());
714       })) {
715     offset = AffineExpr();
716     strides.clear();
717     return failure();
718   }
719 
720   return success();
721 }
722 
723 LogicalResult mlir::getStridesAndOffset(MemRefType t,
724                                         SmallVectorImpl<int64_t> &strides,
725                                         int64_t &offset) {
726   AffineExpr offsetExpr;
727   SmallVector<AffineExpr, 4> strideExprs;
728   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
729     return failure();
730   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
731     offset = cst.getValue();
732   else
733     offset = ShapedType::kDynamicStrideOrOffset;
734   for (auto e : strideExprs) {
735     if (auto c = e.dyn_cast<AffineConstantExpr>())
736       strides.push_back(c.getValue());
737     else
738       strides.push_back(ShapedType::kDynamicStrideOrOffset);
739   }
740   return success();
741 }
742 
743 //===----------------------------------------------------------------------===//
744 /// TupleType
745 //===----------------------------------------------------------------------===//
746 
747 /// Return the elements types for this tuple.
748 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
749 
750 /// Accumulate the types contained in this tuple and tuples nested within it.
751 /// Note that this only flattens nested tuples, not any other container type,
752 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
753 /// (i32, tensor<i32>, f32, i64)
754 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
755   for (Type type : getTypes()) {
756     if (auto nestedTuple = type.dyn_cast<TupleType>())
757       nestedTuple.getFlattenedTypes(types);
758     else
759       types.push_back(type);
760   }
761 }
762 
763 /// Return the number of element types.
764 size_t TupleType::size() const { return getImpl()->size(); }
765 
766 //===----------------------------------------------------------------------===//
767 // Type Utilities
768 //===----------------------------------------------------------------------===//
769 
770 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
771                                            int64_t offset,
772                                            MLIRContext *context) {
773   AffineExpr expr;
774   unsigned nSymbols = 0;
775 
776   // AffineExpr for offset.
777   // Static case.
778   if (offset != MemRefType::getDynamicStrideOrOffset()) {
779     auto cst = getAffineConstantExpr(offset, context);
780     expr = cst;
781   } else {
782     // Dynamic case, new symbol for the offset.
783     auto sym = getAffineSymbolExpr(nSymbols++, context);
784     expr = sym;
785   }
786 
787   // AffineExpr for strides.
788   for (auto en : llvm::enumerate(strides)) {
789     auto dim = en.index();
790     auto stride = en.value();
791     assert(stride != 0 && "Invalid stride specification");
792     auto d = getAffineDimExpr(dim, context);
793     AffineExpr mult;
794     // Static case.
795     if (stride != MemRefType::getDynamicStrideOrOffset())
796       mult = getAffineConstantExpr(stride, context);
797     else
798       // Dynamic case, new symbol for each new stride.
799       mult = getAffineSymbolExpr(nSymbols++, context);
800     expr = expr + d * mult;
801   }
802 
803   return AffineMap::get(strides.size(), nSymbols, expr);
804 }
805 
806 /// Return a version of `t` with identity layout if it can be determined
807 /// statically that the layout is the canonical contiguous strided layout.
808 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
809 /// `t` with simplified layout.
810 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
811 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
812   auto affineMaps = t.getAffineMaps();
813   // Already in canonical form.
814   if (affineMaps.empty())
815     return t;
816 
817   // Can't reduce to canonical identity form, return in canonical form.
818   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
819     return t;
820 
821   // Corner-case for 0-D affine maps.
822   auto m = affineMaps[0];
823   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
824     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
825       if (cst.getValue() == 0)
826         return MemRefType::Builder(t).setAffineMaps({});
827     return t;
828   }
829 
830   // 0-D corner case for empty shape that still have an affine map. Example:
831   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
832   // offset needs to remain, just return t.
833   if (t.getShape().empty())
834     return t;
835 
836   // If the canonical strided layout for the sizes of `t` is equal to the
837   // simplified layout of `t` we can just return an empty layout. Otherwise,
838   // just simplify the existing layout.
839   AffineExpr expr =
840       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
841   auto simplifiedLayoutExpr =
842       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
843   if (expr != simplifiedLayoutExpr)
844     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
845         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
846   return MemRefType::Builder(t).setAffineMaps({});
847 }
848 
849 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
850                                                 ArrayRef<AffineExpr> exprs,
851                                                 MLIRContext *context) {
852   assert(!sizes.empty() && !exprs.empty() &&
853          "expected non-empty sizes and exprs");
854 
855   // Size 0 corner case is useful for canonicalizations.
856   if (llvm::is_contained(sizes, 0))
857     return getAffineConstantExpr(0, context);
858 
859   auto maps = AffineMap::inferFromExprList(exprs);
860   assert(!maps.empty() && "Expected one non-empty map");
861   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
862 
863   AffineExpr expr;
864   bool dynamicPoisonBit = false;
865   int64_t runningSize = 1;
866   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
867     int64_t size = std::get<1>(en);
868     // Degenerate case, no size =-> no stride
869     if (size == 0)
870       continue;
871     AffineExpr dimExpr = std::get<0>(en);
872     AffineExpr stride = dynamicPoisonBit
873                             ? getAffineSymbolExpr(nSymbols++, context)
874                             : getAffineConstantExpr(runningSize, context);
875     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
876     if (size > 0) {
877       runningSize *= size;
878       assert(runningSize > 0 && "integer overflow in size computation");
879     } else {
880       dynamicPoisonBit = true;
881     }
882   }
883   return simplifyAffineExpr(expr, numDims, nSymbols);
884 }
885 
886 /// Return a version of `t` with a layout that has all dynamic offset and
887 /// strides. This is used to erase the static layout.
888 MemRefType mlir::eraseStridedLayout(MemRefType t) {
889   auto val = ShapedType::kDynamicStrideOrOffset;
890   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
891       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
892 }
893 
894 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
895                                                 MLIRContext *context) {
896   SmallVector<AffineExpr, 4> exprs;
897   exprs.reserve(sizes.size());
898   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
899     exprs.push_back(getAffineDimExpr(dim, context));
900   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
901 }
902 
903 /// Return true if the layout for `t` is compatible with strided semantics.
904 bool mlir::isStrided(MemRefType t) {
905   int64_t offset;
906   SmallVector<int64_t, 4> strides;
907   auto res = getStridesAndOffset(t, strides, offset);
908   return succeeded(res);
909 }
910 
911 /// Return the layout map in strided linear layout AffineMap form.
912 /// Return null if the layout is not compatible with a strided layout.
913 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
914   int64_t offset;
915   SmallVector<int64_t, 4> strides;
916   if (failed(getStridesAndOffset(t, strides, offset)))
917     return AffineMap();
918   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
919 }
920