1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 //===----------------------------------------------------------------------===//
24 /// Tablegen Type Definitions
25 //===----------------------------------------------------------------------===//
26 
27 #define GET_TYPEDEF_CLASSES
28 #include "mlir/IR/BuiltinTypes.cpp.inc"
29 
30 //===----------------------------------------------------------------------===//
31 /// ComplexType
32 //===----------------------------------------------------------------------===//
33 
34 ComplexType ComplexType::get(Type elementType) {
35   return Base::get(elementType.getContext(), elementType);
36 }
37 
38 ComplexType ComplexType::getChecked(Type elementType, Location location) {
39   return Base::getChecked(location, elementType);
40 }
41 
42 /// Verify the construction of an integer type.
43 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
44                                                         Type elementType) {
45   if (!elementType.isIntOrFloat())
46     return emitError(loc, "invalid element type for complex");
47   return success();
48 }
49 
50 Type ComplexType::getElementType() { return getImpl()->elementType; }
51 
52 //===----------------------------------------------------------------------===//
53 // Integer Type
54 //===----------------------------------------------------------------------===//
55 
56 // static constexpr must have a definition (until in C++17 and inline variable).
57 constexpr unsigned IntegerType::kMaxWidth;
58 
59 /// Verify the construction of an integer type.
60 LogicalResult
61 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
62                                           SignednessSemantics signedness) {
63   if (width > IntegerType::kMaxWidth) {
64     return emitError(loc) << "integer bitwidth is limited to "
65                           << IntegerType::kMaxWidth << " bits";
66   }
67   return success();
68 }
69 
70 unsigned IntegerType::getWidth() const { return getImpl()->width; }
71 
72 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
73   return getImpl()->signedness;
74 }
75 
76 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
77   if (!scale)
78     return IntegerType();
79   return IntegerType::get(scale * getWidth(), getSignedness(), getContext());
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Float Type
84 //===----------------------------------------------------------------------===//
85 
86 unsigned FloatType::getWidth() {
87   if (isa<Float16Type, BFloat16Type>())
88     return 16;
89   if (isa<Float32Type>())
90     return 32;
91   if (isa<Float64Type>())
92     return 64;
93   llvm_unreachable("unexpected float type");
94 }
95 
96 /// Returns the floating semantics for the given type.
97 const llvm::fltSemantics &FloatType::getFloatSemantics() {
98   if (isa<BFloat16Type>())
99     return APFloat::BFloat();
100   if (isa<Float16Type>())
101     return APFloat::IEEEhalf();
102   if (isa<Float32Type>())
103     return APFloat::IEEEsingle();
104   if (isa<Float64Type>())
105     return APFloat::IEEEdouble();
106   llvm_unreachable("non-floating point type used");
107 }
108 
109 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
110   if (!scale)
111     return FloatType();
112   MLIRContext *ctx = getContext();
113   if (isF16() || isBF16()) {
114     if (scale == 2)
115       return FloatType::getF32(ctx);
116     if (scale == 4)
117       return FloatType::getF64(ctx);
118   }
119   if (isF32())
120     if (scale == 2)
121       return FloatType::getF64(ctx);
122   return FloatType();
123 }
124 
125 //===----------------------------------------------------------------------===//
126 // FunctionType
127 //===----------------------------------------------------------------------===//
128 
129 FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
130                                MLIRContext *context) {
131   return Base::get(context, inputs, results);
132 }
133 
134 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
135 
136 ArrayRef<Type> FunctionType::getInputs() const {
137   return getImpl()->getInputs();
138 }
139 
140 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
141 
142 ArrayRef<Type> FunctionType::getResults() const {
143   return getImpl()->getResults();
144 }
145 
146 /// Helper to call a callback once on each index in the range
147 /// [0, `totalIndices`), *except* for the indices given in `indices`.
148 /// `indices` is allowed to have duplicates and can be in any order.
149 inline void iterateIndicesExcept(unsigned totalIndices,
150                                  ArrayRef<unsigned> indices,
151                                  function_ref<void(unsigned)> callback) {
152   llvm::BitVector skipIndices(totalIndices);
153   for (unsigned i : indices)
154     skipIndices.set(i);
155 
156   for (unsigned i = 0; i < totalIndices; ++i)
157     if (!skipIndices.test(i))
158       callback(i);
159 }
160 
161 /// Returns a new function type without the specified arguments and results.
162 FunctionType
163 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
164                                        ArrayRef<unsigned> resultIndices) {
165   ArrayRef<Type> newInputTypes = getInputs();
166   SmallVector<Type, 4> newInputTypesBuffer;
167   if (!argIndices.empty()) {
168     unsigned originalNumArgs = getNumInputs();
169     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
170       newInputTypesBuffer.emplace_back(getInput(i));
171     });
172     newInputTypes = newInputTypesBuffer;
173   }
174 
175   ArrayRef<Type> newResultTypes = getResults();
176   SmallVector<Type, 4> newResultTypesBuffer;
177   if (!resultIndices.empty()) {
178     unsigned originalNumResults = getNumResults();
179     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
180       newResultTypesBuffer.emplace_back(getResult(i));
181     });
182     newResultTypes = newResultTypesBuffer;
183   }
184 
185   return get(newInputTypes, newResultTypes, getContext());
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // OpaqueType
190 //===----------------------------------------------------------------------===//
191 
192 OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
193                            MLIRContext *context) {
194   return Base::get(context, dialect, typeData);
195 }
196 
197 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
198                                   MLIRContext *context, Location location) {
199   return Base::getChecked(location, dialect, typeData);
200 }
201 
202 /// Returns the dialect namespace of the opaque type.
203 Identifier OpaqueType::getDialectNamespace() const {
204   return getImpl()->dialectNamespace;
205 }
206 
207 /// Returns the raw type data of the opaque type.
208 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
209 
210 /// Verify the construction of an opaque type.
211 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
212                                                        Identifier dialect,
213                                                        StringRef typeData) {
214   if (!Dialect::isValidNamespace(dialect.strref()))
215     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
216   return success();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // ShapedType
221 //===----------------------------------------------------------------------===//
222 constexpr int64_t ShapedType::kDynamicSize;
223 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
224 
225 Type ShapedType::getElementType() const {
226   return static_cast<ImplType *>(impl)->elementType;
227 }
228 
229 unsigned ShapedType::getElementTypeBitWidth() const {
230   return getElementType().getIntOrFloatBitWidth();
231 }
232 
233 int64_t ShapedType::getNumElements() const {
234   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
235   auto shape = getShape();
236   int64_t num = 1;
237   for (auto dim : shape)
238     num *= dim;
239   return num;
240 }
241 
242 int64_t ShapedType::getRank() const { return getShape().size(); }
243 
244 bool ShapedType::hasRank() const {
245   return !isa<UnrankedMemRefType, UnrankedTensorType>();
246 }
247 
248 int64_t ShapedType::getDimSize(unsigned idx) const {
249   assert(idx < getRank() && "invalid index for shaped type");
250   return getShape()[idx];
251 }
252 
253 bool ShapedType::isDynamicDim(unsigned idx) const {
254   assert(idx < getRank() && "invalid index for shaped type");
255   return isDynamic(getShape()[idx]);
256 }
257 
258 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
259   assert(index < getRank() && "invalid index");
260   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
261   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
262 }
263 
264 /// Get the number of bits require to store a value of the given shaped type.
265 /// Compute the value recursively since tensors are allowed to have vectors as
266 /// elements.
267 int64_t ShapedType::getSizeInBits() const {
268   assert(hasStaticShape() &&
269          "cannot get the bit size of an aggregate with a dynamic shape");
270 
271   auto elementType = getElementType();
272   if (elementType.isIntOrFloat())
273     return elementType.getIntOrFloatBitWidth() * getNumElements();
274 
275   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
276     elementType = complexType.getElementType();
277     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
278   }
279 
280   // Tensors can have vectors and other tensors as elements, other shaped types
281   // cannot.
282   assert(isa<TensorType>() && "unsupported element type");
283   assert((elementType.isa<VectorType, TensorType>()) &&
284          "unsupported tensor element type");
285   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
286 }
287 
288 ArrayRef<int64_t> ShapedType::getShape() const {
289   if (auto vectorType = dyn_cast<VectorType>())
290     return vectorType.getShape();
291   if (auto tensorType = dyn_cast<RankedTensorType>())
292     return tensorType.getShape();
293   return cast<MemRefType>().getShape();
294 }
295 
296 int64_t ShapedType::getNumDynamicDims() const {
297   return llvm::count_if(getShape(), isDynamic);
298 }
299 
300 bool ShapedType::hasStaticShape() const {
301   return hasRank() && llvm::none_of(getShape(), isDynamic);
302 }
303 
304 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
305   return hasStaticShape() && getShape() == shape;
306 }
307 
308 //===----------------------------------------------------------------------===//
309 // VectorType
310 //===----------------------------------------------------------------------===//
311 
312 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
313   return Base::get(elementType.getContext(), shape, elementType);
314 }
315 
316 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
317                                   Location location) {
318   return Base::getChecked(location, shape, elementType);
319 }
320 
321 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
322                                                        ArrayRef<int64_t> shape,
323                                                        Type elementType) {
324   if (shape.empty())
325     return emitError(loc, "vector types must have at least one dimension");
326 
327   if (!isValidElementType(elementType))
328     return emitError(loc, "vector elements must be int or float type");
329 
330   if (any_of(shape, [](int64_t i) { return i <= 0; }))
331     return emitError(loc, "vector types must have positive constant sizes");
332 
333   return success();
334 }
335 
336 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
337 
338 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
339   if (!scale)
340     return VectorType();
341   if (auto et = getElementType().dyn_cast<IntegerType>())
342     if (auto scaledEt = et.scaleElementBitwidth(scale))
343       return VectorType::get(getShape(), scaledEt);
344   if (auto et = getElementType().dyn_cast<FloatType>())
345     if (auto scaledEt = et.scaleElementBitwidth(scale))
346       return VectorType::get(getShape(), scaledEt);
347   return VectorType();
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // TensorType
352 //===----------------------------------------------------------------------===//
353 
354 // Check if "elementType" can be an element type of a tensor. Emit errors if
355 // location is not nullptr.  Returns failure if check failed.
356 static LogicalResult checkTensorElementType(Location location,
357                                             Type elementType) {
358   if (!TensorType::isValidElementType(elementType))
359     return emitError(location, "invalid tensor element type: ") << elementType;
360   return success();
361 }
362 
363 /// Return true if the specified element type is ok in a tensor.
364 bool TensorType::isValidElementType(Type type) {
365   // Note: Non standard/builtin types are allowed to exist within tensor
366   // types. Dialects are expected to verify that tensor types have a valid
367   // element type within that dialect.
368   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
369                   IndexType>() ||
370          !type.getDialect().getNamespace().empty();
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // RankedTensorType
375 //===----------------------------------------------------------------------===//
376 
377 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
378                                        Type elementType) {
379   return Base::get(elementType.getContext(), shape, elementType);
380 }
381 
382 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
383                                               Type elementType,
384                                               Location location) {
385   return Base::getChecked(location, shape, elementType);
386 }
387 
388 LogicalResult RankedTensorType::verifyConstructionInvariants(
389     Location loc, ArrayRef<int64_t> shape, Type elementType) {
390   for (int64_t s : shape) {
391     if (s < -1)
392       return emitError(loc, "invalid tensor dimension size");
393   }
394   return checkTensorElementType(loc, elementType);
395 }
396 
397 ArrayRef<int64_t> RankedTensorType::getShape() const {
398   return getImpl()->getShape();
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // UnrankedTensorType
403 //===----------------------------------------------------------------------===//
404 
405 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
406   return Base::get(elementType.getContext(), elementType);
407 }
408 
409 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
410                                                   Location location) {
411   return Base::getChecked(location, elementType);
412 }
413 
414 LogicalResult
415 UnrankedTensorType::verifyConstructionInvariants(Location loc,
416                                                  Type elementType) {
417   return checkTensorElementType(loc, elementType);
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // BaseMemRefType
422 //===----------------------------------------------------------------------===//
423 
424 unsigned BaseMemRefType::getMemorySpace() const {
425   return static_cast<ImplType *>(impl)->memorySpace;
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // MemRefType
430 //===----------------------------------------------------------------------===//
431 
432 /// Get or create a new MemRefType based on shape, element type, affine
433 /// map composition, and memory space.  Assumes the arguments define a
434 /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
435 /// construction failures.
436 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
437                            ArrayRef<AffineMap> affineMapComposition,
438                            unsigned memorySpace) {
439   auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
440                         /*location=*/llvm::None);
441   assert(result && "Failed to construct instance of MemRefType.");
442   return result;
443 }
444 
445 /// Get or create a new MemRefType based on shape, element type, affine
446 /// map composition, and memory space declared at the given location.
447 /// If the location is unknown, the last argument should be an instance of
448 /// UnknownLoc.  If the MemRefType defined by the arguments would be
449 /// ill-formed, emits errors (to the handler registered with the context or to
450 /// the error stream) and returns nullptr.
451 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
452                                   ArrayRef<AffineMap> affineMapComposition,
453                                   unsigned memorySpace, Location location) {
454   return getImpl(shape, elementType, affineMapComposition, memorySpace,
455                  location);
456 }
457 
458 /// Get or create a new MemRefType defined by the arguments.  If the resulting
459 /// type would be ill-formed, return nullptr.  If the location is provided,
460 /// emit detailed error messages.  To emit errors when the location is unknown,
461 /// pass in an instance of UnknownLoc.
462 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
463                                ArrayRef<AffineMap> affineMapComposition,
464                                unsigned memorySpace,
465                                Optional<Location> location) {
466   auto *context = elementType.getContext();
467 
468   if (!BaseMemRefType::isValidElementType(elementType))
469     return emitOptionalError(location, "invalid memref element type"),
470            MemRefType();
471 
472   for (int64_t s : shape) {
473     // Negative sizes are not allowed except for `-1` that means dynamic size.
474     if (s < -1)
475       return emitOptionalError(location, "invalid memref size"), MemRefType();
476   }
477 
478   // Check that the structure of the composition is valid, i.e. that each
479   // subsequent affine map has as many inputs as the previous map has results.
480   // Take the dimensionality of the MemRef for the first map.
481   auto dim = shape.size();
482   unsigned i = 0;
483   for (const auto &affineMap : affineMapComposition) {
484     if (affineMap.getNumDims() != dim) {
485       if (location)
486         emitError(*location)
487             << "memref affine map dimension mismatch between "
488             << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
489             << " and affine map" << i + 1 << ": " << dim
490             << " != " << affineMap.getNumDims();
491       return nullptr;
492     }
493 
494     dim = affineMap.getNumResults();
495     ++i;
496   }
497 
498   // Drop identity maps from the composition.
499   // This may lead to the composition becoming empty, which is interpreted as an
500   // implicit identity.
501   SmallVector<AffineMap, 2> cleanedAffineMapComposition;
502   for (const auto &map : affineMapComposition) {
503     if (map.isIdentity())
504       continue;
505     cleanedAffineMapComposition.push_back(map);
506   }
507 
508   return Base::get(context, shape, elementType, cleanedAffineMapComposition,
509                    memorySpace);
510 }
511 
512 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
513 
514 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
515   return getImpl()->getAffineMaps();
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // UnrankedMemRefType
520 //===----------------------------------------------------------------------===//
521 
522 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
523                                            unsigned memorySpace) {
524   return Base::get(elementType.getContext(), elementType, memorySpace);
525 }
526 
527 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
528                                                   unsigned memorySpace,
529                                                   Location location) {
530   return Base::getChecked(location, elementType, memorySpace);
531 }
532 
533 LogicalResult
534 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
535                                                  unsigned memorySpace) {
536   if (!BaseMemRefType::isValidElementType(elementType))
537     return emitError(loc, "invalid memref element type");
538   return success();
539 }
540 
541 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
542 // i.e. single term). Accumulate the AffineExpr into the existing one.
543 static void extractStridesFromTerm(AffineExpr e,
544                                    AffineExpr multiplicativeFactor,
545                                    MutableArrayRef<AffineExpr> strides,
546                                    AffineExpr &offset) {
547   if (auto dim = e.dyn_cast<AffineDimExpr>())
548     strides[dim.getPosition()] =
549         strides[dim.getPosition()] + multiplicativeFactor;
550   else
551     offset = offset + e * multiplicativeFactor;
552 }
553 
554 /// Takes a single AffineExpr `e` and populates the `strides` array with the
555 /// strides expressions for each dim position.
556 /// The convention is that the strides for dimensions d0, .. dn appear in
557 /// order to make indexing intuitive into the result.
558 static LogicalResult extractStrides(AffineExpr e,
559                                     AffineExpr multiplicativeFactor,
560                                     MutableArrayRef<AffineExpr> strides,
561                                     AffineExpr &offset) {
562   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
563   if (!bin) {
564     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
565     return success();
566   }
567 
568   if (bin.getKind() == AffineExprKind::CeilDiv ||
569       bin.getKind() == AffineExprKind::FloorDiv ||
570       bin.getKind() == AffineExprKind::Mod)
571     return failure();
572 
573   if (bin.getKind() == AffineExprKind::Mul) {
574     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
575     if (dim) {
576       strides[dim.getPosition()] =
577           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
578       return success();
579     }
580     // LHS and RHS may both contain complex expressions of dims. Try one path
581     // and if it fails try the other. This is guaranteed to succeed because
582     // only one path may have a `dim`, otherwise this is not an AffineExpr in
583     // the first place.
584     if (bin.getLHS().isSymbolicOrConstant())
585       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
586                             strides, offset);
587     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
588                           strides, offset);
589   }
590 
591   if (bin.getKind() == AffineExprKind::Add) {
592     auto res1 =
593         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
594     auto res2 =
595         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
596     return success(succeeded(res1) && succeeded(res2));
597   }
598 
599   llvm_unreachable("unexpected binary operation");
600 }
601 
602 LogicalResult mlir::getStridesAndOffset(MemRefType t,
603                                         SmallVectorImpl<AffineExpr> &strides,
604                                         AffineExpr &offset) {
605   auto affineMaps = t.getAffineMaps();
606   // For now strides are only computed on a single affine map with a single
607   // result (i.e. the closed subset of linearization maps that are compatible
608   // with striding semantics).
609   // TODO: support more forms on a per-need basis.
610   if (affineMaps.size() > 1)
611     return failure();
612   if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
613     return failure();
614 
615   auto zero = getAffineConstantExpr(0, t.getContext());
616   auto one = getAffineConstantExpr(1, t.getContext());
617   offset = zero;
618   strides.assign(t.getRank(), zero);
619 
620   AffineMap m;
621   if (!affineMaps.empty()) {
622     m = affineMaps.front();
623     assert(!m.isIdentity() && "unexpected identity map");
624   }
625 
626   // Canonical case for empty map.
627   if (!m) {
628     // 0-D corner case, offset is already 0.
629     if (t.getRank() == 0)
630       return success();
631     auto stridedExpr =
632         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
633     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
634       return success();
635     assert(false && "unexpected failure: extract strides in canonical layout");
636   }
637 
638   // Non-canonical case requires more work.
639   auto stridedExpr =
640       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
641   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
642     offset = AffineExpr();
643     strides.clear();
644     return failure();
645   }
646 
647   // Simplify results to allow folding to constants and simple checks.
648   unsigned numDims = m.getNumDims();
649   unsigned numSymbols = m.getNumSymbols();
650   offset = simplifyAffineExpr(offset, numDims, numSymbols);
651   for (auto &stride : strides)
652     stride = simplifyAffineExpr(stride, numDims, numSymbols);
653 
654   /// In practice, a strided memref must be internally non-aliasing. Test
655   /// against 0 as a proxy.
656   /// TODO: static cases can have more advanced checks.
657   /// TODO: dynamic cases would require a way to compare symbolic
658   /// expressions and would probably need an affine set context propagated
659   /// everywhere.
660   if (llvm::any_of(strides, [](AffineExpr e) {
661         return e == getAffineConstantExpr(0, e.getContext());
662       })) {
663     offset = AffineExpr();
664     strides.clear();
665     return failure();
666   }
667 
668   return success();
669 }
670 
671 LogicalResult mlir::getStridesAndOffset(MemRefType t,
672                                         SmallVectorImpl<int64_t> &strides,
673                                         int64_t &offset) {
674   AffineExpr offsetExpr;
675   SmallVector<AffineExpr, 4> strideExprs;
676   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
677     return failure();
678   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
679     offset = cst.getValue();
680   else
681     offset = ShapedType::kDynamicStrideOrOffset;
682   for (auto e : strideExprs) {
683     if (auto c = e.dyn_cast<AffineConstantExpr>())
684       strides.push_back(c.getValue());
685     else
686       strides.push_back(ShapedType::kDynamicStrideOrOffset);
687   }
688   return success();
689 }
690 
691 //===----------------------------------------------------------------------===//
692 /// TupleType
693 //===----------------------------------------------------------------------===//
694 
695 /// Get or create a new TupleType with the provided element types. Assumes the
696 /// arguments define a well-formed type.
697 TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
698   return Base::get(context, elementTypes);
699 }
700 
701 /// Get or create an empty tuple type.
702 TupleType TupleType::get(MLIRContext *context) { return get({}, context); }
703 
704 /// Return the elements types for this tuple.
705 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
706 
707 /// Accumulate the types contained in this tuple and tuples nested within it.
708 /// Note that this only flattens nested tuples, not any other container type,
709 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
710 /// (i32, tensor<i32>, f32, i64)
711 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
712   for (Type type : getTypes()) {
713     if (auto nestedTuple = type.dyn_cast<TupleType>())
714       nestedTuple.getFlattenedTypes(types);
715     else
716       types.push_back(type);
717   }
718 }
719 
720 /// Return the number of element types.
721 size_t TupleType::size() const { return getImpl()->size(); }
722 
723 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
724                                            int64_t offset,
725                                            MLIRContext *context) {
726   AffineExpr expr;
727   unsigned nSymbols = 0;
728 
729   // AffineExpr for offset.
730   // Static case.
731   if (offset != MemRefType::getDynamicStrideOrOffset()) {
732     auto cst = getAffineConstantExpr(offset, context);
733     expr = cst;
734   } else {
735     // Dynamic case, new symbol for the offset.
736     auto sym = getAffineSymbolExpr(nSymbols++, context);
737     expr = sym;
738   }
739 
740   // AffineExpr for strides.
741   for (auto en : llvm::enumerate(strides)) {
742     auto dim = en.index();
743     auto stride = en.value();
744     assert(stride != 0 && "Invalid stride specification");
745     auto d = getAffineDimExpr(dim, context);
746     AffineExpr mult;
747     // Static case.
748     if (stride != MemRefType::getDynamicStrideOrOffset())
749       mult = getAffineConstantExpr(stride, context);
750     else
751       // Dynamic case, new symbol for each new stride.
752       mult = getAffineSymbolExpr(nSymbols++, context);
753     expr = expr + d * mult;
754   }
755 
756   return AffineMap::get(strides.size(), nSymbols, expr);
757 }
758 
759 /// Return a version of `t` with identity layout if it can be determined
760 /// statically that the layout is the canonical contiguous strided layout.
761 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
762 /// `t` with simplified layout.
763 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
764 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
765   auto affineMaps = t.getAffineMaps();
766   // Already in canonical form.
767   if (affineMaps.empty())
768     return t;
769 
770   // Can't reduce to canonical identity form, return in canonical form.
771   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
772     return t;
773 
774   // If the canonical strided layout for the sizes of `t` is equal to the
775   // simplified layout of `t` we can just return an empty layout. Otherwise,
776   // just simplify the existing layout.
777   AffineExpr expr =
778       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
779   auto m = affineMaps[0];
780   auto simplifiedLayoutExpr =
781       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
782   if (expr != simplifiedLayoutExpr)
783     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
784         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
785   return MemRefType::Builder(t).setAffineMaps({});
786 }
787 
788 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
789                                                 ArrayRef<AffineExpr> exprs,
790                                                 MLIRContext *context) {
791   // Size 0 corner case is useful for canonicalizations.
792   if (llvm::is_contained(sizes, 0))
793     return getAffineConstantExpr(0, context);
794 
795   auto maps = AffineMap::inferFromExprList(exprs);
796   assert(!maps.empty() && "Expected one non-empty map");
797   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
798 
799   AffineExpr expr;
800   bool dynamicPoisonBit = false;
801   int64_t runningSize = 1;
802   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
803     int64_t size = std::get<1>(en);
804     // Degenerate case, no size =-> no stride
805     if (size == 0)
806       continue;
807     AffineExpr dimExpr = std::get<0>(en);
808     AffineExpr stride = dynamicPoisonBit
809                             ? getAffineSymbolExpr(nSymbols++, context)
810                             : getAffineConstantExpr(runningSize, context);
811     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
812     if (size > 0)
813       runningSize *= size;
814     else
815       dynamicPoisonBit = true;
816   }
817   return simplifyAffineExpr(expr, numDims, nSymbols);
818 }
819 
820 /// Return a version of `t` with a layout that has all dynamic offset and
821 /// strides. This is used to erase the static layout.
822 MemRefType mlir::eraseStridedLayout(MemRefType t) {
823   auto val = ShapedType::kDynamicStrideOrOffset;
824   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
825       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
826 }
827 
828 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
829                                                 MLIRContext *context) {
830   SmallVector<AffineExpr, 4> exprs;
831   exprs.reserve(sizes.size());
832   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
833     exprs.push_back(getAffineDimExpr(dim, context));
834   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
835 }
836 
837 /// Return true if the layout for `t` is compatible with strided semantics.
838 bool mlir::isStrided(MemRefType t) {
839   int64_t offset;
840   SmallVector<int64_t, 4> stridesAndOffset;
841   auto res = getStridesAndOffset(t, stridesAndOffset, offset);
842   return succeeded(res);
843 }
844