1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "llvm/ADT/APFloat.h"
18 #include "llvm/ADT/BitVector.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 //===----------------------------------------------------------------------===//
27 /// Tablegen Type Definitions
28 //===----------------------------------------------------------------------===//
29 
30 #define GET_TYPEDEF_CLASSES
31 #include "mlir/IR/BuiltinTypes.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // BuiltinDialect
35 //===----------------------------------------------------------------------===//
36 
37 void BuiltinDialect::registerTypes() {
38   addTypes<
39 #define GET_TYPEDEF_LIST
40 #include "mlir/IR/BuiltinTypes.cpp.inc"
41       >();
42 }
43 
44 //===----------------------------------------------------------------------===//
45 /// ComplexType
46 //===----------------------------------------------------------------------===//
47 
48 /// Verify the construction of an integer type.
49 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
50                                   Type elementType) {
51   if (!elementType.isIntOrFloat())
52     return emitError() << "invalid element type for complex";
53   return success();
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Integer Type
58 //===----------------------------------------------------------------------===//
59 
60 // static constexpr must have a definition (until in C++17 and inline variable).
61 constexpr unsigned IntegerType::kMaxWidth;
62 
63 /// Verify the construction of an integer type.
64 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
65                                   unsigned width,
66                                   SignednessSemantics signedness) {
67   if (width > IntegerType::kMaxWidth) {
68     return emitError() << "integer bitwidth is limited to "
69                        << IntegerType::kMaxWidth << " bits";
70   }
71   return success();
72 }
73 
74 unsigned IntegerType::getWidth() const { return getImpl()->width; }
75 
76 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
77   return getImpl()->signedness;
78 }
79 
80 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
81   if (!scale)
82     return IntegerType();
83   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Float Type
88 //===----------------------------------------------------------------------===//
89 
90 unsigned FloatType::getWidth() {
91   if (isa<Float16Type, BFloat16Type>())
92     return 16;
93   if (isa<Float32Type>())
94     return 32;
95   if (isa<Float64Type>())
96     return 64;
97   if (isa<Float80Type>())
98     return 80;
99   if (isa<Float128Type>())
100     return 128;
101   llvm_unreachable("unexpected float type");
102 }
103 
104 /// Returns the floating semantics for the given type.
105 const llvm::fltSemantics &FloatType::getFloatSemantics() {
106   if (isa<BFloat16Type>())
107     return APFloat::BFloat();
108   if (isa<Float16Type>())
109     return APFloat::IEEEhalf();
110   if (isa<Float32Type>())
111     return APFloat::IEEEsingle();
112   if (isa<Float64Type>())
113     return APFloat::IEEEdouble();
114   if (isa<Float80Type>())
115     return APFloat::x87DoubleExtended();
116   if (isa<Float128Type>())
117     return APFloat::IEEEquad();
118   llvm_unreachable("non-floating point type used");
119 }
120 
121 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
122   if (!scale)
123     return FloatType();
124   MLIRContext *ctx = getContext();
125   if (isF16() || isBF16()) {
126     if (scale == 2)
127       return FloatType::getF32(ctx);
128     if (scale == 4)
129       return FloatType::getF64(ctx);
130   }
131   if (isF32())
132     if (scale == 2)
133       return FloatType::getF64(ctx);
134   return FloatType();
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // FunctionType
139 //===----------------------------------------------------------------------===//
140 
141 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
142 
143 ArrayRef<Type> FunctionType::getInputs() const {
144   return getImpl()->getInputs();
145 }
146 
147 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
148 
149 ArrayRef<Type> FunctionType::getResults() const {
150   return getImpl()->getResults();
151 }
152 
153 /// Helper to call a callback once on each index in the range
154 /// [0, `totalIndices`), *except* for the indices given in `indices`.
155 /// `indices` is allowed to have duplicates and can be in any order.
156 inline void iterateIndicesExcept(unsigned totalIndices,
157                                  ArrayRef<unsigned> indices,
158                                  function_ref<void(unsigned)> callback) {
159   llvm::BitVector skipIndices(totalIndices);
160   for (unsigned i : indices)
161     skipIndices.set(i);
162 
163   for (unsigned i = 0; i < totalIndices; ++i)
164     if (!skipIndices.test(i))
165       callback(i);
166 }
167 
168 /// Returns a new function type without the specified arguments and results.
169 FunctionType
170 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
171                                        ArrayRef<unsigned> resultIndices) {
172   ArrayRef<Type> newInputTypes = getInputs();
173   SmallVector<Type, 4> newInputTypesBuffer;
174   if (!argIndices.empty()) {
175     unsigned originalNumArgs = getNumInputs();
176     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
177       newInputTypesBuffer.emplace_back(getInput(i));
178     });
179     newInputTypes = newInputTypesBuffer;
180   }
181 
182   ArrayRef<Type> newResultTypes = getResults();
183   SmallVector<Type, 4> newResultTypesBuffer;
184   if (!resultIndices.empty()) {
185     unsigned originalNumResults = getNumResults();
186     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
187       newResultTypesBuffer.emplace_back(getResult(i));
188     });
189     newResultTypes = newResultTypesBuffer;
190   }
191 
192   return get(getContext(), newInputTypes, newResultTypes);
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // OpaqueType
197 //===----------------------------------------------------------------------===//
198 
199 /// Verify the construction of an opaque type.
200 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
201                                  Identifier dialect, StringRef typeData) {
202   if (!Dialect::isValidNamespace(dialect.strref()))
203     return emitError() << "invalid dialect namespace '" << dialect << "'";
204   return success();
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // ShapedType
209 //===----------------------------------------------------------------------===//
210 constexpr int64_t ShapedType::kDynamicSize;
211 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
212 
213 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
214   if (auto other = dyn_cast<MemRefType>()) {
215     MemRefType::Builder b(other);
216     b.setShape(shape);
217     b.setElementType(elementType);
218     return b;
219   }
220 
221   if (auto other = dyn_cast<UnrankedMemRefType>()) {
222     MemRefType::Builder b(shape, elementType);
223     b.setMemorySpace(other.getMemorySpace());
224     return b;
225   }
226 
227   if (isa<TensorType>())
228     return RankedTensorType::get(shape, elementType);
229 
230   if (isa<VectorType>())
231     return VectorType::get(shape, elementType);
232 
233   llvm_unreachable("Unhandled ShapedType clone case");
234 }
235 
236 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
237   if (auto other = dyn_cast<MemRefType>()) {
238     MemRefType::Builder b(other);
239     b.setShape(shape);
240     return b;
241   }
242 
243   if (auto other = dyn_cast<UnrankedMemRefType>()) {
244     MemRefType::Builder b(shape, other.getElementType());
245     b.setShape(shape);
246     b.setMemorySpace(other.getMemorySpace());
247     return b;
248   }
249 
250   if (isa<TensorType>())
251     return RankedTensorType::get(shape, getElementType());
252 
253   if (isa<VectorType>())
254     return VectorType::get(shape, getElementType());
255 
256   llvm_unreachable("Unhandled ShapedType clone case");
257 }
258 
259 ShapedType ShapedType::clone(Type elementType) {
260   if (auto other = dyn_cast<MemRefType>()) {
261     MemRefType::Builder b(other);
262     b.setElementType(elementType);
263     return b;
264   }
265 
266   if (auto other = dyn_cast<UnrankedMemRefType>()) {
267     return UnrankedMemRefType::get(elementType, other.getMemorySpace());
268   }
269 
270   if (isa<TensorType>()) {
271     if (hasRank())
272       return RankedTensorType::get(getShape(), elementType);
273     return UnrankedTensorType::get(elementType);
274   }
275 
276   if (isa<VectorType>())
277     return VectorType::get(getShape(), elementType);
278 
279   llvm_unreachable("Unhandled ShapedType clone hit");
280 }
281 
282 Type ShapedType::getElementType() const {
283   return TypeSwitch<Type, Type>(*this)
284       .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
285             UnrankedMemRefType>([](auto ty) { return ty.getElementType(); });
286 }
287 
288 unsigned ShapedType::getElementTypeBitWidth() const {
289   return getElementType().getIntOrFloatBitWidth();
290 }
291 
292 int64_t ShapedType::getNumElements() const {
293   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
294   auto shape = getShape();
295   int64_t num = 1;
296   for (auto dim : shape) {
297     num *= dim;
298     assert(num >= 0 && "integer overflow in element count computation");
299   }
300   return num;
301 }
302 
303 int64_t ShapedType::getRank() const {
304   assert(hasRank() && "cannot query rank of unranked shaped type");
305   return getShape().size();
306 }
307 
308 bool ShapedType::hasRank() const {
309   return !isa<UnrankedMemRefType, UnrankedTensorType>();
310 }
311 
312 int64_t ShapedType::getDimSize(unsigned idx) const {
313   assert(idx < getRank() && "invalid index for shaped type");
314   return getShape()[idx];
315 }
316 
317 bool ShapedType::isDynamicDim(unsigned idx) const {
318   assert(idx < getRank() && "invalid index for shaped type");
319   return isDynamic(getShape()[idx]);
320 }
321 
322 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
323   assert(index < getRank() && "invalid index");
324   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
325   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
326 }
327 
328 /// Get the number of bits require to store a value of the given shaped type.
329 /// Compute the value recursively since tensors are allowed to have vectors as
330 /// elements.
331 int64_t ShapedType::getSizeInBits() const {
332   assert(hasStaticShape() &&
333          "cannot get the bit size of an aggregate with a dynamic shape");
334 
335   auto elementType = getElementType();
336   if (elementType.isIntOrFloat())
337     return elementType.getIntOrFloatBitWidth() * getNumElements();
338 
339   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
340     elementType = complexType.getElementType();
341     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
342   }
343 
344   // Tensors can have vectors and other tensors as elements, other shaped types
345   // cannot.
346   assert(isa<TensorType>() && "unsupported element type");
347   assert((elementType.isa<VectorType, TensorType>()) &&
348          "unsupported tensor element type");
349   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
350 }
351 
352 ArrayRef<int64_t> ShapedType::getShape() const {
353   if (auto vectorType = dyn_cast<VectorType>())
354     return vectorType.getShape();
355   if (auto tensorType = dyn_cast<RankedTensorType>())
356     return tensorType.getShape();
357   return cast<MemRefType>().getShape();
358 }
359 
360 int64_t ShapedType::getNumDynamicDims() const {
361   return llvm::count_if(getShape(), isDynamic);
362 }
363 
364 bool ShapedType::hasStaticShape() const {
365   return hasRank() && llvm::none_of(getShape(), isDynamic);
366 }
367 
368 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
369   return hasStaticShape() && getShape() == shape;
370 }
371 
372 //===----------------------------------------------------------------------===//
373 // VectorType
374 //===----------------------------------------------------------------------===//
375 
376 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
377                                  ArrayRef<int64_t> shape, Type elementType) {
378   if (shape.empty())
379     return emitError() << "vector types must have at least one dimension";
380 
381   if (!isValidElementType(elementType))
382     return emitError() << "vector elements must be int or float type";
383 
384   if (any_of(shape, [](int64_t i) { return i <= 0; }))
385     return emitError() << "vector types must have positive constant sizes";
386 
387   return success();
388 }
389 
390 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
391   if (!scale)
392     return VectorType();
393   if (auto et = getElementType().dyn_cast<IntegerType>())
394     if (auto scaledEt = et.scaleElementBitwidth(scale))
395       return VectorType::get(getShape(), scaledEt);
396   if (auto et = getElementType().dyn_cast<FloatType>())
397     if (auto scaledEt = et.scaleElementBitwidth(scale))
398       return VectorType::get(getShape(), scaledEt);
399   return VectorType();
400 }
401 
402 //===----------------------------------------------------------------------===//
403 // TensorType
404 //===----------------------------------------------------------------------===//
405 
406 // Check if "elementType" can be an element type of a tensor.
407 static LogicalResult
408 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
409                        Type elementType) {
410   if (!TensorType::isValidElementType(elementType))
411     return emitError() << "invalid tensor element type: " << elementType;
412   return success();
413 }
414 
415 /// Return true if the specified element type is ok in a tensor.
416 bool TensorType::isValidElementType(Type type) {
417   // Note: Non standard/builtin types are allowed to exist within tensor
418   // types. Dialects are expected to verify that tensor types have a valid
419   // element type within that dialect.
420   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
421                   IndexType>() ||
422          !type.getDialect().getNamespace().empty();
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // RankedTensorType
427 //===----------------------------------------------------------------------===//
428 
429 LogicalResult
430 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
431                          ArrayRef<int64_t> shape, Type elementType) {
432   for (int64_t s : shape)
433     if (s < -1)
434       return emitError() << "invalid tensor dimension size";
435   return checkTensorElementType(emitError, elementType);
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // UnrankedTensorType
440 //===----------------------------------------------------------------------===//
441 
442 LogicalResult
443 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
444                            Type elementType) {
445   return checkTensorElementType(emitError, elementType);
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // BaseMemRefType
450 //===----------------------------------------------------------------------===//
451 
452 Attribute BaseMemRefType::getMemorySpace() const {
453   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
454     return rankedMemRefTy.getMemorySpace();
455   return cast<UnrankedMemRefType>().getMemorySpace();
456 }
457 
458 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
459   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
460     return rankedMemRefTy.getMemorySpaceAsInt();
461   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
462 }
463 
464 //===----------------------------------------------------------------------===//
465 // MemRefType
466 //===----------------------------------------------------------------------===//
467 
468 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
469 /// `originalShape` with some `1` entries erased, return the set of indices
470 /// that specifies which of the entries of `originalShape` are dropped to obtain
471 /// `reducedShape`. The returned mask can be applied as a projection to
472 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
473 /// which dimensions must be kept when e.g. compute MemRef strides under
474 /// rank-reducing operations. Return None if reducedShape cannot be obtained
475 /// by dropping only `1` entries in `originalShape`.
476 llvm::Optional<llvm::SmallDenseSet<unsigned>>
477 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
478                                ArrayRef<int64_t> reducedShape) {
479   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
480   llvm::SmallDenseSet<unsigned> unusedDims;
481   unsigned reducedIdx = 0;
482   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
483     // Greedily insert `originalIdx` if no match.
484     if (reducedIdx < reducedRank &&
485         originalShape[originalIdx] == reducedShape[reducedIdx]) {
486       reducedIdx++;
487       continue;
488     }
489 
490     unusedDims.insert(originalIdx);
491     // If no match on `originalIdx`, the `originalShape` at this dimension
492     // must be 1, otherwise we bail.
493     if (originalShape[originalIdx] != 1)
494       return llvm::None;
495   }
496   // The whole reducedShape must be scanned, otherwise we bail.
497   if (reducedIdx != reducedRank)
498     return llvm::None;
499   return unusedDims;
500 }
501 
502 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
503   // Empty attribute is allowed as default memory space.
504   if (!memorySpace)
505     return true;
506 
507   // Supported built-in attributes.
508   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
509     return true;
510 
511   // Allow custom dialect attributes.
512   if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect()))
513     return true;
514 
515   return false;
516 }
517 
518 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
519                                                MLIRContext *ctx) {
520   if (memorySpace == 0)
521     return nullptr;
522 
523   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
524 }
525 
526 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
527   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
528   if (intMemorySpace && intMemorySpace.getValue() == 0)
529     return nullptr;
530 
531   return memorySpace;
532 }
533 
534 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
535   if (!memorySpace)
536     return 0;
537 
538   assert(memorySpace.isa<IntegerAttr>() &&
539          "Using `getMemorySpaceInteger` with non-Integer attribute");
540 
541   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
542 }
543 
544 MemRefType::Builder &
545 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
546   memorySpace =
547       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
548   return *this;
549 }
550 
551 unsigned MemRefType::getMemorySpaceAsInt() const {
552   return detail::getMemorySpaceAsInt(getMemorySpace());
553 }
554 
555 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
556                                  ArrayRef<int64_t> shape, Type elementType,
557                                  ArrayRef<AffineMap> affineMapComposition,
558                                  Attribute memorySpace) {
559   if (!BaseMemRefType::isValidElementType(elementType))
560     return emitError() << "invalid memref element type";
561 
562   // Negative sizes are not allowed except for `-1` that means dynamic size.
563   for (int64_t s : shape)
564     if (s < -1)
565       return emitError() << "invalid memref size";
566 
567   // Check that the structure of the composition is valid, i.e. that each
568   // subsequent affine map has as many inputs as the previous map has results.
569   // Take the dimensionality of the MemRef for the first map.
570   size_t dim = shape.size();
571   for (auto it : llvm::enumerate(affineMapComposition)) {
572     AffineMap map = it.value();
573     if (map.getNumDims() == dim) {
574       dim = map.getNumResults();
575       continue;
576     }
577     return emitError() << "memref affine map dimension mismatch between "
578                        << (it.index() == 0 ? Twine("memref rank")
579                                            : "affine map " + Twine(it.index()))
580                        << " and affine map" << it.index() + 1 << ": " << dim
581                        << " != " << map.getNumDims();
582   }
583 
584   if (!isSupportedMemorySpace(memorySpace)) {
585     return emitError() << "unsupported memory space Attribute";
586   }
587 
588   return success();
589 }
590 
591 //===----------------------------------------------------------------------===//
592 // UnrankedMemRefType
593 //===----------------------------------------------------------------------===//
594 
595 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
596   return detail::getMemorySpaceAsInt(getMemorySpace());
597 }
598 
599 LogicalResult
600 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
601                            Type elementType, Attribute memorySpace) {
602   if (!BaseMemRefType::isValidElementType(elementType))
603     return emitError() << "invalid memref element type";
604 
605   if (!isSupportedMemorySpace(memorySpace))
606     return emitError() << "unsupported memory space Attribute";
607 
608   return success();
609 }
610 
611 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
612 // i.e. single term). Accumulate the AffineExpr into the existing one.
613 static void extractStridesFromTerm(AffineExpr e,
614                                    AffineExpr multiplicativeFactor,
615                                    MutableArrayRef<AffineExpr> strides,
616                                    AffineExpr &offset) {
617   if (auto dim = e.dyn_cast<AffineDimExpr>())
618     strides[dim.getPosition()] =
619         strides[dim.getPosition()] + multiplicativeFactor;
620   else
621     offset = offset + e * multiplicativeFactor;
622 }
623 
624 /// Takes a single AffineExpr `e` and populates the `strides` array with the
625 /// strides expressions for each dim position.
626 /// The convention is that the strides for dimensions d0, .. dn appear in
627 /// order to make indexing intuitive into the result.
628 static LogicalResult extractStrides(AffineExpr e,
629                                     AffineExpr multiplicativeFactor,
630                                     MutableArrayRef<AffineExpr> strides,
631                                     AffineExpr &offset) {
632   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
633   if (!bin) {
634     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
635     return success();
636   }
637 
638   if (bin.getKind() == AffineExprKind::CeilDiv ||
639       bin.getKind() == AffineExprKind::FloorDiv ||
640       bin.getKind() == AffineExprKind::Mod)
641     return failure();
642 
643   if (bin.getKind() == AffineExprKind::Mul) {
644     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
645     if (dim) {
646       strides[dim.getPosition()] =
647           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
648       return success();
649     }
650     // LHS and RHS may both contain complex expressions of dims. Try one path
651     // and if it fails try the other. This is guaranteed to succeed because
652     // only one path may have a `dim`, otherwise this is not an AffineExpr in
653     // the first place.
654     if (bin.getLHS().isSymbolicOrConstant())
655       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
656                             strides, offset);
657     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
658                           strides, offset);
659   }
660 
661   if (bin.getKind() == AffineExprKind::Add) {
662     auto res1 =
663         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
664     auto res2 =
665         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
666     return success(succeeded(res1) && succeeded(res2));
667   }
668 
669   llvm_unreachable("unexpected binary operation");
670 }
671 
672 LogicalResult mlir::getStridesAndOffset(MemRefType t,
673                                         SmallVectorImpl<AffineExpr> &strides,
674                                         AffineExpr &offset) {
675   auto affineMaps = t.getAffineMaps();
676 
677   if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1)
678     return failure();
679 
680   AffineMap m;
681   if (!affineMaps.empty()) {
682     m = affineMaps.back();
683     for (size_t i = affineMaps.size() - 1; i > 0; --i)
684       m = m.compose(affineMaps[i - 1]);
685     assert(!m.isIdentity() && "unexpected identity map");
686   }
687 
688   auto zero = getAffineConstantExpr(0, t.getContext());
689   auto one = getAffineConstantExpr(1, t.getContext());
690   offset = zero;
691   strides.assign(t.getRank(), zero);
692 
693   // Canonical case for empty map.
694   if (!m) {
695     // 0-D corner case, offset is already 0.
696     if (t.getRank() == 0)
697       return success();
698     auto stridedExpr =
699         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
700     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
701       return success();
702     assert(false && "unexpected failure: extract strides in canonical layout");
703   }
704 
705   // Non-canonical case requires more work.
706   auto stridedExpr =
707       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
708   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
709     offset = AffineExpr();
710     strides.clear();
711     return failure();
712   }
713 
714   // Simplify results to allow folding to constants and simple checks.
715   unsigned numDims = m.getNumDims();
716   unsigned numSymbols = m.getNumSymbols();
717   offset = simplifyAffineExpr(offset, numDims, numSymbols);
718   for (auto &stride : strides)
719     stride = simplifyAffineExpr(stride, numDims, numSymbols);
720 
721   /// In practice, a strided memref must be internally non-aliasing. Test
722   /// against 0 as a proxy.
723   /// TODO: static cases can have more advanced checks.
724   /// TODO: dynamic cases would require a way to compare symbolic
725   /// expressions and would probably need an affine set context propagated
726   /// everywhere.
727   if (llvm::any_of(strides, [](AffineExpr e) {
728         return e == getAffineConstantExpr(0, e.getContext());
729       })) {
730     offset = AffineExpr();
731     strides.clear();
732     return failure();
733   }
734 
735   return success();
736 }
737 
738 LogicalResult mlir::getStridesAndOffset(MemRefType t,
739                                         SmallVectorImpl<int64_t> &strides,
740                                         int64_t &offset) {
741   AffineExpr offsetExpr;
742   SmallVector<AffineExpr, 4> strideExprs;
743   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
744     return failure();
745   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
746     offset = cst.getValue();
747   else
748     offset = ShapedType::kDynamicStrideOrOffset;
749   for (auto e : strideExprs) {
750     if (auto c = e.dyn_cast<AffineConstantExpr>())
751       strides.push_back(c.getValue());
752     else
753       strides.push_back(ShapedType::kDynamicStrideOrOffset);
754   }
755   return success();
756 }
757 
758 //===----------------------------------------------------------------------===//
759 /// TupleType
760 //===----------------------------------------------------------------------===//
761 
762 /// Return the elements types for this tuple.
763 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
764 
765 /// Accumulate the types contained in this tuple and tuples nested within it.
766 /// Note that this only flattens nested tuples, not any other container type,
767 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
768 /// (i32, tensor<i32>, f32, i64)
769 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
770   for (Type type : getTypes()) {
771     if (auto nestedTuple = type.dyn_cast<TupleType>())
772       nestedTuple.getFlattenedTypes(types);
773     else
774       types.push_back(type);
775   }
776 }
777 
778 /// Return the number of element types.
779 size_t TupleType::size() const { return getImpl()->size(); }
780 
781 //===----------------------------------------------------------------------===//
782 // Type Utilities
783 //===----------------------------------------------------------------------===//
784 
785 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
786                                            int64_t offset,
787                                            MLIRContext *context) {
788   AffineExpr expr;
789   unsigned nSymbols = 0;
790 
791   // AffineExpr for offset.
792   // Static case.
793   if (offset != MemRefType::getDynamicStrideOrOffset()) {
794     auto cst = getAffineConstantExpr(offset, context);
795     expr = cst;
796   } else {
797     // Dynamic case, new symbol for the offset.
798     auto sym = getAffineSymbolExpr(nSymbols++, context);
799     expr = sym;
800   }
801 
802   // AffineExpr for strides.
803   for (auto en : llvm::enumerate(strides)) {
804     auto dim = en.index();
805     auto stride = en.value();
806     assert(stride != 0 && "Invalid stride specification");
807     auto d = getAffineDimExpr(dim, context);
808     AffineExpr mult;
809     // Static case.
810     if (stride != MemRefType::getDynamicStrideOrOffset())
811       mult = getAffineConstantExpr(stride, context);
812     else
813       // Dynamic case, new symbol for each new stride.
814       mult = getAffineSymbolExpr(nSymbols++, context);
815     expr = expr + d * mult;
816   }
817 
818   return AffineMap::get(strides.size(), nSymbols, expr);
819 }
820 
821 /// Return a version of `t` with identity layout if it can be determined
822 /// statically that the layout is the canonical contiguous strided layout.
823 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
824 /// `t` with simplified layout.
825 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
826 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
827   auto affineMaps = t.getAffineMaps();
828   // Already in canonical form.
829   if (affineMaps.empty())
830     return t;
831 
832   // Can't reduce to canonical identity form, return in canonical form.
833   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
834     return t;
835 
836   // Corner-case for 0-D affine maps.
837   auto m = affineMaps[0];
838   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
839     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
840       if (cst.getValue() == 0)
841         return MemRefType::Builder(t).setAffineMaps({});
842     return t;
843   }
844 
845   // 0-D corner case for empty shape that still have an affine map. Example:
846   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
847   // offset needs to remain, just return t.
848   if (t.getShape().empty())
849     return t;
850 
851   // If the canonical strided layout for the sizes of `t` is equal to the
852   // simplified layout of `t` we can just return an empty layout. Otherwise,
853   // just simplify the existing layout.
854   AffineExpr expr =
855       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
856   auto simplifiedLayoutExpr =
857       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
858   if (expr != simplifiedLayoutExpr)
859     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
860         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
861   return MemRefType::Builder(t).setAffineMaps({});
862 }
863 
864 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
865                                                 ArrayRef<AffineExpr> exprs,
866                                                 MLIRContext *context) {
867   assert(!sizes.empty() && !exprs.empty() &&
868          "expected non-empty sizes and exprs");
869 
870   // Size 0 corner case is useful for canonicalizations.
871   if (llvm::is_contained(sizes, 0))
872     return getAffineConstantExpr(0, context);
873 
874   auto maps = AffineMap::inferFromExprList(exprs);
875   assert(!maps.empty() && "Expected one non-empty map");
876   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
877 
878   AffineExpr expr;
879   bool dynamicPoisonBit = false;
880   int64_t runningSize = 1;
881   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
882     int64_t size = std::get<1>(en);
883     // Degenerate case, no size =-> no stride
884     if (size == 0)
885       continue;
886     AffineExpr dimExpr = std::get<0>(en);
887     AffineExpr stride = dynamicPoisonBit
888                             ? getAffineSymbolExpr(nSymbols++, context)
889                             : getAffineConstantExpr(runningSize, context);
890     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
891     if (size > 0) {
892       runningSize *= size;
893       assert(runningSize > 0 && "integer overflow in size computation");
894     } else {
895       dynamicPoisonBit = true;
896     }
897   }
898   return simplifyAffineExpr(expr, numDims, nSymbols);
899 }
900 
901 /// Return a version of `t` with a layout that has all dynamic offset and
902 /// strides. This is used to erase the static layout.
903 MemRefType mlir::eraseStridedLayout(MemRefType t) {
904   auto val = ShapedType::kDynamicStrideOrOffset;
905   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
906       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
907 }
908 
909 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
910                                                 MLIRContext *context) {
911   SmallVector<AffineExpr, 4> exprs;
912   exprs.reserve(sizes.size());
913   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
914     exprs.push_back(getAffineDimExpr(dim, context));
915   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
916 }
917 
918 /// Return true if the layout for `t` is compatible with strided semantics.
919 bool mlir::isStrided(MemRefType t) {
920   int64_t offset;
921   SmallVector<int64_t, 4> strides;
922   auto res = getStridesAndOffset(t, strides, offset);
923   return succeeded(res);
924 }
925 
926 /// Return the layout map in strided linear layout AffineMap form.
927 /// Return null if the layout is not compatible with a strided layout.
928 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
929   int64_t offset;
930   SmallVector<int64_t, 4> strides;
931   if (failed(getStridesAndOffset(t, strides, offset)))
932     return AffineMap();
933   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
934 }
935