1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/FunctionInterfaces.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 /// Tablegen Type Definitions
31 //===----------------------------------------------------------------------===//
32 
33 #define GET_TYPEDEF_CLASSES
34 #include "mlir/IR/BuiltinTypes.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // BuiltinDialect
38 //===----------------------------------------------------------------------===//
39 
40 void BuiltinDialect::registerTypes() {
41   addTypes<
42 #define GET_TYPEDEF_LIST
43 #include "mlir/IR/BuiltinTypes.cpp.inc"
44       >();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 /// ComplexType
49 //===----------------------------------------------------------------------===//
50 
51 /// Verify the construction of an integer type.
52 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
53                                   Type elementType) {
54   if (!elementType.isIntOrFloat())
55     return emitError() << "invalid element type for complex";
56   return success();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // Integer Type
61 //===----------------------------------------------------------------------===//
62 
63 // static constexpr must have a definition (until in C++17 and inline variable).
64 constexpr unsigned IntegerType::kMaxWidth;
65 
66 /// Verify the construction of an integer type.
67 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
68                                   unsigned width,
69                                   SignednessSemantics signedness) {
70   if (width > IntegerType::kMaxWidth) {
71     return emitError() << "integer bitwidth is limited to "
72                        << IntegerType::kMaxWidth << " bits";
73   }
74   return success();
75 }
76 
77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78 
79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80   return getImpl()->signedness;
81 }
82 
83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84   if (!scale)
85     return IntegerType();
86   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Float Type
91 //===----------------------------------------------------------------------===//
92 
93 unsigned FloatType::getWidth() {
94   if (isa<Float16Type, BFloat16Type>())
95     return 16;
96   if (isa<Float32Type>())
97     return 32;
98   if (isa<Float64Type>())
99     return 64;
100   if (isa<Float80Type>())
101     return 80;
102   if (isa<Float128Type>())
103     return 128;
104   llvm_unreachable("unexpected float type");
105 }
106 
107 /// Returns the floating semantics for the given type.
108 const llvm::fltSemantics &FloatType::getFloatSemantics() {
109   if (isa<BFloat16Type>())
110     return APFloat::BFloat();
111   if (isa<Float16Type>())
112     return APFloat::IEEEhalf();
113   if (isa<Float32Type>())
114     return APFloat::IEEEsingle();
115   if (isa<Float64Type>())
116     return APFloat::IEEEdouble();
117   if (isa<Float80Type>())
118     return APFloat::x87DoubleExtended();
119   if (isa<Float128Type>())
120     return APFloat::IEEEquad();
121   llvm_unreachable("non-floating point type used");
122 }
123 
124 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
125   if (!scale)
126     return FloatType();
127   MLIRContext *ctx = getContext();
128   if (isF16() || isBF16()) {
129     if (scale == 2)
130       return FloatType::getF32(ctx);
131     if (scale == 4)
132       return FloatType::getF64(ctx);
133   }
134   if (isF32())
135     if (scale == 2)
136       return FloatType::getF64(ctx);
137   return FloatType();
138 }
139 
140 unsigned FloatType::getFPMantissaWidth() {
141   return APFloat::semanticsPrecision(getFloatSemantics());
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // FunctionType
146 //===----------------------------------------------------------------------===//
147 
148 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
149 
150 ArrayRef<Type> FunctionType::getInputs() const {
151   return getImpl()->getInputs();
152 }
153 
154 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
155 
156 ArrayRef<Type> FunctionType::getResults() const {
157   return getImpl()->getResults();
158 }
159 
160 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
161   return get(getContext(), inputs, results);
162 }
163 
164 /// Returns a new function type with the specified arguments and results
165 /// inserted.
166 FunctionType FunctionType::getWithArgsAndResults(
167     ArrayRef<unsigned> argIndices, TypeRange argTypes,
168     ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
169   SmallVector<Type> argStorage, resultStorage;
170   TypeRange newArgTypes = function_interface_impl::insertTypesInto(
171       getInputs(), argIndices, argTypes, argStorage);
172   TypeRange newResultTypes = function_interface_impl::insertTypesInto(
173       getResults(), resultIndices, resultTypes, resultStorage);
174   return clone(newArgTypes, newResultTypes);
175 }
176 
177 /// Returns a new function type without the specified arguments and results.
178 FunctionType
179 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
180                                        const BitVector &resultIndices) {
181   SmallVector<Type> argStorage, resultStorage;
182   TypeRange newArgTypes = function_interface_impl::filterTypesOut(
183       getInputs(), argIndices, argStorage);
184   TypeRange newResultTypes = function_interface_impl::filterTypesOut(
185       getResults(), resultIndices, resultStorage);
186   return clone(newArgTypes, newResultTypes);
187 }
188 
189 void FunctionType::walkImmediateSubElements(
190     function_ref<void(Attribute)> walkAttrsFn,
191     function_ref<void(Type)> walkTypesFn) const {
192   for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
193     walkTypesFn(type);
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // OpaqueType
198 //===----------------------------------------------------------------------===//
199 
200 /// Verify the construction of an opaque type.
201 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
202                                  StringAttr dialect, StringRef typeData) {
203   if (!Dialect::isValidNamespace(dialect.strref()))
204     return emitError() << "invalid dialect namespace '" << dialect << "'";
205 
206   // Check that the dialect is actually registered.
207   MLIRContext *context = dialect.getContext();
208   if (!context->allowsUnregisteredDialects() &&
209       !context->getLoadedDialect(dialect.strref())) {
210     return emitError()
211            << "`!" << dialect << "<\"" << typeData << "\">"
212            << "` type created with unregistered dialect. If this is "
213               "intended, please call allowUnregisteredDialects() on the "
214               "MLIRContext, or use -allow-unregistered-dialect with "
215               "the MLIR opt tool used";
216   }
217 
218   return success();
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // VectorType
223 //===----------------------------------------------------------------------===//
224 
225 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
226                                  ArrayRef<int64_t> shape, Type elementType,
227                                  unsigned numScalableDims) {
228   if (!isValidElementType(elementType))
229     return emitError()
230            << "vector elements must be int/index/float type but got "
231            << elementType;
232 
233   if (any_of(shape, [](int64_t i) { return i <= 0; }))
234     return emitError()
235            << "vector types must have positive constant sizes but got "
236            << shape;
237 
238   return success();
239 }
240 
241 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
242   if (!scale)
243     return VectorType();
244   if (auto et = getElementType().dyn_cast<IntegerType>())
245     if (auto scaledEt = et.scaleElementBitwidth(scale))
246       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
247   if (auto et = getElementType().dyn_cast<FloatType>())
248     if (auto scaledEt = et.scaleElementBitwidth(scale))
249       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
250   return VectorType();
251 }
252 
253 void VectorType::walkImmediateSubElements(
254     function_ref<void(Attribute)> walkAttrsFn,
255     function_ref<void(Type)> walkTypesFn) const {
256   walkTypesFn(getElementType());
257 }
258 
259 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
260                                  Type elementType) const {
261   return VectorType::get(shape.value_or(getShape()), elementType,
262                          getNumScalableDims());
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // TensorType
267 //===----------------------------------------------------------------------===//
268 
269 Type TensorType::getElementType() const {
270   return llvm::TypeSwitch<TensorType, Type>(*this)
271       .Case<RankedTensorType, UnrankedTensorType>(
272           [](auto type) { return type.getElementType(); });
273 }
274 
275 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
276 
277 ArrayRef<int64_t> TensorType::getShape() const {
278   return cast<RankedTensorType>().getShape();
279 }
280 
281 TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
282                                  Type elementType) const {
283   if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
284     if (shape)
285       return RankedTensorType::get(*shape, elementType);
286     return UnrankedTensorType::get(elementType);
287   }
288 
289   auto rankedTy = cast<RankedTensorType>();
290   if (!shape)
291     return RankedTensorType::get(rankedTy.getShape(), elementType,
292                                  rankedTy.getEncoding());
293   return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
294                                rankedTy.getEncoding());
295 }
296 
297 // Check if "elementType" can be an element type of a tensor.
298 static LogicalResult
299 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
300                        Type elementType) {
301   if (!TensorType::isValidElementType(elementType))
302     return emitError() << "invalid tensor element type: " << elementType;
303   return success();
304 }
305 
306 /// Return true if the specified element type is ok in a tensor.
307 bool TensorType::isValidElementType(Type type) {
308   // Note: Non standard/builtin types are allowed to exist within tensor
309   // types. Dialects are expected to verify that tensor types have a valid
310   // element type within that dialect.
311   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
312                   IndexType>() ||
313          !llvm::isa<BuiltinDialect>(type.getDialect());
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // RankedTensorType
318 //===----------------------------------------------------------------------===//
319 
320 LogicalResult
321 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
322                          ArrayRef<int64_t> shape, Type elementType,
323                          Attribute encoding) {
324   for (int64_t s : shape)
325     if (s < -1)
326       return emitError() << "invalid tensor dimension size";
327   if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
328     if (failed(v.verifyEncoding(shape, elementType, emitError)))
329       return failure();
330   return checkTensorElementType(emitError, elementType);
331 }
332 
333 void RankedTensorType::walkImmediateSubElements(
334     function_ref<void(Attribute)> walkAttrsFn,
335     function_ref<void(Type)> walkTypesFn) const {
336   walkTypesFn(getElementType());
337   if (Attribute encoding = getEncoding())
338     walkAttrsFn(encoding);
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // UnrankedTensorType
343 //===----------------------------------------------------------------------===//
344 
345 LogicalResult
346 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
347                            Type elementType) {
348   return checkTensorElementType(emitError, elementType);
349 }
350 
351 void UnrankedTensorType::walkImmediateSubElements(
352     function_ref<void(Attribute)> walkAttrsFn,
353     function_ref<void(Type)> walkTypesFn) const {
354   walkTypesFn(getElementType());
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // BaseMemRefType
359 //===----------------------------------------------------------------------===//
360 
361 Type BaseMemRefType::getElementType() const {
362   return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
363       .Case<MemRefType, UnrankedMemRefType>(
364           [](auto type) { return type.getElementType(); });
365 }
366 
367 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
368 
369 ArrayRef<int64_t> BaseMemRefType::getShape() const {
370   return cast<MemRefType>().getShape();
371 }
372 
373 BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape,
374                                          Type elementType) const {
375   if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
376     if (!shape)
377       return UnrankedMemRefType::get(elementType, getMemorySpace());
378     MemRefType::Builder builder(*shape, elementType);
379     builder.setMemorySpace(getMemorySpace());
380     return builder;
381   }
382 
383   MemRefType::Builder builder(cast<MemRefType>());
384   if (shape)
385     builder.setShape(*shape);
386   builder.setElementType(elementType);
387   return builder;
388 }
389 
390 Attribute BaseMemRefType::getMemorySpace() const {
391   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
392     return rankedMemRefTy.getMemorySpace();
393   return cast<UnrankedMemRefType>().getMemorySpace();
394 }
395 
396 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
397   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
398     return rankedMemRefTy.getMemorySpaceAsInt();
399   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
400 }
401 
402 //===----------------------------------------------------------------------===//
403 // MemRefType
404 //===----------------------------------------------------------------------===//
405 
406 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
407 /// `originalShape` with some `1` entries erased, return the set of indices
408 /// that specifies which of the entries of `originalShape` are dropped to obtain
409 /// `reducedShape`. The returned mask can be applied as a projection to
410 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
411 /// which dimensions must be kept when e.g. compute MemRef strides under
412 /// rank-reducing operations. Return None if reducedShape cannot be obtained
413 /// by dropping only `1` entries in `originalShape`.
414 llvm::Optional<llvm::SmallDenseSet<unsigned>>
415 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
416                                ArrayRef<int64_t> reducedShape) {
417   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
418   llvm::SmallDenseSet<unsigned> unusedDims;
419   unsigned reducedIdx = 0;
420   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
421     // Greedily insert `originalIdx` if match.
422     if (reducedIdx < reducedRank &&
423         originalShape[originalIdx] == reducedShape[reducedIdx]) {
424       reducedIdx++;
425       continue;
426     }
427 
428     unusedDims.insert(originalIdx);
429     // If no match on `originalIdx`, the `originalShape` at this dimension
430     // must be 1, otherwise we bail.
431     if (originalShape[originalIdx] != 1)
432       return llvm::None;
433   }
434   // The whole reducedShape must be scanned, otherwise we bail.
435   if (reducedIdx != reducedRank)
436     return llvm::None;
437   return unusedDims;
438 }
439 
440 SliceVerificationResult
441 mlir::isRankReducedType(ShapedType originalType,
442                         ShapedType candidateReducedType) {
443   if (originalType == candidateReducedType)
444     return SliceVerificationResult::Success;
445 
446   ShapedType originalShapedType = originalType.cast<ShapedType>();
447   ShapedType candidateReducedShapedType =
448       candidateReducedType.cast<ShapedType>();
449 
450   // Rank and size logic is valid for all ShapedTypes.
451   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
452   ArrayRef<int64_t> candidateReducedShape =
453       candidateReducedShapedType.getShape();
454   unsigned originalRank = originalShape.size(),
455            candidateReducedRank = candidateReducedShape.size();
456   if (candidateReducedRank > originalRank)
457     return SliceVerificationResult::RankTooLarge;
458 
459   auto optionalUnusedDimsMask =
460       computeRankReductionMask(originalShape, candidateReducedShape);
461 
462   // Sizes cannot be matched in case empty vector is returned.
463   if (!optionalUnusedDimsMask)
464     return SliceVerificationResult::SizeMismatch;
465 
466   if (originalShapedType.getElementType() !=
467       candidateReducedShapedType.getElementType())
468     return SliceVerificationResult::ElemTypeMismatch;
469 
470   return SliceVerificationResult::Success;
471 }
472 
473 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
474   // Empty attribute is allowed as default memory space.
475   if (!memorySpace)
476     return true;
477 
478   // Supported built-in attributes.
479   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
480     return true;
481 
482   // Allow custom dialect attributes.
483   if (!isa<BuiltinDialect>(memorySpace.getDialect()))
484     return true;
485 
486   return false;
487 }
488 
489 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
490                                                MLIRContext *ctx) {
491   if (memorySpace == 0)
492     return nullptr;
493 
494   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
495 }
496 
497 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
498   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
499   if (intMemorySpace && intMemorySpace.getValue() == 0)
500     return nullptr;
501 
502   return memorySpace;
503 }
504 
505 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
506   if (!memorySpace)
507     return 0;
508 
509   assert(memorySpace.isa<IntegerAttr>() &&
510          "Using `getMemorySpaceInteger` with non-Integer attribute");
511 
512   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
513 }
514 
515 MemRefType::Builder &
516 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
517   memorySpace =
518       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
519   return *this;
520 }
521 
522 unsigned MemRefType::getMemorySpaceAsInt() const {
523   return detail::getMemorySpaceAsInt(getMemorySpace());
524 }
525 
526 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
527                            MemRefLayoutAttrInterface layout,
528                            Attribute memorySpace) {
529   // Use default layout for empty attribute.
530   if (!layout)
531     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
532         shape.size(), elementType.getContext()));
533 
534   // Drop default memory space value and replace it with empty attribute.
535   memorySpace = skipDefaultMemorySpace(memorySpace);
536 
537   return Base::get(elementType.getContext(), shape, elementType, layout,
538                    memorySpace);
539 }
540 
541 MemRefType MemRefType::getChecked(
542     function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
543     Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
544 
545   // Use default layout for empty attribute.
546   if (!layout)
547     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
548         shape.size(), elementType.getContext()));
549 
550   // Drop default memory space value and replace it with empty attribute.
551   memorySpace = skipDefaultMemorySpace(memorySpace);
552 
553   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
554                           elementType, layout, memorySpace);
555 }
556 
557 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
558                            AffineMap map, Attribute memorySpace) {
559 
560   // Use default layout for empty map.
561   if (!map)
562     map = AffineMap::getMultiDimIdentityMap(shape.size(),
563                                             elementType.getContext());
564 
565   // Wrap AffineMap into Attribute.
566   Attribute layout = AffineMapAttr::get(map);
567 
568   // Drop default memory space value and replace it with empty attribute.
569   memorySpace = skipDefaultMemorySpace(memorySpace);
570 
571   return Base::get(elementType.getContext(), shape, elementType, layout,
572                    memorySpace);
573 }
574 
575 MemRefType
576 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
577                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
578                        Attribute memorySpace) {
579 
580   // Use default layout for empty map.
581   if (!map)
582     map = AffineMap::getMultiDimIdentityMap(shape.size(),
583                                             elementType.getContext());
584 
585   // Wrap AffineMap into Attribute.
586   Attribute layout = AffineMapAttr::get(map);
587 
588   // Drop default memory space value and replace it with empty attribute.
589   memorySpace = skipDefaultMemorySpace(memorySpace);
590 
591   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
592                           elementType, layout, memorySpace);
593 }
594 
595 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
596                            AffineMap map, unsigned memorySpaceInd) {
597 
598   // Use default layout for empty map.
599   if (!map)
600     map = AffineMap::getMultiDimIdentityMap(shape.size(),
601                                             elementType.getContext());
602 
603   // Wrap AffineMap into Attribute.
604   Attribute layout = AffineMapAttr::get(map);
605 
606   // Convert deprecated integer-like memory space to Attribute.
607   Attribute memorySpace =
608       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
609 
610   return Base::get(elementType.getContext(), shape, elementType, layout,
611                    memorySpace);
612 }
613 
614 MemRefType
615 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
616                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
617                        unsigned memorySpaceInd) {
618 
619   // Use default layout for empty map.
620   if (!map)
621     map = AffineMap::getMultiDimIdentityMap(shape.size(),
622                                             elementType.getContext());
623 
624   // Wrap AffineMap into Attribute.
625   Attribute layout = AffineMapAttr::get(map);
626 
627   // Convert deprecated integer-like memory space to Attribute.
628   Attribute memorySpace =
629       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
630 
631   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
632                           elementType, layout, memorySpace);
633 }
634 
635 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
636                                  ArrayRef<int64_t> shape, Type elementType,
637                                  MemRefLayoutAttrInterface layout,
638                                  Attribute memorySpace) {
639   if (!BaseMemRefType::isValidElementType(elementType))
640     return emitError() << "invalid memref element type";
641 
642   // Negative sizes are not allowed except for `-1` that means dynamic size.
643   for (int64_t s : shape)
644     if (s < -1)
645       return emitError() << "invalid memref size";
646 
647   assert(layout && "missing layout specification");
648   if (failed(layout.verifyLayout(shape, emitError)))
649     return failure();
650 
651   if (!isSupportedMemorySpace(memorySpace))
652     return emitError() << "unsupported memory space Attribute";
653 
654   return success();
655 }
656 
657 void MemRefType::walkImmediateSubElements(
658     function_ref<void(Attribute)> walkAttrsFn,
659     function_ref<void(Type)> walkTypesFn) const {
660   walkTypesFn(getElementType());
661   if (!getLayout().isIdentity())
662     walkAttrsFn(getLayout());
663   walkAttrsFn(getMemorySpace());
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // UnrankedMemRefType
668 //===----------------------------------------------------------------------===//
669 
670 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
671   return detail::getMemorySpaceAsInt(getMemorySpace());
672 }
673 
674 LogicalResult
675 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
676                            Type elementType, Attribute memorySpace) {
677   if (!BaseMemRefType::isValidElementType(elementType))
678     return emitError() << "invalid memref element type";
679 
680   if (!isSupportedMemorySpace(memorySpace))
681     return emitError() << "unsupported memory space Attribute";
682 
683   return success();
684 }
685 
686 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
687 // i.e. single term). Accumulate the AffineExpr into the existing one.
688 static void extractStridesFromTerm(AffineExpr e,
689                                    AffineExpr multiplicativeFactor,
690                                    MutableArrayRef<AffineExpr> strides,
691                                    AffineExpr &offset) {
692   if (auto dim = e.dyn_cast<AffineDimExpr>())
693     strides[dim.getPosition()] =
694         strides[dim.getPosition()] + multiplicativeFactor;
695   else
696     offset = offset + e * multiplicativeFactor;
697 }
698 
699 /// Takes a single AffineExpr `e` and populates the `strides` array with the
700 /// strides expressions for each dim position.
701 /// The convention is that the strides for dimensions d0, .. dn appear in
702 /// order to make indexing intuitive into the result.
703 static LogicalResult extractStrides(AffineExpr e,
704                                     AffineExpr multiplicativeFactor,
705                                     MutableArrayRef<AffineExpr> strides,
706                                     AffineExpr &offset) {
707   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
708   if (!bin) {
709     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
710     return success();
711   }
712 
713   if (bin.getKind() == AffineExprKind::CeilDiv ||
714       bin.getKind() == AffineExprKind::FloorDiv ||
715       bin.getKind() == AffineExprKind::Mod)
716     return failure();
717 
718   if (bin.getKind() == AffineExprKind::Mul) {
719     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
720     if (dim) {
721       strides[dim.getPosition()] =
722           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
723       return success();
724     }
725     // LHS and RHS may both contain complex expressions of dims. Try one path
726     // and if it fails try the other. This is guaranteed to succeed because
727     // only one path may have a `dim`, otherwise this is not an AffineExpr in
728     // the first place.
729     if (bin.getLHS().isSymbolicOrConstant())
730       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
731                             strides, offset);
732     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
733                           strides, offset);
734   }
735 
736   if (bin.getKind() == AffineExprKind::Add) {
737     auto res1 =
738         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
739     auto res2 =
740         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
741     return success(succeeded(res1) && succeeded(res2));
742   }
743 
744   llvm_unreachable("unexpected binary operation");
745 }
746 
747 LogicalResult mlir::getStridesAndOffset(MemRefType t,
748                                         SmallVectorImpl<AffineExpr> &strides,
749                                         AffineExpr &offset) {
750   AffineMap m = t.getLayout().getAffineMap();
751 
752   if (m.getNumResults() != 1 && !m.isIdentity())
753     return failure();
754 
755   auto zero = getAffineConstantExpr(0, t.getContext());
756   auto one = getAffineConstantExpr(1, t.getContext());
757   offset = zero;
758   strides.assign(t.getRank(), zero);
759 
760   // Canonical case for empty map.
761   if (m.isIdentity()) {
762     // 0-D corner case, offset is already 0.
763     if (t.getRank() == 0)
764       return success();
765     auto stridedExpr =
766         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
767     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
768       return success();
769     assert(false && "unexpected failure: extract strides in canonical layout");
770   }
771 
772   // Non-canonical case requires more work.
773   auto stridedExpr =
774       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
775   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
776     offset = AffineExpr();
777     strides.clear();
778     return failure();
779   }
780 
781   // Simplify results to allow folding to constants and simple checks.
782   unsigned numDims = m.getNumDims();
783   unsigned numSymbols = m.getNumSymbols();
784   offset = simplifyAffineExpr(offset, numDims, numSymbols);
785   for (auto &stride : strides)
786     stride = simplifyAffineExpr(stride, numDims, numSymbols);
787 
788   /// In practice, a strided memref must be internally non-aliasing. Test
789   /// against 0 as a proxy.
790   /// TODO: static cases can have more advanced checks.
791   /// TODO: dynamic cases would require a way to compare symbolic
792   /// expressions and would probably need an affine set context propagated
793   /// everywhere.
794   if (llvm::any_of(strides, [](AffineExpr e) {
795         return e == getAffineConstantExpr(0, e.getContext());
796       })) {
797     offset = AffineExpr();
798     strides.clear();
799     return failure();
800   }
801 
802   return success();
803 }
804 
805 LogicalResult mlir::getStridesAndOffset(MemRefType t,
806                                         SmallVectorImpl<int64_t> &strides,
807                                         int64_t &offset) {
808   AffineExpr offsetExpr;
809   SmallVector<AffineExpr, 4> strideExprs;
810   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
811     return failure();
812   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
813     offset = cst.getValue();
814   else
815     offset = ShapedType::kDynamicStrideOrOffset;
816   for (auto e : strideExprs) {
817     if (auto c = e.dyn_cast<AffineConstantExpr>())
818       strides.push_back(c.getValue());
819     else
820       strides.push_back(ShapedType::kDynamicStrideOrOffset);
821   }
822   return success();
823 }
824 
825 void UnrankedMemRefType::walkImmediateSubElements(
826     function_ref<void(Attribute)> walkAttrsFn,
827     function_ref<void(Type)> walkTypesFn) const {
828   walkTypesFn(getElementType());
829   walkAttrsFn(getMemorySpace());
830 }
831 
832 //===----------------------------------------------------------------------===//
833 /// TupleType
834 //===----------------------------------------------------------------------===//
835 
836 /// Return the elements types for this tuple.
837 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
838 
839 /// Accumulate the types contained in this tuple and tuples nested within it.
840 /// Note that this only flattens nested tuples, not any other container type,
841 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
842 /// (i32, tensor<i32>, f32, i64)
843 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
844   for (Type type : getTypes()) {
845     if (auto nestedTuple = type.dyn_cast<TupleType>())
846       nestedTuple.getFlattenedTypes(types);
847     else
848       types.push_back(type);
849   }
850 }
851 
852 /// Return the number of element types.
853 size_t TupleType::size() const { return getImpl()->size(); }
854 
855 void TupleType::walkImmediateSubElements(
856     function_ref<void(Attribute)> walkAttrsFn,
857     function_ref<void(Type)> walkTypesFn) const {
858   for (Type type : getTypes())
859     walkTypesFn(type);
860 }
861 
862 //===----------------------------------------------------------------------===//
863 // Type Utilities
864 //===----------------------------------------------------------------------===//
865 
866 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
867                                            int64_t offset,
868                                            MLIRContext *context) {
869   AffineExpr expr;
870   unsigned nSymbols = 0;
871 
872   // AffineExpr for offset.
873   // Static case.
874   if (offset != MemRefType::getDynamicStrideOrOffset()) {
875     auto cst = getAffineConstantExpr(offset, context);
876     expr = cst;
877   } else {
878     // Dynamic case, new symbol for the offset.
879     auto sym = getAffineSymbolExpr(nSymbols++, context);
880     expr = sym;
881   }
882 
883   // AffineExpr for strides.
884   for (const auto &en : llvm::enumerate(strides)) {
885     auto dim = en.index();
886     auto stride = en.value();
887     assert(stride != 0 && "Invalid stride specification");
888     auto d = getAffineDimExpr(dim, context);
889     AffineExpr mult;
890     // Static case.
891     if (stride != MemRefType::getDynamicStrideOrOffset())
892       mult = getAffineConstantExpr(stride, context);
893     else
894       // Dynamic case, new symbol for each new stride.
895       mult = getAffineSymbolExpr(nSymbols++, context);
896     expr = expr + d * mult;
897   }
898 
899   return AffineMap::get(strides.size(), nSymbols, expr);
900 }
901 
902 /// Return a version of `t` with identity layout if it can be determined
903 /// statically that the layout is the canonical contiguous strided layout.
904 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
905 /// `t` with simplified layout.
906 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
907 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
908   AffineMap m = t.getLayout().getAffineMap();
909 
910   // Already in canonical form.
911   if (m.isIdentity())
912     return t;
913 
914   // Can't reduce to canonical identity form, return in canonical form.
915   if (m.getNumResults() > 1)
916     return t;
917 
918   // Corner-case for 0-D affine maps.
919   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
920     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
921       if (cst.getValue() == 0)
922         return MemRefType::Builder(t).setLayout({});
923     return t;
924   }
925 
926   // 0-D corner case for empty shape that still have an affine map. Example:
927   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
928   // offset needs to remain, just return t.
929   if (t.getShape().empty())
930     return t;
931 
932   // If the canonical strided layout for the sizes of `t` is equal to the
933   // simplified layout of `t` we can just return an empty layout. Otherwise,
934   // just simplify the existing layout.
935   AffineExpr expr =
936       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
937   auto simplifiedLayoutExpr =
938       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
939   if (expr != simplifiedLayoutExpr)
940     return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
941         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
942   return MemRefType::Builder(t).setLayout({});
943 }
944 
945 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
946                                                 ArrayRef<AffineExpr> exprs,
947                                                 MLIRContext *context) {
948   // Size 0 corner case is useful for canonicalizations.
949   if (sizes.empty() || llvm::is_contained(sizes, 0))
950     return getAffineConstantExpr(0, context);
951 
952   assert(!exprs.empty() && "expected exprs");
953   auto maps = AffineMap::inferFromExprList(exprs);
954   assert(!maps.empty() && "Expected one non-empty map");
955   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
956 
957   AffineExpr expr;
958   bool dynamicPoisonBit = false;
959   int64_t runningSize = 1;
960   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
961     int64_t size = std::get<1>(en);
962     // Degenerate case, no size =-> no stride
963     if (size == 0)
964       continue;
965     AffineExpr dimExpr = std::get<0>(en);
966     AffineExpr stride = dynamicPoisonBit
967                             ? getAffineSymbolExpr(nSymbols++, context)
968                             : getAffineConstantExpr(runningSize, context);
969     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
970     if (size > 0) {
971       runningSize *= size;
972       assert(runningSize > 0 && "integer overflow in size computation");
973     } else {
974       dynamicPoisonBit = true;
975     }
976   }
977   return simplifyAffineExpr(expr, numDims, nSymbols);
978 }
979 
980 /// Return a version of `t` with a layout that has all dynamic offset and
981 /// strides. This is used to erase the static layout.
982 MemRefType mlir::eraseStridedLayout(MemRefType t) {
983   auto val = ShapedType::kDynamicStrideOrOffset;
984   return MemRefType::Builder(t).setLayout(
985       AffineMapAttr::get(makeStridedLinearLayoutMap(
986           SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
987 }
988 
989 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
990                                                 MLIRContext *context) {
991   SmallVector<AffineExpr, 4> exprs;
992   exprs.reserve(sizes.size());
993   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
994     exprs.push_back(getAffineDimExpr(dim, context));
995   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
996 }
997 
998 /// Return true if the layout for `t` is compatible with strided semantics.
999 bool mlir::isStrided(MemRefType t) {
1000   int64_t offset;
1001   SmallVector<int64_t, 4> strides;
1002   auto res = getStridesAndOffset(t, strides, offset);
1003   return succeeded(res);
1004 }
1005 
1006 /// Return the layout map in strided linear layout AffineMap form.
1007 /// Return null if the layout is not compatible with a strided layout.
1008 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
1009   int64_t offset;
1010   SmallVector<int64_t, 4> strides;
1011   if (failed(getStridesAndOffset(t, strides, offset)))
1012     return AffineMap();
1013   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1014 }
1015 
1016 /// Return the AffineExpr representation of the offset, assuming `memRefType`
1017 /// is a strided memref.
1018 static AffineExpr getOffsetExpr(MemRefType memrefType) {
1019   SmallVector<AffineExpr> strides;
1020   AffineExpr offset;
1021   if (failed(getStridesAndOffset(memrefType, strides, offset)))
1022     assert(false && "expected strided memref");
1023   return offset;
1024 }
1025 
1026 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1027 /// `offset` AffineExpr.
1028 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
1029                                                    ArrayRef<int64_t> shape,
1030                                                    Type elementType,
1031                                                    AffineExpr offset) {
1032   AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1033   AffineExpr contiguousRowMajor = canonical + offset;
1034   AffineMap contiguousRowMajorMap =
1035       AffineMap::inferFromExprList({contiguousRowMajor})[0];
1036   return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1037 }
1038 
1039 /// Helper determining if a memref is static-shape and contiguous-row-major
1040 /// layout, while still allowing for an arbitrary offset (any static or
1041 /// dynamic value).
1042 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1043   if (!memrefType.hasStaticShape())
1044     return false;
1045   AffineExpr offset = getOffsetExpr(memrefType);
1046   MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1047       memrefType.getContext(), memrefType.getShape(),
1048       memrefType.getElementType(), offset);
1049   return canonicalizeStridedLayout(memrefType) ==
1050          canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1051 }
1052