1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/FunctionInterfaces.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 /// Tablegen Type Definitions
31 //===----------------------------------------------------------------------===//
32 
33 #define GET_TYPEDEF_CLASSES
34 #include "mlir/IR/BuiltinTypes.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // BuiltinDialect
38 //===----------------------------------------------------------------------===//
39 
40 void BuiltinDialect::registerTypes() {
41   addTypes<
42 #define GET_TYPEDEF_LIST
43 #include "mlir/IR/BuiltinTypes.cpp.inc"
44       >();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 /// ComplexType
49 //===----------------------------------------------------------------------===//
50 
51 /// Verify the construction of an integer type.
52 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
53                                   Type elementType) {
54   if (!elementType.isIntOrFloat())
55     return emitError() << "invalid element type for complex";
56   return success();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // Integer Type
61 //===----------------------------------------------------------------------===//
62 
63 // static constexpr must have a definition (until in C++17 and inline variable).
64 constexpr unsigned IntegerType::kMaxWidth;
65 
66 /// Verify the construction of an integer type.
67 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
68                                   unsigned width,
69                                   SignednessSemantics signedness) {
70   if (width > IntegerType::kMaxWidth) {
71     return emitError() << "integer bitwidth is limited to "
72                        << IntegerType::kMaxWidth << " bits";
73   }
74   return success();
75 }
76 
77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78 
79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80   return getImpl()->signedness;
81 }
82 
83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84   if (!scale)
85     return IntegerType();
86   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Float Type
91 //===----------------------------------------------------------------------===//
92 
93 unsigned FloatType::getWidth() {
94   if (isa<Float16Type, BFloat16Type>())
95     return 16;
96   if (isa<Float32Type>())
97     return 32;
98   if (isa<Float64Type>())
99     return 64;
100   if (isa<Float80Type>())
101     return 80;
102   if (isa<Float128Type>())
103     return 128;
104   llvm_unreachable("unexpected float type");
105 }
106 
107 /// Returns the floating semantics for the given type.
108 const llvm::fltSemantics &FloatType::getFloatSemantics() {
109   if (isa<BFloat16Type>())
110     return APFloat::BFloat();
111   if (isa<Float16Type>())
112     return APFloat::IEEEhalf();
113   if (isa<Float32Type>())
114     return APFloat::IEEEsingle();
115   if (isa<Float64Type>())
116     return APFloat::IEEEdouble();
117   if (isa<Float80Type>())
118     return APFloat::x87DoubleExtended();
119   if (isa<Float128Type>())
120     return APFloat::IEEEquad();
121   llvm_unreachable("non-floating point type used");
122 }
123 
124 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
125   if (!scale)
126     return FloatType();
127   MLIRContext *ctx = getContext();
128   if (isF16() || isBF16()) {
129     if (scale == 2)
130       return FloatType::getF32(ctx);
131     if (scale == 4)
132       return FloatType::getF64(ctx);
133   }
134   if (isF32())
135     if (scale == 2)
136       return FloatType::getF64(ctx);
137   return FloatType();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // FunctionType
142 //===----------------------------------------------------------------------===//
143 
144 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
145 
146 ArrayRef<Type> FunctionType::getInputs() const {
147   return getImpl()->getInputs();
148 }
149 
150 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
151 
152 ArrayRef<Type> FunctionType::getResults() const {
153   return getImpl()->getResults();
154 }
155 
156 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
157   return get(getContext(), inputs, results);
158 }
159 
160 /// Returns a new function type with the specified arguments and results
161 /// inserted.
162 FunctionType FunctionType::getWithArgsAndResults(
163     ArrayRef<unsigned> argIndices, TypeRange argTypes,
164     ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
165   SmallVector<Type> argStorage, resultStorage;
166   TypeRange newArgTypes = function_interface_impl::insertTypesInto(
167       getInputs(), argIndices, argTypes, argStorage);
168   TypeRange newResultTypes = function_interface_impl::insertTypesInto(
169       getResults(), resultIndices, resultTypes, resultStorage);
170   return clone(newArgTypes, newResultTypes);
171 }
172 
173 /// Returns a new function type without the specified arguments and results.
174 FunctionType
175 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
176                                        ArrayRef<unsigned> resultIndices) {
177   SmallVector<Type> argStorage, resultStorage;
178   TypeRange newArgTypes = function_interface_impl::filterTypesOut(
179       getInputs(), argIndices, argStorage);
180   TypeRange newResultTypes = function_interface_impl::filterTypesOut(
181       getResults(), resultIndices, resultStorage);
182   return clone(newArgTypes, newResultTypes);
183 }
184 
185 void FunctionType::walkImmediateSubElements(
186     function_ref<void(Attribute)> walkAttrsFn,
187     function_ref<void(Type)> walkTypesFn) const {
188   for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
189     walkTypesFn(type);
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // OpaqueType
194 //===----------------------------------------------------------------------===//
195 
196 /// Verify the construction of an opaque type.
197 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
198                                  StringAttr dialect, StringRef typeData) {
199   if (!Dialect::isValidNamespace(dialect.strref()))
200     return emitError() << "invalid dialect namespace '" << dialect << "'";
201 
202   // Check that the dialect is actually registered.
203   MLIRContext *context = dialect.getContext();
204   if (!context->allowsUnregisteredDialects() &&
205       !context->getLoadedDialect(dialect.strref())) {
206     return emitError()
207            << "`!" << dialect << "<\"" << typeData << "\">"
208            << "` type created with unregistered dialect. If this is "
209               "intended, please call allowUnregisteredDialects() on the "
210               "MLIRContext, or use -allow-unregistered-dialect with "
211               "the MLIR opt tool used";
212   }
213 
214   return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // VectorType
219 //===----------------------------------------------------------------------===//
220 
221 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
222                                  ArrayRef<int64_t> shape, Type elementType,
223                                  unsigned numScalableDims) {
224   if (!isValidElementType(elementType))
225     return emitError()
226            << "vector elements must be int/index/float type but got "
227            << elementType;
228 
229   if (any_of(shape, [](int64_t i) { return i <= 0; }))
230     return emitError()
231            << "vector types must have positive constant sizes but got "
232            << shape;
233 
234   return success();
235 }
236 
237 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
238   if (!scale)
239     return VectorType();
240   if (auto et = getElementType().dyn_cast<IntegerType>())
241     if (auto scaledEt = et.scaleElementBitwidth(scale))
242       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
243   if (auto et = getElementType().dyn_cast<FloatType>())
244     if (auto scaledEt = et.scaleElementBitwidth(scale))
245       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
246   return VectorType();
247 }
248 
249 void VectorType::walkImmediateSubElements(
250     function_ref<void(Attribute)> walkAttrsFn,
251     function_ref<void(Type)> walkTypesFn) const {
252   walkTypesFn(getElementType());
253 }
254 
255 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
256                                  Type elementType) const {
257   return VectorType::get(shape.getValueOr(getShape()), elementType,
258                          getNumScalableDims());
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // TensorType
263 //===----------------------------------------------------------------------===//
264 
265 Type TensorType::getElementType() const {
266   return llvm::TypeSwitch<TensorType, Type>(*this)
267       .Case<RankedTensorType, UnrankedTensorType>(
268           [](auto type) { return type.getElementType(); });
269 }
270 
271 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
272 
273 ArrayRef<int64_t> TensorType::getShape() const {
274   return cast<RankedTensorType>().getShape();
275 }
276 
277 TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
278                                  Type elementType) const {
279   if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
280     if (shape)
281       return RankedTensorType::get(*shape, elementType);
282     return UnrankedTensorType::get(elementType);
283   }
284 
285   auto rankedTy = cast<RankedTensorType>();
286   if (!shape)
287     return RankedTensorType::get(rankedTy.getShape(), elementType,
288                                  rankedTy.getEncoding());
289   return RankedTensorType::get(shape.getValueOr(rankedTy.getShape()),
290                                elementType, rankedTy.getEncoding());
291 }
292 
293 // Check if "elementType" can be an element type of a tensor.
294 static LogicalResult
295 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
296                        Type elementType) {
297   if (!TensorType::isValidElementType(elementType))
298     return emitError() << "invalid tensor element type: " << elementType;
299   return success();
300 }
301 
302 /// Return true if the specified element type is ok in a tensor.
303 bool TensorType::isValidElementType(Type type) {
304   // Note: Non standard/builtin types are allowed to exist within tensor
305   // types. Dialects are expected to verify that tensor types have a valid
306   // element type within that dialect.
307   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
308                   IndexType>() ||
309          !llvm::isa<BuiltinDialect>(type.getDialect());
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // RankedTensorType
314 //===----------------------------------------------------------------------===//
315 
316 LogicalResult
317 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
318                          ArrayRef<int64_t> shape, Type elementType,
319                          Attribute encoding) {
320   for (int64_t s : shape)
321     if (s < -1)
322       return emitError() << "invalid tensor dimension size";
323   if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
324     if (failed(v.verifyEncoding(shape, elementType, emitError)))
325       return failure();
326   return checkTensorElementType(emitError, elementType);
327 }
328 
329 void RankedTensorType::walkImmediateSubElements(
330     function_ref<void(Attribute)> walkAttrsFn,
331     function_ref<void(Type)> walkTypesFn) const {
332   walkTypesFn(getElementType());
333   if (Attribute encoding = getEncoding())
334     walkAttrsFn(encoding);
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // UnrankedTensorType
339 //===----------------------------------------------------------------------===//
340 
341 LogicalResult
342 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
343                            Type elementType) {
344   return checkTensorElementType(emitError, elementType);
345 }
346 
347 void UnrankedTensorType::walkImmediateSubElements(
348     function_ref<void(Attribute)> walkAttrsFn,
349     function_ref<void(Type)> walkTypesFn) const {
350   walkTypesFn(getElementType());
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // BaseMemRefType
355 //===----------------------------------------------------------------------===//
356 
357 Type BaseMemRefType::getElementType() const {
358   return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
359       .Case<MemRefType, UnrankedMemRefType>(
360           [](auto type) { return type.getElementType(); });
361 }
362 
363 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
364 
365 ArrayRef<int64_t> BaseMemRefType::getShape() const {
366   return cast<MemRefType>().getShape();
367 }
368 
369 BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape,
370                                          Type elementType) const {
371   if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
372     if (!shape)
373       return UnrankedMemRefType::get(elementType, getMemorySpace());
374     MemRefType::Builder builder(*shape, elementType);
375     builder.setMemorySpace(getMemorySpace());
376     return builder;
377   }
378 
379   MemRefType::Builder builder(cast<MemRefType>());
380   if (shape)
381     builder.setShape(*shape);
382   builder.setElementType(elementType);
383   return builder;
384 }
385 
386 Attribute BaseMemRefType::getMemorySpace() const {
387   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
388     return rankedMemRefTy.getMemorySpace();
389   return cast<UnrankedMemRefType>().getMemorySpace();
390 }
391 
392 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
393   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
394     return rankedMemRefTy.getMemorySpaceAsInt();
395   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // MemRefType
400 //===----------------------------------------------------------------------===//
401 
402 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
403 /// `originalShape` with some `1` entries erased, return the set of indices
404 /// that specifies which of the entries of `originalShape` are dropped to obtain
405 /// `reducedShape`. The returned mask can be applied as a projection to
406 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
407 /// which dimensions must be kept when e.g. compute MemRef strides under
408 /// rank-reducing operations. Return None if reducedShape cannot be obtained
409 /// by dropping only `1` entries in `originalShape`.
410 llvm::Optional<llvm::SmallDenseSet<unsigned>>
411 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
412                                ArrayRef<int64_t> reducedShape) {
413   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
414   llvm::SmallDenseSet<unsigned> unusedDims;
415   unsigned reducedIdx = 0;
416   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
417     // Greedily insert `originalIdx` if match.
418     if (reducedIdx < reducedRank &&
419         originalShape[originalIdx] == reducedShape[reducedIdx]) {
420       reducedIdx++;
421       continue;
422     }
423 
424     unusedDims.insert(originalIdx);
425     // If no match on `originalIdx`, the `originalShape` at this dimension
426     // must be 1, otherwise we bail.
427     if (originalShape[originalIdx] != 1)
428       return llvm::None;
429   }
430   // The whole reducedShape must be scanned, otherwise we bail.
431   if (reducedIdx != reducedRank)
432     return llvm::None;
433   return unusedDims;
434 }
435 
436 SliceVerificationResult
437 mlir::isRankReducedType(ShapedType originalType,
438                         ShapedType candidateReducedType) {
439   if (originalType == candidateReducedType)
440     return SliceVerificationResult::Success;
441 
442   ShapedType originalShapedType = originalType.cast<ShapedType>();
443   ShapedType candidateReducedShapedType =
444       candidateReducedType.cast<ShapedType>();
445 
446   // Rank and size logic is valid for all ShapedTypes.
447   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
448   ArrayRef<int64_t> candidateReducedShape =
449       candidateReducedShapedType.getShape();
450   unsigned originalRank = originalShape.size(),
451            candidateReducedRank = candidateReducedShape.size();
452   if (candidateReducedRank > originalRank)
453     return SliceVerificationResult::RankTooLarge;
454 
455   auto optionalUnusedDimsMask =
456       computeRankReductionMask(originalShape, candidateReducedShape);
457 
458   // Sizes cannot be matched in case empty vector is returned.
459   if (!optionalUnusedDimsMask.hasValue())
460     return SliceVerificationResult::SizeMismatch;
461 
462   if (originalShapedType.getElementType() !=
463       candidateReducedShapedType.getElementType())
464     return SliceVerificationResult::ElemTypeMismatch;
465 
466   return SliceVerificationResult::Success;
467 }
468 
469 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
470   // Empty attribute is allowed as default memory space.
471   if (!memorySpace)
472     return true;
473 
474   // Supported built-in attributes.
475   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
476     return true;
477 
478   // Allow custom dialect attributes.
479   if (!isa<BuiltinDialect>(memorySpace.getDialect()))
480     return true;
481 
482   return false;
483 }
484 
485 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
486                                                MLIRContext *ctx) {
487   if (memorySpace == 0)
488     return nullptr;
489 
490   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
491 }
492 
493 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
494   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
495   if (intMemorySpace && intMemorySpace.getValue() == 0)
496     return nullptr;
497 
498   return memorySpace;
499 }
500 
501 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
502   if (!memorySpace)
503     return 0;
504 
505   assert(memorySpace.isa<IntegerAttr>() &&
506          "Using `getMemorySpaceInteger` with non-Integer attribute");
507 
508   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
509 }
510 
511 MemRefType::Builder &
512 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
513   memorySpace =
514       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
515   return *this;
516 }
517 
518 unsigned MemRefType::getMemorySpaceAsInt() const {
519   return detail::getMemorySpaceAsInt(getMemorySpace());
520 }
521 
522 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
523                            MemRefLayoutAttrInterface layout,
524                            Attribute memorySpace) {
525   // Use default layout for empty attribute.
526   if (!layout)
527     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
528         shape.size(), elementType.getContext()));
529 
530   // Drop default memory space value and replace it with empty attribute.
531   memorySpace = skipDefaultMemorySpace(memorySpace);
532 
533   return Base::get(elementType.getContext(), shape, elementType, layout,
534                    memorySpace);
535 }
536 
537 MemRefType MemRefType::getChecked(
538     function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
539     Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
540 
541   // Use default layout for empty attribute.
542   if (!layout)
543     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
544         shape.size(), elementType.getContext()));
545 
546   // Drop default memory space value and replace it with empty attribute.
547   memorySpace = skipDefaultMemorySpace(memorySpace);
548 
549   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
550                           elementType, layout, memorySpace);
551 }
552 
553 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
554                            AffineMap map, Attribute memorySpace) {
555 
556   // Use default layout for empty map.
557   if (!map)
558     map = AffineMap::getMultiDimIdentityMap(shape.size(),
559                                             elementType.getContext());
560 
561   // Wrap AffineMap into Attribute.
562   Attribute layout = AffineMapAttr::get(map);
563 
564   // Drop default memory space value and replace it with empty attribute.
565   memorySpace = skipDefaultMemorySpace(memorySpace);
566 
567   return Base::get(elementType.getContext(), shape, elementType, layout,
568                    memorySpace);
569 }
570 
571 MemRefType
572 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
573                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
574                        Attribute memorySpace) {
575 
576   // Use default layout for empty map.
577   if (!map)
578     map = AffineMap::getMultiDimIdentityMap(shape.size(),
579                                             elementType.getContext());
580 
581   // Wrap AffineMap into Attribute.
582   Attribute layout = AffineMapAttr::get(map);
583 
584   // Drop default memory space value and replace it with empty attribute.
585   memorySpace = skipDefaultMemorySpace(memorySpace);
586 
587   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
588                           elementType, layout, memorySpace);
589 }
590 
591 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
592                            AffineMap map, unsigned memorySpaceInd) {
593 
594   // Use default layout for empty map.
595   if (!map)
596     map = AffineMap::getMultiDimIdentityMap(shape.size(),
597                                             elementType.getContext());
598 
599   // Wrap AffineMap into Attribute.
600   Attribute layout = AffineMapAttr::get(map);
601 
602   // Convert deprecated integer-like memory space to Attribute.
603   Attribute memorySpace =
604       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
605 
606   return Base::get(elementType.getContext(), shape, elementType, layout,
607                    memorySpace);
608 }
609 
610 MemRefType
611 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
612                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
613                        unsigned memorySpaceInd) {
614 
615   // Use default layout for empty map.
616   if (!map)
617     map = AffineMap::getMultiDimIdentityMap(shape.size(),
618                                             elementType.getContext());
619 
620   // Wrap AffineMap into Attribute.
621   Attribute layout = AffineMapAttr::get(map);
622 
623   // Convert deprecated integer-like memory space to Attribute.
624   Attribute memorySpace =
625       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
626 
627   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
628                           elementType, layout, memorySpace);
629 }
630 
631 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
632                                  ArrayRef<int64_t> shape, Type elementType,
633                                  MemRefLayoutAttrInterface layout,
634                                  Attribute memorySpace) {
635   if (!BaseMemRefType::isValidElementType(elementType))
636     return emitError() << "invalid memref element type";
637 
638   // Negative sizes are not allowed except for `-1` that means dynamic size.
639   for (int64_t s : shape)
640     if (s < -1)
641       return emitError() << "invalid memref size";
642 
643   assert(layout && "missing layout specification");
644   if (failed(layout.verifyLayout(shape, emitError)))
645     return failure();
646 
647   if (!isSupportedMemorySpace(memorySpace))
648     return emitError() << "unsupported memory space Attribute";
649 
650   return success();
651 }
652 
653 void MemRefType::walkImmediateSubElements(
654     function_ref<void(Attribute)> walkAttrsFn,
655     function_ref<void(Type)> walkTypesFn) const {
656   walkTypesFn(getElementType());
657   if (!getLayout().isIdentity())
658     walkAttrsFn(getLayout());
659   walkAttrsFn(getMemorySpace());
660 }
661 
662 //===----------------------------------------------------------------------===//
663 // UnrankedMemRefType
664 //===----------------------------------------------------------------------===//
665 
666 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
667   return detail::getMemorySpaceAsInt(getMemorySpace());
668 }
669 
670 LogicalResult
671 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
672                            Type elementType, Attribute memorySpace) {
673   if (!BaseMemRefType::isValidElementType(elementType))
674     return emitError() << "invalid memref element type";
675 
676   if (!isSupportedMemorySpace(memorySpace))
677     return emitError() << "unsupported memory space Attribute";
678 
679   return success();
680 }
681 
682 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
683 // i.e. single term). Accumulate the AffineExpr into the existing one.
684 static void extractStridesFromTerm(AffineExpr e,
685                                    AffineExpr multiplicativeFactor,
686                                    MutableArrayRef<AffineExpr> strides,
687                                    AffineExpr &offset) {
688   if (auto dim = e.dyn_cast<AffineDimExpr>())
689     strides[dim.getPosition()] =
690         strides[dim.getPosition()] + multiplicativeFactor;
691   else
692     offset = offset + e * multiplicativeFactor;
693 }
694 
695 /// Takes a single AffineExpr `e` and populates the `strides` array with the
696 /// strides expressions for each dim position.
697 /// The convention is that the strides for dimensions d0, .. dn appear in
698 /// order to make indexing intuitive into the result.
699 static LogicalResult extractStrides(AffineExpr e,
700                                     AffineExpr multiplicativeFactor,
701                                     MutableArrayRef<AffineExpr> strides,
702                                     AffineExpr &offset) {
703   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
704   if (!bin) {
705     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
706     return success();
707   }
708 
709   if (bin.getKind() == AffineExprKind::CeilDiv ||
710       bin.getKind() == AffineExprKind::FloorDiv ||
711       bin.getKind() == AffineExprKind::Mod)
712     return failure();
713 
714   if (bin.getKind() == AffineExprKind::Mul) {
715     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
716     if (dim) {
717       strides[dim.getPosition()] =
718           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
719       return success();
720     }
721     // LHS and RHS may both contain complex expressions of dims. Try one path
722     // and if it fails try the other. This is guaranteed to succeed because
723     // only one path may have a `dim`, otherwise this is not an AffineExpr in
724     // the first place.
725     if (bin.getLHS().isSymbolicOrConstant())
726       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
727                             strides, offset);
728     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
729                           strides, offset);
730   }
731 
732   if (bin.getKind() == AffineExprKind::Add) {
733     auto res1 =
734         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
735     auto res2 =
736         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
737     return success(succeeded(res1) && succeeded(res2));
738   }
739 
740   llvm_unreachable("unexpected binary operation");
741 }
742 
743 LogicalResult mlir::getStridesAndOffset(MemRefType t,
744                                         SmallVectorImpl<AffineExpr> &strides,
745                                         AffineExpr &offset) {
746   AffineMap m = t.getLayout().getAffineMap();
747 
748   if (m.getNumResults() != 1 && !m.isIdentity())
749     return failure();
750 
751   auto zero = getAffineConstantExpr(0, t.getContext());
752   auto one = getAffineConstantExpr(1, t.getContext());
753   offset = zero;
754   strides.assign(t.getRank(), zero);
755 
756   // Canonical case for empty map.
757   if (m.isIdentity()) {
758     // 0-D corner case, offset is already 0.
759     if (t.getRank() == 0)
760       return success();
761     auto stridedExpr =
762         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
763     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
764       return success();
765     assert(false && "unexpected failure: extract strides in canonical layout");
766   }
767 
768   // Non-canonical case requires more work.
769   auto stridedExpr =
770       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
771   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
772     offset = AffineExpr();
773     strides.clear();
774     return failure();
775   }
776 
777   // Simplify results to allow folding to constants and simple checks.
778   unsigned numDims = m.getNumDims();
779   unsigned numSymbols = m.getNumSymbols();
780   offset = simplifyAffineExpr(offset, numDims, numSymbols);
781   for (auto &stride : strides)
782     stride = simplifyAffineExpr(stride, numDims, numSymbols);
783 
784   /// In practice, a strided memref must be internally non-aliasing. Test
785   /// against 0 as a proxy.
786   /// TODO: static cases can have more advanced checks.
787   /// TODO: dynamic cases would require a way to compare symbolic
788   /// expressions and would probably need an affine set context propagated
789   /// everywhere.
790   if (llvm::any_of(strides, [](AffineExpr e) {
791         return e == getAffineConstantExpr(0, e.getContext());
792       })) {
793     offset = AffineExpr();
794     strides.clear();
795     return failure();
796   }
797 
798   return success();
799 }
800 
801 LogicalResult mlir::getStridesAndOffset(MemRefType t,
802                                         SmallVectorImpl<int64_t> &strides,
803                                         int64_t &offset) {
804   AffineExpr offsetExpr;
805   SmallVector<AffineExpr, 4> strideExprs;
806   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
807     return failure();
808   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
809     offset = cst.getValue();
810   else
811     offset = ShapedType::kDynamicStrideOrOffset;
812   for (auto e : strideExprs) {
813     if (auto c = e.dyn_cast<AffineConstantExpr>())
814       strides.push_back(c.getValue());
815     else
816       strides.push_back(ShapedType::kDynamicStrideOrOffset);
817   }
818   return success();
819 }
820 
821 void UnrankedMemRefType::walkImmediateSubElements(
822     function_ref<void(Attribute)> walkAttrsFn,
823     function_ref<void(Type)> walkTypesFn) const {
824   walkTypesFn(getElementType());
825   walkAttrsFn(getMemorySpace());
826 }
827 
828 //===----------------------------------------------------------------------===//
829 /// TupleType
830 //===----------------------------------------------------------------------===//
831 
832 /// Return the elements types for this tuple.
833 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
834 
835 /// Accumulate the types contained in this tuple and tuples nested within it.
836 /// Note that this only flattens nested tuples, not any other container type,
837 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
838 /// (i32, tensor<i32>, f32, i64)
839 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
840   for (Type type : getTypes()) {
841     if (auto nestedTuple = type.dyn_cast<TupleType>())
842       nestedTuple.getFlattenedTypes(types);
843     else
844       types.push_back(type);
845   }
846 }
847 
848 /// Return the number of element types.
849 size_t TupleType::size() const { return getImpl()->size(); }
850 
851 void TupleType::walkImmediateSubElements(
852     function_ref<void(Attribute)> walkAttrsFn,
853     function_ref<void(Type)> walkTypesFn) const {
854   for (Type type : getTypes())
855     walkTypesFn(type);
856 }
857 
858 //===----------------------------------------------------------------------===//
859 // Type Utilities
860 //===----------------------------------------------------------------------===//
861 
862 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
863                                            int64_t offset,
864                                            MLIRContext *context) {
865   AffineExpr expr;
866   unsigned nSymbols = 0;
867 
868   // AffineExpr for offset.
869   // Static case.
870   if (offset != MemRefType::getDynamicStrideOrOffset()) {
871     auto cst = getAffineConstantExpr(offset, context);
872     expr = cst;
873   } else {
874     // Dynamic case, new symbol for the offset.
875     auto sym = getAffineSymbolExpr(nSymbols++, context);
876     expr = sym;
877   }
878 
879   // AffineExpr for strides.
880   for (const auto &en : llvm::enumerate(strides)) {
881     auto dim = en.index();
882     auto stride = en.value();
883     assert(stride != 0 && "Invalid stride specification");
884     auto d = getAffineDimExpr(dim, context);
885     AffineExpr mult;
886     // Static case.
887     if (stride != MemRefType::getDynamicStrideOrOffset())
888       mult = getAffineConstantExpr(stride, context);
889     else
890       // Dynamic case, new symbol for each new stride.
891       mult = getAffineSymbolExpr(nSymbols++, context);
892     expr = expr + d * mult;
893   }
894 
895   return AffineMap::get(strides.size(), nSymbols, expr);
896 }
897 
898 /// Return a version of `t` with identity layout if it can be determined
899 /// statically that the layout is the canonical contiguous strided layout.
900 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
901 /// `t` with simplified layout.
902 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
903 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
904   AffineMap m = t.getLayout().getAffineMap();
905 
906   // Already in canonical form.
907   if (m.isIdentity())
908     return t;
909 
910   // Can't reduce to canonical identity form, return in canonical form.
911   if (m.getNumResults() > 1)
912     return t;
913 
914   // Corner-case for 0-D affine maps.
915   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
916     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
917       if (cst.getValue() == 0)
918         return MemRefType::Builder(t).setLayout({});
919     return t;
920   }
921 
922   // 0-D corner case for empty shape that still have an affine map. Example:
923   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
924   // offset needs to remain, just return t.
925   if (t.getShape().empty())
926     return t;
927 
928   // If the canonical strided layout for the sizes of `t` is equal to the
929   // simplified layout of `t` we can just return an empty layout. Otherwise,
930   // just simplify the existing layout.
931   AffineExpr expr =
932       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
933   auto simplifiedLayoutExpr =
934       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
935   if (expr != simplifiedLayoutExpr)
936     return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
937         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
938   return MemRefType::Builder(t).setLayout({});
939 }
940 
941 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
942                                                 ArrayRef<AffineExpr> exprs,
943                                                 MLIRContext *context) {
944   assert(!sizes.empty() && !exprs.empty() &&
945          "expected non-empty sizes and exprs");
946 
947   // Size 0 corner case is useful for canonicalizations.
948   if (llvm::is_contained(sizes, 0))
949     return getAffineConstantExpr(0, context);
950 
951   auto maps = AffineMap::inferFromExprList(exprs);
952   assert(!maps.empty() && "Expected one non-empty map");
953   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
954 
955   AffineExpr expr;
956   bool dynamicPoisonBit = false;
957   int64_t runningSize = 1;
958   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
959     int64_t size = std::get<1>(en);
960     // Degenerate case, no size =-> no stride
961     if (size == 0)
962       continue;
963     AffineExpr dimExpr = std::get<0>(en);
964     AffineExpr stride = dynamicPoisonBit
965                             ? getAffineSymbolExpr(nSymbols++, context)
966                             : getAffineConstantExpr(runningSize, context);
967     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
968     if (size > 0) {
969       runningSize *= size;
970       assert(runningSize > 0 && "integer overflow in size computation");
971     } else {
972       dynamicPoisonBit = true;
973     }
974   }
975   return simplifyAffineExpr(expr, numDims, nSymbols);
976 }
977 
978 /// Return a version of `t` with a layout that has all dynamic offset and
979 /// strides. This is used to erase the static layout.
980 MemRefType mlir::eraseStridedLayout(MemRefType t) {
981   auto val = ShapedType::kDynamicStrideOrOffset;
982   return MemRefType::Builder(t).setLayout(
983       AffineMapAttr::get(makeStridedLinearLayoutMap(
984           SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
985 }
986 
987 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
988                                                 MLIRContext *context) {
989   SmallVector<AffineExpr, 4> exprs;
990   exprs.reserve(sizes.size());
991   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
992     exprs.push_back(getAffineDimExpr(dim, context));
993   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
994 }
995 
996 /// Return true if the layout for `t` is compatible with strided semantics.
997 bool mlir::isStrided(MemRefType t) {
998   int64_t offset;
999   SmallVector<int64_t, 4> strides;
1000   auto res = getStridesAndOffset(t, strides, offset);
1001   return succeeded(res);
1002 }
1003 
1004 /// Return the layout map in strided linear layout AffineMap form.
1005 /// Return null if the layout is not compatible with a strided layout.
1006 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
1007   int64_t offset;
1008   SmallVector<int64_t, 4> strides;
1009   if (failed(getStridesAndOffset(t, strides, offset)))
1010     return AffineMap();
1011   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1012 }
1013 
1014 /// Return the AffineExpr representation of the offset, assuming `memRefType`
1015 /// is a strided memref.
1016 static AffineExpr getOffsetExpr(MemRefType memrefType) {
1017   SmallVector<AffineExpr> strides;
1018   AffineExpr offset;
1019   if (failed(getStridesAndOffset(memrefType, strides, offset)))
1020     assert(false && "expected strided memref");
1021   return offset;
1022 }
1023 
1024 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1025 /// `offset` AffineExpr.
1026 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
1027                                                    ArrayRef<int64_t> shape,
1028                                                    Type elementType,
1029                                                    AffineExpr offset) {
1030   AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1031   AffineExpr contiguousRowMajor = canonical + offset;
1032   AffineMap contiguousRowMajorMap =
1033       AffineMap::inferFromExprList({contiguousRowMajor})[0];
1034   return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1035 }
1036 
1037 /// Helper determining if a memref is static-shape and contiguous-row-major
1038 /// layout, while still allowing for an arbitrary offset (any static or
1039 /// dynamic value).
1040 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1041   if (!memrefType.hasStaticShape())
1042     return false;
1043   AffineExpr offset = getOffsetExpr(memrefType);
1044   MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1045       memrefType.getContext(), memrefType.getShape(),
1046       memrefType.getElementType(), offset);
1047   return canonicalizeStridedLayout(memrefType) ==
1048          canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1049 }
1050