1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 //===----------------------------------------------------------------------===//
24 /// Tablegen Type Definitions
25 //===----------------------------------------------------------------------===//
26 
27 #define GET_TYPEDEF_CLASSES
28 #include "mlir/IR/BuiltinTypes.cpp.inc"
29 
30 //===----------------------------------------------------------------------===//
31 /// ComplexType
32 //===----------------------------------------------------------------------===//
33 
34 /// Verify the construction of an integer type.
35 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
36                                                         Type elementType) {
37   if (!elementType.isIntOrFloat())
38     return emitError(loc, "invalid element type for complex");
39   return success();
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Integer Type
44 //===----------------------------------------------------------------------===//
45 
46 // static constexpr must have a definition (until in C++17 and inline variable).
47 constexpr unsigned IntegerType::kMaxWidth;
48 
49 /// Verify the construction of an integer type.
50 LogicalResult
51 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
52                                           SignednessSemantics signedness) {
53   if (width > IntegerType::kMaxWidth) {
54     return emitError(loc) << "integer bitwidth is limited to "
55                           << IntegerType::kMaxWidth << " bits";
56   }
57   return success();
58 }
59 
60 unsigned IntegerType::getWidth() const { return getImpl()->width; }
61 
62 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
63   return getImpl()->signedness;
64 }
65 
66 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
67   if (!scale)
68     return IntegerType();
69   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // Float Type
74 //===----------------------------------------------------------------------===//
75 
76 unsigned FloatType::getWidth() {
77   if (isa<Float16Type, BFloat16Type>())
78     return 16;
79   if (isa<Float32Type>())
80     return 32;
81   if (isa<Float64Type>())
82     return 64;
83   if (isa<Float80Type>())
84     return 80;
85   if (isa<Float128Type>())
86     return 128;
87   llvm_unreachable("unexpected float type");
88 }
89 
90 /// Returns the floating semantics for the given type.
91 const llvm::fltSemantics &FloatType::getFloatSemantics() {
92   if (isa<BFloat16Type>())
93     return APFloat::BFloat();
94   if (isa<Float16Type>())
95     return APFloat::IEEEhalf();
96   if (isa<Float32Type>())
97     return APFloat::IEEEsingle();
98   if (isa<Float64Type>())
99     return APFloat::IEEEdouble();
100   if (isa<Float80Type>())
101     return APFloat::x87DoubleExtended();
102   if (isa<Float128Type>())
103     return APFloat::IEEEquad();
104   llvm_unreachable("non-floating point type used");
105 }
106 
107 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
108   if (!scale)
109     return FloatType();
110   MLIRContext *ctx = getContext();
111   if (isF16() || isBF16()) {
112     if (scale == 2)
113       return FloatType::getF32(ctx);
114     if (scale == 4)
115       return FloatType::getF64(ctx);
116   }
117   if (isF32())
118     if (scale == 2)
119       return FloatType::getF64(ctx);
120   return FloatType();
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // FunctionType
125 //===----------------------------------------------------------------------===//
126 
127 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
128 
129 ArrayRef<Type> FunctionType::getInputs() const {
130   return getImpl()->getInputs();
131 }
132 
133 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
134 
135 ArrayRef<Type> FunctionType::getResults() const {
136   return getImpl()->getResults();
137 }
138 
139 /// Helper to call a callback once on each index in the range
140 /// [0, `totalIndices`), *except* for the indices given in `indices`.
141 /// `indices` is allowed to have duplicates and can be in any order.
142 inline void iterateIndicesExcept(unsigned totalIndices,
143                                  ArrayRef<unsigned> indices,
144                                  function_ref<void(unsigned)> callback) {
145   llvm::BitVector skipIndices(totalIndices);
146   for (unsigned i : indices)
147     skipIndices.set(i);
148 
149   for (unsigned i = 0; i < totalIndices; ++i)
150     if (!skipIndices.test(i))
151       callback(i);
152 }
153 
154 /// Returns a new function type without the specified arguments and results.
155 FunctionType
156 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
157                                        ArrayRef<unsigned> resultIndices) {
158   ArrayRef<Type> newInputTypes = getInputs();
159   SmallVector<Type, 4> newInputTypesBuffer;
160   if (!argIndices.empty()) {
161     unsigned originalNumArgs = getNumInputs();
162     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
163       newInputTypesBuffer.emplace_back(getInput(i));
164     });
165     newInputTypes = newInputTypesBuffer;
166   }
167 
168   ArrayRef<Type> newResultTypes = getResults();
169   SmallVector<Type, 4> newResultTypesBuffer;
170   if (!resultIndices.empty()) {
171     unsigned originalNumResults = getNumResults();
172     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
173       newResultTypesBuffer.emplace_back(getResult(i));
174     });
175     newResultTypes = newResultTypesBuffer;
176   }
177 
178   return get(getContext(), newInputTypes, newResultTypes);
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // OpaqueType
183 //===----------------------------------------------------------------------===//
184 
185 /// Verify the construction of an opaque type.
186 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
187                                                        Identifier dialect,
188                                                        StringRef typeData) {
189   if (!Dialect::isValidNamespace(dialect.strref()))
190     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
191   return success();
192 }
193 
194 //===----------------------------------------------------------------------===//
195 // ShapedType
196 //===----------------------------------------------------------------------===//
197 constexpr int64_t ShapedType::kDynamicSize;
198 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
199 
200 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
201   if (auto other = dyn_cast<MemRefType>()) {
202     MemRefType::Builder b(other);
203     b.setShape(shape);
204     b.setElementType(elementType);
205     return b;
206   }
207 
208   if (auto other = dyn_cast<UnrankedMemRefType>()) {
209     MemRefType::Builder b(shape, elementType);
210     b.setMemorySpace(other.getMemorySpace());
211     return b;
212   }
213 
214   if (isa<TensorType>())
215     return RankedTensorType::get(shape, elementType);
216 
217   if (isa<VectorType>())
218     return VectorType::get(shape, elementType);
219 
220   llvm_unreachable("Unhandled ShapedType clone case");
221 }
222 
223 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
224   if (auto other = dyn_cast<MemRefType>()) {
225     MemRefType::Builder b(other);
226     b.setShape(shape);
227     return b;
228   }
229 
230   if (auto other = dyn_cast<UnrankedMemRefType>()) {
231     MemRefType::Builder b(shape, other.getElementType());
232     b.setShape(shape);
233     b.setMemorySpace(other.getMemorySpace());
234     return b;
235   }
236 
237   if (isa<TensorType>())
238     return RankedTensorType::get(shape, getElementType());
239 
240   if (isa<VectorType>())
241     return VectorType::get(shape, getElementType());
242 
243   llvm_unreachable("Unhandled ShapedType clone case");
244 }
245 
246 ShapedType ShapedType::clone(Type elementType) {
247   if (auto other = dyn_cast<MemRefType>()) {
248     MemRefType::Builder b(other);
249     b.setElementType(elementType);
250     return b;
251   }
252 
253   if (auto other = dyn_cast<UnrankedMemRefType>()) {
254     return UnrankedMemRefType::get(elementType, other.getMemorySpace());
255   }
256 
257   if (isa<TensorType>()) {
258     if (hasRank())
259       return RankedTensorType::get(getShape(), elementType);
260     return UnrankedTensorType::get(elementType);
261   }
262 
263   if (isa<VectorType>())
264     return VectorType::get(getShape(), elementType);
265 
266   llvm_unreachable("Unhandled ShapedType clone hit");
267 }
268 
269 Type ShapedType::getElementType() const {
270   return static_cast<ImplType *>(impl)->elementType;
271 }
272 
273 unsigned ShapedType::getElementTypeBitWidth() const {
274   return getElementType().getIntOrFloatBitWidth();
275 }
276 
277 int64_t ShapedType::getNumElements() const {
278   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
279   auto shape = getShape();
280   int64_t num = 1;
281   for (auto dim : shape) {
282     num *= dim;
283     assert(num >= 0 && "integer overflow in element count computation");
284   }
285   return num;
286 }
287 
288 int64_t ShapedType::getRank() const {
289   assert(hasRank() && "cannot query rank of unranked shaped type");
290   return getShape().size();
291 }
292 
293 bool ShapedType::hasRank() const {
294   return !isa<UnrankedMemRefType, UnrankedTensorType>();
295 }
296 
297 int64_t ShapedType::getDimSize(unsigned idx) const {
298   assert(idx < getRank() && "invalid index for shaped type");
299   return getShape()[idx];
300 }
301 
302 bool ShapedType::isDynamicDim(unsigned idx) const {
303   assert(idx < getRank() && "invalid index for shaped type");
304   return isDynamic(getShape()[idx]);
305 }
306 
307 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
308   assert(index < getRank() && "invalid index");
309   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
310   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
311 }
312 
313 /// Get the number of bits require to store a value of the given shaped type.
314 /// Compute the value recursively since tensors are allowed to have vectors as
315 /// elements.
316 int64_t ShapedType::getSizeInBits() const {
317   assert(hasStaticShape() &&
318          "cannot get the bit size of an aggregate with a dynamic shape");
319 
320   auto elementType = getElementType();
321   if (elementType.isIntOrFloat())
322     return elementType.getIntOrFloatBitWidth() * getNumElements();
323 
324   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
325     elementType = complexType.getElementType();
326     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
327   }
328 
329   // Tensors can have vectors and other tensors as elements, other shaped types
330   // cannot.
331   assert(isa<TensorType>() && "unsupported element type");
332   assert((elementType.isa<VectorType, TensorType>()) &&
333          "unsupported tensor element type");
334   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
335 }
336 
337 ArrayRef<int64_t> ShapedType::getShape() const {
338   if (auto vectorType = dyn_cast<VectorType>())
339     return vectorType.getShape();
340   if (auto tensorType = dyn_cast<RankedTensorType>())
341     return tensorType.getShape();
342   return cast<MemRefType>().getShape();
343 }
344 
345 int64_t ShapedType::getNumDynamicDims() const {
346   return llvm::count_if(getShape(), isDynamic);
347 }
348 
349 bool ShapedType::hasStaticShape() const {
350   return hasRank() && llvm::none_of(getShape(), isDynamic);
351 }
352 
353 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
354   return hasStaticShape() && getShape() == shape;
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // VectorType
359 //===----------------------------------------------------------------------===//
360 
361 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
362   return Base::get(elementType.getContext(), shape, elementType);
363 }
364 
365 VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape,
366                                   Type elementType) {
367   return Base::getChecked(location, shape, elementType);
368 }
369 
370 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
371                                                        ArrayRef<int64_t> shape,
372                                                        Type elementType) {
373   if (shape.empty())
374     return emitError(loc, "vector types must have at least one dimension");
375 
376   if (!isValidElementType(elementType))
377     return emitError(loc, "vector elements must be int or float type");
378 
379   if (any_of(shape, [](int64_t i) { return i <= 0; }))
380     return emitError(loc, "vector types must have positive constant sizes");
381 
382   return success();
383 }
384 
385 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
386 
387 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
388   if (!scale)
389     return VectorType();
390   if (auto et = getElementType().dyn_cast<IntegerType>())
391     if (auto scaledEt = et.scaleElementBitwidth(scale))
392       return VectorType::get(getShape(), scaledEt);
393   if (auto et = getElementType().dyn_cast<FloatType>())
394     if (auto scaledEt = et.scaleElementBitwidth(scale))
395       return VectorType::get(getShape(), scaledEt);
396   return VectorType();
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // TensorType
401 //===----------------------------------------------------------------------===//
402 
403 // Check if "elementType" can be an element type of a tensor. Emit errors if
404 // location is not nullptr.  Returns failure if check failed.
405 static LogicalResult checkTensorElementType(Location location,
406                                             Type elementType) {
407   if (!TensorType::isValidElementType(elementType))
408     return emitError(location, "invalid tensor element type: ") << elementType;
409   return success();
410 }
411 
412 /// Return true if the specified element type is ok in a tensor.
413 bool TensorType::isValidElementType(Type type) {
414   // Note: Non standard/builtin types are allowed to exist within tensor
415   // types. Dialects are expected to verify that tensor types have a valid
416   // element type within that dialect.
417   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
418                   IndexType>() ||
419          !type.getDialect().getNamespace().empty();
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // RankedTensorType
424 //===----------------------------------------------------------------------===//
425 
426 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
427                                        Type elementType) {
428   return Base::get(elementType.getContext(), shape, elementType);
429 }
430 
431 RankedTensorType RankedTensorType::getChecked(Location location,
432                                               ArrayRef<int64_t> shape,
433                                               Type elementType) {
434   return Base::getChecked(location, shape, elementType);
435 }
436 
437 LogicalResult RankedTensorType::verifyConstructionInvariants(
438     Location loc, ArrayRef<int64_t> shape, Type elementType) {
439   for (int64_t s : shape) {
440     if (s < -1)
441       return emitError(loc, "invalid tensor dimension size");
442   }
443   return checkTensorElementType(loc, elementType);
444 }
445 
446 ArrayRef<int64_t> RankedTensorType::getShape() const {
447   return getImpl()->getShape();
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // UnrankedTensorType
452 //===----------------------------------------------------------------------===//
453 
454 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
455   return Base::get(elementType.getContext(), elementType);
456 }
457 
458 UnrankedTensorType UnrankedTensorType::getChecked(Location location,
459                                                   Type elementType) {
460   return Base::getChecked(location, elementType);
461 }
462 
463 LogicalResult
464 UnrankedTensorType::verifyConstructionInvariants(Location loc,
465                                                  Type elementType) {
466   return checkTensorElementType(loc, elementType);
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // BaseMemRefType
471 //===----------------------------------------------------------------------===//
472 
473 unsigned BaseMemRefType::getMemorySpace() const {
474   return static_cast<ImplType *>(impl)->memorySpace;
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // MemRefType
479 //===----------------------------------------------------------------------===//
480 
481 /// Get or create a new MemRefType based on shape, element type, affine
482 /// map composition, and memory space.  Assumes the arguments define a
483 /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
484 /// construction failures.
485 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
486                            ArrayRef<AffineMap> affineMapComposition,
487                            unsigned memorySpace) {
488   auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
489                         /*location=*/llvm::None);
490   assert(result && "Failed to construct instance of MemRefType.");
491   return result;
492 }
493 
494 /// Get or create a new MemRefType based on shape, element type, affine
495 /// map composition, and memory space declared at the given location.
496 /// If the location is unknown, the last argument should be an instance of
497 /// UnknownLoc.  If the MemRefType defined by the arguments would be
498 /// ill-formed, emits errors (to the handler registered with the context or to
499 /// the error stream) and returns nullptr.
500 MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
501                                   Type elementType,
502                                   ArrayRef<AffineMap> affineMapComposition,
503                                   unsigned memorySpace) {
504   return getImpl(shape, elementType, affineMapComposition, memorySpace,
505                  location);
506 }
507 
508 /// Get or create a new MemRefType defined by the arguments.  If the resulting
509 /// type would be ill-formed, return nullptr.  If the location is provided,
510 /// emit detailed error messages.  To emit errors when the location is unknown,
511 /// pass in an instance of UnknownLoc.
512 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
513                                ArrayRef<AffineMap> affineMapComposition,
514                                unsigned memorySpace,
515                                Optional<Location> location) {
516   auto *context = elementType.getContext();
517 
518   if (!BaseMemRefType::isValidElementType(elementType))
519     return (void)emitOptionalError(location, "invalid memref element type"),
520            MemRefType();
521 
522   for (int64_t s : shape) {
523     // Negative sizes are not allowed except for `-1` that means dynamic size.
524     if (s < -1)
525       return (void)emitOptionalError(location, "invalid memref size"),
526              MemRefType();
527   }
528 
529   // Check that the structure of the composition is valid, i.e. that each
530   // subsequent affine map has as many inputs as the previous map has results.
531   // Take the dimensionality of the MemRef for the first map.
532   auto dim = shape.size();
533   unsigned i = 0;
534   for (const auto &affineMap : affineMapComposition) {
535     if (affineMap.getNumDims() != dim) {
536       if (location)
537         emitError(*location)
538             << "memref affine map dimension mismatch between "
539             << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
540             << " and affine map" << i + 1 << ": " << dim
541             << " != " << affineMap.getNumDims();
542       return nullptr;
543     }
544 
545     dim = affineMap.getNumResults();
546     ++i;
547   }
548 
549   // Drop identity maps from the composition.
550   // This may lead to the composition becoming empty, which is interpreted as an
551   // implicit identity.
552   SmallVector<AffineMap, 2> cleanedAffineMapComposition;
553   for (const auto &map : affineMapComposition) {
554     if (map.isIdentity())
555       continue;
556     cleanedAffineMapComposition.push_back(map);
557   }
558 
559   return Base::get(context, shape, elementType, cleanedAffineMapComposition,
560                    memorySpace);
561 }
562 
563 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
564 
565 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
566   return getImpl()->getAffineMaps();
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // UnrankedMemRefType
571 //===----------------------------------------------------------------------===//
572 
573 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
574                                            unsigned memorySpace) {
575   return Base::get(elementType.getContext(), elementType, memorySpace);
576 }
577 
578 UnrankedMemRefType UnrankedMemRefType::getChecked(Location location,
579                                                   Type elementType,
580                                                   unsigned memorySpace) {
581   return Base::getChecked(location, elementType, memorySpace);
582 }
583 
584 LogicalResult
585 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
586                                                  unsigned memorySpace) {
587   if (!BaseMemRefType::isValidElementType(elementType))
588     return emitError(loc, "invalid memref element type");
589   return success();
590 }
591 
592 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
593 // i.e. single term). Accumulate the AffineExpr into the existing one.
594 static void extractStridesFromTerm(AffineExpr e,
595                                    AffineExpr multiplicativeFactor,
596                                    MutableArrayRef<AffineExpr> strides,
597                                    AffineExpr &offset) {
598   if (auto dim = e.dyn_cast<AffineDimExpr>())
599     strides[dim.getPosition()] =
600         strides[dim.getPosition()] + multiplicativeFactor;
601   else
602     offset = offset + e * multiplicativeFactor;
603 }
604 
605 /// Takes a single AffineExpr `e` and populates the `strides` array with the
606 /// strides expressions for each dim position.
607 /// The convention is that the strides for dimensions d0, .. dn appear in
608 /// order to make indexing intuitive into the result.
609 static LogicalResult extractStrides(AffineExpr e,
610                                     AffineExpr multiplicativeFactor,
611                                     MutableArrayRef<AffineExpr> strides,
612                                     AffineExpr &offset) {
613   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
614   if (!bin) {
615     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
616     return success();
617   }
618 
619   if (bin.getKind() == AffineExprKind::CeilDiv ||
620       bin.getKind() == AffineExprKind::FloorDiv ||
621       bin.getKind() == AffineExprKind::Mod)
622     return failure();
623 
624   if (bin.getKind() == AffineExprKind::Mul) {
625     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
626     if (dim) {
627       strides[dim.getPosition()] =
628           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
629       return success();
630     }
631     // LHS and RHS may both contain complex expressions of dims. Try one path
632     // and if it fails try the other. This is guaranteed to succeed because
633     // only one path may have a `dim`, otherwise this is not an AffineExpr in
634     // the first place.
635     if (bin.getLHS().isSymbolicOrConstant())
636       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
637                             strides, offset);
638     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
639                           strides, offset);
640   }
641 
642   if (bin.getKind() == AffineExprKind::Add) {
643     auto res1 =
644         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
645     auto res2 =
646         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
647     return success(succeeded(res1) && succeeded(res2));
648   }
649 
650   llvm_unreachable("unexpected binary operation");
651 }
652 
653 LogicalResult mlir::getStridesAndOffset(MemRefType t,
654                                         SmallVectorImpl<AffineExpr> &strides,
655                                         AffineExpr &offset) {
656   auto affineMaps = t.getAffineMaps();
657   // For now strides are only computed on a single affine map with a single
658   // result (i.e. the closed subset of linearization maps that are compatible
659   // with striding semantics).
660   // TODO: support more forms on a per-need basis.
661   if (affineMaps.size() > 1)
662     return failure();
663   if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
664     return failure();
665 
666   auto zero = getAffineConstantExpr(0, t.getContext());
667   auto one = getAffineConstantExpr(1, t.getContext());
668   offset = zero;
669   strides.assign(t.getRank(), zero);
670 
671   AffineMap m;
672   if (!affineMaps.empty()) {
673     m = affineMaps.front();
674     assert(!m.isIdentity() && "unexpected identity map");
675   }
676 
677   // Canonical case for empty map.
678   if (!m) {
679     // 0-D corner case, offset is already 0.
680     if (t.getRank() == 0)
681       return success();
682     auto stridedExpr =
683         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
684     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
685       return success();
686     assert(false && "unexpected failure: extract strides in canonical layout");
687   }
688 
689   // Non-canonical case requires more work.
690   auto stridedExpr =
691       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
692   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
693     offset = AffineExpr();
694     strides.clear();
695     return failure();
696   }
697 
698   // Simplify results to allow folding to constants and simple checks.
699   unsigned numDims = m.getNumDims();
700   unsigned numSymbols = m.getNumSymbols();
701   offset = simplifyAffineExpr(offset, numDims, numSymbols);
702   for (auto &stride : strides)
703     stride = simplifyAffineExpr(stride, numDims, numSymbols);
704 
705   /// In practice, a strided memref must be internally non-aliasing. Test
706   /// against 0 as a proxy.
707   /// TODO: static cases can have more advanced checks.
708   /// TODO: dynamic cases would require a way to compare symbolic
709   /// expressions and would probably need an affine set context propagated
710   /// everywhere.
711   if (llvm::any_of(strides, [](AffineExpr e) {
712         return e == getAffineConstantExpr(0, e.getContext());
713       })) {
714     offset = AffineExpr();
715     strides.clear();
716     return failure();
717   }
718 
719   return success();
720 }
721 
722 LogicalResult mlir::getStridesAndOffset(MemRefType t,
723                                         SmallVectorImpl<int64_t> &strides,
724                                         int64_t &offset) {
725   AffineExpr offsetExpr;
726   SmallVector<AffineExpr, 4> strideExprs;
727   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
728     return failure();
729   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
730     offset = cst.getValue();
731   else
732     offset = ShapedType::kDynamicStrideOrOffset;
733   for (auto e : strideExprs) {
734     if (auto c = e.dyn_cast<AffineConstantExpr>())
735       strides.push_back(c.getValue());
736     else
737       strides.push_back(ShapedType::kDynamicStrideOrOffset);
738   }
739   return success();
740 }
741 
742 //===----------------------------------------------------------------------===//
743 /// TupleType
744 //===----------------------------------------------------------------------===//
745 
746 /// Return the elements types for this tuple.
747 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
748 
749 /// Accumulate the types contained in this tuple and tuples nested within it.
750 /// Note that this only flattens nested tuples, not any other container type,
751 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
752 /// (i32, tensor<i32>, f32, i64)
753 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
754   for (Type type : getTypes()) {
755     if (auto nestedTuple = type.dyn_cast<TupleType>())
756       nestedTuple.getFlattenedTypes(types);
757     else
758       types.push_back(type);
759   }
760 }
761 
762 /// Return the number of element types.
763 size_t TupleType::size() const { return getImpl()->size(); }
764 
765 //===----------------------------------------------------------------------===//
766 // Type Utilities
767 //===----------------------------------------------------------------------===//
768 
769 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
770                                            int64_t offset,
771                                            MLIRContext *context) {
772   AffineExpr expr;
773   unsigned nSymbols = 0;
774 
775   // AffineExpr for offset.
776   // Static case.
777   if (offset != MemRefType::getDynamicStrideOrOffset()) {
778     auto cst = getAffineConstantExpr(offset, context);
779     expr = cst;
780   } else {
781     // Dynamic case, new symbol for the offset.
782     auto sym = getAffineSymbolExpr(nSymbols++, context);
783     expr = sym;
784   }
785 
786   // AffineExpr for strides.
787   for (auto en : llvm::enumerate(strides)) {
788     auto dim = en.index();
789     auto stride = en.value();
790     assert(stride != 0 && "Invalid stride specification");
791     auto d = getAffineDimExpr(dim, context);
792     AffineExpr mult;
793     // Static case.
794     if (stride != MemRefType::getDynamicStrideOrOffset())
795       mult = getAffineConstantExpr(stride, context);
796     else
797       // Dynamic case, new symbol for each new stride.
798       mult = getAffineSymbolExpr(nSymbols++, context);
799     expr = expr + d * mult;
800   }
801 
802   return AffineMap::get(strides.size(), nSymbols, expr);
803 }
804 
805 /// Return a version of `t` with identity layout if it can be determined
806 /// statically that the layout is the canonical contiguous strided layout.
807 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
808 /// `t` with simplified layout.
809 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
810 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
811   auto affineMaps = t.getAffineMaps();
812   // Already in canonical form.
813   if (affineMaps.empty())
814     return t;
815 
816   // Can't reduce to canonical identity form, return in canonical form.
817   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
818     return t;
819 
820   // Corner-case for 0-D affine maps.
821   auto m = affineMaps[0];
822   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
823     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
824       if (cst.getValue() == 0)
825         return MemRefType::Builder(t).setAffineMaps({});
826     return t;
827   }
828 
829   // 0-D corner case for empty shape that still have an affine map. Example:
830   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
831   // offset needs to remain, just return t.
832   if (t.getShape().empty())
833     return t;
834 
835   // If the canonical strided layout for the sizes of `t` is equal to the
836   // simplified layout of `t` we can just return an empty layout. Otherwise,
837   // just simplify the existing layout.
838   AffineExpr expr =
839       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
840   auto simplifiedLayoutExpr =
841       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
842   if (expr != simplifiedLayoutExpr)
843     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
844         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
845   return MemRefType::Builder(t).setAffineMaps({});
846 }
847 
848 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
849                                                 ArrayRef<AffineExpr> exprs,
850                                                 MLIRContext *context) {
851   assert(!sizes.empty() && !exprs.empty() &&
852          "expected non-empty sizes and exprs");
853 
854   // Size 0 corner case is useful for canonicalizations.
855   if (llvm::is_contained(sizes, 0))
856     return getAffineConstantExpr(0, context);
857 
858   auto maps = AffineMap::inferFromExprList(exprs);
859   assert(!maps.empty() && "Expected one non-empty map");
860   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
861 
862   AffineExpr expr;
863   bool dynamicPoisonBit = false;
864   int64_t runningSize = 1;
865   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
866     int64_t size = std::get<1>(en);
867     // Degenerate case, no size =-> no stride
868     if (size == 0)
869       continue;
870     AffineExpr dimExpr = std::get<0>(en);
871     AffineExpr stride = dynamicPoisonBit
872                             ? getAffineSymbolExpr(nSymbols++, context)
873                             : getAffineConstantExpr(runningSize, context);
874     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
875     if (size > 0) {
876       runningSize *= size;
877       assert(runningSize > 0 && "integer overflow in size computation");
878     } else {
879       dynamicPoisonBit = true;
880     }
881   }
882   return simplifyAffineExpr(expr, numDims, nSymbols);
883 }
884 
885 /// Return a version of `t` with a layout that has all dynamic offset and
886 /// strides. This is used to erase the static layout.
887 MemRefType mlir::eraseStridedLayout(MemRefType t) {
888   auto val = ShapedType::kDynamicStrideOrOffset;
889   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
890       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
891 }
892 
893 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
894                                                 MLIRContext *context) {
895   SmallVector<AffineExpr, 4> exprs;
896   exprs.reserve(sizes.size());
897   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
898     exprs.push_back(getAffineDimExpr(dim, context));
899   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
900 }
901 
902 /// Return true if the layout for `t` is compatible with strided semantics.
903 bool mlir::isStrided(MemRefType t) {
904   int64_t offset;
905   SmallVector<int64_t, 4> strides;
906   auto res = getStridesAndOffset(t, strides, offset);
907   return succeeded(res);
908 }
909 
910 /// Return the layout map in strided linear layout AffineMap form.
911 /// Return null if the layout is not compatible with a strided layout.
912 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
913   int64_t offset;
914   SmallVector<int64_t, 4> strides;
915   if (failed(getStridesAndOffset(t, strides, offset)))
916     return AffineMap();
917   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
918 }
919