1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 //===----------------------------------------------------------------------===//
24 /// ComplexType
25 //===----------------------------------------------------------------------===//
26 
27 ComplexType ComplexType::get(Type elementType) {
28   return Base::get(elementType.getContext(), elementType);
29 }
30 
31 ComplexType ComplexType::getChecked(Type elementType, Location location) {
32   return Base::getChecked(location, elementType);
33 }
34 
35 /// Verify the construction of an integer type.
36 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
37                                                         Type elementType) {
38   if (!elementType.isIntOrFloat())
39     return emitError(loc, "invalid element type for complex");
40   return success();
41 }
42 
43 Type ComplexType::getElementType() { return getImpl()->elementType; }
44 
45 //===----------------------------------------------------------------------===//
46 // Integer Type
47 //===----------------------------------------------------------------------===//
48 
49 // static constexpr must have a definition (until in C++17 and inline variable).
50 constexpr unsigned IntegerType::kMaxWidth;
51 
52 /// Verify the construction of an integer type.
53 LogicalResult
54 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
55                                           SignednessSemantics signedness) {
56   if (width > IntegerType::kMaxWidth) {
57     return emitError(loc) << "integer bitwidth is limited to "
58                           << IntegerType::kMaxWidth << " bits";
59   }
60   return success();
61 }
62 
63 unsigned IntegerType::getWidth() const { return getImpl()->width; }
64 
65 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
66   return getImpl()->signedness;
67 }
68 
69 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
70   if (!scale)
71     return IntegerType();
72   return IntegerType::get(scale * getWidth(), getSignedness(), getContext());
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // Float Type
77 //===----------------------------------------------------------------------===//
78 
79 unsigned FloatType::getWidth() {
80   if (isa<Float16Type, BFloat16Type>())
81     return 16;
82   if (isa<Float32Type>())
83     return 32;
84   if (isa<Float64Type>())
85     return 64;
86   llvm_unreachable("unexpected float type");
87 }
88 
89 /// Returns the floating semantics for the given type.
90 const llvm::fltSemantics &FloatType::getFloatSemantics() {
91   if (isa<BFloat16Type>())
92     return APFloat::BFloat();
93   if (isa<Float16Type>())
94     return APFloat::IEEEhalf();
95   if (isa<Float32Type>())
96     return APFloat::IEEEsingle();
97   if (isa<Float64Type>())
98     return APFloat::IEEEdouble();
99   llvm_unreachable("non-floating point type used");
100 }
101 
102 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
103   if (!scale)
104     return FloatType();
105   MLIRContext *ctx = getContext();
106   if (isF16() || isBF16()) {
107     if (scale == 2)
108       return FloatType::getF32(ctx);
109     if (scale == 4)
110       return FloatType::getF64(ctx);
111   }
112   if (isF32())
113     if (scale == 2)
114       return FloatType::getF64(ctx);
115   return FloatType();
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // FunctionType
120 //===----------------------------------------------------------------------===//
121 
122 FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
123                                MLIRContext *context) {
124   return Base::get(context, inputs, results);
125 }
126 
127 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
128 
129 ArrayRef<Type> FunctionType::getInputs() const {
130   return getImpl()->getInputs();
131 }
132 
133 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
134 
135 ArrayRef<Type> FunctionType::getResults() const {
136   return getImpl()->getResults();
137 }
138 
139 /// Helper to call a callback once on each index in the range
140 /// [0, `totalIndices`), *except* for the indices given in `indices`.
141 /// `indices` is allowed to have duplicates and can be in any order.
142 inline void iterateIndicesExcept(unsigned totalIndices,
143                                  ArrayRef<unsigned> indices,
144                                  function_ref<void(unsigned)> callback) {
145   llvm::BitVector skipIndices(totalIndices);
146   for (unsigned i : indices)
147     skipIndices.set(i);
148 
149   for (unsigned i = 0; i < totalIndices; ++i)
150     if (!skipIndices.test(i))
151       callback(i);
152 }
153 
154 /// Returns a new function type without the specified arguments and results.
155 FunctionType
156 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
157                                        ArrayRef<unsigned> resultIndices) {
158   ArrayRef<Type> newInputTypes = getInputs();
159   SmallVector<Type, 4> newInputTypesBuffer;
160   if (!argIndices.empty()) {
161     unsigned originalNumArgs = getNumInputs();
162     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
163       newInputTypesBuffer.emplace_back(getInput(i));
164     });
165     newInputTypes = newInputTypesBuffer;
166   }
167 
168   ArrayRef<Type> newResultTypes = getResults();
169   SmallVector<Type, 4> newResultTypesBuffer;
170   if (!resultIndices.empty()) {
171     unsigned originalNumResults = getNumResults();
172     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
173       newResultTypesBuffer.emplace_back(getResult(i));
174     });
175     newResultTypes = newResultTypesBuffer;
176   }
177 
178   return get(newInputTypes, newResultTypes, getContext());
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // OpaqueType
183 //===----------------------------------------------------------------------===//
184 
185 OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
186                            MLIRContext *context) {
187   return Base::get(context, dialect, typeData);
188 }
189 
190 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
191                                   MLIRContext *context, Location location) {
192   return Base::getChecked(location, dialect, typeData);
193 }
194 
195 /// Returns the dialect namespace of the opaque type.
196 Identifier OpaqueType::getDialectNamespace() const {
197   return getImpl()->dialectNamespace;
198 }
199 
200 /// Returns the raw type data of the opaque type.
201 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
202 
203 /// Verify the construction of an opaque type.
204 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
205                                                        Identifier dialect,
206                                                        StringRef typeData) {
207   if (!Dialect::isValidNamespace(dialect.strref()))
208     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
209   return success();
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // ShapedType
214 //===----------------------------------------------------------------------===//
215 constexpr int64_t ShapedType::kDynamicSize;
216 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
217 
218 Type ShapedType::getElementType() const {
219   return static_cast<ImplType *>(impl)->elementType;
220 }
221 
222 unsigned ShapedType::getElementTypeBitWidth() const {
223   return getElementType().getIntOrFloatBitWidth();
224 }
225 
226 int64_t ShapedType::getNumElements() const {
227   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
228   auto shape = getShape();
229   int64_t num = 1;
230   for (auto dim : shape)
231     num *= dim;
232   return num;
233 }
234 
235 int64_t ShapedType::getRank() const { return getShape().size(); }
236 
237 bool ShapedType::hasRank() const {
238   return !isa<UnrankedMemRefType, UnrankedTensorType>();
239 }
240 
241 int64_t ShapedType::getDimSize(unsigned idx) const {
242   assert(idx < getRank() && "invalid index for shaped type");
243   return getShape()[idx];
244 }
245 
246 bool ShapedType::isDynamicDim(unsigned idx) const {
247   assert(idx < getRank() && "invalid index for shaped type");
248   return isDynamic(getShape()[idx]);
249 }
250 
251 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
252   assert(index < getRank() && "invalid index");
253   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
254   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
255 }
256 
257 /// Get the number of bits require to store a value of the given shaped type.
258 /// Compute the value recursively since tensors are allowed to have vectors as
259 /// elements.
260 int64_t ShapedType::getSizeInBits() const {
261   assert(hasStaticShape() &&
262          "cannot get the bit size of an aggregate with a dynamic shape");
263 
264   auto elementType = getElementType();
265   if (elementType.isIntOrFloat())
266     return elementType.getIntOrFloatBitWidth() * getNumElements();
267 
268   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
269     elementType = complexType.getElementType();
270     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
271   }
272 
273   // Tensors can have vectors and other tensors as elements, other shaped types
274   // cannot.
275   assert(isa<TensorType>() && "unsupported element type");
276   assert((elementType.isa<VectorType, TensorType>()) &&
277          "unsupported tensor element type");
278   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
279 }
280 
281 ArrayRef<int64_t> ShapedType::getShape() const {
282   if (auto vectorType = dyn_cast<VectorType>())
283     return vectorType.getShape();
284   if (auto tensorType = dyn_cast<RankedTensorType>())
285     return tensorType.getShape();
286   return cast<MemRefType>().getShape();
287 }
288 
289 int64_t ShapedType::getNumDynamicDims() const {
290   return llvm::count_if(getShape(), isDynamic);
291 }
292 
293 bool ShapedType::hasStaticShape() const {
294   return hasRank() && llvm::none_of(getShape(), isDynamic);
295 }
296 
297 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
298   return hasStaticShape() && getShape() == shape;
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // VectorType
303 //===----------------------------------------------------------------------===//
304 
305 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
306   return Base::get(elementType.getContext(), shape, elementType);
307 }
308 
309 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
310                                   Location location) {
311   return Base::getChecked(location, shape, elementType);
312 }
313 
314 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
315                                                        ArrayRef<int64_t> shape,
316                                                        Type elementType) {
317   if (shape.empty())
318     return emitError(loc, "vector types must have at least one dimension");
319 
320   if (!isValidElementType(elementType))
321     return emitError(loc, "vector elements must be int or float type");
322 
323   if (any_of(shape, [](int64_t i) { return i <= 0; }))
324     return emitError(loc, "vector types must have positive constant sizes");
325 
326   return success();
327 }
328 
329 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
330 
331 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
332   if (!scale)
333     return VectorType();
334   if (auto et = getElementType().dyn_cast<IntegerType>())
335     if (auto scaledEt = et.scaleElementBitwidth(scale))
336       return VectorType::get(getShape(), scaledEt);
337   if (auto et = getElementType().dyn_cast<FloatType>())
338     if (auto scaledEt = et.scaleElementBitwidth(scale))
339       return VectorType::get(getShape(), scaledEt);
340   return VectorType();
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // TensorType
345 //===----------------------------------------------------------------------===//
346 
347 // Check if "elementType" can be an element type of a tensor. Emit errors if
348 // location is not nullptr.  Returns failure if check failed.
349 static LogicalResult checkTensorElementType(Location location,
350                                             Type elementType) {
351   if (!TensorType::isValidElementType(elementType))
352     return emitError(location, "invalid tensor element type: ") << elementType;
353   return success();
354 }
355 
356 /// Return true if the specified element type is ok in a tensor.
357 bool TensorType::isValidElementType(Type type) {
358   // Note: Non standard/builtin types are allowed to exist within tensor
359   // types. Dialects are expected to verify that tensor types have a valid
360   // element type within that dialect.
361   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
362                   IndexType>() ||
363          !type.getDialect().getNamespace().empty();
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // RankedTensorType
368 //===----------------------------------------------------------------------===//
369 
370 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
371                                        Type elementType) {
372   return Base::get(elementType.getContext(), shape, elementType);
373 }
374 
375 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
376                                               Type elementType,
377                                               Location location) {
378   return Base::getChecked(location, shape, elementType);
379 }
380 
381 LogicalResult RankedTensorType::verifyConstructionInvariants(
382     Location loc, ArrayRef<int64_t> shape, Type elementType) {
383   for (int64_t s : shape) {
384     if (s < -1)
385       return emitError(loc, "invalid tensor dimension size");
386   }
387   return checkTensorElementType(loc, elementType);
388 }
389 
390 ArrayRef<int64_t> RankedTensorType::getShape() const {
391   return getImpl()->getShape();
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // UnrankedTensorType
396 //===----------------------------------------------------------------------===//
397 
398 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
399   return Base::get(elementType.getContext(), elementType);
400 }
401 
402 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
403                                                   Location location) {
404   return Base::getChecked(location, elementType);
405 }
406 
407 LogicalResult
408 UnrankedTensorType::verifyConstructionInvariants(Location loc,
409                                                  Type elementType) {
410   return checkTensorElementType(loc, elementType);
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // BaseMemRefType
415 //===----------------------------------------------------------------------===//
416 
417 unsigned BaseMemRefType::getMemorySpace() const {
418   return static_cast<ImplType *>(impl)->memorySpace;
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // MemRefType
423 //===----------------------------------------------------------------------===//
424 
425 /// Get or create a new MemRefType based on shape, element type, affine
426 /// map composition, and memory space.  Assumes the arguments define a
427 /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
428 /// construction failures.
429 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
430                            ArrayRef<AffineMap> affineMapComposition,
431                            unsigned memorySpace) {
432   auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
433                         /*location=*/llvm::None);
434   assert(result && "Failed to construct instance of MemRefType.");
435   return result;
436 }
437 
438 /// Get or create a new MemRefType based on shape, element type, affine
439 /// map composition, and memory space declared at the given location.
440 /// If the location is unknown, the last argument should be an instance of
441 /// UnknownLoc.  If the MemRefType defined by the arguments would be
442 /// ill-formed, emits errors (to the handler registered with the context or to
443 /// the error stream) and returns nullptr.
444 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
445                                   ArrayRef<AffineMap> affineMapComposition,
446                                   unsigned memorySpace, Location location) {
447   return getImpl(shape, elementType, affineMapComposition, memorySpace,
448                  location);
449 }
450 
451 /// Get or create a new MemRefType defined by the arguments.  If the resulting
452 /// type would be ill-formed, return nullptr.  If the location is provided,
453 /// emit detailed error messages.  To emit errors when the location is unknown,
454 /// pass in an instance of UnknownLoc.
455 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
456                                ArrayRef<AffineMap> affineMapComposition,
457                                unsigned memorySpace,
458                                Optional<Location> location) {
459   auto *context = elementType.getContext();
460 
461   if (!BaseMemRefType::isValidElementType(elementType))
462     return emitOptionalError(location, "invalid memref element type"),
463            MemRefType();
464 
465   for (int64_t s : shape) {
466     // Negative sizes are not allowed except for `-1` that means dynamic size.
467     if (s < -1)
468       return emitOptionalError(location, "invalid memref size"), MemRefType();
469   }
470 
471   // Check that the structure of the composition is valid, i.e. that each
472   // subsequent affine map has as many inputs as the previous map has results.
473   // Take the dimensionality of the MemRef for the first map.
474   auto dim = shape.size();
475   unsigned i = 0;
476   for (const auto &affineMap : affineMapComposition) {
477     if (affineMap.getNumDims() != dim) {
478       if (location)
479         emitError(*location)
480             << "memref affine map dimension mismatch between "
481             << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
482             << " and affine map" << i + 1 << ": " << dim
483             << " != " << affineMap.getNumDims();
484       return nullptr;
485     }
486 
487     dim = affineMap.getNumResults();
488     ++i;
489   }
490 
491   // Drop identity maps from the composition.
492   // This may lead to the composition becoming empty, which is interpreted as an
493   // implicit identity.
494   SmallVector<AffineMap, 2> cleanedAffineMapComposition;
495   for (const auto &map : affineMapComposition) {
496     if (map.isIdentity())
497       continue;
498     cleanedAffineMapComposition.push_back(map);
499   }
500 
501   return Base::get(context, shape, elementType, cleanedAffineMapComposition,
502                    memorySpace);
503 }
504 
505 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
506 
507 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
508   return getImpl()->getAffineMaps();
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // UnrankedMemRefType
513 //===----------------------------------------------------------------------===//
514 
515 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
516                                            unsigned memorySpace) {
517   return Base::get(elementType.getContext(), elementType, memorySpace);
518 }
519 
520 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
521                                                   unsigned memorySpace,
522                                                   Location location) {
523   return Base::getChecked(location, elementType, memorySpace);
524 }
525 
526 LogicalResult
527 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
528                                                  unsigned memorySpace) {
529   if (!BaseMemRefType::isValidElementType(elementType))
530     return emitError(loc, "invalid memref element type");
531   return success();
532 }
533 
534 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
535 // i.e. single term). Accumulate the AffineExpr into the existing one.
536 static void extractStridesFromTerm(AffineExpr e,
537                                    AffineExpr multiplicativeFactor,
538                                    MutableArrayRef<AffineExpr> strides,
539                                    AffineExpr &offset) {
540   if (auto dim = e.dyn_cast<AffineDimExpr>())
541     strides[dim.getPosition()] =
542         strides[dim.getPosition()] + multiplicativeFactor;
543   else
544     offset = offset + e * multiplicativeFactor;
545 }
546 
547 /// Takes a single AffineExpr `e` and populates the `strides` array with the
548 /// strides expressions for each dim position.
549 /// The convention is that the strides for dimensions d0, .. dn appear in
550 /// order to make indexing intuitive into the result.
551 static LogicalResult extractStrides(AffineExpr e,
552                                     AffineExpr multiplicativeFactor,
553                                     MutableArrayRef<AffineExpr> strides,
554                                     AffineExpr &offset) {
555   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
556   if (!bin) {
557     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
558     return success();
559   }
560 
561   if (bin.getKind() == AffineExprKind::CeilDiv ||
562       bin.getKind() == AffineExprKind::FloorDiv ||
563       bin.getKind() == AffineExprKind::Mod)
564     return failure();
565 
566   if (bin.getKind() == AffineExprKind::Mul) {
567     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
568     if (dim) {
569       strides[dim.getPosition()] =
570           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
571       return success();
572     }
573     // LHS and RHS may both contain complex expressions of dims. Try one path
574     // and if it fails try the other. This is guaranteed to succeed because
575     // only one path may have a `dim`, otherwise this is not an AffineExpr in
576     // the first place.
577     if (bin.getLHS().isSymbolicOrConstant())
578       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
579                             strides, offset);
580     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
581                           strides, offset);
582   }
583 
584   if (bin.getKind() == AffineExprKind::Add) {
585     auto res1 =
586         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
587     auto res2 =
588         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
589     return success(succeeded(res1) && succeeded(res2));
590   }
591 
592   llvm_unreachable("unexpected binary operation");
593 }
594 
595 LogicalResult mlir::getStridesAndOffset(MemRefType t,
596                                         SmallVectorImpl<AffineExpr> &strides,
597                                         AffineExpr &offset) {
598   auto affineMaps = t.getAffineMaps();
599   // For now strides are only computed on a single affine map with a single
600   // result (i.e. the closed subset of linearization maps that are compatible
601   // with striding semantics).
602   // TODO: support more forms on a per-need basis.
603   if (affineMaps.size() > 1)
604     return failure();
605   if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
606     return failure();
607 
608   auto zero = getAffineConstantExpr(0, t.getContext());
609   auto one = getAffineConstantExpr(1, t.getContext());
610   offset = zero;
611   strides.assign(t.getRank(), zero);
612 
613   AffineMap m;
614   if (!affineMaps.empty()) {
615     m = affineMaps.front();
616     assert(!m.isIdentity() && "unexpected identity map");
617   }
618 
619   // Canonical case for empty map.
620   if (!m) {
621     // 0-D corner case, offset is already 0.
622     if (t.getRank() == 0)
623       return success();
624     auto stridedExpr =
625         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
626     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
627       return success();
628     assert(false && "unexpected failure: extract strides in canonical layout");
629   }
630 
631   // Non-canonical case requires more work.
632   auto stridedExpr =
633       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
634   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
635     offset = AffineExpr();
636     strides.clear();
637     return failure();
638   }
639 
640   // Simplify results to allow folding to constants and simple checks.
641   unsigned numDims = m.getNumDims();
642   unsigned numSymbols = m.getNumSymbols();
643   offset = simplifyAffineExpr(offset, numDims, numSymbols);
644   for (auto &stride : strides)
645     stride = simplifyAffineExpr(stride, numDims, numSymbols);
646 
647   /// In practice, a strided memref must be internally non-aliasing. Test
648   /// against 0 as a proxy.
649   /// TODO: static cases can have more advanced checks.
650   /// TODO: dynamic cases would require a way to compare symbolic
651   /// expressions and would probably need an affine set context propagated
652   /// everywhere.
653   if (llvm::any_of(strides, [](AffineExpr e) {
654         return e == getAffineConstantExpr(0, e.getContext());
655       })) {
656     offset = AffineExpr();
657     strides.clear();
658     return failure();
659   }
660 
661   return success();
662 }
663 
664 LogicalResult mlir::getStridesAndOffset(MemRefType t,
665                                         SmallVectorImpl<int64_t> &strides,
666                                         int64_t &offset) {
667   AffineExpr offsetExpr;
668   SmallVector<AffineExpr, 4> strideExprs;
669   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
670     return failure();
671   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
672     offset = cst.getValue();
673   else
674     offset = ShapedType::kDynamicStrideOrOffset;
675   for (auto e : strideExprs) {
676     if (auto c = e.dyn_cast<AffineConstantExpr>())
677       strides.push_back(c.getValue());
678     else
679       strides.push_back(ShapedType::kDynamicStrideOrOffset);
680   }
681   return success();
682 }
683 
684 //===----------------------------------------------------------------------===//
685 /// TupleType
686 //===----------------------------------------------------------------------===//
687 
688 /// Get or create a new TupleType with the provided element types. Assumes the
689 /// arguments define a well-formed type.
690 TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
691   return Base::get(context, elementTypes);
692 }
693 
694 /// Get or create an empty tuple type.
695 TupleType TupleType::get(MLIRContext *context) { return get({}, context); }
696 
697 /// Return the elements types for this tuple.
698 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
699 
700 /// Accumulate the types contained in this tuple and tuples nested within it.
701 /// Note that this only flattens nested tuples, not any other container type,
702 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
703 /// (i32, tensor<i32>, f32, i64)
704 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
705   for (Type type : getTypes()) {
706     if (auto nestedTuple = type.dyn_cast<TupleType>())
707       nestedTuple.getFlattenedTypes(types);
708     else
709       types.push_back(type);
710   }
711 }
712 
713 /// Return the number of element types.
714 size_t TupleType::size() const { return getImpl()->size(); }
715 
716 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
717                                            int64_t offset,
718                                            MLIRContext *context) {
719   AffineExpr expr;
720   unsigned nSymbols = 0;
721 
722   // AffineExpr for offset.
723   // Static case.
724   if (offset != MemRefType::getDynamicStrideOrOffset()) {
725     auto cst = getAffineConstantExpr(offset, context);
726     expr = cst;
727   } else {
728     // Dynamic case, new symbol for the offset.
729     auto sym = getAffineSymbolExpr(nSymbols++, context);
730     expr = sym;
731   }
732 
733   // AffineExpr for strides.
734   for (auto en : llvm::enumerate(strides)) {
735     auto dim = en.index();
736     auto stride = en.value();
737     assert(stride != 0 && "Invalid stride specification");
738     auto d = getAffineDimExpr(dim, context);
739     AffineExpr mult;
740     // Static case.
741     if (stride != MemRefType::getDynamicStrideOrOffset())
742       mult = getAffineConstantExpr(stride, context);
743     else
744       // Dynamic case, new symbol for each new stride.
745       mult = getAffineSymbolExpr(nSymbols++, context);
746     expr = expr + d * mult;
747   }
748 
749   return AffineMap::get(strides.size(), nSymbols, expr);
750 }
751 
752 /// Return a version of `t` with identity layout if it can be determined
753 /// statically that the layout is the canonical contiguous strided layout.
754 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
755 /// `t` with simplified layout.
756 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
757 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
758   auto affineMaps = t.getAffineMaps();
759   // Already in canonical form.
760   if (affineMaps.empty())
761     return t;
762 
763   // Can't reduce to canonical identity form, return in canonical form.
764   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
765     return t;
766 
767   // If the canonical strided layout for the sizes of `t` is equal to the
768   // simplified layout of `t` we can just return an empty layout. Otherwise,
769   // just simplify the existing layout.
770   AffineExpr expr =
771       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
772   auto m = affineMaps[0];
773   auto simplifiedLayoutExpr =
774       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
775   if (expr != simplifiedLayoutExpr)
776     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
777         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
778   return MemRefType::Builder(t).setAffineMaps({});
779 }
780 
781 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
782                                                 ArrayRef<AffineExpr> exprs,
783                                                 MLIRContext *context) {
784   // Size 0 corner case is useful for canonicalizations.
785   if (llvm::is_contained(sizes, 0))
786     return getAffineConstantExpr(0, context);
787 
788   auto maps = AffineMap::inferFromExprList(exprs);
789   assert(!maps.empty() && "Expected one non-empty map");
790   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
791 
792   AffineExpr expr;
793   bool dynamicPoisonBit = false;
794   int64_t runningSize = 1;
795   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
796     int64_t size = std::get<1>(en);
797     // Degenerate case, no size =-> no stride
798     if (size == 0)
799       continue;
800     AffineExpr dimExpr = std::get<0>(en);
801     AffineExpr stride = dynamicPoisonBit
802                             ? getAffineSymbolExpr(nSymbols++, context)
803                             : getAffineConstantExpr(runningSize, context);
804     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
805     if (size > 0)
806       runningSize *= size;
807     else
808       dynamicPoisonBit = true;
809   }
810   return simplifyAffineExpr(expr, numDims, nSymbols);
811 }
812 
813 /// Return a version of `t` with a layout that has all dynamic offset and
814 /// strides. This is used to erase the static layout.
815 MemRefType mlir::eraseStridedLayout(MemRefType t) {
816   auto val = ShapedType::kDynamicStrideOrOffset;
817   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
818       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
819 }
820 
821 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
822                                                 MLIRContext *context) {
823   SmallVector<AffineExpr, 4> exprs;
824   exprs.reserve(sizes.size());
825   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
826     exprs.push_back(getAffineDimExpr(dim, context));
827   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
828 }
829 
830 /// Return true if the layout for `t` is compatible with strided semantics.
831 bool mlir::isStrided(MemRefType t) {
832   int64_t offset;
833   SmallVector<int64_t, 4> stridesAndOffset;
834   auto res = getStridesAndOffset(t, stridesAndOffset, offset);
835   return succeeded(res);
836 }
837