1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/BitVector.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 //===----------------------------------------------------------------------===//
28 /// Tablegen Type Definitions
29 //===----------------------------------------------------------------------===//
30 
31 #define GET_TYPEDEF_CLASSES
32 #include "mlir/IR/BuiltinTypes.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 /// Tablegen Interface Definitions
36 //===----------------------------------------------------------------------===//
37 
38 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
39 
40 //===----------------------------------------------------------------------===//
41 // BuiltinDialect
42 //===----------------------------------------------------------------------===//
43 
44 void BuiltinDialect::registerTypes() {
45   addTypes<
46 #define GET_TYPEDEF_LIST
47 #include "mlir/IR/BuiltinTypes.cpp.inc"
48       >();
49 }
50 
51 //===----------------------------------------------------------------------===//
52 /// ComplexType
53 //===----------------------------------------------------------------------===//
54 
55 /// Verify the construction of an integer type.
56 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
57                                   Type elementType) {
58   if (!elementType.isIntOrFloat())
59     return emitError() << "invalid element type for complex";
60   return success();
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Integer Type
65 //===----------------------------------------------------------------------===//
66 
67 // static constexpr must have a definition (until in C++17 and inline variable).
68 constexpr unsigned IntegerType::kMaxWidth;
69 
70 /// Verify the construction of an integer type.
71 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
72                                   unsigned width,
73                                   SignednessSemantics signedness) {
74   if (width > IntegerType::kMaxWidth) {
75     return emitError() << "integer bitwidth is limited to "
76                        << IntegerType::kMaxWidth << " bits";
77   }
78   return success();
79 }
80 
81 unsigned IntegerType::getWidth() const { return getImpl()->width; }
82 
83 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
84   return getImpl()->signedness;
85 }
86 
87 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
88   if (!scale)
89     return IntegerType();
90   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Float Type
95 //===----------------------------------------------------------------------===//
96 
97 unsigned FloatType::getWidth() {
98   if (isa<Float16Type, BFloat16Type>())
99     return 16;
100   if (isa<Float32Type>())
101     return 32;
102   if (isa<Float64Type>())
103     return 64;
104   if (isa<Float80Type>())
105     return 80;
106   if (isa<Float128Type>())
107     return 128;
108   llvm_unreachable("unexpected float type");
109 }
110 
111 /// Returns the floating semantics for the given type.
112 const llvm::fltSemantics &FloatType::getFloatSemantics() {
113   if (isa<BFloat16Type>())
114     return APFloat::BFloat();
115   if (isa<Float16Type>())
116     return APFloat::IEEEhalf();
117   if (isa<Float32Type>())
118     return APFloat::IEEEsingle();
119   if (isa<Float64Type>())
120     return APFloat::IEEEdouble();
121   if (isa<Float80Type>())
122     return APFloat::x87DoubleExtended();
123   if (isa<Float128Type>())
124     return APFloat::IEEEquad();
125   llvm_unreachable("non-floating point type used");
126 }
127 
128 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
129   if (!scale)
130     return FloatType();
131   MLIRContext *ctx = getContext();
132   if (isF16() || isBF16()) {
133     if (scale == 2)
134       return FloatType::getF32(ctx);
135     if (scale == 4)
136       return FloatType::getF64(ctx);
137   }
138   if (isF32())
139     if (scale == 2)
140       return FloatType::getF64(ctx);
141   return FloatType();
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // FunctionType
146 //===----------------------------------------------------------------------===//
147 
148 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
149 
150 ArrayRef<Type> FunctionType::getInputs() const {
151   return getImpl()->getInputs();
152 }
153 
154 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
155 
156 ArrayRef<Type> FunctionType::getResults() const {
157   return getImpl()->getResults();
158 }
159 
160 /// Helper to call a callback once on each index in the range
161 /// [0, `totalIndices`), *except* for the indices given in `indices`.
162 /// `indices` is allowed to have duplicates and can be in any order.
163 inline void iterateIndicesExcept(unsigned totalIndices,
164                                  ArrayRef<unsigned> indices,
165                                  function_ref<void(unsigned)> callback) {
166   llvm::BitVector skipIndices(totalIndices);
167   for (unsigned i : indices)
168     skipIndices.set(i);
169 
170   for (unsigned i = 0; i < totalIndices; ++i)
171     if (!skipIndices.test(i))
172       callback(i);
173 }
174 
175 /// Returns a new function type without the specified arguments and results.
176 FunctionType
177 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
178                                        ArrayRef<unsigned> resultIndices) {
179   ArrayRef<Type> newInputTypes = getInputs();
180   SmallVector<Type, 4> newInputTypesBuffer;
181   if (!argIndices.empty()) {
182     unsigned originalNumArgs = getNumInputs();
183     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
184       newInputTypesBuffer.emplace_back(getInput(i));
185     });
186     newInputTypes = newInputTypesBuffer;
187   }
188 
189   ArrayRef<Type> newResultTypes = getResults();
190   SmallVector<Type, 4> newResultTypesBuffer;
191   if (!resultIndices.empty()) {
192     unsigned originalNumResults = getNumResults();
193     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
194       newResultTypesBuffer.emplace_back(getResult(i));
195     });
196     newResultTypes = newResultTypesBuffer;
197   }
198 
199   return get(getContext(), newInputTypes, newResultTypes);
200 }
201 
202 void FunctionType::walkImmediateSubElements(
203     function_ref<void(Attribute)> walkAttrsFn,
204     function_ref<void(Type)> walkTypesFn) const {
205   for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
206     walkTypesFn(type);
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // OpaqueType
211 //===----------------------------------------------------------------------===//
212 
213 /// Verify the construction of an opaque type.
214 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
215                                  Identifier dialect, StringRef typeData) {
216   if (!Dialect::isValidNamespace(dialect.strref()))
217     return emitError() << "invalid dialect namespace '" << dialect << "'";
218 
219   // Check that the dialect is actually registered.
220   MLIRContext *context = dialect.getContext();
221   if (!context->allowsUnregisteredDialects() &&
222       !context->getLoadedDialect(dialect.strref())) {
223     return emitError()
224            << "`!" << dialect << "<\"" << typeData << "\">"
225            << "` type created with unregistered dialect. If this is "
226               "intended, please call allowUnregisteredDialects() on the "
227               "MLIRContext, or use -allow-unregistered-dialect with "
228               "mlir-opt";
229   }
230 
231   return success();
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // ShapedType
236 //===----------------------------------------------------------------------===//
237 constexpr int64_t ShapedType::kDynamicSize;
238 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
239 
240 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
241   if (auto other = dyn_cast<MemRefType>()) {
242     MemRefType::Builder b(other);
243     b.setShape(shape);
244     b.setElementType(elementType);
245     return b;
246   }
247 
248   if (auto other = dyn_cast<UnrankedMemRefType>()) {
249     MemRefType::Builder b(shape, elementType);
250     b.setMemorySpace(other.getMemorySpace());
251     return b;
252   }
253 
254   if (isa<TensorType>())
255     return RankedTensorType::get(shape, elementType);
256 
257   if (isa<VectorType>())
258     return VectorType::get(shape, elementType);
259 
260   llvm_unreachable("Unhandled ShapedType clone case");
261 }
262 
263 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
264   if (auto other = dyn_cast<MemRefType>()) {
265     MemRefType::Builder b(other);
266     b.setShape(shape);
267     return b;
268   }
269 
270   if (auto other = dyn_cast<UnrankedMemRefType>()) {
271     MemRefType::Builder b(shape, other.getElementType());
272     b.setShape(shape);
273     b.setMemorySpace(other.getMemorySpace());
274     return b;
275   }
276 
277   if (isa<TensorType>())
278     return RankedTensorType::get(shape, getElementType());
279 
280   if (isa<VectorType>())
281     return VectorType::get(shape, getElementType());
282 
283   llvm_unreachable("Unhandled ShapedType clone case");
284 }
285 
286 ShapedType ShapedType::clone(Type elementType) {
287   if (auto other = dyn_cast<MemRefType>()) {
288     MemRefType::Builder b(other);
289     b.setElementType(elementType);
290     return b;
291   }
292 
293   if (auto other = dyn_cast<UnrankedMemRefType>()) {
294     return UnrankedMemRefType::get(elementType, other.getMemorySpace());
295   }
296 
297   if (isa<TensorType>()) {
298     if (hasRank())
299       return RankedTensorType::get(getShape(), elementType);
300     return UnrankedTensorType::get(elementType);
301   }
302 
303   if (isa<VectorType>())
304     return VectorType::get(getShape(), elementType);
305 
306   llvm_unreachable("Unhandled ShapedType clone hit");
307 }
308 
309 Type ShapedType::getElementType() const {
310   return TypeSwitch<Type, Type>(*this)
311       .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
312             UnrankedMemRefType>([](auto ty) { return ty.getElementType(); });
313 }
314 
315 unsigned ShapedType::getElementTypeBitWidth() const {
316   return getElementType().getIntOrFloatBitWidth();
317 }
318 
319 int64_t ShapedType::getNumElements() const {
320   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
321   auto shape = getShape();
322   int64_t num = 1;
323   for (auto dim : shape) {
324     num *= dim;
325     assert(num >= 0 && "integer overflow in element count computation");
326   }
327   return num;
328 }
329 
330 int64_t ShapedType::getRank() const {
331   assert(hasRank() && "cannot query rank of unranked shaped type");
332   return getShape().size();
333 }
334 
335 bool ShapedType::hasRank() const {
336   return !isa<UnrankedMemRefType, UnrankedTensorType>();
337 }
338 
339 int64_t ShapedType::getDimSize(unsigned idx) const {
340   assert(idx < getRank() && "invalid index for shaped type");
341   return getShape()[idx];
342 }
343 
344 bool ShapedType::isDynamicDim(unsigned idx) const {
345   assert(idx < getRank() && "invalid index for shaped type");
346   return isDynamic(getShape()[idx]);
347 }
348 
349 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
350   assert(index < getRank() && "invalid index");
351   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
352   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
353 }
354 
355 /// Get the number of bits require to store a value of the given shaped type.
356 /// Compute the value recursively since tensors are allowed to have vectors as
357 /// elements.
358 int64_t ShapedType::getSizeInBits() const {
359   assert(hasStaticShape() &&
360          "cannot get the bit size of an aggregate with a dynamic shape");
361 
362   auto elementType = getElementType();
363   if (elementType.isIntOrFloat())
364     return elementType.getIntOrFloatBitWidth() * getNumElements();
365 
366   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
367     elementType = complexType.getElementType();
368     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
369   }
370 
371   // Tensors can have vectors and other tensors as elements, other shaped types
372   // cannot.
373   assert(isa<TensorType>() && "unsupported element type");
374   assert((elementType.isa<VectorType, TensorType>()) &&
375          "unsupported tensor element type");
376   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
377 }
378 
379 ArrayRef<int64_t> ShapedType::getShape() const {
380   if (auto vectorType = dyn_cast<VectorType>())
381     return vectorType.getShape();
382   if (auto tensorType = dyn_cast<RankedTensorType>())
383     return tensorType.getShape();
384   return cast<MemRefType>().getShape();
385 }
386 
387 int64_t ShapedType::getNumDynamicDims() const {
388   return llvm::count_if(getShape(), isDynamic);
389 }
390 
391 bool ShapedType::hasStaticShape() const {
392   return hasRank() && llvm::none_of(getShape(), isDynamic);
393 }
394 
395 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
396   return hasStaticShape() && getShape() == shape;
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // VectorType
401 //===----------------------------------------------------------------------===//
402 
403 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
404                                  ArrayRef<int64_t> shape, Type elementType) {
405   if (shape.empty())
406     return emitError() << "vector types must have at least one dimension";
407 
408   if (!isValidElementType(elementType))
409     return emitError() << "vector elements must be int/index/float type";
410 
411   if (any_of(shape, [](int64_t i) { return i <= 0; }))
412     return emitError() << "vector types must have positive constant sizes";
413 
414   return success();
415 }
416 
417 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
418   if (!scale)
419     return VectorType();
420   if (auto et = getElementType().dyn_cast<IntegerType>())
421     if (auto scaledEt = et.scaleElementBitwidth(scale))
422       return VectorType::get(getShape(), scaledEt);
423   if (auto et = getElementType().dyn_cast<FloatType>())
424     if (auto scaledEt = et.scaleElementBitwidth(scale))
425       return VectorType::get(getShape(), scaledEt);
426   return VectorType();
427 }
428 
429 void VectorType::walkImmediateSubElements(
430     function_ref<void(Attribute)> walkAttrsFn,
431     function_ref<void(Type)> walkTypesFn) const {
432   walkTypesFn(getElementType());
433 }
434 
435 //===----------------------------------------------------------------------===//
436 // TensorType
437 //===----------------------------------------------------------------------===//
438 
439 // Check if "elementType" can be an element type of a tensor.
440 static LogicalResult
441 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
442                        Type elementType) {
443   if (!TensorType::isValidElementType(elementType))
444     return emitError() << "invalid tensor element type: " << elementType;
445   return success();
446 }
447 
448 /// Return true if the specified element type is ok in a tensor.
449 bool TensorType::isValidElementType(Type type) {
450   // Note: Non standard/builtin types are allowed to exist within tensor
451   // types. Dialects are expected to verify that tensor types have a valid
452   // element type within that dialect.
453   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
454                   IndexType>() ||
455          !type.getDialect().getNamespace().empty();
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // RankedTensorType
460 //===----------------------------------------------------------------------===//
461 
462 LogicalResult
463 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
464                          ArrayRef<int64_t> shape, Type elementType,
465                          Attribute encoding) {
466   for (int64_t s : shape)
467     if (s < -1)
468       return emitError() << "invalid tensor dimension size";
469   if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
470     if (failed(v.verifyEncoding(shape, elementType, emitError)))
471       return failure();
472   return checkTensorElementType(emitError, elementType);
473 }
474 
475 void RankedTensorType::walkImmediateSubElements(
476     function_ref<void(Attribute)> walkAttrsFn,
477     function_ref<void(Type)> walkTypesFn) const {
478   walkTypesFn(getElementType());
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // UnrankedTensorType
483 //===----------------------------------------------------------------------===//
484 
485 LogicalResult
486 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
487                            Type elementType) {
488   return checkTensorElementType(emitError, elementType);
489 }
490 
491 void UnrankedTensorType::walkImmediateSubElements(
492     function_ref<void(Attribute)> walkAttrsFn,
493     function_ref<void(Type)> walkTypesFn) const {
494   walkTypesFn(getElementType());
495 }
496 
497 //===----------------------------------------------------------------------===//
498 // BaseMemRefType
499 //===----------------------------------------------------------------------===//
500 
501 Attribute BaseMemRefType::getMemorySpace() const {
502   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
503     return rankedMemRefTy.getMemorySpace();
504   return cast<UnrankedMemRefType>().getMemorySpace();
505 }
506 
507 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
508   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
509     return rankedMemRefTy.getMemorySpaceAsInt();
510   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // MemRefType
515 //===----------------------------------------------------------------------===//
516 
517 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
518 /// `originalShape` with some `1` entries erased, return the set of indices
519 /// that specifies which of the entries of `originalShape` are dropped to obtain
520 /// `reducedShape`. The returned mask can be applied as a projection to
521 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
522 /// which dimensions must be kept when e.g. compute MemRef strides under
523 /// rank-reducing operations. Return None if reducedShape cannot be obtained
524 /// by dropping only `1` entries in `originalShape`.
525 llvm::Optional<llvm::SmallDenseSet<unsigned>>
526 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
527                                ArrayRef<int64_t> reducedShape) {
528   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
529   llvm::SmallDenseSet<unsigned> unusedDims;
530   unsigned reducedIdx = 0;
531   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
532     // Greedily insert `originalIdx` if no match.
533     if (reducedIdx < reducedRank &&
534         originalShape[originalIdx] == reducedShape[reducedIdx]) {
535       reducedIdx++;
536       continue;
537     }
538 
539     unusedDims.insert(originalIdx);
540     // If no match on `originalIdx`, the `originalShape` at this dimension
541     // must be 1, otherwise we bail.
542     if (originalShape[originalIdx] != 1)
543       return llvm::None;
544   }
545   // The whole reducedShape must be scanned, otherwise we bail.
546   if (reducedIdx != reducedRank)
547     return llvm::None;
548   return unusedDims;
549 }
550 
551 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
552   // Empty attribute is allowed as default memory space.
553   if (!memorySpace)
554     return true;
555 
556   // Supported built-in attributes.
557   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
558     return true;
559 
560   // Allow custom dialect attributes.
561   if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect()))
562     return true;
563 
564   return false;
565 }
566 
567 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
568                                                MLIRContext *ctx) {
569   if (memorySpace == 0)
570     return nullptr;
571 
572   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
573 }
574 
575 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
576   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
577   if (intMemorySpace && intMemorySpace.getValue() == 0)
578     return nullptr;
579 
580   return memorySpace;
581 }
582 
583 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
584   if (!memorySpace)
585     return 0;
586 
587   assert(memorySpace.isa<IntegerAttr>() &&
588          "Using `getMemorySpaceInteger` with non-Integer attribute");
589 
590   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
591 }
592 
593 MemRefType::Builder &
594 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
595   memorySpace =
596       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
597   return *this;
598 }
599 
600 unsigned MemRefType::getMemorySpaceAsInt() const {
601   return detail::getMemorySpaceAsInt(getMemorySpace());
602 }
603 
604 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
605                                  ArrayRef<int64_t> shape, Type elementType,
606                                  ArrayRef<AffineMap> affineMapComposition,
607                                  Attribute memorySpace) {
608   if (!BaseMemRefType::isValidElementType(elementType))
609     return emitError() << "invalid memref element type";
610 
611   // Negative sizes are not allowed except for `-1` that means dynamic size.
612   for (int64_t s : shape)
613     if (s < -1)
614       return emitError() << "invalid memref size";
615 
616   // Check that the structure of the composition is valid, i.e. that each
617   // subsequent affine map has as many inputs as the previous map has results.
618   // Take the dimensionality of the MemRef for the first map.
619   size_t dim = shape.size();
620   for (auto it : llvm::enumerate(affineMapComposition)) {
621     AffineMap map = it.value();
622     if (map.getNumDims() == dim) {
623       dim = map.getNumResults();
624       continue;
625     }
626     return emitError() << "memref affine map dimension mismatch between "
627                        << (it.index() == 0 ? Twine("memref rank")
628                                            : "affine map " + Twine(it.index()))
629                        << " and affine map" << it.index() + 1 << ": " << dim
630                        << " != " << map.getNumDims();
631   }
632 
633   if (!isSupportedMemorySpace(memorySpace)) {
634     return emitError() << "unsupported memory space Attribute";
635   }
636 
637   return success();
638 }
639 
640 void MemRefType::walkImmediateSubElements(
641     function_ref<void(Attribute)> walkAttrsFn,
642     function_ref<void(Type)> walkTypesFn) const {
643   walkTypesFn(getElementType());
644   walkAttrsFn(getMemorySpace());
645   for (AffineMap map : getAffineMaps())
646     walkAttrsFn(AffineMapAttr::get(map));
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // UnrankedMemRefType
651 //===----------------------------------------------------------------------===//
652 
653 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
654   return detail::getMemorySpaceAsInt(getMemorySpace());
655 }
656 
657 LogicalResult
658 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
659                            Type elementType, Attribute memorySpace) {
660   if (!BaseMemRefType::isValidElementType(elementType))
661     return emitError() << "invalid memref element type";
662 
663   if (!isSupportedMemorySpace(memorySpace))
664     return emitError() << "unsupported memory space Attribute";
665 
666   return success();
667 }
668 
669 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
670 // i.e. single term). Accumulate the AffineExpr into the existing one.
671 static void extractStridesFromTerm(AffineExpr e,
672                                    AffineExpr multiplicativeFactor,
673                                    MutableArrayRef<AffineExpr> strides,
674                                    AffineExpr &offset) {
675   if (auto dim = e.dyn_cast<AffineDimExpr>())
676     strides[dim.getPosition()] =
677         strides[dim.getPosition()] + multiplicativeFactor;
678   else
679     offset = offset + e * multiplicativeFactor;
680 }
681 
682 /// Takes a single AffineExpr `e` and populates the `strides` array with the
683 /// strides expressions for each dim position.
684 /// The convention is that the strides for dimensions d0, .. dn appear in
685 /// order to make indexing intuitive into the result.
686 static LogicalResult extractStrides(AffineExpr e,
687                                     AffineExpr multiplicativeFactor,
688                                     MutableArrayRef<AffineExpr> strides,
689                                     AffineExpr &offset) {
690   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
691   if (!bin) {
692     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
693     return success();
694   }
695 
696   if (bin.getKind() == AffineExprKind::CeilDiv ||
697       bin.getKind() == AffineExprKind::FloorDiv ||
698       bin.getKind() == AffineExprKind::Mod)
699     return failure();
700 
701   if (bin.getKind() == AffineExprKind::Mul) {
702     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
703     if (dim) {
704       strides[dim.getPosition()] =
705           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
706       return success();
707     }
708     // LHS and RHS may both contain complex expressions of dims. Try one path
709     // and if it fails try the other. This is guaranteed to succeed because
710     // only one path may have a `dim`, otherwise this is not an AffineExpr in
711     // the first place.
712     if (bin.getLHS().isSymbolicOrConstant())
713       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
714                             strides, offset);
715     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
716                           strides, offset);
717   }
718 
719   if (bin.getKind() == AffineExprKind::Add) {
720     auto res1 =
721         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
722     auto res2 =
723         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
724     return success(succeeded(res1) && succeeded(res2));
725   }
726 
727   llvm_unreachable("unexpected binary operation");
728 }
729 
730 LogicalResult mlir::getStridesAndOffset(MemRefType t,
731                                         SmallVectorImpl<AffineExpr> &strides,
732                                         AffineExpr &offset) {
733   auto affineMaps = t.getAffineMaps();
734 
735   if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1)
736     return failure();
737 
738   AffineMap m;
739   if (!affineMaps.empty()) {
740     m = affineMaps.back();
741     for (size_t i = affineMaps.size() - 1; i > 0; --i)
742       m = m.compose(affineMaps[i - 1]);
743     assert(!m.isIdentity() && "unexpected identity map");
744   }
745 
746   auto zero = getAffineConstantExpr(0, t.getContext());
747   auto one = getAffineConstantExpr(1, t.getContext());
748   offset = zero;
749   strides.assign(t.getRank(), zero);
750 
751   // Canonical case for empty map.
752   if (!m) {
753     // 0-D corner case, offset is already 0.
754     if (t.getRank() == 0)
755       return success();
756     auto stridedExpr =
757         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
758     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
759       return success();
760     assert(false && "unexpected failure: extract strides in canonical layout");
761   }
762 
763   // Non-canonical case requires more work.
764   auto stridedExpr =
765       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
766   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
767     offset = AffineExpr();
768     strides.clear();
769     return failure();
770   }
771 
772   // Simplify results to allow folding to constants and simple checks.
773   unsigned numDims = m.getNumDims();
774   unsigned numSymbols = m.getNumSymbols();
775   offset = simplifyAffineExpr(offset, numDims, numSymbols);
776   for (auto &stride : strides)
777     stride = simplifyAffineExpr(stride, numDims, numSymbols);
778 
779   /// In practice, a strided memref must be internally non-aliasing. Test
780   /// against 0 as a proxy.
781   /// TODO: static cases can have more advanced checks.
782   /// TODO: dynamic cases would require a way to compare symbolic
783   /// expressions and would probably need an affine set context propagated
784   /// everywhere.
785   if (llvm::any_of(strides, [](AffineExpr e) {
786         return e == getAffineConstantExpr(0, e.getContext());
787       })) {
788     offset = AffineExpr();
789     strides.clear();
790     return failure();
791   }
792 
793   return success();
794 }
795 
796 LogicalResult mlir::getStridesAndOffset(MemRefType t,
797                                         SmallVectorImpl<int64_t> &strides,
798                                         int64_t &offset) {
799   AffineExpr offsetExpr;
800   SmallVector<AffineExpr, 4> strideExprs;
801   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
802     return failure();
803   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
804     offset = cst.getValue();
805   else
806     offset = ShapedType::kDynamicStrideOrOffset;
807   for (auto e : strideExprs) {
808     if (auto c = e.dyn_cast<AffineConstantExpr>())
809       strides.push_back(c.getValue());
810     else
811       strides.push_back(ShapedType::kDynamicStrideOrOffset);
812   }
813   return success();
814 }
815 
816 void UnrankedMemRefType::walkImmediateSubElements(
817     function_ref<void(Attribute)> walkAttrsFn,
818     function_ref<void(Type)> walkTypesFn) const {
819   walkTypesFn(getElementType());
820   walkAttrsFn(getMemorySpace());
821 }
822 
823 //===----------------------------------------------------------------------===//
824 /// TupleType
825 //===----------------------------------------------------------------------===//
826 
827 /// Return the elements types for this tuple.
828 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
829 
830 /// Accumulate the types contained in this tuple and tuples nested within it.
831 /// Note that this only flattens nested tuples, not any other container type,
832 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
833 /// (i32, tensor<i32>, f32, i64)
834 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
835   for (Type type : getTypes()) {
836     if (auto nestedTuple = type.dyn_cast<TupleType>())
837       nestedTuple.getFlattenedTypes(types);
838     else
839       types.push_back(type);
840   }
841 }
842 
843 /// Return the number of element types.
844 size_t TupleType::size() const { return getImpl()->size(); }
845 
846 void TupleType::walkImmediateSubElements(
847     function_ref<void(Attribute)> walkAttrsFn,
848     function_ref<void(Type)> walkTypesFn) const {
849   for (Type type : getTypes())
850     walkTypesFn(type);
851 }
852 
853 //===----------------------------------------------------------------------===//
854 // Type Utilities
855 //===----------------------------------------------------------------------===//
856 
857 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
858                                            int64_t offset,
859                                            MLIRContext *context) {
860   AffineExpr expr;
861   unsigned nSymbols = 0;
862 
863   // AffineExpr for offset.
864   // Static case.
865   if (offset != MemRefType::getDynamicStrideOrOffset()) {
866     auto cst = getAffineConstantExpr(offset, context);
867     expr = cst;
868   } else {
869     // Dynamic case, new symbol for the offset.
870     auto sym = getAffineSymbolExpr(nSymbols++, context);
871     expr = sym;
872   }
873 
874   // AffineExpr for strides.
875   for (auto en : llvm::enumerate(strides)) {
876     auto dim = en.index();
877     auto stride = en.value();
878     assert(stride != 0 && "Invalid stride specification");
879     auto d = getAffineDimExpr(dim, context);
880     AffineExpr mult;
881     // Static case.
882     if (stride != MemRefType::getDynamicStrideOrOffset())
883       mult = getAffineConstantExpr(stride, context);
884     else
885       // Dynamic case, new symbol for each new stride.
886       mult = getAffineSymbolExpr(nSymbols++, context);
887     expr = expr + d * mult;
888   }
889 
890   return AffineMap::get(strides.size(), nSymbols, expr);
891 }
892 
893 /// Return a version of `t` with identity layout if it can be determined
894 /// statically that the layout is the canonical contiguous strided layout.
895 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
896 /// `t` with simplified layout.
897 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
898 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
899   auto affineMaps = t.getAffineMaps();
900   // Already in canonical form.
901   if (affineMaps.empty())
902     return t;
903 
904   // Can't reduce to canonical identity form, return in canonical form.
905   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
906     return t;
907 
908   // Corner-case for 0-D affine maps.
909   auto m = affineMaps[0];
910   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
911     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
912       if (cst.getValue() == 0)
913         return MemRefType::Builder(t).setAffineMaps({});
914     return t;
915   }
916 
917   // 0-D corner case for empty shape that still have an affine map. Example:
918   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
919   // offset needs to remain, just return t.
920   if (t.getShape().empty())
921     return t;
922 
923   // If the canonical strided layout for the sizes of `t` is equal to the
924   // simplified layout of `t` we can just return an empty layout. Otherwise,
925   // just simplify the existing layout.
926   AffineExpr expr =
927       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
928   auto simplifiedLayoutExpr =
929       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
930   if (expr != simplifiedLayoutExpr)
931     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
932         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
933   return MemRefType::Builder(t).setAffineMaps({});
934 }
935 
936 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
937                                                 ArrayRef<AffineExpr> exprs,
938                                                 MLIRContext *context) {
939   assert(!sizes.empty() && !exprs.empty() &&
940          "expected non-empty sizes and exprs");
941 
942   // Size 0 corner case is useful for canonicalizations.
943   if (llvm::is_contained(sizes, 0))
944     return getAffineConstantExpr(0, context);
945 
946   auto maps = AffineMap::inferFromExprList(exprs);
947   assert(!maps.empty() && "Expected one non-empty map");
948   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
949 
950   AffineExpr expr;
951   bool dynamicPoisonBit = false;
952   int64_t runningSize = 1;
953   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
954     int64_t size = std::get<1>(en);
955     // Degenerate case, no size =-> no stride
956     if (size == 0)
957       continue;
958     AffineExpr dimExpr = std::get<0>(en);
959     AffineExpr stride = dynamicPoisonBit
960                             ? getAffineSymbolExpr(nSymbols++, context)
961                             : getAffineConstantExpr(runningSize, context);
962     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
963     if (size > 0) {
964       runningSize *= size;
965       assert(runningSize > 0 && "integer overflow in size computation");
966     } else {
967       dynamicPoisonBit = true;
968     }
969   }
970   return simplifyAffineExpr(expr, numDims, nSymbols);
971 }
972 
973 /// Return a version of `t` with a layout that has all dynamic offset and
974 /// strides. This is used to erase the static layout.
975 MemRefType mlir::eraseStridedLayout(MemRefType t) {
976   auto val = ShapedType::kDynamicStrideOrOffset;
977   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
978       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
979 }
980 
981 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
982                                                 MLIRContext *context) {
983   SmallVector<AffineExpr, 4> exprs;
984   exprs.reserve(sizes.size());
985   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
986     exprs.push_back(getAffineDimExpr(dim, context));
987   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
988 }
989 
990 /// Return true if the layout for `t` is compatible with strided semantics.
991 bool mlir::isStrided(MemRefType t) {
992   int64_t offset;
993   SmallVector<int64_t, 4> strides;
994   auto res = getStridesAndOffset(t, strides, offset);
995   return succeeded(res);
996 }
997 
998 /// Return the layout map in strided linear layout AffineMap form.
999 /// Return null if the layout is not compatible with a strided layout.
1000 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
1001   int64_t offset;
1002   SmallVector<int64_t, 4> strides;
1003   if (failed(getStridesAndOffset(t, strides, offset)))
1004     return AffineMap();
1005   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1006 }
1007