1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/TensorEncoding.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 
35 //===----------------------------------------------------------------------===//
36 /// Tablegen Interface Definitions
37 //===----------------------------------------------------------------------===//
38 
39 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // BuiltinDialect
43 //===----------------------------------------------------------------------===//
44 
45 void BuiltinDialect::registerTypes() {
46   addTypes<
47 #define GET_TYPEDEF_LIST
48 #include "mlir/IR/BuiltinTypes.cpp.inc"
49       >();
50 }
51 
52 //===----------------------------------------------------------------------===//
53 /// ComplexType
54 //===----------------------------------------------------------------------===//
55 
56 /// Verify the construction of an integer type.
57 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
58                                   Type elementType) {
59   if (!elementType.isIntOrFloat())
60     return emitError() << "invalid element type for complex";
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // Integer Type
66 //===----------------------------------------------------------------------===//
67 
68 // static constexpr must have a definition (until in C++17 and inline variable).
69 constexpr unsigned IntegerType::kMaxWidth;
70 
71 /// Verify the construction of an integer type.
72 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
73                                   unsigned width,
74                                   SignednessSemantics signedness) {
75   if (width > IntegerType::kMaxWidth) {
76     return emitError() << "integer bitwidth is limited to "
77                        << IntegerType::kMaxWidth << " bits";
78   }
79   return success();
80 }
81 
82 unsigned IntegerType::getWidth() const { return getImpl()->width; }
83 
84 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
85   return getImpl()->signedness;
86 }
87 
88 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
89   if (!scale)
90     return IntegerType();
91   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Float Type
96 //===----------------------------------------------------------------------===//
97 
98 unsigned FloatType::getWidth() {
99   if (isa<Float16Type, BFloat16Type>())
100     return 16;
101   if (isa<Float32Type>())
102     return 32;
103   if (isa<Float64Type>())
104     return 64;
105   if (isa<Float80Type>())
106     return 80;
107   if (isa<Float128Type>())
108     return 128;
109   llvm_unreachable("unexpected float type");
110 }
111 
112 /// Returns the floating semantics for the given type.
113 const llvm::fltSemantics &FloatType::getFloatSemantics() {
114   if (isa<BFloat16Type>())
115     return APFloat::BFloat();
116   if (isa<Float16Type>())
117     return APFloat::IEEEhalf();
118   if (isa<Float32Type>())
119     return APFloat::IEEEsingle();
120   if (isa<Float64Type>())
121     return APFloat::IEEEdouble();
122   if (isa<Float80Type>())
123     return APFloat::x87DoubleExtended();
124   if (isa<Float128Type>())
125     return APFloat::IEEEquad();
126   llvm_unreachable("non-floating point type used");
127 }
128 
129 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
130   if (!scale)
131     return FloatType();
132   MLIRContext *ctx = getContext();
133   if (isF16() || isBF16()) {
134     if (scale == 2)
135       return FloatType::getF32(ctx);
136     if (scale == 4)
137       return FloatType::getF64(ctx);
138   }
139   if (isF32())
140     if (scale == 2)
141       return FloatType::getF64(ctx);
142   return FloatType();
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // FunctionType
147 //===----------------------------------------------------------------------===//
148 
149 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
150 
151 ArrayRef<Type> FunctionType::getInputs() const {
152   return getImpl()->getInputs();
153 }
154 
155 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
156 
157 ArrayRef<Type> FunctionType::getResults() const {
158   return getImpl()->getResults();
159 }
160 
161 /// Helper to call a callback once on each index in the range
162 /// [0, `totalIndices`), *except* for the indices given in `indices`.
163 /// `indices` is allowed to have duplicates and can be in any order.
164 inline void iterateIndicesExcept(unsigned totalIndices,
165                                  ArrayRef<unsigned> indices,
166                                  function_ref<void(unsigned)> callback) {
167   llvm::BitVector skipIndices(totalIndices);
168   for (unsigned i : indices)
169     skipIndices.set(i);
170 
171   for (unsigned i = 0; i < totalIndices; ++i)
172     if (!skipIndices.test(i))
173       callback(i);
174 }
175 
176 /// Returns a new function type with the specified arguments and results
177 /// inserted.
178 FunctionType FunctionType::getWithArgsAndResults(
179     ArrayRef<unsigned> argIndices, TypeRange argTypes,
180     ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
181   assert(argIndices.size() == argTypes.size());
182   assert(resultIndices.size() == resultTypes.size());
183 
184   ArrayRef<Type> newInputTypes = getInputs();
185   SmallVector<Type, 4> newInputTypesBuffer;
186   if (!argIndices.empty()) {
187     const auto *fromIt = newInputTypes.begin();
188     for (auto it : llvm::zip(argIndices, argTypes)) {
189       const auto *toIt = newInputTypes.begin() + std::get<0>(it);
190       newInputTypesBuffer.append(fromIt, toIt);
191       newInputTypesBuffer.push_back(std::get<1>(it));
192       fromIt = toIt;
193     }
194     newInputTypesBuffer.append(fromIt, newInputTypes.end());
195     newInputTypes = newInputTypesBuffer;
196   }
197 
198   ArrayRef<Type> newResultTypes = getResults();
199   SmallVector<Type, 4> newResultTypesBuffer;
200   if (!resultIndices.empty()) {
201     const auto *fromIt = newResultTypes.begin();
202     for (auto it : llvm::zip(resultIndices, resultTypes)) {
203       const auto *toIt = newResultTypes.begin() + std::get<0>(it);
204       newResultTypesBuffer.append(fromIt, toIt);
205       newResultTypesBuffer.push_back(std::get<1>(it));
206       fromIt = toIt;
207     }
208     newResultTypesBuffer.append(fromIt, newResultTypes.end());
209     newResultTypes = newResultTypesBuffer;
210   }
211 
212   return FunctionType::get(getContext(), newInputTypes, newResultTypes);
213 }
214 
215 /// Returns a new function type without the specified arguments and results.
216 FunctionType
217 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
218                                        ArrayRef<unsigned> resultIndices) {
219   ArrayRef<Type> newInputTypes = getInputs();
220   SmallVector<Type, 4> newInputTypesBuffer;
221   if (!argIndices.empty()) {
222     unsigned originalNumArgs = getNumInputs();
223     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
224       newInputTypesBuffer.emplace_back(getInput(i));
225     });
226     newInputTypes = newInputTypesBuffer;
227   }
228 
229   ArrayRef<Type> newResultTypes = getResults();
230   SmallVector<Type, 4> newResultTypesBuffer;
231   if (!resultIndices.empty()) {
232     unsigned originalNumResults = getNumResults();
233     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
234       newResultTypesBuffer.emplace_back(getResult(i));
235     });
236     newResultTypes = newResultTypesBuffer;
237   }
238 
239   return get(getContext(), newInputTypes, newResultTypes);
240 }
241 
242 void FunctionType::walkImmediateSubElements(
243     function_ref<void(Attribute)> walkAttrsFn,
244     function_ref<void(Type)> walkTypesFn) const {
245   for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
246     walkTypesFn(type);
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // OpaqueType
251 //===----------------------------------------------------------------------===//
252 
253 /// Verify the construction of an opaque type.
254 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
255                                  StringAttr dialect, StringRef typeData) {
256   if (!Dialect::isValidNamespace(dialect.strref()))
257     return emitError() << "invalid dialect namespace '" << dialect << "'";
258 
259   // Check that the dialect is actually registered.
260   MLIRContext *context = dialect.getContext();
261   if (!context->allowsUnregisteredDialects() &&
262       !context->getLoadedDialect(dialect.strref())) {
263     return emitError()
264            << "`!" << dialect << "<\"" << typeData << "\">"
265            << "` type created with unregistered dialect. If this is "
266               "intended, please call allowUnregisteredDialects() on the "
267               "MLIRContext, or use -allow-unregistered-dialect with "
268               "the MLIR opt tool used";
269   }
270 
271   return success();
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // ShapedType
276 //===----------------------------------------------------------------------===//
277 constexpr int64_t ShapedType::kDynamicSize;
278 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
279 
280 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
281   if (auto other = dyn_cast<MemRefType>()) {
282     MemRefType::Builder b(other);
283     b.setShape(shape);
284     b.setElementType(elementType);
285     return b;
286   }
287 
288   if (auto other = dyn_cast<UnrankedMemRefType>()) {
289     MemRefType::Builder b(shape, elementType);
290     b.setMemorySpace(other.getMemorySpace());
291     return b;
292   }
293 
294   if (isa<TensorType>())
295     return RankedTensorType::get(shape, elementType);
296 
297   if (auto vecTy = dyn_cast<VectorType>())
298     return VectorType::get(shape, elementType, vecTy.getNumScalableDims());
299 
300   llvm_unreachable("Unhandled ShapedType clone case");
301 }
302 
303 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
304   if (auto other = dyn_cast<MemRefType>()) {
305     MemRefType::Builder b(other);
306     b.setShape(shape);
307     return b;
308   }
309 
310   if (auto other = dyn_cast<UnrankedMemRefType>()) {
311     MemRefType::Builder b(shape, other.getElementType());
312     b.setShape(shape);
313     b.setMemorySpace(other.getMemorySpace());
314     return b;
315   }
316 
317   if (isa<TensorType>())
318     return RankedTensorType::get(shape, getElementType());
319 
320   if (auto vecTy = dyn_cast<VectorType>())
321     return VectorType::get(shape, getElementType(), vecTy.getNumScalableDims());
322 
323   llvm_unreachable("Unhandled ShapedType clone case");
324 }
325 
326 ShapedType ShapedType::clone(Type elementType) {
327   if (auto other = dyn_cast<MemRefType>()) {
328     MemRefType::Builder b(other);
329     b.setElementType(elementType);
330     return b;
331   }
332 
333   if (auto other = dyn_cast<UnrankedMemRefType>()) {
334     return UnrankedMemRefType::get(elementType, other.getMemorySpace());
335   }
336 
337   if (isa<TensorType>()) {
338     if (hasRank())
339       return RankedTensorType::get(getShape(), elementType);
340     return UnrankedTensorType::get(elementType);
341   }
342 
343   if (auto vecTy = dyn_cast<VectorType>())
344     return VectorType::get(getShape(), elementType, vecTy.getNumScalableDims());
345 
346   llvm_unreachable("Unhandled ShapedType clone hit");
347 }
348 
349 Type ShapedType::getElementType() const {
350   return TypeSwitch<Type, Type>(*this)
351       .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
352             UnrankedMemRefType>([](auto ty) { return ty.getElementType(); });
353 }
354 
355 unsigned ShapedType::getElementTypeBitWidth() const {
356   return getElementType().getIntOrFloatBitWidth();
357 }
358 
359 int64_t ShapedType::getNumElements() const {
360   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
361   auto shape = getShape();
362   int64_t num = 1;
363   for (auto dim : shape) {
364     num *= dim;
365     assert(num >= 0 && "integer overflow in element count computation");
366   }
367   return num;
368 }
369 
370 int64_t ShapedType::getRank() const {
371   assert(hasRank() && "cannot query rank of unranked shaped type");
372   return getShape().size();
373 }
374 
375 bool ShapedType::hasRank() const {
376   return !isa<UnrankedMemRefType, UnrankedTensorType>();
377 }
378 
379 int64_t ShapedType::getDimSize(unsigned idx) const {
380   assert(idx < getRank() && "invalid index for shaped type");
381   return getShape()[idx];
382 }
383 
384 bool ShapedType::isDynamicDim(unsigned idx) const {
385   assert(idx < getRank() && "invalid index for shaped type");
386   return isDynamic(getShape()[idx]);
387 }
388 
389 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
390   assert(index < getRank() && "invalid index");
391   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
392   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
393 }
394 
395 /// Get the number of bits require to store a value of the given shaped type.
396 /// Compute the value recursively since tensors are allowed to have vectors as
397 /// elements.
398 int64_t ShapedType::getSizeInBits() const {
399   assert(hasStaticShape() &&
400          "cannot get the bit size of an aggregate with a dynamic shape");
401 
402   auto elementType = getElementType();
403   if (elementType.isIntOrFloat())
404     return elementType.getIntOrFloatBitWidth() * getNumElements();
405 
406   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
407     elementType = complexType.getElementType();
408     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
409   }
410 
411   // Tensors can have vectors and other tensors as elements, other shaped types
412   // cannot.
413   assert(isa<TensorType>() && "unsupported element type");
414   assert((elementType.isa<VectorType, TensorType>()) &&
415          "unsupported tensor element type");
416   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
417 }
418 
419 ArrayRef<int64_t> ShapedType::getShape() const {
420   if (auto vectorType = dyn_cast<VectorType>())
421     return vectorType.getShape();
422   if (auto tensorType = dyn_cast<RankedTensorType>())
423     return tensorType.getShape();
424   return cast<MemRefType>().getShape();
425 }
426 
427 int64_t ShapedType::getNumDynamicDims() const {
428   return llvm::count_if(getShape(), isDynamic);
429 }
430 
431 bool ShapedType::hasStaticShape() const {
432   return hasRank() && llvm::none_of(getShape(), isDynamic);
433 }
434 
435 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
436   return hasStaticShape() && getShape() == shape;
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // VectorType
441 //===----------------------------------------------------------------------===//
442 
443 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
444                                  ArrayRef<int64_t> shape, Type elementType,
445                                  unsigned numScalableDims) {
446   if (!isValidElementType(elementType))
447     return emitError()
448            << "vector elements must be int/index/float type but got "
449            << elementType;
450 
451   if (any_of(shape, [](int64_t i) { return i <= 0; }))
452     return emitError()
453            << "vector types must have positive constant sizes but got "
454            << shape;
455 
456   return success();
457 }
458 
459 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
460   if (!scale)
461     return VectorType();
462   if (auto et = getElementType().dyn_cast<IntegerType>())
463     if (auto scaledEt = et.scaleElementBitwidth(scale))
464       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
465   if (auto et = getElementType().dyn_cast<FloatType>())
466     if (auto scaledEt = et.scaleElementBitwidth(scale))
467       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
468   return VectorType();
469 }
470 
471 void VectorType::walkImmediateSubElements(
472     function_ref<void(Attribute)> walkAttrsFn,
473     function_ref<void(Type)> walkTypesFn) const {
474   walkTypesFn(getElementType());
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // TensorType
479 //===----------------------------------------------------------------------===//
480 
481 // Check if "elementType" can be an element type of a tensor.
482 static LogicalResult
483 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
484                        Type elementType) {
485   if (!TensorType::isValidElementType(elementType))
486     return emitError() << "invalid tensor element type: " << elementType;
487   return success();
488 }
489 
490 /// Return true if the specified element type is ok in a tensor.
491 bool TensorType::isValidElementType(Type type) {
492   // Note: Non standard/builtin types are allowed to exist within tensor
493   // types. Dialects are expected to verify that tensor types have a valid
494   // element type within that dialect.
495   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
496                   IndexType>() ||
497          !llvm::isa<BuiltinDialect>(type.getDialect());
498 }
499 
500 //===----------------------------------------------------------------------===//
501 // RankedTensorType
502 //===----------------------------------------------------------------------===//
503 
504 LogicalResult
505 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
506                          ArrayRef<int64_t> shape, Type elementType,
507                          Attribute encoding) {
508   for (int64_t s : shape)
509     if (s < -1)
510       return emitError() << "invalid tensor dimension size";
511   if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
512     if (failed(v.verifyEncoding(shape, elementType, emitError)))
513       return failure();
514   return checkTensorElementType(emitError, elementType);
515 }
516 
517 void RankedTensorType::walkImmediateSubElements(
518     function_ref<void(Attribute)> walkAttrsFn,
519     function_ref<void(Type)> walkTypesFn) const {
520   walkTypesFn(getElementType());
521   if (Attribute encoding = getEncoding())
522     walkAttrsFn(encoding);
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // UnrankedTensorType
527 //===----------------------------------------------------------------------===//
528 
529 LogicalResult
530 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
531                            Type elementType) {
532   return checkTensorElementType(emitError, elementType);
533 }
534 
535 void UnrankedTensorType::walkImmediateSubElements(
536     function_ref<void(Attribute)> walkAttrsFn,
537     function_ref<void(Type)> walkTypesFn) const {
538   walkTypesFn(getElementType());
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // BaseMemRefType
543 //===----------------------------------------------------------------------===//
544 
545 Attribute BaseMemRefType::getMemorySpace() const {
546   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
547     return rankedMemRefTy.getMemorySpace();
548   return cast<UnrankedMemRefType>().getMemorySpace();
549 }
550 
551 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
552   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
553     return rankedMemRefTy.getMemorySpaceAsInt();
554   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
555 }
556 
557 //===----------------------------------------------------------------------===//
558 // MemRefType
559 //===----------------------------------------------------------------------===//
560 
561 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
562 /// `originalShape` with some `1` entries erased, return the set of indices
563 /// that specifies which of the entries of `originalShape` are dropped to obtain
564 /// `reducedShape`. The returned mask can be applied as a projection to
565 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
566 /// which dimensions must be kept when e.g. compute MemRef strides under
567 /// rank-reducing operations. Return None if reducedShape cannot be obtained
568 /// by dropping only `1` entries in `originalShape`.
569 llvm::Optional<llvm::SmallDenseSet<unsigned>>
570 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
571                                ArrayRef<int64_t> reducedShape) {
572   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
573   llvm::SmallDenseSet<unsigned> unusedDims;
574   unsigned reducedIdx = 0;
575   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
576     // Greedily insert `originalIdx` if match.
577     if (reducedIdx < reducedRank &&
578         originalShape[originalIdx] == reducedShape[reducedIdx]) {
579       reducedIdx++;
580       continue;
581     }
582 
583     unusedDims.insert(originalIdx);
584     // If no match on `originalIdx`, the `originalShape` at this dimension
585     // must be 1, otherwise we bail.
586     if (originalShape[originalIdx] != 1)
587       return llvm::None;
588   }
589   // The whole reducedShape must be scanned, otherwise we bail.
590   if (reducedIdx != reducedRank)
591     return llvm::None;
592   return unusedDims;
593 }
594 
595 SliceVerificationResult
596 mlir::isRankReducedType(ShapedType originalType,
597                         ShapedType candidateReducedType) {
598   if (originalType == candidateReducedType)
599     return SliceVerificationResult::Success;
600 
601   ShapedType originalShapedType = originalType.cast<ShapedType>();
602   ShapedType candidateReducedShapedType =
603       candidateReducedType.cast<ShapedType>();
604 
605   // Rank and size logic is valid for all ShapedTypes.
606   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
607   ArrayRef<int64_t> candidateReducedShape =
608       candidateReducedShapedType.getShape();
609   unsigned originalRank = originalShape.size(),
610            candidateReducedRank = candidateReducedShape.size();
611   if (candidateReducedRank > originalRank)
612     return SliceVerificationResult::RankTooLarge;
613 
614   auto optionalUnusedDimsMask =
615       computeRankReductionMask(originalShape, candidateReducedShape);
616 
617   // Sizes cannot be matched in case empty vector is returned.
618   if (!optionalUnusedDimsMask.hasValue())
619     return SliceVerificationResult::SizeMismatch;
620 
621   if (originalShapedType.getElementType() !=
622       candidateReducedShapedType.getElementType())
623     return SliceVerificationResult::ElemTypeMismatch;
624 
625   return SliceVerificationResult::Success;
626 }
627 
628 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
629   // Empty attribute is allowed as default memory space.
630   if (!memorySpace)
631     return true;
632 
633   // Supported built-in attributes.
634   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
635     return true;
636 
637   // Allow custom dialect attributes.
638   if (!isa<BuiltinDialect>(memorySpace.getDialect()))
639     return true;
640 
641   return false;
642 }
643 
644 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
645                                                MLIRContext *ctx) {
646   if (memorySpace == 0)
647     return nullptr;
648 
649   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
650 }
651 
652 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
653   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
654   if (intMemorySpace && intMemorySpace.getValue() == 0)
655     return nullptr;
656 
657   return memorySpace;
658 }
659 
660 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
661   if (!memorySpace)
662     return 0;
663 
664   assert(memorySpace.isa<IntegerAttr>() &&
665          "Using `getMemorySpaceInteger` with non-Integer attribute");
666 
667   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
668 }
669 
670 MemRefType::Builder &
671 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
672   memorySpace =
673       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
674   return *this;
675 }
676 
677 unsigned MemRefType::getMemorySpaceAsInt() const {
678   return detail::getMemorySpaceAsInt(getMemorySpace());
679 }
680 
681 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
682                            MemRefLayoutAttrInterface layout,
683                            Attribute memorySpace) {
684   // Use default layout for empty attribute.
685   if (!layout)
686     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
687         shape.size(), elementType.getContext()));
688 
689   // Drop default memory space value and replace it with empty attribute.
690   memorySpace = skipDefaultMemorySpace(memorySpace);
691 
692   return Base::get(elementType.getContext(), shape, elementType, layout,
693                    memorySpace);
694 }
695 
696 MemRefType MemRefType::getChecked(
697     function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
698     Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
699 
700   // Use default layout for empty attribute.
701   if (!layout)
702     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
703         shape.size(), elementType.getContext()));
704 
705   // Drop default memory space value and replace it with empty attribute.
706   memorySpace = skipDefaultMemorySpace(memorySpace);
707 
708   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
709                           elementType, layout, memorySpace);
710 }
711 
712 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
713                            AffineMap map, Attribute memorySpace) {
714 
715   // Use default layout for empty map.
716   if (!map)
717     map = AffineMap::getMultiDimIdentityMap(shape.size(),
718                                             elementType.getContext());
719 
720   // Wrap AffineMap into Attribute.
721   Attribute layout = AffineMapAttr::get(map);
722 
723   // Drop default memory space value and replace it with empty attribute.
724   memorySpace = skipDefaultMemorySpace(memorySpace);
725 
726   return Base::get(elementType.getContext(), shape, elementType, layout,
727                    memorySpace);
728 }
729 
730 MemRefType
731 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
732                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
733                        Attribute memorySpace) {
734 
735   // Use default layout for empty map.
736   if (!map)
737     map = AffineMap::getMultiDimIdentityMap(shape.size(),
738                                             elementType.getContext());
739 
740   // Wrap AffineMap into Attribute.
741   Attribute layout = AffineMapAttr::get(map);
742 
743   // Drop default memory space value and replace it with empty attribute.
744   memorySpace = skipDefaultMemorySpace(memorySpace);
745 
746   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
747                           elementType, layout, memorySpace);
748 }
749 
750 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
751                            AffineMap map, unsigned memorySpaceInd) {
752 
753   // Use default layout for empty map.
754   if (!map)
755     map = AffineMap::getMultiDimIdentityMap(shape.size(),
756                                             elementType.getContext());
757 
758   // Wrap AffineMap into Attribute.
759   Attribute layout = AffineMapAttr::get(map);
760 
761   // Convert deprecated integer-like memory space to Attribute.
762   Attribute memorySpace =
763       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
764 
765   return Base::get(elementType.getContext(), shape, elementType, layout,
766                    memorySpace);
767 }
768 
769 MemRefType
770 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
771                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
772                        unsigned memorySpaceInd) {
773 
774   // Use default layout for empty map.
775   if (!map)
776     map = AffineMap::getMultiDimIdentityMap(shape.size(),
777                                             elementType.getContext());
778 
779   // Wrap AffineMap into Attribute.
780   Attribute layout = AffineMapAttr::get(map);
781 
782   // Convert deprecated integer-like memory space to Attribute.
783   Attribute memorySpace =
784       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
785 
786   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
787                           elementType, layout, memorySpace);
788 }
789 
790 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
791                                  ArrayRef<int64_t> shape, Type elementType,
792                                  MemRefLayoutAttrInterface layout,
793                                  Attribute memorySpace) {
794   if (!BaseMemRefType::isValidElementType(elementType))
795     return emitError() << "invalid memref element type";
796 
797   // Negative sizes are not allowed except for `-1` that means dynamic size.
798   for (int64_t s : shape)
799     if (s < -1)
800       return emitError() << "invalid memref size";
801 
802   assert(layout && "missing layout specification");
803   if (failed(layout.verifyLayout(shape, emitError)))
804     return failure();
805 
806   if (!isSupportedMemorySpace(memorySpace))
807     return emitError() << "unsupported memory space Attribute";
808 
809   return success();
810 }
811 
812 void MemRefType::walkImmediateSubElements(
813     function_ref<void(Attribute)> walkAttrsFn,
814     function_ref<void(Type)> walkTypesFn) const {
815   walkTypesFn(getElementType());
816   if (!getLayout().isIdentity())
817     walkAttrsFn(getLayout());
818   walkAttrsFn(getMemorySpace());
819 }
820 
821 //===----------------------------------------------------------------------===//
822 // UnrankedMemRefType
823 //===----------------------------------------------------------------------===//
824 
825 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
826   return detail::getMemorySpaceAsInt(getMemorySpace());
827 }
828 
829 LogicalResult
830 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
831                            Type elementType, Attribute memorySpace) {
832   if (!BaseMemRefType::isValidElementType(elementType))
833     return emitError() << "invalid memref element type";
834 
835   if (!isSupportedMemorySpace(memorySpace))
836     return emitError() << "unsupported memory space Attribute";
837 
838   return success();
839 }
840 
841 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
842 // i.e. single term). Accumulate the AffineExpr into the existing one.
843 static void extractStridesFromTerm(AffineExpr e,
844                                    AffineExpr multiplicativeFactor,
845                                    MutableArrayRef<AffineExpr> strides,
846                                    AffineExpr &offset) {
847   if (auto dim = e.dyn_cast<AffineDimExpr>())
848     strides[dim.getPosition()] =
849         strides[dim.getPosition()] + multiplicativeFactor;
850   else
851     offset = offset + e * multiplicativeFactor;
852 }
853 
854 /// Takes a single AffineExpr `e` and populates the `strides` array with the
855 /// strides expressions for each dim position.
856 /// The convention is that the strides for dimensions d0, .. dn appear in
857 /// order to make indexing intuitive into the result.
858 static LogicalResult extractStrides(AffineExpr e,
859                                     AffineExpr multiplicativeFactor,
860                                     MutableArrayRef<AffineExpr> strides,
861                                     AffineExpr &offset) {
862   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
863   if (!bin) {
864     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
865     return success();
866   }
867 
868   if (bin.getKind() == AffineExprKind::CeilDiv ||
869       bin.getKind() == AffineExprKind::FloorDiv ||
870       bin.getKind() == AffineExprKind::Mod)
871     return failure();
872 
873   if (bin.getKind() == AffineExprKind::Mul) {
874     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
875     if (dim) {
876       strides[dim.getPosition()] =
877           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
878       return success();
879     }
880     // LHS and RHS may both contain complex expressions of dims. Try one path
881     // and if it fails try the other. This is guaranteed to succeed because
882     // only one path may have a `dim`, otherwise this is not an AffineExpr in
883     // the first place.
884     if (bin.getLHS().isSymbolicOrConstant())
885       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
886                             strides, offset);
887     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
888                           strides, offset);
889   }
890 
891   if (bin.getKind() == AffineExprKind::Add) {
892     auto res1 =
893         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
894     auto res2 =
895         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
896     return success(succeeded(res1) && succeeded(res2));
897   }
898 
899   llvm_unreachable("unexpected binary operation");
900 }
901 
902 LogicalResult mlir::getStridesAndOffset(MemRefType t,
903                                         SmallVectorImpl<AffineExpr> &strides,
904                                         AffineExpr &offset) {
905   AffineMap m = t.getLayout().getAffineMap();
906 
907   if (m.getNumResults() != 1 && !m.isIdentity())
908     return failure();
909 
910   auto zero = getAffineConstantExpr(0, t.getContext());
911   auto one = getAffineConstantExpr(1, t.getContext());
912   offset = zero;
913   strides.assign(t.getRank(), zero);
914 
915   // Canonical case for empty map.
916   if (m.isIdentity()) {
917     // 0-D corner case, offset is already 0.
918     if (t.getRank() == 0)
919       return success();
920     auto stridedExpr =
921         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
922     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
923       return success();
924     assert(false && "unexpected failure: extract strides in canonical layout");
925   }
926 
927   // Non-canonical case requires more work.
928   auto stridedExpr =
929       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
930   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
931     offset = AffineExpr();
932     strides.clear();
933     return failure();
934   }
935 
936   // Simplify results to allow folding to constants and simple checks.
937   unsigned numDims = m.getNumDims();
938   unsigned numSymbols = m.getNumSymbols();
939   offset = simplifyAffineExpr(offset, numDims, numSymbols);
940   for (auto &stride : strides)
941     stride = simplifyAffineExpr(stride, numDims, numSymbols);
942 
943   /// In practice, a strided memref must be internally non-aliasing. Test
944   /// against 0 as a proxy.
945   /// TODO: static cases can have more advanced checks.
946   /// TODO: dynamic cases would require a way to compare symbolic
947   /// expressions and would probably need an affine set context propagated
948   /// everywhere.
949   if (llvm::any_of(strides, [](AffineExpr e) {
950         return e == getAffineConstantExpr(0, e.getContext());
951       })) {
952     offset = AffineExpr();
953     strides.clear();
954     return failure();
955   }
956 
957   return success();
958 }
959 
960 LogicalResult mlir::getStridesAndOffset(MemRefType t,
961                                         SmallVectorImpl<int64_t> &strides,
962                                         int64_t &offset) {
963   AffineExpr offsetExpr;
964   SmallVector<AffineExpr, 4> strideExprs;
965   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
966     return failure();
967   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
968     offset = cst.getValue();
969   else
970     offset = ShapedType::kDynamicStrideOrOffset;
971   for (auto e : strideExprs) {
972     if (auto c = e.dyn_cast<AffineConstantExpr>())
973       strides.push_back(c.getValue());
974     else
975       strides.push_back(ShapedType::kDynamicStrideOrOffset);
976   }
977   return success();
978 }
979 
980 void UnrankedMemRefType::walkImmediateSubElements(
981     function_ref<void(Attribute)> walkAttrsFn,
982     function_ref<void(Type)> walkTypesFn) const {
983   walkTypesFn(getElementType());
984   walkAttrsFn(getMemorySpace());
985 }
986 
987 //===----------------------------------------------------------------------===//
988 /// TupleType
989 //===----------------------------------------------------------------------===//
990 
991 /// Return the elements types for this tuple.
992 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
993 
994 /// Accumulate the types contained in this tuple and tuples nested within it.
995 /// Note that this only flattens nested tuples, not any other container type,
996 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
997 /// (i32, tensor<i32>, f32, i64)
998 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
999   for (Type type : getTypes()) {
1000     if (auto nestedTuple = type.dyn_cast<TupleType>())
1001       nestedTuple.getFlattenedTypes(types);
1002     else
1003       types.push_back(type);
1004   }
1005 }
1006 
1007 /// Return the number of element types.
1008 size_t TupleType::size() const { return getImpl()->size(); }
1009 
1010 void TupleType::walkImmediateSubElements(
1011     function_ref<void(Attribute)> walkAttrsFn,
1012     function_ref<void(Type)> walkTypesFn) const {
1013   for (Type type : getTypes())
1014     walkTypesFn(type);
1015 }
1016 
1017 //===----------------------------------------------------------------------===//
1018 // Type Utilities
1019 //===----------------------------------------------------------------------===//
1020 
1021 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
1022                                            int64_t offset,
1023                                            MLIRContext *context) {
1024   AffineExpr expr;
1025   unsigned nSymbols = 0;
1026 
1027   // AffineExpr for offset.
1028   // Static case.
1029   if (offset != MemRefType::getDynamicStrideOrOffset()) {
1030     auto cst = getAffineConstantExpr(offset, context);
1031     expr = cst;
1032   } else {
1033     // Dynamic case, new symbol for the offset.
1034     auto sym = getAffineSymbolExpr(nSymbols++, context);
1035     expr = sym;
1036   }
1037 
1038   // AffineExpr for strides.
1039   for (const auto &en : llvm::enumerate(strides)) {
1040     auto dim = en.index();
1041     auto stride = en.value();
1042     assert(stride != 0 && "Invalid stride specification");
1043     auto d = getAffineDimExpr(dim, context);
1044     AffineExpr mult;
1045     // Static case.
1046     if (stride != MemRefType::getDynamicStrideOrOffset())
1047       mult = getAffineConstantExpr(stride, context);
1048     else
1049       // Dynamic case, new symbol for each new stride.
1050       mult = getAffineSymbolExpr(nSymbols++, context);
1051     expr = expr + d * mult;
1052   }
1053 
1054   return AffineMap::get(strides.size(), nSymbols, expr);
1055 }
1056 
1057 /// Return a version of `t` with identity layout if it can be determined
1058 /// statically that the layout is the canonical contiguous strided layout.
1059 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
1060 /// `t` with simplified layout.
1061 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
1062 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
1063   AffineMap m = t.getLayout().getAffineMap();
1064 
1065   // Already in canonical form.
1066   if (m.isIdentity())
1067     return t;
1068 
1069   // Can't reduce to canonical identity form, return in canonical form.
1070   if (m.getNumResults() > 1)
1071     return t;
1072 
1073   // Corner-case for 0-D affine maps.
1074   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
1075     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
1076       if (cst.getValue() == 0)
1077         return MemRefType::Builder(t).setLayout({});
1078     return t;
1079   }
1080 
1081   // 0-D corner case for empty shape that still have an affine map. Example:
1082   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
1083   // offset needs to remain, just return t.
1084   if (t.getShape().empty())
1085     return t;
1086 
1087   // If the canonical strided layout for the sizes of `t` is equal to the
1088   // simplified layout of `t` we can just return an empty layout. Otherwise,
1089   // just simplify the existing layout.
1090   AffineExpr expr =
1091       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
1092   auto simplifiedLayoutExpr =
1093       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
1094   if (expr != simplifiedLayoutExpr)
1095     return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
1096         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
1097   return MemRefType::Builder(t).setLayout({});
1098 }
1099 
1100 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
1101                                                 ArrayRef<AffineExpr> exprs,
1102                                                 MLIRContext *context) {
1103   assert(!sizes.empty() && !exprs.empty() &&
1104          "expected non-empty sizes and exprs");
1105 
1106   // Size 0 corner case is useful for canonicalizations.
1107   if (llvm::is_contained(sizes, 0))
1108     return getAffineConstantExpr(0, context);
1109 
1110   auto maps = AffineMap::inferFromExprList(exprs);
1111   assert(!maps.empty() && "Expected one non-empty map");
1112   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
1113 
1114   AffineExpr expr;
1115   bool dynamicPoisonBit = false;
1116   int64_t runningSize = 1;
1117   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
1118     int64_t size = std::get<1>(en);
1119     // Degenerate case, no size =-> no stride
1120     if (size == 0)
1121       continue;
1122     AffineExpr dimExpr = std::get<0>(en);
1123     AffineExpr stride = dynamicPoisonBit
1124                             ? getAffineSymbolExpr(nSymbols++, context)
1125                             : getAffineConstantExpr(runningSize, context);
1126     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
1127     if (size > 0) {
1128       runningSize *= size;
1129       assert(runningSize > 0 && "integer overflow in size computation");
1130     } else {
1131       dynamicPoisonBit = true;
1132     }
1133   }
1134   return simplifyAffineExpr(expr, numDims, nSymbols);
1135 }
1136 
1137 /// Return a version of `t` with a layout that has all dynamic offset and
1138 /// strides. This is used to erase the static layout.
1139 MemRefType mlir::eraseStridedLayout(MemRefType t) {
1140   auto val = ShapedType::kDynamicStrideOrOffset;
1141   return MemRefType::Builder(t).setLayout(
1142       AffineMapAttr::get(makeStridedLinearLayoutMap(
1143           SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
1144 }
1145 
1146 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
1147                                                 MLIRContext *context) {
1148   SmallVector<AffineExpr, 4> exprs;
1149   exprs.reserve(sizes.size());
1150   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
1151     exprs.push_back(getAffineDimExpr(dim, context));
1152   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
1153 }
1154 
1155 /// Return true if the layout for `t` is compatible with strided semantics.
1156 bool mlir::isStrided(MemRefType t) {
1157   int64_t offset;
1158   SmallVector<int64_t, 4> strides;
1159   auto res = getStridesAndOffset(t, strides, offset);
1160   return succeeded(res);
1161 }
1162 
1163 /// Return the layout map in strided linear layout AffineMap form.
1164 /// Return null if the layout is not compatible with a strided layout.
1165 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
1166   int64_t offset;
1167   SmallVector<int64_t, 4> strides;
1168   if (failed(getStridesAndOffset(t, strides, offset)))
1169     return AffineMap();
1170   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1171 }
1172 
1173 /// Return the AffineExpr representation of the offset, assuming `memRefType`
1174 /// is a strided memref.
1175 static AffineExpr getOffsetExpr(MemRefType memrefType) {
1176   SmallVector<AffineExpr> strides;
1177   AffineExpr offset;
1178   if (failed(getStridesAndOffset(memrefType, strides, offset)))
1179     assert(false && "expected strided memref");
1180   return offset;
1181 }
1182 
1183 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1184 /// `offset` AffineExpr.
1185 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
1186                                                    ArrayRef<int64_t> shape,
1187                                                    Type elementType,
1188                                                    AffineExpr offset) {
1189   AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1190   AffineExpr contiguousRowMajor = canonical + offset;
1191   AffineMap contiguousRowMajorMap =
1192       AffineMap::inferFromExprList({contiguousRowMajor})[0];
1193   return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1194 }
1195 
1196 /// Helper determining if a memref is static-shape and contiguous-row-major
1197 /// layout, while still allowing for an arbitrary offset (any static or
1198 /// dynamic value).
1199 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1200   if (!memrefType.hasStaticShape())
1201     return false;
1202   AffineExpr offset = getOffsetExpr(memrefType);
1203   MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1204       memrefType.getContext(), memrefType.getShape(),
1205       memrefType.getElementType(), offset);
1206   return canonicalizeStridedLayout(memrefType) ==
1207          canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1208 }
1209