1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/TensorEncoding.h"
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/BitVector.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 //===----------------------------------------------------------------------===//
28 /// Tablegen Type Definitions
29 //===----------------------------------------------------------------------===//
30 
31 #define GET_TYPEDEF_CLASSES
32 #include "mlir/IR/BuiltinTypes.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 /// Tablegen Interface Definitions
36 //===----------------------------------------------------------------------===//
37 
38 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
39 
40 //===----------------------------------------------------------------------===//
41 // BuiltinDialect
42 //===----------------------------------------------------------------------===//
43 
44 void BuiltinDialect::registerTypes() {
45   addTypes<
46 #define GET_TYPEDEF_LIST
47 #include "mlir/IR/BuiltinTypes.cpp.inc"
48       >();
49 }
50 
51 //===----------------------------------------------------------------------===//
52 /// ComplexType
53 //===----------------------------------------------------------------------===//
54 
55 /// Verify the construction of an integer type.
56 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
57                                   Type elementType) {
58   if (!elementType.isIntOrFloat())
59     return emitError() << "invalid element type for complex";
60   return success();
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Integer Type
65 //===----------------------------------------------------------------------===//
66 
67 // static constexpr must have a definition (until in C++17 and inline variable).
68 constexpr unsigned IntegerType::kMaxWidth;
69 
70 /// Verify the construction of an integer type.
71 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
72                                   unsigned width,
73                                   SignednessSemantics signedness) {
74   if (width > IntegerType::kMaxWidth) {
75     return emitError() << "integer bitwidth is limited to "
76                        << IntegerType::kMaxWidth << " bits";
77   }
78   return success();
79 }
80 
81 unsigned IntegerType::getWidth() const { return getImpl()->width; }
82 
83 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
84   return getImpl()->signedness;
85 }
86 
87 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
88   if (!scale)
89     return IntegerType();
90   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Float Type
95 //===----------------------------------------------------------------------===//
96 
97 unsigned FloatType::getWidth() {
98   if (isa<Float16Type, BFloat16Type>())
99     return 16;
100   if (isa<Float32Type>())
101     return 32;
102   if (isa<Float64Type>())
103     return 64;
104   if (isa<Float80Type>())
105     return 80;
106   if (isa<Float128Type>())
107     return 128;
108   llvm_unreachable("unexpected float type");
109 }
110 
111 /// Returns the floating semantics for the given type.
112 const llvm::fltSemantics &FloatType::getFloatSemantics() {
113   if (isa<BFloat16Type>())
114     return APFloat::BFloat();
115   if (isa<Float16Type>())
116     return APFloat::IEEEhalf();
117   if (isa<Float32Type>())
118     return APFloat::IEEEsingle();
119   if (isa<Float64Type>())
120     return APFloat::IEEEdouble();
121   if (isa<Float80Type>())
122     return APFloat::x87DoubleExtended();
123   if (isa<Float128Type>())
124     return APFloat::IEEEquad();
125   llvm_unreachable("non-floating point type used");
126 }
127 
128 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
129   if (!scale)
130     return FloatType();
131   MLIRContext *ctx = getContext();
132   if (isF16() || isBF16()) {
133     if (scale == 2)
134       return FloatType::getF32(ctx);
135     if (scale == 4)
136       return FloatType::getF64(ctx);
137   }
138   if (isF32())
139     if (scale == 2)
140       return FloatType::getF64(ctx);
141   return FloatType();
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // FunctionType
146 //===----------------------------------------------------------------------===//
147 
148 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
149 
150 ArrayRef<Type> FunctionType::getInputs() const {
151   return getImpl()->getInputs();
152 }
153 
154 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
155 
156 ArrayRef<Type> FunctionType::getResults() const {
157   return getImpl()->getResults();
158 }
159 
160 /// Helper to call a callback once on each index in the range
161 /// [0, `totalIndices`), *except* for the indices given in `indices`.
162 /// `indices` is allowed to have duplicates and can be in any order.
163 inline void iterateIndicesExcept(unsigned totalIndices,
164                                  ArrayRef<unsigned> indices,
165                                  function_ref<void(unsigned)> callback) {
166   llvm::BitVector skipIndices(totalIndices);
167   for (unsigned i : indices)
168     skipIndices.set(i);
169 
170   for (unsigned i = 0; i < totalIndices; ++i)
171     if (!skipIndices.test(i))
172       callback(i);
173 }
174 
175 /// Returns a new function type with the specified arguments and results
176 /// inserted.
177 FunctionType FunctionType::getWithArgsAndResults(
178     ArrayRef<unsigned> argIndices, TypeRange argTypes,
179     ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
180   assert(argIndices.size() == argTypes.size());
181   assert(resultIndices.size() == resultTypes.size());
182 
183   ArrayRef<Type> newInputTypes = getInputs();
184   SmallVector<Type, 4> newInputTypesBuffer;
185   if (!argIndices.empty()) {
186     const auto *fromIt = newInputTypes.begin();
187     for (auto it : llvm::zip(argIndices, argTypes)) {
188       const auto *toIt = newInputTypes.begin() + std::get<0>(it);
189       newInputTypesBuffer.append(fromIt, toIt);
190       newInputTypesBuffer.push_back(std::get<1>(it));
191       fromIt = toIt;
192     }
193     newInputTypesBuffer.append(fromIt, newInputTypes.end());
194     newInputTypes = newInputTypesBuffer;
195   }
196 
197   ArrayRef<Type> newResultTypes = getResults();
198   SmallVector<Type, 4> newResultTypesBuffer;
199   if (!resultIndices.empty()) {
200     const auto *fromIt = newResultTypes.begin();
201     for (auto it : llvm::zip(resultIndices, resultTypes)) {
202       const auto *toIt = newResultTypes.begin() + std::get<0>(it);
203       newResultTypesBuffer.append(fromIt, toIt);
204       newResultTypesBuffer.push_back(std::get<1>(it));
205       fromIt = toIt;
206     }
207     newResultTypesBuffer.append(fromIt, newResultTypes.end());
208     newResultTypes = newResultTypesBuffer;
209   }
210 
211   return FunctionType::get(getContext(), newInputTypes, newResultTypes);
212 }
213 
214 /// Returns a new function type without the specified arguments and results.
215 FunctionType
216 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
217                                        ArrayRef<unsigned> resultIndices) {
218   ArrayRef<Type> newInputTypes = getInputs();
219   SmallVector<Type, 4> newInputTypesBuffer;
220   if (!argIndices.empty()) {
221     unsigned originalNumArgs = getNumInputs();
222     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
223       newInputTypesBuffer.emplace_back(getInput(i));
224     });
225     newInputTypes = newInputTypesBuffer;
226   }
227 
228   ArrayRef<Type> newResultTypes = getResults();
229   SmallVector<Type, 4> newResultTypesBuffer;
230   if (!resultIndices.empty()) {
231     unsigned originalNumResults = getNumResults();
232     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
233       newResultTypesBuffer.emplace_back(getResult(i));
234     });
235     newResultTypes = newResultTypesBuffer;
236   }
237 
238   return get(getContext(), newInputTypes, newResultTypes);
239 }
240 
241 void FunctionType::walkImmediateSubElements(
242     function_ref<void(Attribute)> walkAttrsFn,
243     function_ref<void(Type)> walkTypesFn) const {
244   for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
245     walkTypesFn(type);
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // OpaqueType
250 //===----------------------------------------------------------------------===//
251 
252 /// Verify the construction of an opaque type.
253 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
254                                  Identifier dialect, StringRef typeData) {
255   if (!Dialect::isValidNamespace(dialect.strref()))
256     return emitError() << "invalid dialect namespace '" << dialect << "'";
257 
258   // Check that the dialect is actually registered.
259   MLIRContext *context = dialect.getContext();
260   if (!context->allowsUnregisteredDialects() &&
261       !context->getLoadedDialect(dialect.strref())) {
262     return emitError()
263            << "`!" << dialect << "<\"" << typeData << "\">"
264            << "` type created with unregistered dialect. If this is "
265               "intended, please call allowUnregisteredDialects() on the "
266               "MLIRContext, or use -allow-unregistered-dialect with "
267               "the MLIR opt tool used";
268   }
269 
270   return success();
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // ShapedType
275 //===----------------------------------------------------------------------===//
276 constexpr int64_t ShapedType::kDynamicSize;
277 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
278 
279 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
280   if (auto other = dyn_cast<MemRefType>()) {
281     MemRefType::Builder b(other);
282     b.setShape(shape);
283     b.setElementType(elementType);
284     return b;
285   }
286 
287   if (auto other = dyn_cast<UnrankedMemRefType>()) {
288     MemRefType::Builder b(shape, elementType);
289     b.setMemorySpace(other.getMemorySpace());
290     return b;
291   }
292 
293   if (isa<TensorType>())
294     return RankedTensorType::get(shape, elementType);
295 
296   if (isa<VectorType>())
297     return VectorType::get(shape, elementType);
298 
299   llvm_unreachable("Unhandled ShapedType clone case");
300 }
301 
302 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
303   if (auto other = dyn_cast<MemRefType>()) {
304     MemRefType::Builder b(other);
305     b.setShape(shape);
306     return b;
307   }
308 
309   if (auto other = dyn_cast<UnrankedMemRefType>()) {
310     MemRefType::Builder b(shape, other.getElementType());
311     b.setShape(shape);
312     b.setMemorySpace(other.getMemorySpace());
313     return b;
314   }
315 
316   if (isa<TensorType>())
317     return RankedTensorType::get(shape, getElementType());
318 
319   if (isa<VectorType>())
320     return VectorType::get(shape, getElementType());
321 
322   llvm_unreachable("Unhandled ShapedType clone case");
323 }
324 
325 ShapedType ShapedType::clone(Type elementType) {
326   if (auto other = dyn_cast<MemRefType>()) {
327     MemRefType::Builder b(other);
328     b.setElementType(elementType);
329     return b;
330   }
331 
332   if (auto other = dyn_cast<UnrankedMemRefType>()) {
333     return UnrankedMemRefType::get(elementType, other.getMemorySpace());
334   }
335 
336   if (isa<TensorType>()) {
337     if (hasRank())
338       return RankedTensorType::get(getShape(), elementType);
339     return UnrankedTensorType::get(elementType);
340   }
341 
342   if (isa<VectorType>())
343     return VectorType::get(getShape(), elementType);
344 
345   llvm_unreachable("Unhandled ShapedType clone hit");
346 }
347 
348 Type ShapedType::getElementType() const {
349   return TypeSwitch<Type, Type>(*this)
350       .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
351             UnrankedMemRefType>([](auto ty) { return ty.getElementType(); });
352 }
353 
354 unsigned ShapedType::getElementTypeBitWidth() const {
355   return getElementType().getIntOrFloatBitWidth();
356 }
357 
358 int64_t ShapedType::getNumElements() const {
359   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
360   auto shape = getShape();
361   int64_t num = 1;
362   for (auto dim : shape) {
363     num *= dim;
364     assert(num >= 0 && "integer overflow in element count computation");
365   }
366   return num;
367 }
368 
369 int64_t ShapedType::getRank() const {
370   assert(hasRank() && "cannot query rank of unranked shaped type");
371   return getShape().size();
372 }
373 
374 bool ShapedType::hasRank() const {
375   return !isa<UnrankedMemRefType, UnrankedTensorType>();
376 }
377 
378 int64_t ShapedType::getDimSize(unsigned idx) const {
379   assert(idx < getRank() && "invalid index for shaped type");
380   return getShape()[idx];
381 }
382 
383 bool ShapedType::isDynamicDim(unsigned idx) const {
384   assert(idx < getRank() && "invalid index for shaped type");
385   return isDynamic(getShape()[idx]);
386 }
387 
388 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
389   assert(index < getRank() && "invalid index");
390   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
391   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
392 }
393 
394 /// Get the number of bits require to store a value of the given shaped type.
395 /// Compute the value recursively since tensors are allowed to have vectors as
396 /// elements.
397 int64_t ShapedType::getSizeInBits() const {
398   assert(hasStaticShape() &&
399          "cannot get the bit size of an aggregate with a dynamic shape");
400 
401   auto elementType = getElementType();
402   if (elementType.isIntOrFloat())
403     return elementType.getIntOrFloatBitWidth() * getNumElements();
404 
405   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
406     elementType = complexType.getElementType();
407     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
408   }
409 
410   // Tensors can have vectors and other tensors as elements, other shaped types
411   // cannot.
412   assert(isa<TensorType>() && "unsupported element type");
413   assert((elementType.isa<VectorType, TensorType>()) &&
414          "unsupported tensor element type");
415   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
416 }
417 
418 ArrayRef<int64_t> ShapedType::getShape() const {
419   if (auto vectorType = dyn_cast<VectorType>())
420     return vectorType.getShape();
421   if (auto tensorType = dyn_cast<RankedTensorType>())
422     return tensorType.getShape();
423   return cast<MemRefType>().getShape();
424 }
425 
426 int64_t ShapedType::getNumDynamicDims() const {
427   return llvm::count_if(getShape(), isDynamic);
428 }
429 
430 bool ShapedType::hasStaticShape() const {
431   return hasRank() && llvm::none_of(getShape(), isDynamic);
432 }
433 
434 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
435   return hasStaticShape() && getShape() == shape;
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // VectorType
440 //===----------------------------------------------------------------------===//
441 
442 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
443                                  ArrayRef<int64_t> shape, Type elementType) {
444   if (shape.empty())
445     return emitError() << "vector types must have at least one dimension";
446 
447   if (!isValidElementType(elementType))
448     return emitError()
449            << "vector elements must be int/index/float type but got "
450            << elementType;
451 
452   if (any_of(shape, [](int64_t i) { return i <= 0; }))
453     return emitError()
454            << "vector types must have positive constant sizes but got "
455            << shape;
456 
457   return success();
458 }
459 
460 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
461   if (!scale)
462     return VectorType();
463   if (auto et = getElementType().dyn_cast<IntegerType>())
464     if (auto scaledEt = et.scaleElementBitwidth(scale))
465       return VectorType::get(getShape(), scaledEt);
466   if (auto et = getElementType().dyn_cast<FloatType>())
467     if (auto scaledEt = et.scaleElementBitwidth(scale))
468       return VectorType::get(getShape(), scaledEt);
469   return VectorType();
470 }
471 
472 void VectorType::walkImmediateSubElements(
473     function_ref<void(Attribute)> walkAttrsFn,
474     function_ref<void(Type)> walkTypesFn) const {
475   walkTypesFn(getElementType());
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // TensorType
480 //===----------------------------------------------------------------------===//
481 
482 // Check if "elementType" can be an element type of a tensor.
483 static LogicalResult
484 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
485                        Type elementType) {
486   if (!TensorType::isValidElementType(elementType))
487     return emitError() << "invalid tensor element type: " << elementType;
488   return success();
489 }
490 
491 /// Return true if the specified element type is ok in a tensor.
492 bool TensorType::isValidElementType(Type type) {
493   // Note: Non standard/builtin types are allowed to exist within tensor
494   // types. Dialects are expected to verify that tensor types have a valid
495   // element type within that dialect.
496   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
497                   IndexType>() ||
498          !llvm::isa<BuiltinDialect>(type.getDialect());
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // RankedTensorType
503 //===----------------------------------------------------------------------===//
504 
505 LogicalResult
506 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
507                          ArrayRef<int64_t> shape, Type elementType,
508                          Attribute encoding) {
509   for (int64_t s : shape)
510     if (s < -1)
511       return emitError() << "invalid tensor dimension size";
512   if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
513     if (failed(v.verifyEncoding(shape, elementType, emitError)))
514       return failure();
515   return checkTensorElementType(emitError, elementType);
516 }
517 
518 void RankedTensorType::walkImmediateSubElements(
519     function_ref<void(Attribute)> walkAttrsFn,
520     function_ref<void(Type)> walkTypesFn) const {
521   walkTypesFn(getElementType());
522   if (Attribute encoding = getEncoding())
523     walkAttrsFn(encoding);
524 }
525 
526 //===----------------------------------------------------------------------===//
527 // UnrankedTensorType
528 //===----------------------------------------------------------------------===//
529 
530 LogicalResult
531 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
532                            Type elementType) {
533   return checkTensorElementType(emitError, elementType);
534 }
535 
536 void UnrankedTensorType::walkImmediateSubElements(
537     function_ref<void(Attribute)> walkAttrsFn,
538     function_ref<void(Type)> walkTypesFn) const {
539   walkTypesFn(getElementType());
540 }
541 
542 //===----------------------------------------------------------------------===//
543 // BaseMemRefType
544 //===----------------------------------------------------------------------===//
545 
546 Attribute BaseMemRefType::getMemorySpace() const {
547   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
548     return rankedMemRefTy.getMemorySpace();
549   return cast<UnrankedMemRefType>().getMemorySpace();
550 }
551 
552 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
553   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
554     return rankedMemRefTy.getMemorySpaceAsInt();
555   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // MemRefType
560 //===----------------------------------------------------------------------===//
561 
562 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
563 /// `originalShape` with some `1` entries erased, return the set of indices
564 /// that specifies which of the entries of `originalShape` are dropped to obtain
565 /// `reducedShape`. The returned mask can be applied as a projection to
566 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
567 /// which dimensions must be kept when e.g. compute MemRef strides under
568 /// rank-reducing operations. Return None if reducedShape cannot be obtained
569 /// by dropping only `1` entries in `originalShape`.
570 llvm::Optional<llvm::SmallDenseSet<unsigned>>
571 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
572                                ArrayRef<int64_t> reducedShape) {
573   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
574   llvm::SmallDenseSet<unsigned> unusedDims;
575   unsigned reducedIdx = 0;
576   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
577     // Greedily insert `originalIdx` if no match.
578     if (reducedIdx < reducedRank &&
579         originalShape[originalIdx] == reducedShape[reducedIdx]) {
580       reducedIdx++;
581       continue;
582     }
583 
584     unusedDims.insert(originalIdx);
585     // If no match on `originalIdx`, the `originalShape` at this dimension
586     // must be 1, otherwise we bail.
587     if (originalShape[originalIdx] != 1)
588       return llvm::None;
589   }
590   // The whole reducedShape must be scanned, otherwise we bail.
591   if (reducedIdx != reducedRank)
592     return llvm::None;
593   return unusedDims;
594 }
595 
596 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
597   // Empty attribute is allowed as default memory space.
598   if (!memorySpace)
599     return true;
600 
601   // Supported built-in attributes.
602   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
603     return true;
604 
605   // Allow custom dialect attributes.
606   if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect()))
607     return true;
608 
609   return false;
610 }
611 
612 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
613                                                MLIRContext *ctx) {
614   if (memorySpace == 0)
615     return nullptr;
616 
617   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
618 }
619 
620 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
621   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
622   if (intMemorySpace && intMemorySpace.getValue() == 0)
623     return nullptr;
624 
625   return memorySpace;
626 }
627 
628 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
629   if (!memorySpace)
630     return 0;
631 
632   assert(memorySpace.isa<IntegerAttr>() &&
633          "Using `getMemorySpaceInteger` with non-Integer attribute");
634 
635   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
636 }
637 
638 MemRefType::Builder &
639 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
640   memorySpace =
641       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
642   return *this;
643 }
644 
645 unsigned MemRefType::getMemorySpaceAsInt() const {
646   return detail::getMemorySpaceAsInt(getMemorySpace());
647 }
648 
649 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
650                            MemRefLayoutAttrInterface layout,
651                            Attribute memorySpace) {
652   // Use default layout for empty attribute.
653   if (!layout)
654     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
655         shape.size(), elementType.getContext()));
656 
657   // Drop default memory space value and replace it with empty attribute.
658   memorySpace = skipDefaultMemorySpace(memorySpace);
659 
660   return Base::get(elementType.getContext(), shape, elementType, layout,
661                    memorySpace);
662 }
663 
664 MemRefType MemRefType::getChecked(
665     function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
666     Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
667 
668   // Use default layout for empty attribute.
669   if (!layout)
670     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
671         shape.size(), elementType.getContext()));
672 
673   // Drop default memory space value and replace it with empty attribute.
674   memorySpace = skipDefaultMemorySpace(memorySpace);
675 
676   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
677                           elementType, layout, memorySpace);
678 }
679 
680 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
681                            AffineMap map, Attribute memorySpace) {
682 
683   // Use default layout for empty map.
684   if (!map)
685     map = AffineMap::getMultiDimIdentityMap(shape.size(),
686                                             elementType.getContext());
687 
688   // Wrap AffineMap into Attribute.
689   Attribute layout = AffineMapAttr::get(map);
690 
691   // Drop default memory space value and replace it with empty attribute.
692   memorySpace = skipDefaultMemorySpace(memorySpace);
693 
694   return Base::get(elementType.getContext(), shape, elementType, layout,
695                    memorySpace);
696 }
697 
698 MemRefType
699 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
700                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
701                        Attribute memorySpace) {
702 
703   // Use default layout for empty map.
704   if (!map)
705     map = AffineMap::getMultiDimIdentityMap(shape.size(),
706                                             elementType.getContext());
707 
708   // Wrap AffineMap into Attribute.
709   Attribute layout = AffineMapAttr::get(map);
710 
711   // Drop default memory space value and replace it with empty attribute.
712   memorySpace = skipDefaultMemorySpace(memorySpace);
713 
714   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
715                           elementType, layout, memorySpace);
716 }
717 
718 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
719                            AffineMap map, unsigned memorySpaceInd) {
720 
721   // Use default layout for empty map.
722   if (!map)
723     map = AffineMap::getMultiDimIdentityMap(shape.size(),
724                                             elementType.getContext());
725 
726   // Wrap AffineMap into Attribute.
727   Attribute layout = AffineMapAttr::get(map);
728 
729   // Convert deprecated integer-like memory space to Attribute.
730   Attribute memorySpace =
731       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
732 
733   return Base::get(elementType.getContext(), shape, elementType, layout,
734                    memorySpace);
735 }
736 
737 MemRefType
738 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
739                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
740                        unsigned memorySpaceInd) {
741 
742   // Use default layout for empty map.
743   if (!map)
744     map = AffineMap::getMultiDimIdentityMap(shape.size(),
745                                             elementType.getContext());
746 
747   // Wrap AffineMap into Attribute.
748   Attribute layout = AffineMapAttr::get(map);
749 
750   // Convert deprecated integer-like memory space to Attribute.
751   Attribute memorySpace =
752       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
753 
754   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
755                           elementType, layout, memorySpace);
756 }
757 
758 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
759                                  ArrayRef<int64_t> shape, Type elementType,
760                                  MemRefLayoutAttrInterface layout,
761                                  Attribute memorySpace) {
762   if (!BaseMemRefType::isValidElementType(elementType))
763     return emitError() << "invalid memref element type";
764 
765   // Negative sizes are not allowed except for `-1` that means dynamic size.
766   for (int64_t s : shape)
767     if (s < -1)
768       return emitError() << "invalid memref size";
769 
770   assert(layout && "missing layout specification");
771   if (failed(layout.verifyLayout(shape, emitError)))
772     return failure();
773 
774   if (!isSupportedMemorySpace(memorySpace))
775     return emitError() << "unsupported memory space Attribute";
776 
777   return success();
778 }
779 
780 void MemRefType::walkImmediateSubElements(
781     function_ref<void(Attribute)> walkAttrsFn,
782     function_ref<void(Type)> walkTypesFn) const {
783   walkTypesFn(getElementType());
784   if (!getLayout().isIdentity())
785     walkAttrsFn(getLayout());
786   walkAttrsFn(getMemorySpace());
787 }
788 
789 //===----------------------------------------------------------------------===//
790 // UnrankedMemRefType
791 //===----------------------------------------------------------------------===//
792 
793 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
794   return detail::getMemorySpaceAsInt(getMemorySpace());
795 }
796 
797 LogicalResult
798 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
799                            Type elementType, Attribute memorySpace) {
800   if (!BaseMemRefType::isValidElementType(elementType))
801     return emitError() << "invalid memref element type";
802 
803   if (!isSupportedMemorySpace(memorySpace))
804     return emitError() << "unsupported memory space Attribute";
805 
806   return success();
807 }
808 
809 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
810 // i.e. single term). Accumulate the AffineExpr into the existing one.
811 static void extractStridesFromTerm(AffineExpr e,
812                                    AffineExpr multiplicativeFactor,
813                                    MutableArrayRef<AffineExpr> strides,
814                                    AffineExpr &offset) {
815   if (auto dim = e.dyn_cast<AffineDimExpr>())
816     strides[dim.getPosition()] =
817         strides[dim.getPosition()] + multiplicativeFactor;
818   else
819     offset = offset + e * multiplicativeFactor;
820 }
821 
822 /// Takes a single AffineExpr `e` and populates the `strides` array with the
823 /// strides expressions for each dim position.
824 /// The convention is that the strides for dimensions d0, .. dn appear in
825 /// order to make indexing intuitive into the result.
826 static LogicalResult extractStrides(AffineExpr e,
827                                     AffineExpr multiplicativeFactor,
828                                     MutableArrayRef<AffineExpr> strides,
829                                     AffineExpr &offset) {
830   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
831   if (!bin) {
832     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
833     return success();
834   }
835 
836   if (bin.getKind() == AffineExprKind::CeilDiv ||
837       bin.getKind() == AffineExprKind::FloorDiv ||
838       bin.getKind() == AffineExprKind::Mod)
839     return failure();
840 
841   if (bin.getKind() == AffineExprKind::Mul) {
842     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
843     if (dim) {
844       strides[dim.getPosition()] =
845           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
846       return success();
847     }
848     // LHS and RHS may both contain complex expressions of dims. Try one path
849     // and if it fails try the other. This is guaranteed to succeed because
850     // only one path may have a `dim`, otherwise this is not an AffineExpr in
851     // the first place.
852     if (bin.getLHS().isSymbolicOrConstant())
853       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
854                             strides, offset);
855     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
856                           strides, offset);
857   }
858 
859   if (bin.getKind() == AffineExprKind::Add) {
860     auto res1 =
861         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
862     auto res2 =
863         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
864     return success(succeeded(res1) && succeeded(res2));
865   }
866 
867   llvm_unreachable("unexpected binary operation");
868 }
869 
870 LogicalResult mlir::getStridesAndOffset(MemRefType t,
871                                         SmallVectorImpl<AffineExpr> &strides,
872                                         AffineExpr &offset) {
873   AffineMap m = t.getLayout().getAffineMap();
874 
875   if (m.getNumResults() != 1 && !m.isIdentity())
876     return failure();
877 
878   auto zero = getAffineConstantExpr(0, t.getContext());
879   auto one = getAffineConstantExpr(1, t.getContext());
880   offset = zero;
881   strides.assign(t.getRank(), zero);
882 
883   // Canonical case for empty map.
884   if (m.isIdentity()) {
885     // 0-D corner case, offset is already 0.
886     if (t.getRank() == 0)
887       return success();
888     auto stridedExpr =
889         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
890     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
891       return success();
892     assert(false && "unexpected failure: extract strides in canonical layout");
893   }
894 
895   // Non-canonical case requires more work.
896   auto stridedExpr =
897       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
898   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
899     offset = AffineExpr();
900     strides.clear();
901     return failure();
902   }
903 
904   // Simplify results to allow folding to constants and simple checks.
905   unsigned numDims = m.getNumDims();
906   unsigned numSymbols = m.getNumSymbols();
907   offset = simplifyAffineExpr(offset, numDims, numSymbols);
908   for (auto &stride : strides)
909     stride = simplifyAffineExpr(stride, numDims, numSymbols);
910 
911   /// In practice, a strided memref must be internally non-aliasing. Test
912   /// against 0 as a proxy.
913   /// TODO: static cases can have more advanced checks.
914   /// TODO: dynamic cases would require a way to compare symbolic
915   /// expressions and would probably need an affine set context propagated
916   /// everywhere.
917   if (llvm::any_of(strides, [](AffineExpr e) {
918         return e == getAffineConstantExpr(0, e.getContext());
919       })) {
920     offset = AffineExpr();
921     strides.clear();
922     return failure();
923   }
924 
925   return success();
926 }
927 
928 LogicalResult mlir::getStridesAndOffset(MemRefType t,
929                                         SmallVectorImpl<int64_t> &strides,
930                                         int64_t &offset) {
931   AffineExpr offsetExpr;
932   SmallVector<AffineExpr, 4> strideExprs;
933   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
934     return failure();
935   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
936     offset = cst.getValue();
937   else
938     offset = ShapedType::kDynamicStrideOrOffset;
939   for (auto e : strideExprs) {
940     if (auto c = e.dyn_cast<AffineConstantExpr>())
941       strides.push_back(c.getValue());
942     else
943       strides.push_back(ShapedType::kDynamicStrideOrOffset);
944   }
945   return success();
946 }
947 
948 void UnrankedMemRefType::walkImmediateSubElements(
949     function_ref<void(Attribute)> walkAttrsFn,
950     function_ref<void(Type)> walkTypesFn) const {
951   walkTypesFn(getElementType());
952   walkAttrsFn(getMemorySpace());
953 }
954 
955 //===----------------------------------------------------------------------===//
956 /// TupleType
957 //===----------------------------------------------------------------------===//
958 
959 /// Return the elements types for this tuple.
960 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
961 
962 /// Accumulate the types contained in this tuple and tuples nested within it.
963 /// Note that this only flattens nested tuples, not any other container type,
964 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
965 /// (i32, tensor<i32>, f32, i64)
966 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
967   for (Type type : getTypes()) {
968     if (auto nestedTuple = type.dyn_cast<TupleType>())
969       nestedTuple.getFlattenedTypes(types);
970     else
971       types.push_back(type);
972   }
973 }
974 
975 /// Return the number of element types.
976 size_t TupleType::size() const { return getImpl()->size(); }
977 
978 void TupleType::walkImmediateSubElements(
979     function_ref<void(Attribute)> walkAttrsFn,
980     function_ref<void(Type)> walkTypesFn) const {
981   for (Type type : getTypes())
982     walkTypesFn(type);
983 }
984 
985 //===----------------------------------------------------------------------===//
986 // Type Utilities
987 //===----------------------------------------------------------------------===//
988 
989 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
990                                            int64_t offset,
991                                            MLIRContext *context) {
992   AffineExpr expr;
993   unsigned nSymbols = 0;
994 
995   // AffineExpr for offset.
996   // Static case.
997   if (offset != MemRefType::getDynamicStrideOrOffset()) {
998     auto cst = getAffineConstantExpr(offset, context);
999     expr = cst;
1000   } else {
1001     // Dynamic case, new symbol for the offset.
1002     auto sym = getAffineSymbolExpr(nSymbols++, context);
1003     expr = sym;
1004   }
1005 
1006   // AffineExpr for strides.
1007   for (auto en : llvm::enumerate(strides)) {
1008     auto dim = en.index();
1009     auto stride = en.value();
1010     assert(stride != 0 && "Invalid stride specification");
1011     auto d = getAffineDimExpr(dim, context);
1012     AffineExpr mult;
1013     // Static case.
1014     if (stride != MemRefType::getDynamicStrideOrOffset())
1015       mult = getAffineConstantExpr(stride, context);
1016     else
1017       // Dynamic case, new symbol for each new stride.
1018       mult = getAffineSymbolExpr(nSymbols++, context);
1019     expr = expr + d * mult;
1020   }
1021 
1022   return AffineMap::get(strides.size(), nSymbols, expr);
1023 }
1024 
1025 /// Return a version of `t` with identity layout if it can be determined
1026 /// statically that the layout is the canonical contiguous strided layout.
1027 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
1028 /// `t` with simplified layout.
1029 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
1030 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
1031   AffineMap m = t.getLayout().getAffineMap();
1032 
1033   // Already in canonical form.
1034   if (m.isIdentity())
1035     return t;
1036 
1037   // Can't reduce to canonical identity form, return in canonical form.
1038   if (m.getNumResults() > 1)
1039     return t;
1040 
1041   // Corner-case for 0-D affine maps.
1042   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
1043     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
1044       if (cst.getValue() == 0)
1045         return MemRefType::Builder(t).setLayout({});
1046     return t;
1047   }
1048 
1049   // 0-D corner case for empty shape that still have an affine map. Example:
1050   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
1051   // offset needs to remain, just return t.
1052   if (t.getShape().empty())
1053     return t;
1054 
1055   // If the canonical strided layout for the sizes of `t` is equal to the
1056   // simplified layout of `t` we can just return an empty layout. Otherwise,
1057   // just simplify the existing layout.
1058   AffineExpr expr =
1059       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
1060   auto simplifiedLayoutExpr =
1061       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
1062   if (expr != simplifiedLayoutExpr)
1063     return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
1064         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
1065   return MemRefType::Builder(t).setLayout({});
1066 }
1067 
1068 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
1069                                                 ArrayRef<AffineExpr> exprs,
1070                                                 MLIRContext *context) {
1071   assert(!sizes.empty() && !exprs.empty() &&
1072          "expected non-empty sizes and exprs");
1073 
1074   // Size 0 corner case is useful for canonicalizations.
1075   if (llvm::is_contained(sizes, 0))
1076     return getAffineConstantExpr(0, context);
1077 
1078   auto maps = AffineMap::inferFromExprList(exprs);
1079   assert(!maps.empty() && "Expected one non-empty map");
1080   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
1081 
1082   AffineExpr expr;
1083   bool dynamicPoisonBit = false;
1084   int64_t runningSize = 1;
1085   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
1086     int64_t size = std::get<1>(en);
1087     // Degenerate case, no size =-> no stride
1088     if (size == 0)
1089       continue;
1090     AffineExpr dimExpr = std::get<0>(en);
1091     AffineExpr stride = dynamicPoisonBit
1092                             ? getAffineSymbolExpr(nSymbols++, context)
1093                             : getAffineConstantExpr(runningSize, context);
1094     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
1095     if (size > 0) {
1096       runningSize *= size;
1097       assert(runningSize > 0 && "integer overflow in size computation");
1098     } else {
1099       dynamicPoisonBit = true;
1100     }
1101   }
1102   return simplifyAffineExpr(expr, numDims, nSymbols);
1103 }
1104 
1105 /// Return a version of `t` with a layout that has all dynamic offset and
1106 /// strides. This is used to erase the static layout.
1107 MemRefType mlir::eraseStridedLayout(MemRefType t) {
1108   auto val = ShapedType::kDynamicStrideOrOffset;
1109   return MemRefType::Builder(t).setLayout(
1110       AffineMapAttr::get(makeStridedLinearLayoutMap(
1111           SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
1112 }
1113 
1114 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
1115                                                 MLIRContext *context) {
1116   SmallVector<AffineExpr, 4> exprs;
1117   exprs.reserve(sizes.size());
1118   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
1119     exprs.push_back(getAffineDimExpr(dim, context));
1120   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
1121 }
1122 
1123 /// Return true if the layout for `t` is compatible with strided semantics.
1124 bool mlir::isStrided(MemRefType t) {
1125   int64_t offset;
1126   SmallVector<int64_t, 4> strides;
1127   auto res = getStridesAndOffset(t, strides, offset);
1128   return succeeded(res);
1129 }
1130 
1131 /// Return the layout map in strided linear layout AffineMap form.
1132 /// Return null if the layout is not compatible with a strided layout.
1133 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
1134   int64_t offset;
1135   SmallVector<int64_t, 4> strides;
1136   if (failed(getStridesAndOffset(t, strides, offset)))
1137     return AffineMap();
1138   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1139 }
1140