1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/TensorEncoding.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 
35 //===----------------------------------------------------------------------===//
36 // BuiltinDialect
37 //===----------------------------------------------------------------------===//
38 
39 void BuiltinDialect::registerTypes() {
40   addTypes<
41 #define GET_TYPEDEF_LIST
42 #include "mlir/IR/BuiltinTypes.cpp.inc"
43       >();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 /// ComplexType
48 //===----------------------------------------------------------------------===//
49 
50 /// Verify the construction of an integer type.
51 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
52                                   Type elementType) {
53   if (!elementType.isIntOrFloat())
54     return emitError() << "invalid element type for complex";
55   return success();
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Integer Type
60 //===----------------------------------------------------------------------===//
61 
62 // static constexpr must have a definition (until in C++17 and inline variable).
63 constexpr unsigned IntegerType::kMaxWidth;
64 
65 /// Verify the construction of an integer type.
66 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
67                                   unsigned width,
68                                   SignednessSemantics signedness) {
69   if (width > IntegerType::kMaxWidth) {
70     return emitError() << "integer bitwidth is limited to "
71                        << IntegerType::kMaxWidth << " bits";
72   }
73   return success();
74 }
75 
76 unsigned IntegerType::getWidth() const { return getImpl()->width; }
77 
78 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
79   return getImpl()->signedness;
80 }
81 
82 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
83   if (!scale)
84     return IntegerType();
85   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // Float Type
90 //===----------------------------------------------------------------------===//
91 
92 unsigned FloatType::getWidth() {
93   if (isa<Float16Type, BFloat16Type>())
94     return 16;
95   if (isa<Float32Type>())
96     return 32;
97   if (isa<Float64Type>())
98     return 64;
99   if (isa<Float80Type>())
100     return 80;
101   if (isa<Float128Type>())
102     return 128;
103   llvm_unreachable("unexpected float type");
104 }
105 
106 /// Returns the floating semantics for the given type.
107 const llvm::fltSemantics &FloatType::getFloatSemantics() {
108   if (isa<BFloat16Type>())
109     return APFloat::BFloat();
110   if (isa<Float16Type>())
111     return APFloat::IEEEhalf();
112   if (isa<Float32Type>())
113     return APFloat::IEEEsingle();
114   if (isa<Float64Type>())
115     return APFloat::IEEEdouble();
116   if (isa<Float80Type>())
117     return APFloat::x87DoubleExtended();
118   if (isa<Float128Type>())
119     return APFloat::IEEEquad();
120   llvm_unreachable("non-floating point type used");
121 }
122 
123 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
124   if (!scale)
125     return FloatType();
126   MLIRContext *ctx = getContext();
127   if (isF16() || isBF16()) {
128     if (scale == 2)
129       return FloatType::getF32(ctx);
130     if (scale == 4)
131       return FloatType::getF64(ctx);
132   }
133   if (isF32())
134     if (scale == 2)
135       return FloatType::getF64(ctx);
136   return FloatType();
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // FunctionType
141 //===----------------------------------------------------------------------===//
142 
143 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
144 
145 ArrayRef<Type> FunctionType::getInputs() const {
146   return getImpl()->getInputs();
147 }
148 
149 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
150 
151 ArrayRef<Type> FunctionType::getResults() const {
152   return getImpl()->getResults();
153 }
154 
155 /// Helper to call a callback once on each index in the range
156 /// [0, `totalIndices`), *except* for the indices given in `indices`.
157 /// `indices` is allowed to have duplicates and can be in any order.
158 inline void iterateIndicesExcept(unsigned totalIndices,
159                                  ArrayRef<unsigned> indices,
160                                  function_ref<void(unsigned)> callback) {
161   llvm::BitVector skipIndices(totalIndices);
162   for (unsigned i : indices)
163     skipIndices.set(i);
164 
165   for (unsigned i = 0; i < totalIndices; ++i)
166     if (!skipIndices.test(i))
167       callback(i);
168 }
169 
170 /// Returns a new function type with the specified arguments and results
171 /// inserted.
172 FunctionType FunctionType::getWithArgsAndResults(
173     ArrayRef<unsigned> argIndices, TypeRange argTypes,
174     ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
175   assert(argIndices.size() == argTypes.size());
176   assert(resultIndices.size() == resultTypes.size());
177 
178   ArrayRef<Type> newInputTypes = getInputs();
179   SmallVector<Type, 4> newInputTypesBuffer;
180   if (!argIndices.empty()) {
181     const auto *fromIt = newInputTypes.begin();
182     for (auto it : llvm::zip(argIndices, argTypes)) {
183       const auto *toIt = newInputTypes.begin() + std::get<0>(it);
184       newInputTypesBuffer.append(fromIt, toIt);
185       newInputTypesBuffer.push_back(std::get<1>(it));
186       fromIt = toIt;
187     }
188     newInputTypesBuffer.append(fromIt, newInputTypes.end());
189     newInputTypes = newInputTypesBuffer;
190   }
191 
192   ArrayRef<Type> newResultTypes = getResults();
193   SmallVector<Type, 4> newResultTypesBuffer;
194   if (!resultIndices.empty()) {
195     const auto *fromIt = newResultTypes.begin();
196     for (auto it : llvm::zip(resultIndices, resultTypes)) {
197       const auto *toIt = newResultTypes.begin() + std::get<0>(it);
198       newResultTypesBuffer.append(fromIt, toIt);
199       newResultTypesBuffer.push_back(std::get<1>(it));
200       fromIt = toIt;
201     }
202     newResultTypesBuffer.append(fromIt, newResultTypes.end());
203     newResultTypes = newResultTypesBuffer;
204   }
205 
206   return FunctionType::get(getContext(), newInputTypes, newResultTypes);
207 }
208 
209 /// Returns a new function type without the specified arguments and results.
210 FunctionType
211 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
212                                        ArrayRef<unsigned> resultIndices) {
213   ArrayRef<Type> newInputTypes = getInputs();
214   SmallVector<Type, 4> newInputTypesBuffer;
215   if (!argIndices.empty()) {
216     unsigned originalNumArgs = getNumInputs();
217     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
218       newInputTypesBuffer.emplace_back(getInput(i));
219     });
220     newInputTypes = newInputTypesBuffer;
221   }
222 
223   ArrayRef<Type> newResultTypes = getResults();
224   SmallVector<Type, 4> newResultTypesBuffer;
225   if (!resultIndices.empty()) {
226     unsigned originalNumResults = getNumResults();
227     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
228       newResultTypesBuffer.emplace_back(getResult(i));
229     });
230     newResultTypes = newResultTypesBuffer;
231   }
232 
233   return get(getContext(), newInputTypes, newResultTypes);
234 }
235 
236 void FunctionType::walkImmediateSubElements(
237     function_ref<void(Attribute)> walkAttrsFn,
238     function_ref<void(Type)> walkTypesFn) const {
239   for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
240     walkTypesFn(type);
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // OpaqueType
245 //===----------------------------------------------------------------------===//
246 
247 /// Verify the construction of an opaque type.
248 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
249                                  StringAttr dialect, StringRef typeData) {
250   if (!Dialect::isValidNamespace(dialect.strref()))
251     return emitError() << "invalid dialect namespace '" << dialect << "'";
252 
253   // Check that the dialect is actually registered.
254   MLIRContext *context = dialect.getContext();
255   if (!context->allowsUnregisteredDialects() &&
256       !context->getLoadedDialect(dialect.strref())) {
257     return emitError()
258            << "`!" << dialect << "<\"" << typeData << "\">"
259            << "` type created with unregistered dialect. If this is "
260               "intended, please call allowUnregisteredDialects() on the "
261               "MLIRContext, or use -allow-unregistered-dialect with "
262               "the MLIR opt tool used";
263   }
264 
265   return success();
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // VectorType
270 //===----------------------------------------------------------------------===//
271 
272 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
273                                  ArrayRef<int64_t> shape, Type elementType,
274                                  unsigned numScalableDims) {
275   if (!isValidElementType(elementType))
276     return emitError()
277            << "vector elements must be int/index/float type but got "
278            << elementType;
279 
280   if (any_of(shape, [](int64_t i) { return i <= 0; }))
281     return emitError()
282            << "vector types must have positive constant sizes but got "
283            << shape;
284 
285   return success();
286 }
287 
288 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
289   if (!scale)
290     return VectorType();
291   if (auto et = getElementType().dyn_cast<IntegerType>())
292     if (auto scaledEt = et.scaleElementBitwidth(scale))
293       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
294   if (auto et = getElementType().dyn_cast<FloatType>())
295     if (auto scaledEt = et.scaleElementBitwidth(scale))
296       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
297   return VectorType();
298 }
299 
300 void VectorType::walkImmediateSubElements(
301     function_ref<void(Attribute)> walkAttrsFn,
302     function_ref<void(Type)> walkTypesFn) const {
303   walkTypesFn(getElementType());
304 }
305 
306 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
307                                  Type elementType) const {
308   return VectorType::get(shape.getValueOr(getShape()), elementType,
309                          getNumScalableDims());
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // TensorType
314 //===----------------------------------------------------------------------===//
315 
316 Type TensorType::getElementType() const {
317   return llvm::TypeSwitch<TensorType, Type>(*this)
318       .Case<RankedTensorType, UnrankedTensorType>(
319           [](auto type) { return type.getElementType(); });
320 }
321 
322 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
323 
324 ArrayRef<int64_t> TensorType::getShape() const {
325   return cast<RankedTensorType>().getShape();
326 }
327 
328 TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
329                                  Type elementType) const {
330   if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
331     if (shape)
332       return RankedTensorType::get(*shape, elementType);
333     return UnrankedTensorType::get(elementType);
334   }
335 
336   auto rankedTy = cast<RankedTensorType>();
337   if (!shape)
338     return RankedTensorType::get(rankedTy.getShape(), elementType,
339                                  rankedTy.getEncoding());
340   return RankedTensorType::get(shape.getValueOr(rankedTy.getShape()),
341                                elementType, rankedTy.getEncoding());
342 }
343 
344 // Check if "elementType" can be an element type of a tensor.
345 static LogicalResult
346 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
347                        Type elementType) {
348   if (!TensorType::isValidElementType(elementType))
349     return emitError() << "invalid tensor element type: " << elementType;
350   return success();
351 }
352 
353 /// Return true if the specified element type is ok in a tensor.
354 bool TensorType::isValidElementType(Type type) {
355   // Note: Non standard/builtin types are allowed to exist within tensor
356   // types. Dialects are expected to verify that tensor types have a valid
357   // element type within that dialect.
358   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
359                   IndexType>() ||
360          !llvm::isa<BuiltinDialect>(type.getDialect());
361 }
362 
363 //===----------------------------------------------------------------------===//
364 // RankedTensorType
365 //===----------------------------------------------------------------------===//
366 
367 LogicalResult
368 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
369                          ArrayRef<int64_t> shape, Type elementType,
370                          Attribute encoding) {
371   for (int64_t s : shape)
372     if (s < -1)
373       return emitError() << "invalid tensor dimension size";
374   if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
375     if (failed(v.verifyEncoding(shape, elementType, emitError)))
376       return failure();
377   return checkTensorElementType(emitError, elementType);
378 }
379 
380 void RankedTensorType::walkImmediateSubElements(
381     function_ref<void(Attribute)> walkAttrsFn,
382     function_ref<void(Type)> walkTypesFn) const {
383   walkTypesFn(getElementType());
384   if (Attribute encoding = getEncoding())
385     walkAttrsFn(encoding);
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // UnrankedTensorType
390 //===----------------------------------------------------------------------===//
391 
392 LogicalResult
393 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
394                            Type elementType) {
395   return checkTensorElementType(emitError, elementType);
396 }
397 
398 void UnrankedTensorType::walkImmediateSubElements(
399     function_ref<void(Attribute)> walkAttrsFn,
400     function_ref<void(Type)> walkTypesFn) const {
401   walkTypesFn(getElementType());
402 }
403 
404 //===----------------------------------------------------------------------===//
405 // BaseMemRefType
406 //===----------------------------------------------------------------------===//
407 
408 Type BaseMemRefType::getElementType() const {
409   return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
410       .Case<MemRefType, UnrankedMemRefType>(
411           [](auto type) { return type.getElementType(); });
412 }
413 
414 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
415 
416 ArrayRef<int64_t> BaseMemRefType::getShape() const {
417   return cast<MemRefType>().getShape();
418 }
419 
420 BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape,
421                                          Type elementType) const {
422   if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
423     if (!shape)
424       return UnrankedMemRefType::get(elementType, getMemorySpace());
425     MemRefType::Builder builder(*shape, elementType);
426     builder.setMemorySpace(getMemorySpace());
427     return builder;
428   }
429 
430   MemRefType::Builder builder(cast<MemRefType>());
431   if (shape)
432     builder.setShape(*shape);
433   builder.setElementType(elementType);
434   return builder;
435 }
436 
437 Attribute BaseMemRefType::getMemorySpace() const {
438   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
439     return rankedMemRefTy.getMemorySpace();
440   return cast<UnrankedMemRefType>().getMemorySpace();
441 }
442 
443 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
444   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
445     return rankedMemRefTy.getMemorySpaceAsInt();
446   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // MemRefType
451 //===----------------------------------------------------------------------===//
452 
453 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
454 /// `originalShape` with some `1` entries erased, return the set of indices
455 /// that specifies which of the entries of `originalShape` are dropped to obtain
456 /// `reducedShape`. The returned mask can be applied as a projection to
457 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
458 /// which dimensions must be kept when e.g. compute MemRef strides under
459 /// rank-reducing operations. Return None if reducedShape cannot be obtained
460 /// by dropping only `1` entries in `originalShape`.
461 llvm::Optional<llvm::SmallDenseSet<unsigned>>
462 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
463                                ArrayRef<int64_t> reducedShape) {
464   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
465   llvm::SmallDenseSet<unsigned> unusedDims;
466   unsigned reducedIdx = 0;
467   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
468     // Greedily insert `originalIdx` if match.
469     if (reducedIdx < reducedRank &&
470         originalShape[originalIdx] == reducedShape[reducedIdx]) {
471       reducedIdx++;
472       continue;
473     }
474 
475     unusedDims.insert(originalIdx);
476     // If no match on `originalIdx`, the `originalShape` at this dimension
477     // must be 1, otherwise we bail.
478     if (originalShape[originalIdx] != 1)
479       return llvm::None;
480   }
481   // The whole reducedShape must be scanned, otherwise we bail.
482   if (reducedIdx != reducedRank)
483     return llvm::None;
484   return unusedDims;
485 }
486 
487 SliceVerificationResult
488 mlir::isRankReducedType(ShapedType originalType,
489                         ShapedType candidateReducedType) {
490   if (originalType == candidateReducedType)
491     return SliceVerificationResult::Success;
492 
493   ShapedType originalShapedType = originalType.cast<ShapedType>();
494   ShapedType candidateReducedShapedType =
495       candidateReducedType.cast<ShapedType>();
496 
497   // Rank and size logic is valid for all ShapedTypes.
498   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
499   ArrayRef<int64_t> candidateReducedShape =
500       candidateReducedShapedType.getShape();
501   unsigned originalRank = originalShape.size(),
502            candidateReducedRank = candidateReducedShape.size();
503   if (candidateReducedRank > originalRank)
504     return SliceVerificationResult::RankTooLarge;
505 
506   auto optionalUnusedDimsMask =
507       computeRankReductionMask(originalShape, candidateReducedShape);
508 
509   // Sizes cannot be matched in case empty vector is returned.
510   if (!optionalUnusedDimsMask.hasValue())
511     return SliceVerificationResult::SizeMismatch;
512 
513   if (originalShapedType.getElementType() !=
514       candidateReducedShapedType.getElementType())
515     return SliceVerificationResult::ElemTypeMismatch;
516 
517   return SliceVerificationResult::Success;
518 }
519 
520 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
521   // Empty attribute is allowed as default memory space.
522   if (!memorySpace)
523     return true;
524 
525   // Supported built-in attributes.
526   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
527     return true;
528 
529   // Allow custom dialect attributes.
530   if (!isa<BuiltinDialect>(memorySpace.getDialect()))
531     return true;
532 
533   return false;
534 }
535 
536 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
537                                                MLIRContext *ctx) {
538   if (memorySpace == 0)
539     return nullptr;
540 
541   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
542 }
543 
544 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
545   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
546   if (intMemorySpace && intMemorySpace.getValue() == 0)
547     return nullptr;
548 
549   return memorySpace;
550 }
551 
552 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
553   if (!memorySpace)
554     return 0;
555 
556   assert(memorySpace.isa<IntegerAttr>() &&
557          "Using `getMemorySpaceInteger` with non-Integer attribute");
558 
559   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
560 }
561 
562 MemRefType::Builder &
563 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
564   memorySpace =
565       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
566   return *this;
567 }
568 
569 unsigned MemRefType::getMemorySpaceAsInt() const {
570   return detail::getMemorySpaceAsInt(getMemorySpace());
571 }
572 
573 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
574                            MemRefLayoutAttrInterface layout,
575                            Attribute memorySpace) {
576   // Use default layout for empty attribute.
577   if (!layout)
578     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
579         shape.size(), elementType.getContext()));
580 
581   // Drop default memory space value and replace it with empty attribute.
582   memorySpace = skipDefaultMemorySpace(memorySpace);
583 
584   return Base::get(elementType.getContext(), shape, elementType, layout,
585                    memorySpace);
586 }
587 
588 MemRefType MemRefType::getChecked(
589     function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
590     Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
591 
592   // Use default layout for empty attribute.
593   if (!layout)
594     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
595         shape.size(), elementType.getContext()));
596 
597   // Drop default memory space value and replace it with empty attribute.
598   memorySpace = skipDefaultMemorySpace(memorySpace);
599 
600   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
601                           elementType, layout, memorySpace);
602 }
603 
604 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
605                            AffineMap map, Attribute memorySpace) {
606 
607   // Use default layout for empty map.
608   if (!map)
609     map = AffineMap::getMultiDimIdentityMap(shape.size(),
610                                             elementType.getContext());
611 
612   // Wrap AffineMap into Attribute.
613   Attribute layout = AffineMapAttr::get(map);
614 
615   // Drop default memory space value and replace it with empty attribute.
616   memorySpace = skipDefaultMemorySpace(memorySpace);
617 
618   return Base::get(elementType.getContext(), shape, elementType, layout,
619                    memorySpace);
620 }
621 
622 MemRefType
623 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
624                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
625                        Attribute memorySpace) {
626 
627   // Use default layout for empty map.
628   if (!map)
629     map = AffineMap::getMultiDimIdentityMap(shape.size(),
630                                             elementType.getContext());
631 
632   // Wrap AffineMap into Attribute.
633   Attribute layout = AffineMapAttr::get(map);
634 
635   // Drop default memory space value and replace it with empty attribute.
636   memorySpace = skipDefaultMemorySpace(memorySpace);
637 
638   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
639                           elementType, layout, memorySpace);
640 }
641 
642 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
643                            AffineMap map, unsigned memorySpaceInd) {
644 
645   // Use default layout for empty map.
646   if (!map)
647     map = AffineMap::getMultiDimIdentityMap(shape.size(),
648                                             elementType.getContext());
649 
650   // Wrap AffineMap into Attribute.
651   Attribute layout = AffineMapAttr::get(map);
652 
653   // Convert deprecated integer-like memory space to Attribute.
654   Attribute memorySpace =
655       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
656 
657   return Base::get(elementType.getContext(), shape, elementType, layout,
658                    memorySpace);
659 }
660 
661 MemRefType
662 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
663                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
664                        unsigned memorySpaceInd) {
665 
666   // Use default layout for empty map.
667   if (!map)
668     map = AffineMap::getMultiDimIdentityMap(shape.size(),
669                                             elementType.getContext());
670 
671   // Wrap AffineMap into Attribute.
672   Attribute layout = AffineMapAttr::get(map);
673 
674   // Convert deprecated integer-like memory space to Attribute.
675   Attribute memorySpace =
676       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
677 
678   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
679                           elementType, layout, memorySpace);
680 }
681 
682 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
683                                  ArrayRef<int64_t> shape, Type elementType,
684                                  MemRefLayoutAttrInterface layout,
685                                  Attribute memorySpace) {
686   if (!BaseMemRefType::isValidElementType(elementType))
687     return emitError() << "invalid memref element type";
688 
689   // Negative sizes are not allowed except for `-1` that means dynamic size.
690   for (int64_t s : shape)
691     if (s < -1)
692       return emitError() << "invalid memref size";
693 
694   assert(layout && "missing layout specification");
695   if (failed(layout.verifyLayout(shape, emitError)))
696     return failure();
697 
698   if (!isSupportedMemorySpace(memorySpace))
699     return emitError() << "unsupported memory space Attribute";
700 
701   return success();
702 }
703 
704 void MemRefType::walkImmediateSubElements(
705     function_ref<void(Attribute)> walkAttrsFn,
706     function_ref<void(Type)> walkTypesFn) const {
707   walkTypesFn(getElementType());
708   if (!getLayout().isIdentity())
709     walkAttrsFn(getLayout());
710   walkAttrsFn(getMemorySpace());
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // UnrankedMemRefType
715 //===----------------------------------------------------------------------===//
716 
717 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
718   return detail::getMemorySpaceAsInt(getMemorySpace());
719 }
720 
721 LogicalResult
722 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
723                            Type elementType, Attribute memorySpace) {
724   if (!BaseMemRefType::isValidElementType(elementType))
725     return emitError() << "invalid memref element type";
726 
727   if (!isSupportedMemorySpace(memorySpace))
728     return emitError() << "unsupported memory space Attribute";
729 
730   return success();
731 }
732 
733 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
734 // i.e. single term). Accumulate the AffineExpr into the existing one.
735 static void extractStridesFromTerm(AffineExpr e,
736                                    AffineExpr multiplicativeFactor,
737                                    MutableArrayRef<AffineExpr> strides,
738                                    AffineExpr &offset) {
739   if (auto dim = e.dyn_cast<AffineDimExpr>())
740     strides[dim.getPosition()] =
741         strides[dim.getPosition()] + multiplicativeFactor;
742   else
743     offset = offset + e * multiplicativeFactor;
744 }
745 
746 /// Takes a single AffineExpr `e` and populates the `strides` array with the
747 /// strides expressions for each dim position.
748 /// The convention is that the strides for dimensions d0, .. dn appear in
749 /// order to make indexing intuitive into the result.
750 static LogicalResult extractStrides(AffineExpr e,
751                                     AffineExpr multiplicativeFactor,
752                                     MutableArrayRef<AffineExpr> strides,
753                                     AffineExpr &offset) {
754   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
755   if (!bin) {
756     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
757     return success();
758   }
759 
760   if (bin.getKind() == AffineExprKind::CeilDiv ||
761       bin.getKind() == AffineExprKind::FloorDiv ||
762       bin.getKind() == AffineExprKind::Mod)
763     return failure();
764 
765   if (bin.getKind() == AffineExprKind::Mul) {
766     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
767     if (dim) {
768       strides[dim.getPosition()] =
769           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
770       return success();
771     }
772     // LHS and RHS may both contain complex expressions of dims. Try one path
773     // and if it fails try the other. This is guaranteed to succeed because
774     // only one path may have a `dim`, otherwise this is not an AffineExpr in
775     // the first place.
776     if (bin.getLHS().isSymbolicOrConstant())
777       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
778                             strides, offset);
779     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
780                           strides, offset);
781   }
782 
783   if (bin.getKind() == AffineExprKind::Add) {
784     auto res1 =
785         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
786     auto res2 =
787         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
788     return success(succeeded(res1) && succeeded(res2));
789   }
790 
791   llvm_unreachable("unexpected binary operation");
792 }
793 
794 LogicalResult mlir::getStridesAndOffset(MemRefType t,
795                                         SmallVectorImpl<AffineExpr> &strides,
796                                         AffineExpr &offset) {
797   AffineMap m = t.getLayout().getAffineMap();
798 
799   if (m.getNumResults() != 1 && !m.isIdentity())
800     return failure();
801 
802   auto zero = getAffineConstantExpr(0, t.getContext());
803   auto one = getAffineConstantExpr(1, t.getContext());
804   offset = zero;
805   strides.assign(t.getRank(), zero);
806 
807   // Canonical case for empty map.
808   if (m.isIdentity()) {
809     // 0-D corner case, offset is already 0.
810     if (t.getRank() == 0)
811       return success();
812     auto stridedExpr =
813         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
814     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
815       return success();
816     assert(false && "unexpected failure: extract strides in canonical layout");
817   }
818 
819   // Non-canonical case requires more work.
820   auto stridedExpr =
821       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
822   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
823     offset = AffineExpr();
824     strides.clear();
825     return failure();
826   }
827 
828   // Simplify results to allow folding to constants and simple checks.
829   unsigned numDims = m.getNumDims();
830   unsigned numSymbols = m.getNumSymbols();
831   offset = simplifyAffineExpr(offset, numDims, numSymbols);
832   for (auto &stride : strides)
833     stride = simplifyAffineExpr(stride, numDims, numSymbols);
834 
835   /// In practice, a strided memref must be internally non-aliasing. Test
836   /// against 0 as a proxy.
837   /// TODO: static cases can have more advanced checks.
838   /// TODO: dynamic cases would require a way to compare symbolic
839   /// expressions and would probably need an affine set context propagated
840   /// everywhere.
841   if (llvm::any_of(strides, [](AffineExpr e) {
842         return e == getAffineConstantExpr(0, e.getContext());
843       })) {
844     offset = AffineExpr();
845     strides.clear();
846     return failure();
847   }
848 
849   return success();
850 }
851 
852 LogicalResult mlir::getStridesAndOffset(MemRefType t,
853                                         SmallVectorImpl<int64_t> &strides,
854                                         int64_t &offset) {
855   AffineExpr offsetExpr;
856   SmallVector<AffineExpr, 4> strideExprs;
857   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
858     return failure();
859   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
860     offset = cst.getValue();
861   else
862     offset = ShapedType::kDynamicStrideOrOffset;
863   for (auto e : strideExprs) {
864     if (auto c = e.dyn_cast<AffineConstantExpr>())
865       strides.push_back(c.getValue());
866     else
867       strides.push_back(ShapedType::kDynamicStrideOrOffset);
868   }
869   return success();
870 }
871 
872 void UnrankedMemRefType::walkImmediateSubElements(
873     function_ref<void(Attribute)> walkAttrsFn,
874     function_ref<void(Type)> walkTypesFn) const {
875   walkTypesFn(getElementType());
876   walkAttrsFn(getMemorySpace());
877 }
878 
879 //===----------------------------------------------------------------------===//
880 /// TupleType
881 //===----------------------------------------------------------------------===//
882 
883 /// Return the elements types for this tuple.
884 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
885 
886 /// Accumulate the types contained in this tuple and tuples nested within it.
887 /// Note that this only flattens nested tuples, not any other container type,
888 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
889 /// (i32, tensor<i32>, f32, i64)
890 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
891   for (Type type : getTypes()) {
892     if (auto nestedTuple = type.dyn_cast<TupleType>())
893       nestedTuple.getFlattenedTypes(types);
894     else
895       types.push_back(type);
896   }
897 }
898 
899 /// Return the number of element types.
900 size_t TupleType::size() const { return getImpl()->size(); }
901 
902 void TupleType::walkImmediateSubElements(
903     function_ref<void(Attribute)> walkAttrsFn,
904     function_ref<void(Type)> walkTypesFn) const {
905   for (Type type : getTypes())
906     walkTypesFn(type);
907 }
908 
909 //===----------------------------------------------------------------------===//
910 // Type Utilities
911 //===----------------------------------------------------------------------===//
912 
913 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
914                                            int64_t offset,
915                                            MLIRContext *context) {
916   AffineExpr expr;
917   unsigned nSymbols = 0;
918 
919   // AffineExpr for offset.
920   // Static case.
921   if (offset != MemRefType::getDynamicStrideOrOffset()) {
922     auto cst = getAffineConstantExpr(offset, context);
923     expr = cst;
924   } else {
925     // Dynamic case, new symbol for the offset.
926     auto sym = getAffineSymbolExpr(nSymbols++, context);
927     expr = sym;
928   }
929 
930   // AffineExpr for strides.
931   for (const auto &en : llvm::enumerate(strides)) {
932     auto dim = en.index();
933     auto stride = en.value();
934     assert(stride != 0 && "Invalid stride specification");
935     auto d = getAffineDimExpr(dim, context);
936     AffineExpr mult;
937     // Static case.
938     if (stride != MemRefType::getDynamicStrideOrOffset())
939       mult = getAffineConstantExpr(stride, context);
940     else
941       // Dynamic case, new symbol for each new stride.
942       mult = getAffineSymbolExpr(nSymbols++, context);
943     expr = expr + d * mult;
944   }
945 
946   return AffineMap::get(strides.size(), nSymbols, expr);
947 }
948 
949 /// Return a version of `t` with identity layout if it can be determined
950 /// statically that the layout is the canonical contiguous strided layout.
951 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
952 /// `t` with simplified layout.
953 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
954 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
955   AffineMap m = t.getLayout().getAffineMap();
956 
957   // Already in canonical form.
958   if (m.isIdentity())
959     return t;
960 
961   // Can't reduce to canonical identity form, return in canonical form.
962   if (m.getNumResults() > 1)
963     return t;
964 
965   // Corner-case for 0-D affine maps.
966   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
967     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
968       if (cst.getValue() == 0)
969         return MemRefType::Builder(t).setLayout({});
970     return t;
971   }
972 
973   // 0-D corner case for empty shape that still have an affine map. Example:
974   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
975   // offset needs to remain, just return t.
976   if (t.getShape().empty())
977     return t;
978 
979   // If the canonical strided layout for the sizes of `t` is equal to the
980   // simplified layout of `t` we can just return an empty layout. Otherwise,
981   // just simplify the existing layout.
982   AffineExpr expr =
983       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
984   auto simplifiedLayoutExpr =
985       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
986   if (expr != simplifiedLayoutExpr)
987     return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
988         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
989   return MemRefType::Builder(t).setLayout({});
990 }
991 
992 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
993                                                 ArrayRef<AffineExpr> exprs,
994                                                 MLIRContext *context) {
995   assert(!sizes.empty() && !exprs.empty() &&
996          "expected non-empty sizes and exprs");
997 
998   // Size 0 corner case is useful for canonicalizations.
999   if (llvm::is_contained(sizes, 0))
1000     return getAffineConstantExpr(0, context);
1001 
1002   auto maps = AffineMap::inferFromExprList(exprs);
1003   assert(!maps.empty() && "Expected one non-empty map");
1004   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
1005 
1006   AffineExpr expr;
1007   bool dynamicPoisonBit = false;
1008   int64_t runningSize = 1;
1009   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
1010     int64_t size = std::get<1>(en);
1011     // Degenerate case, no size =-> no stride
1012     if (size == 0)
1013       continue;
1014     AffineExpr dimExpr = std::get<0>(en);
1015     AffineExpr stride = dynamicPoisonBit
1016                             ? getAffineSymbolExpr(nSymbols++, context)
1017                             : getAffineConstantExpr(runningSize, context);
1018     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
1019     if (size > 0) {
1020       runningSize *= size;
1021       assert(runningSize > 0 && "integer overflow in size computation");
1022     } else {
1023       dynamicPoisonBit = true;
1024     }
1025   }
1026   return simplifyAffineExpr(expr, numDims, nSymbols);
1027 }
1028 
1029 /// Return a version of `t` with a layout that has all dynamic offset and
1030 /// strides. This is used to erase the static layout.
1031 MemRefType mlir::eraseStridedLayout(MemRefType t) {
1032   auto val = ShapedType::kDynamicStrideOrOffset;
1033   return MemRefType::Builder(t).setLayout(
1034       AffineMapAttr::get(makeStridedLinearLayoutMap(
1035           SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
1036 }
1037 
1038 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
1039                                                 MLIRContext *context) {
1040   SmallVector<AffineExpr, 4> exprs;
1041   exprs.reserve(sizes.size());
1042   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
1043     exprs.push_back(getAffineDimExpr(dim, context));
1044   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
1045 }
1046 
1047 /// Return true if the layout for `t` is compatible with strided semantics.
1048 bool mlir::isStrided(MemRefType t) {
1049   int64_t offset;
1050   SmallVector<int64_t, 4> strides;
1051   auto res = getStridesAndOffset(t, strides, offset);
1052   return succeeded(res);
1053 }
1054 
1055 /// Return the layout map in strided linear layout AffineMap form.
1056 /// Return null if the layout is not compatible with a strided layout.
1057 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
1058   int64_t offset;
1059   SmallVector<int64_t, 4> strides;
1060   if (failed(getStridesAndOffset(t, strides, offset)))
1061     return AffineMap();
1062   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1063 }
1064 
1065 /// Return the AffineExpr representation of the offset, assuming `memRefType`
1066 /// is a strided memref.
1067 static AffineExpr getOffsetExpr(MemRefType memrefType) {
1068   SmallVector<AffineExpr> strides;
1069   AffineExpr offset;
1070   if (failed(getStridesAndOffset(memrefType, strides, offset)))
1071     assert(false && "expected strided memref");
1072   return offset;
1073 }
1074 
1075 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1076 /// `offset` AffineExpr.
1077 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
1078                                                    ArrayRef<int64_t> shape,
1079                                                    Type elementType,
1080                                                    AffineExpr offset) {
1081   AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1082   AffineExpr contiguousRowMajor = canonical + offset;
1083   AffineMap contiguousRowMajorMap =
1084       AffineMap::inferFromExprList({contiguousRowMajor})[0];
1085   return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1086 }
1087 
1088 /// Helper determining if a memref is static-shape and contiguous-row-major
1089 /// layout, while still allowing for an arbitrary offset (any static or
1090 /// dynamic value).
1091 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1092   if (!memrefType.hasStaticShape())
1093     return false;
1094   AffineExpr offset = getOffsetExpr(memrefType);
1095   MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1096       memrefType.getContext(), memrefType.getShape(),
1097       memrefType.getElementType(), offset);
1098   return canonicalizeStridedLayout(memrefType) ==
1099          canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1100 }
1101