1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 //===----------------------------------------------------------------------===//
24 /// Tablegen Type Definitions
25 //===----------------------------------------------------------------------===//
26 
27 #define GET_TYPEDEF_CLASSES
28 #include "mlir/IR/BuiltinTypes.cpp.inc"
29 
30 //===----------------------------------------------------------------------===//
31 /// ComplexType
32 //===----------------------------------------------------------------------===//
33 
34 ComplexType ComplexType::get(Type elementType) {
35   return Base::get(elementType.getContext(), elementType);
36 }
37 
38 ComplexType ComplexType::getChecked(Location location, Type elementType) {
39   return Base::getChecked(location, elementType);
40 }
41 
42 /// Verify the construction of an integer type.
43 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
44                                                         Type elementType) {
45   if (!elementType.isIntOrFloat())
46     return emitError(loc, "invalid element type for complex");
47   return success();
48 }
49 
50 Type ComplexType::getElementType() { return getImpl()->elementType; }
51 
52 //===----------------------------------------------------------------------===//
53 // Integer Type
54 //===----------------------------------------------------------------------===//
55 
56 // static constexpr must have a definition (until in C++17 and inline variable).
57 constexpr unsigned IntegerType::kMaxWidth;
58 
59 /// Verify the construction of an integer type.
60 LogicalResult
61 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
62                                           SignednessSemantics signedness) {
63   if (width > IntegerType::kMaxWidth) {
64     return emitError(loc) << "integer bitwidth is limited to "
65                           << IntegerType::kMaxWidth << " bits";
66   }
67   return success();
68 }
69 
70 unsigned IntegerType::getWidth() const { return getImpl()->width; }
71 
72 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
73   return getImpl()->signedness;
74 }
75 
76 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
77   if (!scale)
78     return IntegerType();
79   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Float Type
84 //===----------------------------------------------------------------------===//
85 
86 unsigned FloatType::getWidth() {
87   if (isa<Float16Type, BFloat16Type>())
88     return 16;
89   if (isa<Float32Type>())
90     return 32;
91   if (isa<Float64Type>())
92     return 64;
93   llvm_unreachable("unexpected float type");
94 }
95 
96 /// Returns the floating semantics for the given type.
97 const llvm::fltSemantics &FloatType::getFloatSemantics() {
98   if (isa<BFloat16Type>())
99     return APFloat::BFloat();
100   if (isa<Float16Type>())
101     return APFloat::IEEEhalf();
102   if (isa<Float32Type>())
103     return APFloat::IEEEsingle();
104   if (isa<Float64Type>())
105     return APFloat::IEEEdouble();
106   llvm_unreachable("non-floating point type used");
107 }
108 
109 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
110   if (!scale)
111     return FloatType();
112   MLIRContext *ctx = getContext();
113   if (isF16() || isBF16()) {
114     if (scale == 2)
115       return FloatType::getF32(ctx);
116     if (scale == 4)
117       return FloatType::getF64(ctx);
118   }
119   if (isF32())
120     if (scale == 2)
121       return FloatType::getF64(ctx);
122   return FloatType();
123 }
124 
125 //===----------------------------------------------------------------------===//
126 // FunctionType
127 //===----------------------------------------------------------------------===//
128 
129 FunctionType FunctionType::get(MLIRContext *context, TypeRange inputs,
130                                TypeRange results) {
131   return Base::get(context, inputs, results);
132 }
133 
134 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
135 
136 ArrayRef<Type> FunctionType::getInputs() const {
137   return getImpl()->getInputs();
138 }
139 
140 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
141 
142 ArrayRef<Type> FunctionType::getResults() const {
143   return getImpl()->getResults();
144 }
145 
146 /// Helper to call a callback once on each index in the range
147 /// [0, `totalIndices`), *except* for the indices given in `indices`.
148 /// `indices` is allowed to have duplicates and can be in any order.
149 inline void iterateIndicesExcept(unsigned totalIndices,
150                                  ArrayRef<unsigned> indices,
151                                  function_ref<void(unsigned)> callback) {
152   llvm::BitVector skipIndices(totalIndices);
153   for (unsigned i : indices)
154     skipIndices.set(i);
155 
156   for (unsigned i = 0; i < totalIndices; ++i)
157     if (!skipIndices.test(i))
158       callback(i);
159 }
160 
161 /// Returns a new function type without the specified arguments and results.
162 FunctionType
163 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
164                                        ArrayRef<unsigned> resultIndices) {
165   ArrayRef<Type> newInputTypes = getInputs();
166   SmallVector<Type, 4> newInputTypesBuffer;
167   if (!argIndices.empty()) {
168     unsigned originalNumArgs = getNumInputs();
169     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
170       newInputTypesBuffer.emplace_back(getInput(i));
171     });
172     newInputTypes = newInputTypesBuffer;
173   }
174 
175   ArrayRef<Type> newResultTypes = getResults();
176   SmallVector<Type, 4> newResultTypesBuffer;
177   if (!resultIndices.empty()) {
178     unsigned originalNumResults = getNumResults();
179     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
180       newResultTypesBuffer.emplace_back(getResult(i));
181     });
182     newResultTypes = newResultTypesBuffer;
183   }
184 
185   return get(getContext(), newInputTypes, newResultTypes);
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // OpaqueType
190 //===----------------------------------------------------------------------===//
191 
192 OpaqueType OpaqueType::get(MLIRContext *context, Identifier dialect,
193                            StringRef typeData) {
194   return Base::get(context, dialect, typeData);
195 }
196 
197 OpaqueType OpaqueType::getChecked(Location location, Identifier dialect,
198                                   StringRef typeData) {
199   return Base::getChecked(location, dialect, typeData);
200 }
201 
202 /// Returns the dialect namespace of the opaque type.
203 Identifier OpaqueType::getDialectNamespace() const {
204   return getImpl()->dialectNamespace;
205 }
206 
207 /// Returns the raw type data of the opaque type.
208 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
209 
210 /// Verify the construction of an opaque type.
211 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
212                                                        Identifier dialect,
213                                                        StringRef typeData) {
214   if (!Dialect::isValidNamespace(dialect.strref()))
215     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
216   return success();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // ShapedType
221 //===----------------------------------------------------------------------===//
222 constexpr int64_t ShapedType::kDynamicSize;
223 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
224 
225 Type ShapedType::getElementType() const {
226   return static_cast<ImplType *>(impl)->elementType;
227 }
228 
229 unsigned ShapedType::getElementTypeBitWidth() const {
230   return getElementType().getIntOrFloatBitWidth();
231 }
232 
233 int64_t ShapedType::getNumElements() const {
234   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
235   auto shape = getShape();
236   int64_t num = 1;
237   for (auto dim : shape)
238     num *= dim;
239   return num;
240 }
241 
242 int64_t ShapedType::getRank() const { return getShape().size(); }
243 
244 bool ShapedType::hasRank() const {
245   return !isa<UnrankedMemRefType, UnrankedTensorType>();
246 }
247 
248 int64_t ShapedType::getDimSize(unsigned idx) const {
249   assert(idx < getRank() && "invalid index for shaped type");
250   return getShape()[idx];
251 }
252 
253 bool ShapedType::isDynamicDim(unsigned idx) const {
254   assert(idx < getRank() && "invalid index for shaped type");
255   return isDynamic(getShape()[idx]);
256 }
257 
258 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
259   assert(index < getRank() && "invalid index");
260   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
261   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
262 }
263 
264 /// Get the number of bits require to store a value of the given shaped type.
265 /// Compute the value recursively since tensors are allowed to have vectors as
266 /// elements.
267 int64_t ShapedType::getSizeInBits() const {
268   assert(hasStaticShape() &&
269          "cannot get the bit size of an aggregate with a dynamic shape");
270 
271   auto elementType = getElementType();
272   if (elementType.isIntOrFloat())
273     return elementType.getIntOrFloatBitWidth() * getNumElements();
274 
275   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
276     elementType = complexType.getElementType();
277     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
278   }
279 
280   // Tensors can have vectors and other tensors as elements, other shaped types
281   // cannot.
282   assert(isa<TensorType>() && "unsupported element type");
283   assert((elementType.isa<VectorType, TensorType>()) &&
284          "unsupported tensor element type");
285   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
286 }
287 
288 ArrayRef<int64_t> ShapedType::getShape() const {
289   if (auto vectorType = dyn_cast<VectorType>())
290     return vectorType.getShape();
291   if (auto tensorType = dyn_cast<RankedTensorType>())
292     return tensorType.getShape();
293   return cast<MemRefType>().getShape();
294 }
295 
296 int64_t ShapedType::getNumDynamicDims() const {
297   return llvm::count_if(getShape(), isDynamic);
298 }
299 
300 bool ShapedType::hasStaticShape() const {
301   return hasRank() && llvm::none_of(getShape(), isDynamic);
302 }
303 
304 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
305   return hasStaticShape() && getShape() == shape;
306 }
307 
308 //===----------------------------------------------------------------------===//
309 // VectorType
310 //===----------------------------------------------------------------------===//
311 
312 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
313   return Base::get(elementType.getContext(), shape, elementType);
314 }
315 
316 VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape,
317                                   Type elementType) {
318   return Base::getChecked(location, shape, elementType);
319 }
320 
321 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
322                                                        ArrayRef<int64_t> shape,
323                                                        Type elementType) {
324   if (shape.empty())
325     return emitError(loc, "vector types must have at least one dimension");
326 
327   if (!isValidElementType(elementType))
328     return emitError(loc, "vector elements must be int or float type");
329 
330   if (any_of(shape, [](int64_t i) { return i <= 0; }))
331     return emitError(loc, "vector types must have positive constant sizes");
332 
333   return success();
334 }
335 
336 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
337 
338 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
339   if (!scale)
340     return VectorType();
341   if (auto et = getElementType().dyn_cast<IntegerType>())
342     if (auto scaledEt = et.scaleElementBitwidth(scale))
343       return VectorType::get(getShape(), scaledEt);
344   if (auto et = getElementType().dyn_cast<FloatType>())
345     if (auto scaledEt = et.scaleElementBitwidth(scale))
346       return VectorType::get(getShape(), scaledEt);
347   return VectorType();
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // TensorType
352 //===----------------------------------------------------------------------===//
353 
354 // Check if "elementType" can be an element type of a tensor. Emit errors if
355 // location is not nullptr.  Returns failure if check failed.
356 static LogicalResult checkTensorElementType(Location location,
357                                             Type elementType) {
358   if (!TensorType::isValidElementType(elementType))
359     return emitError(location, "invalid tensor element type: ") << elementType;
360   return success();
361 }
362 
363 /// Return true if the specified element type is ok in a tensor.
364 bool TensorType::isValidElementType(Type type) {
365   // Note: Non standard/builtin types are allowed to exist within tensor
366   // types. Dialects are expected to verify that tensor types have a valid
367   // element type within that dialect.
368   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
369                   IndexType>() ||
370          !type.getDialect().getNamespace().empty();
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // RankedTensorType
375 //===----------------------------------------------------------------------===//
376 
377 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
378                                        Type elementType) {
379   return Base::get(elementType.getContext(), shape, elementType);
380 }
381 
382 RankedTensorType RankedTensorType::getChecked(Location location,
383                                               ArrayRef<int64_t> shape,
384                                               Type elementType) {
385   return Base::getChecked(location, shape, elementType);
386 }
387 
388 LogicalResult RankedTensorType::verifyConstructionInvariants(
389     Location loc, ArrayRef<int64_t> shape, Type elementType) {
390   for (int64_t s : shape) {
391     if (s < -1)
392       return emitError(loc, "invalid tensor dimension size");
393   }
394   return checkTensorElementType(loc, elementType);
395 }
396 
397 ArrayRef<int64_t> RankedTensorType::getShape() const {
398   return getImpl()->getShape();
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // UnrankedTensorType
403 //===----------------------------------------------------------------------===//
404 
405 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
406   return Base::get(elementType.getContext(), elementType);
407 }
408 
409 UnrankedTensorType UnrankedTensorType::getChecked(Location location,
410                                                   Type elementType) {
411   return Base::getChecked(location, elementType);
412 }
413 
414 LogicalResult
415 UnrankedTensorType::verifyConstructionInvariants(Location loc,
416                                                  Type elementType) {
417   return checkTensorElementType(loc, elementType);
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // BaseMemRefType
422 //===----------------------------------------------------------------------===//
423 
424 unsigned BaseMemRefType::getMemorySpace() const {
425   return static_cast<ImplType *>(impl)->memorySpace;
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // MemRefType
430 //===----------------------------------------------------------------------===//
431 
432 /// Get or create a new MemRefType based on shape, element type, affine
433 /// map composition, and memory space.  Assumes the arguments define a
434 /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
435 /// construction failures.
436 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
437                            ArrayRef<AffineMap> affineMapComposition,
438                            unsigned memorySpace) {
439   auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
440                         /*location=*/llvm::None);
441   assert(result && "Failed to construct instance of MemRefType.");
442   return result;
443 }
444 
445 /// Get or create a new MemRefType based on shape, element type, affine
446 /// map composition, and memory space declared at the given location.
447 /// If the location is unknown, the last argument should be an instance of
448 /// UnknownLoc.  If the MemRefType defined by the arguments would be
449 /// ill-formed, emits errors (to the handler registered with the context or to
450 /// the error stream) and returns nullptr.
451 MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
452                                   Type elementType,
453                                   ArrayRef<AffineMap> affineMapComposition,
454                                   unsigned memorySpace) {
455   return getImpl(shape, elementType, affineMapComposition, memorySpace,
456                  location);
457 }
458 
459 /// Get or create a new MemRefType defined by the arguments.  If the resulting
460 /// type would be ill-formed, return nullptr.  If the location is provided,
461 /// emit detailed error messages.  To emit errors when the location is unknown,
462 /// pass in an instance of UnknownLoc.
463 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
464                                ArrayRef<AffineMap> affineMapComposition,
465                                unsigned memorySpace,
466                                Optional<Location> location) {
467   auto *context = elementType.getContext();
468 
469   if (!BaseMemRefType::isValidElementType(elementType))
470     return emitOptionalError(location, "invalid memref element type"),
471            MemRefType();
472 
473   for (int64_t s : shape) {
474     // Negative sizes are not allowed except for `-1` that means dynamic size.
475     if (s < -1)
476       return emitOptionalError(location, "invalid memref size"), MemRefType();
477   }
478 
479   // Check that the structure of the composition is valid, i.e. that each
480   // subsequent affine map has as many inputs as the previous map has results.
481   // Take the dimensionality of the MemRef for the first map.
482   auto dim = shape.size();
483   unsigned i = 0;
484   for (const auto &affineMap : affineMapComposition) {
485     if (affineMap.getNumDims() != dim) {
486       if (location)
487         emitError(*location)
488             << "memref affine map dimension mismatch between "
489             << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
490             << " and affine map" << i + 1 << ": " << dim
491             << " != " << affineMap.getNumDims();
492       return nullptr;
493     }
494 
495     dim = affineMap.getNumResults();
496     ++i;
497   }
498 
499   // Drop identity maps from the composition.
500   // This may lead to the composition becoming empty, which is interpreted as an
501   // implicit identity.
502   SmallVector<AffineMap, 2> cleanedAffineMapComposition;
503   for (const auto &map : affineMapComposition) {
504     if (map.isIdentity())
505       continue;
506     cleanedAffineMapComposition.push_back(map);
507   }
508 
509   return Base::get(context, shape, elementType, cleanedAffineMapComposition,
510                    memorySpace);
511 }
512 
513 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
514 
515 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
516   return getImpl()->getAffineMaps();
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // UnrankedMemRefType
521 //===----------------------------------------------------------------------===//
522 
523 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
524                                            unsigned memorySpace) {
525   return Base::get(elementType.getContext(), elementType, memorySpace);
526 }
527 
528 UnrankedMemRefType UnrankedMemRefType::getChecked(Location location,
529                                                   Type elementType,
530                                                   unsigned memorySpace) {
531   return Base::getChecked(location, elementType, memorySpace);
532 }
533 
534 LogicalResult
535 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
536                                                  unsigned memorySpace) {
537   if (!BaseMemRefType::isValidElementType(elementType))
538     return emitError(loc, "invalid memref element type");
539   return success();
540 }
541 
542 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
543 // i.e. single term). Accumulate the AffineExpr into the existing one.
544 static void extractStridesFromTerm(AffineExpr e,
545                                    AffineExpr multiplicativeFactor,
546                                    MutableArrayRef<AffineExpr> strides,
547                                    AffineExpr &offset) {
548   if (auto dim = e.dyn_cast<AffineDimExpr>())
549     strides[dim.getPosition()] =
550         strides[dim.getPosition()] + multiplicativeFactor;
551   else
552     offset = offset + e * multiplicativeFactor;
553 }
554 
555 /// Takes a single AffineExpr `e` and populates the `strides` array with the
556 /// strides expressions for each dim position.
557 /// The convention is that the strides for dimensions d0, .. dn appear in
558 /// order to make indexing intuitive into the result.
559 static LogicalResult extractStrides(AffineExpr e,
560                                     AffineExpr multiplicativeFactor,
561                                     MutableArrayRef<AffineExpr> strides,
562                                     AffineExpr &offset) {
563   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
564   if (!bin) {
565     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
566     return success();
567   }
568 
569   if (bin.getKind() == AffineExprKind::CeilDiv ||
570       bin.getKind() == AffineExprKind::FloorDiv ||
571       bin.getKind() == AffineExprKind::Mod)
572     return failure();
573 
574   if (bin.getKind() == AffineExprKind::Mul) {
575     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
576     if (dim) {
577       strides[dim.getPosition()] =
578           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
579       return success();
580     }
581     // LHS and RHS may both contain complex expressions of dims. Try one path
582     // and if it fails try the other. This is guaranteed to succeed because
583     // only one path may have a `dim`, otherwise this is not an AffineExpr in
584     // the first place.
585     if (bin.getLHS().isSymbolicOrConstant())
586       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
587                             strides, offset);
588     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
589                           strides, offset);
590   }
591 
592   if (bin.getKind() == AffineExprKind::Add) {
593     auto res1 =
594         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
595     auto res2 =
596         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
597     return success(succeeded(res1) && succeeded(res2));
598   }
599 
600   llvm_unreachable("unexpected binary operation");
601 }
602 
603 LogicalResult mlir::getStridesAndOffset(MemRefType t,
604                                         SmallVectorImpl<AffineExpr> &strides,
605                                         AffineExpr &offset) {
606   auto affineMaps = t.getAffineMaps();
607   // For now strides are only computed on a single affine map with a single
608   // result (i.e. the closed subset of linearization maps that are compatible
609   // with striding semantics).
610   // TODO: support more forms on a per-need basis.
611   if (affineMaps.size() > 1)
612     return failure();
613   if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
614     return failure();
615 
616   auto zero = getAffineConstantExpr(0, t.getContext());
617   auto one = getAffineConstantExpr(1, t.getContext());
618   offset = zero;
619   strides.assign(t.getRank(), zero);
620 
621   AffineMap m;
622   if (!affineMaps.empty()) {
623     m = affineMaps.front();
624     assert(!m.isIdentity() && "unexpected identity map");
625   }
626 
627   // Canonical case for empty map.
628   if (!m) {
629     // 0-D corner case, offset is already 0.
630     if (t.getRank() == 0)
631       return success();
632     auto stridedExpr =
633         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
634     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
635       return success();
636     assert(false && "unexpected failure: extract strides in canonical layout");
637   }
638 
639   // Non-canonical case requires more work.
640   auto stridedExpr =
641       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
642   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
643     offset = AffineExpr();
644     strides.clear();
645     return failure();
646   }
647 
648   // Simplify results to allow folding to constants and simple checks.
649   unsigned numDims = m.getNumDims();
650   unsigned numSymbols = m.getNumSymbols();
651   offset = simplifyAffineExpr(offset, numDims, numSymbols);
652   for (auto &stride : strides)
653     stride = simplifyAffineExpr(stride, numDims, numSymbols);
654 
655   /// In practice, a strided memref must be internally non-aliasing. Test
656   /// against 0 as a proxy.
657   /// TODO: static cases can have more advanced checks.
658   /// TODO: dynamic cases would require a way to compare symbolic
659   /// expressions and would probably need an affine set context propagated
660   /// everywhere.
661   if (llvm::any_of(strides, [](AffineExpr e) {
662         return e == getAffineConstantExpr(0, e.getContext());
663       })) {
664     offset = AffineExpr();
665     strides.clear();
666     return failure();
667   }
668 
669   return success();
670 }
671 
672 LogicalResult mlir::getStridesAndOffset(MemRefType t,
673                                         SmallVectorImpl<int64_t> &strides,
674                                         int64_t &offset) {
675   AffineExpr offsetExpr;
676   SmallVector<AffineExpr, 4> strideExprs;
677   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
678     return failure();
679   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
680     offset = cst.getValue();
681   else
682     offset = ShapedType::kDynamicStrideOrOffset;
683   for (auto e : strideExprs) {
684     if (auto c = e.dyn_cast<AffineConstantExpr>())
685       strides.push_back(c.getValue());
686     else
687       strides.push_back(ShapedType::kDynamicStrideOrOffset);
688   }
689   return success();
690 }
691 
692 //===----------------------------------------------------------------------===//
693 /// TupleType
694 //===----------------------------------------------------------------------===//
695 
696 /// Get or create a new TupleType with the provided element types. Assumes the
697 /// arguments define a well-formed type.
698 TupleType TupleType::get(MLIRContext *context, TypeRange elementTypes) {
699   return Base::get(context, elementTypes);
700 }
701 
702 /// Get or create an empty tuple type.
703 TupleType TupleType::get(MLIRContext *context) { return get(context, {}); }
704 
705 /// Return the elements types for this tuple.
706 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
707 
708 /// Accumulate the types contained in this tuple and tuples nested within it.
709 /// Note that this only flattens nested tuples, not any other container type,
710 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
711 /// (i32, tensor<i32>, f32, i64)
712 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
713   for (Type type : getTypes()) {
714     if (auto nestedTuple = type.dyn_cast<TupleType>())
715       nestedTuple.getFlattenedTypes(types);
716     else
717       types.push_back(type);
718   }
719 }
720 
721 /// Return the number of element types.
722 size_t TupleType::size() const { return getImpl()->size(); }
723 
724 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
725                                            int64_t offset,
726                                            MLIRContext *context) {
727   AffineExpr expr;
728   unsigned nSymbols = 0;
729 
730   // AffineExpr for offset.
731   // Static case.
732   if (offset != MemRefType::getDynamicStrideOrOffset()) {
733     auto cst = getAffineConstantExpr(offset, context);
734     expr = cst;
735   } else {
736     // Dynamic case, new symbol for the offset.
737     auto sym = getAffineSymbolExpr(nSymbols++, context);
738     expr = sym;
739   }
740 
741   // AffineExpr for strides.
742   for (auto en : llvm::enumerate(strides)) {
743     auto dim = en.index();
744     auto stride = en.value();
745     assert(stride != 0 && "Invalid stride specification");
746     auto d = getAffineDimExpr(dim, context);
747     AffineExpr mult;
748     // Static case.
749     if (stride != MemRefType::getDynamicStrideOrOffset())
750       mult = getAffineConstantExpr(stride, context);
751     else
752       // Dynamic case, new symbol for each new stride.
753       mult = getAffineSymbolExpr(nSymbols++, context);
754     expr = expr + d * mult;
755   }
756 
757   return AffineMap::get(strides.size(), nSymbols, expr);
758 }
759 
760 /// Return a version of `t` with identity layout if it can be determined
761 /// statically that the layout is the canonical contiguous strided layout.
762 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
763 /// `t` with simplified layout.
764 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
765 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
766   auto affineMaps = t.getAffineMaps();
767   // Already in canonical form.
768   if (affineMaps.empty())
769     return t;
770 
771   // Can't reduce to canonical identity form, return in canonical form.
772   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
773     return t;
774 
775   // If the canonical strided layout for the sizes of `t` is equal to the
776   // simplified layout of `t` we can just return an empty layout. Otherwise,
777   // just simplify the existing layout.
778   AffineExpr expr =
779       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
780   auto m = affineMaps[0];
781   auto simplifiedLayoutExpr =
782       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
783   if (expr != simplifiedLayoutExpr)
784     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
785         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
786   return MemRefType::Builder(t).setAffineMaps({});
787 }
788 
789 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
790                                                 ArrayRef<AffineExpr> exprs,
791                                                 MLIRContext *context) {
792   // Size 0 corner case is useful for canonicalizations.
793   if (llvm::is_contained(sizes, 0))
794     return getAffineConstantExpr(0, context);
795 
796   auto maps = AffineMap::inferFromExprList(exprs);
797   assert(!maps.empty() && "Expected one non-empty map");
798   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
799 
800   AffineExpr expr;
801   bool dynamicPoisonBit = false;
802   int64_t runningSize = 1;
803   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
804     int64_t size = std::get<1>(en);
805     // Degenerate case, no size =-> no stride
806     if (size == 0)
807       continue;
808     AffineExpr dimExpr = std::get<0>(en);
809     AffineExpr stride = dynamicPoisonBit
810                             ? getAffineSymbolExpr(nSymbols++, context)
811                             : getAffineConstantExpr(runningSize, context);
812     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
813     if (size > 0)
814       runningSize *= size;
815     else
816       dynamicPoisonBit = true;
817   }
818   return simplifyAffineExpr(expr, numDims, nSymbols);
819 }
820 
821 /// Return a version of `t` with a layout that has all dynamic offset and
822 /// strides. This is used to erase the static layout.
823 MemRefType mlir::eraseStridedLayout(MemRefType t) {
824   auto val = ShapedType::kDynamicStrideOrOffset;
825   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
826       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
827 }
828 
829 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
830                                                 MLIRContext *context) {
831   SmallVector<AffineExpr, 4> exprs;
832   exprs.reserve(sizes.size());
833   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
834     exprs.push_back(getAffineDimExpr(dim, context));
835   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
836 }
837 
838 /// Return true if the layout for `t` is compatible with strided semantics.
839 bool mlir::isStrided(MemRefType t) {
840   int64_t offset;
841   SmallVector<int64_t, 4> stridesAndOffset;
842   auto res = getStridesAndOffset(t, stridesAndOffset, offset);
843   return succeeded(res);
844 }
845