1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/TensorEncoding.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::detail;
27 
28 //===----------------------------------------------------------------------===//
29 /// Tablegen Type Definitions
30 //===----------------------------------------------------------------------===//
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34 
35 //===----------------------------------------------------------------------===//
36 /// Tablegen Interface Definitions
37 //===----------------------------------------------------------------------===//
38 
39 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // BuiltinDialect
43 //===----------------------------------------------------------------------===//
44 
45 void BuiltinDialect::registerTypes() {
46   addTypes<
47 #define GET_TYPEDEF_LIST
48 #include "mlir/IR/BuiltinTypes.cpp.inc"
49       >();
50 }
51 
52 //===----------------------------------------------------------------------===//
53 /// ComplexType
54 //===----------------------------------------------------------------------===//
55 
56 /// Verify the construction of an integer type.
57 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
58                                   Type elementType) {
59   if (!elementType.isIntOrFloat())
60     return emitError() << "invalid element type for complex";
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // Integer Type
66 //===----------------------------------------------------------------------===//
67 
68 // static constexpr must have a definition (until in C++17 and inline variable).
69 constexpr unsigned IntegerType::kMaxWidth;
70 
71 /// Verify the construction of an integer type.
72 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
73                                   unsigned width,
74                                   SignednessSemantics signedness) {
75   if (width > IntegerType::kMaxWidth) {
76     return emitError() << "integer bitwidth is limited to "
77                        << IntegerType::kMaxWidth << " bits";
78   }
79   return success();
80 }
81 
82 unsigned IntegerType::getWidth() const { return getImpl()->width; }
83 
84 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
85   return getImpl()->signedness;
86 }
87 
88 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
89   if (!scale)
90     return IntegerType();
91   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Float Type
96 //===----------------------------------------------------------------------===//
97 
98 unsigned FloatType::getWidth() {
99   if (isa<Float16Type, BFloat16Type>())
100     return 16;
101   if (isa<Float32Type>())
102     return 32;
103   if (isa<Float64Type>())
104     return 64;
105   if (isa<Float80Type>())
106     return 80;
107   if (isa<Float128Type>())
108     return 128;
109   llvm_unreachable("unexpected float type");
110 }
111 
112 /// Returns the floating semantics for the given type.
113 const llvm::fltSemantics &FloatType::getFloatSemantics() {
114   if (isa<BFloat16Type>())
115     return APFloat::BFloat();
116   if (isa<Float16Type>())
117     return APFloat::IEEEhalf();
118   if (isa<Float32Type>())
119     return APFloat::IEEEsingle();
120   if (isa<Float64Type>())
121     return APFloat::IEEEdouble();
122   if (isa<Float80Type>())
123     return APFloat::x87DoubleExtended();
124   if (isa<Float128Type>())
125     return APFloat::IEEEquad();
126   llvm_unreachable("non-floating point type used");
127 }
128 
129 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
130   if (!scale)
131     return FloatType();
132   MLIRContext *ctx = getContext();
133   if (isF16() || isBF16()) {
134     if (scale == 2)
135       return FloatType::getF32(ctx);
136     if (scale == 4)
137       return FloatType::getF64(ctx);
138   }
139   if (isF32())
140     if (scale == 2)
141       return FloatType::getF64(ctx);
142   return FloatType();
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // FunctionType
147 //===----------------------------------------------------------------------===//
148 
149 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
150 
151 ArrayRef<Type> FunctionType::getInputs() const {
152   return getImpl()->getInputs();
153 }
154 
155 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
156 
157 ArrayRef<Type> FunctionType::getResults() const {
158   return getImpl()->getResults();
159 }
160 
161 /// Helper to call a callback once on each index in the range
162 /// [0, `totalIndices`), *except* for the indices given in `indices`.
163 /// `indices` is allowed to have duplicates and can be in any order.
164 inline void iterateIndicesExcept(unsigned totalIndices,
165                                  ArrayRef<unsigned> indices,
166                                  function_ref<void(unsigned)> callback) {
167   llvm::BitVector skipIndices(totalIndices);
168   for (unsigned i : indices)
169     skipIndices.set(i);
170 
171   for (unsigned i = 0; i < totalIndices; ++i)
172     if (!skipIndices.test(i))
173       callback(i);
174 }
175 
176 /// Returns a new function type with the specified arguments and results
177 /// inserted.
178 FunctionType FunctionType::getWithArgsAndResults(
179     ArrayRef<unsigned> argIndices, TypeRange argTypes,
180     ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
181   assert(argIndices.size() == argTypes.size());
182   assert(resultIndices.size() == resultTypes.size());
183 
184   ArrayRef<Type> newInputTypes = getInputs();
185   SmallVector<Type, 4> newInputTypesBuffer;
186   if (!argIndices.empty()) {
187     const auto *fromIt = newInputTypes.begin();
188     for (auto it : llvm::zip(argIndices, argTypes)) {
189       const auto *toIt = newInputTypes.begin() + std::get<0>(it);
190       newInputTypesBuffer.append(fromIt, toIt);
191       newInputTypesBuffer.push_back(std::get<1>(it));
192       fromIt = toIt;
193     }
194     newInputTypesBuffer.append(fromIt, newInputTypes.end());
195     newInputTypes = newInputTypesBuffer;
196   }
197 
198   ArrayRef<Type> newResultTypes = getResults();
199   SmallVector<Type, 4> newResultTypesBuffer;
200   if (!resultIndices.empty()) {
201     const auto *fromIt = newResultTypes.begin();
202     for (auto it : llvm::zip(resultIndices, resultTypes)) {
203       const auto *toIt = newResultTypes.begin() + std::get<0>(it);
204       newResultTypesBuffer.append(fromIt, toIt);
205       newResultTypesBuffer.push_back(std::get<1>(it));
206       fromIt = toIt;
207     }
208     newResultTypesBuffer.append(fromIt, newResultTypes.end());
209     newResultTypes = newResultTypesBuffer;
210   }
211 
212   return FunctionType::get(getContext(), newInputTypes, newResultTypes);
213 }
214 
215 /// Returns a new function type without the specified arguments and results.
216 FunctionType
217 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
218                                        ArrayRef<unsigned> resultIndices) {
219   ArrayRef<Type> newInputTypes = getInputs();
220   SmallVector<Type, 4> newInputTypesBuffer;
221   if (!argIndices.empty()) {
222     unsigned originalNumArgs = getNumInputs();
223     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
224       newInputTypesBuffer.emplace_back(getInput(i));
225     });
226     newInputTypes = newInputTypesBuffer;
227   }
228 
229   ArrayRef<Type> newResultTypes = getResults();
230   SmallVector<Type, 4> newResultTypesBuffer;
231   if (!resultIndices.empty()) {
232     unsigned originalNumResults = getNumResults();
233     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
234       newResultTypesBuffer.emplace_back(getResult(i));
235     });
236     newResultTypes = newResultTypesBuffer;
237   }
238 
239   return get(getContext(), newInputTypes, newResultTypes);
240 }
241 
242 void FunctionType::walkImmediateSubElements(
243     function_ref<void(Attribute)> walkAttrsFn,
244     function_ref<void(Type)> walkTypesFn) const {
245   for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
246     walkTypesFn(type);
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // OpaqueType
251 //===----------------------------------------------------------------------===//
252 
253 /// Verify the construction of an opaque type.
254 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
255                                  StringAttr dialect, StringRef typeData) {
256   if (!Dialect::isValidNamespace(dialect.strref()))
257     return emitError() << "invalid dialect namespace '" << dialect << "'";
258 
259   // Check that the dialect is actually registered.
260   MLIRContext *context = dialect.getContext();
261   if (!context->allowsUnregisteredDialects() &&
262       !context->getLoadedDialect(dialect.strref())) {
263     return emitError()
264            << "`!" << dialect << "<\"" << typeData << "\">"
265            << "` type created with unregistered dialect. If this is "
266               "intended, please call allowUnregisteredDialects() on the "
267               "MLIRContext, or use -allow-unregistered-dialect with "
268               "the MLIR opt tool used";
269   }
270 
271   return success();
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // ShapedType
276 //===----------------------------------------------------------------------===//
277 constexpr int64_t ShapedType::kDynamicSize;
278 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
279 
280 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
281   if (auto other = dyn_cast<MemRefType>()) {
282     MemRefType::Builder b(other);
283     b.setShape(shape);
284     b.setElementType(elementType);
285     return b;
286   }
287 
288   if (auto other = dyn_cast<UnrankedMemRefType>()) {
289     MemRefType::Builder b(shape, elementType);
290     b.setMemorySpace(other.getMemorySpace());
291     return b;
292   }
293 
294   if (isa<TensorType>())
295     return RankedTensorType::get(shape, elementType);
296 
297   if (isa<VectorType>())
298     return VectorType::get(shape, elementType);
299 
300   llvm_unreachable("Unhandled ShapedType clone case");
301 }
302 
303 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
304   if (auto other = dyn_cast<MemRefType>()) {
305     MemRefType::Builder b(other);
306     b.setShape(shape);
307     return b;
308   }
309 
310   if (auto other = dyn_cast<UnrankedMemRefType>()) {
311     MemRefType::Builder b(shape, other.getElementType());
312     b.setShape(shape);
313     b.setMemorySpace(other.getMemorySpace());
314     return b;
315   }
316 
317   if (isa<TensorType>())
318     return RankedTensorType::get(shape, getElementType());
319 
320   if (isa<VectorType>())
321     return VectorType::get(shape, getElementType());
322 
323   llvm_unreachable("Unhandled ShapedType clone case");
324 }
325 
326 ShapedType ShapedType::clone(Type elementType) {
327   if (auto other = dyn_cast<MemRefType>()) {
328     MemRefType::Builder b(other);
329     b.setElementType(elementType);
330     return b;
331   }
332 
333   if (auto other = dyn_cast<UnrankedMemRefType>()) {
334     return UnrankedMemRefType::get(elementType, other.getMemorySpace());
335   }
336 
337   if (isa<TensorType>()) {
338     if (hasRank())
339       return RankedTensorType::get(getShape(), elementType);
340     return UnrankedTensorType::get(elementType);
341   }
342 
343   if (isa<VectorType>())
344     return VectorType::get(getShape(), elementType);
345 
346   llvm_unreachable("Unhandled ShapedType clone hit");
347 }
348 
349 Type ShapedType::getElementType() const {
350   return TypeSwitch<Type, Type>(*this)
351       .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
352             UnrankedMemRefType>([](auto ty) { return ty.getElementType(); });
353 }
354 
355 unsigned ShapedType::getElementTypeBitWidth() const {
356   return getElementType().getIntOrFloatBitWidth();
357 }
358 
359 int64_t ShapedType::getNumElements() const {
360   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
361   auto shape = getShape();
362   int64_t num = 1;
363   for (auto dim : shape) {
364     num *= dim;
365     assert(num >= 0 && "integer overflow in element count computation");
366   }
367   return num;
368 }
369 
370 int64_t ShapedType::getRank() const {
371   assert(hasRank() && "cannot query rank of unranked shaped type");
372   return getShape().size();
373 }
374 
375 bool ShapedType::hasRank() const {
376   return !isa<UnrankedMemRefType, UnrankedTensorType>();
377 }
378 
379 int64_t ShapedType::getDimSize(unsigned idx) const {
380   assert(idx < getRank() && "invalid index for shaped type");
381   return getShape()[idx];
382 }
383 
384 bool ShapedType::isDynamicDim(unsigned idx) const {
385   assert(idx < getRank() && "invalid index for shaped type");
386   return isDynamic(getShape()[idx]);
387 }
388 
389 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
390   assert(index < getRank() && "invalid index");
391   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
392   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
393 }
394 
395 /// Get the number of bits require to store a value of the given shaped type.
396 /// Compute the value recursively since tensors are allowed to have vectors as
397 /// elements.
398 int64_t ShapedType::getSizeInBits() const {
399   assert(hasStaticShape() &&
400          "cannot get the bit size of an aggregate with a dynamic shape");
401 
402   auto elementType = getElementType();
403   if (elementType.isIntOrFloat())
404     return elementType.getIntOrFloatBitWidth() * getNumElements();
405 
406   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
407     elementType = complexType.getElementType();
408     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
409   }
410 
411   // Tensors can have vectors and other tensors as elements, other shaped types
412   // cannot.
413   assert(isa<TensorType>() && "unsupported element type");
414   assert((elementType.isa<VectorType, TensorType>()) &&
415          "unsupported tensor element type");
416   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
417 }
418 
419 ArrayRef<int64_t> ShapedType::getShape() const {
420   if (auto vectorType = dyn_cast<VectorType>())
421     return vectorType.getShape();
422   if (auto tensorType = dyn_cast<RankedTensorType>())
423     return tensorType.getShape();
424   return cast<MemRefType>().getShape();
425 }
426 
427 int64_t ShapedType::getNumDynamicDims() const {
428   return llvm::count_if(getShape(), isDynamic);
429 }
430 
431 bool ShapedType::hasStaticShape() const {
432   return hasRank() && llvm::none_of(getShape(), isDynamic);
433 }
434 
435 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
436   return hasStaticShape() && getShape() == shape;
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // VectorType
441 //===----------------------------------------------------------------------===//
442 
443 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
444                                  ArrayRef<int64_t> shape, Type elementType) {
445   if (!isValidElementType(elementType))
446     return emitError()
447            << "vector elements must be int/index/float type but got "
448            << elementType;
449 
450   if (any_of(shape, [](int64_t i) { return i <= 0; }))
451     return emitError()
452            << "vector types must have positive constant sizes but got "
453            << shape;
454 
455   return success();
456 }
457 
458 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
459   if (!scale)
460     return VectorType();
461   if (auto et = getElementType().dyn_cast<IntegerType>())
462     if (auto scaledEt = et.scaleElementBitwidth(scale))
463       return VectorType::get(getShape(), scaledEt);
464   if (auto et = getElementType().dyn_cast<FloatType>())
465     if (auto scaledEt = et.scaleElementBitwidth(scale))
466       return VectorType::get(getShape(), scaledEt);
467   return VectorType();
468 }
469 
470 void VectorType::walkImmediateSubElements(
471     function_ref<void(Attribute)> walkAttrsFn,
472     function_ref<void(Type)> walkTypesFn) const {
473   walkTypesFn(getElementType());
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // TensorType
478 //===----------------------------------------------------------------------===//
479 
480 // Check if "elementType" can be an element type of a tensor.
481 static LogicalResult
482 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
483                        Type elementType) {
484   if (!TensorType::isValidElementType(elementType))
485     return emitError() << "invalid tensor element type: " << elementType;
486   return success();
487 }
488 
489 /// Return true if the specified element type is ok in a tensor.
490 bool TensorType::isValidElementType(Type type) {
491   // Note: Non standard/builtin types are allowed to exist within tensor
492   // types. Dialects are expected to verify that tensor types have a valid
493   // element type within that dialect.
494   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
495                   IndexType>() ||
496          !llvm::isa<BuiltinDialect>(type.getDialect());
497 }
498 
499 //===----------------------------------------------------------------------===//
500 // RankedTensorType
501 //===----------------------------------------------------------------------===//
502 
503 LogicalResult
504 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
505                          ArrayRef<int64_t> shape, Type elementType,
506                          Attribute encoding) {
507   for (int64_t s : shape)
508     if (s < -1)
509       return emitError() << "invalid tensor dimension size";
510   if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
511     if (failed(v.verifyEncoding(shape, elementType, emitError)))
512       return failure();
513   return checkTensorElementType(emitError, elementType);
514 }
515 
516 void RankedTensorType::walkImmediateSubElements(
517     function_ref<void(Attribute)> walkAttrsFn,
518     function_ref<void(Type)> walkTypesFn) const {
519   walkTypesFn(getElementType());
520   if (Attribute encoding = getEncoding())
521     walkAttrsFn(encoding);
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // UnrankedTensorType
526 //===----------------------------------------------------------------------===//
527 
528 LogicalResult
529 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
530                            Type elementType) {
531   return checkTensorElementType(emitError, elementType);
532 }
533 
534 void UnrankedTensorType::walkImmediateSubElements(
535     function_ref<void(Attribute)> walkAttrsFn,
536     function_ref<void(Type)> walkTypesFn) const {
537   walkTypesFn(getElementType());
538 }
539 
540 //===----------------------------------------------------------------------===//
541 // BaseMemRefType
542 //===----------------------------------------------------------------------===//
543 
544 Attribute BaseMemRefType::getMemorySpace() const {
545   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
546     return rankedMemRefTy.getMemorySpace();
547   return cast<UnrankedMemRefType>().getMemorySpace();
548 }
549 
550 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
551   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
552     return rankedMemRefTy.getMemorySpaceAsInt();
553   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // MemRefType
558 //===----------------------------------------------------------------------===//
559 
560 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
561 /// `originalShape` with some `1` entries erased, return the set of indices
562 /// that specifies which of the entries of `originalShape` are dropped to obtain
563 /// `reducedShape`. The returned mask can be applied as a projection to
564 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
565 /// which dimensions must be kept when e.g. compute MemRef strides under
566 /// rank-reducing operations. Return None if reducedShape cannot be obtained
567 /// by dropping only `1` entries in `originalShape`.
568 llvm::Optional<llvm::SmallDenseSet<unsigned>>
569 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
570                                ArrayRef<int64_t> reducedShape) {
571   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
572   llvm::SmallDenseSet<unsigned> unusedDims;
573   unsigned reducedIdx = 0;
574   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
575     // Greedily insert `originalIdx` if match.
576     if (reducedIdx < reducedRank &&
577         originalShape[originalIdx] == reducedShape[reducedIdx]) {
578       reducedIdx++;
579       continue;
580     }
581 
582     unusedDims.insert(originalIdx);
583     // If no match on `originalIdx`, the `originalShape` at this dimension
584     // must be 1, otherwise we bail.
585     if (originalShape[originalIdx] != 1)
586       return llvm::None;
587   }
588   // The whole reducedShape must be scanned, otherwise we bail.
589   if (reducedIdx != reducedRank)
590     return llvm::None;
591   return unusedDims;
592 }
593 
594 SliceVerificationResult
595 mlir::isRankReducedType(ShapedType originalType,
596                         ShapedType candidateReducedType) {
597   if (originalType == candidateReducedType)
598     return SliceVerificationResult::Success;
599 
600   ShapedType originalShapedType = originalType.cast<ShapedType>();
601   ShapedType candidateReducedShapedType =
602       candidateReducedType.cast<ShapedType>();
603 
604   // Rank and size logic is valid for all ShapedTypes.
605   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
606   ArrayRef<int64_t> candidateReducedShape =
607       candidateReducedShapedType.getShape();
608   unsigned originalRank = originalShape.size(),
609            candidateReducedRank = candidateReducedShape.size();
610   if (candidateReducedRank > originalRank)
611     return SliceVerificationResult::RankTooLarge;
612 
613   auto optionalUnusedDimsMask =
614       computeRankReductionMask(originalShape, candidateReducedShape);
615 
616   // Sizes cannot be matched in case empty vector is returned.
617   if (!optionalUnusedDimsMask.hasValue())
618     return SliceVerificationResult::SizeMismatch;
619 
620   if (originalShapedType.getElementType() !=
621       candidateReducedShapedType.getElementType())
622     return SliceVerificationResult::ElemTypeMismatch;
623 
624   return SliceVerificationResult::Success;
625 }
626 
627 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
628   // Empty attribute is allowed as default memory space.
629   if (!memorySpace)
630     return true;
631 
632   // Supported built-in attributes.
633   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
634     return true;
635 
636   // Allow custom dialect attributes.
637   if (!isa<BuiltinDialect>(memorySpace.getDialect()))
638     return true;
639 
640   return false;
641 }
642 
643 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
644                                                MLIRContext *ctx) {
645   if (memorySpace == 0)
646     return nullptr;
647 
648   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
649 }
650 
651 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
652   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
653   if (intMemorySpace && intMemorySpace.getValue() == 0)
654     return nullptr;
655 
656   return memorySpace;
657 }
658 
659 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
660   if (!memorySpace)
661     return 0;
662 
663   assert(memorySpace.isa<IntegerAttr>() &&
664          "Using `getMemorySpaceInteger` with non-Integer attribute");
665 
666   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
667 }
668 
669 MemRefType::Builder &
670 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
671   memorySpace =
672       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
673   return *this;
674 }
675 
676 unsigned MemRefType::getMemorySpaceAsInt() const {
677   return detail::getMemorySpaceAsInt(getMemorySpace());
678 }
679 
680 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
681                            MemRefLayoutAttrInterface layout,
682                            Attribute memorySpace) {
683   // Use default layout for empty attribute.
684   if (!layout)
685     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
686         shape.size(), elementType.getContext()));
687 
688   // Drop default memory space value and replace it with empty attribute.
689   memorySpace = skipDefaultMemorySpace(memorySpace);
690 
691   return Base::get(elementType.getContext(), shape, elementType, layout,
692                    memorySpace);
693 }
694 
695 MemRefType MemRefType::getChecked(
696     function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
697     Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
698 
699   // Use default layout for empty attribute.
700   if (!layout)
701     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
702         shape.size(), elementType.getContext()));
703 
704   // Drop default memory space value and replace it with empty attribute.
705   memorySpace = skipDefaultMemorySpace(memorySpace);
706 
707   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
708                           elementType, layout, memorySpace);
709 }
710 
711 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
712                            AffineMap map, Attribute memorySpace) {
713 
714   // Use default layout for empty map.
715   if (!map)
716     map = AffineMap::getMultiDimIdentityMap(shape.size(),
717                                             elementType.getContext());
718 
719   // Wrap AffineMap into Attribute.
720   Attribute layout = AffineMapAttr::get(map);
721 
722   // Drop default memory space value and replace it with empty attribute.
723   memorySpace = skipDefaultMemorySpace(memorySpace);
724 
725   return Base::get(elementType.getContext(), shape, elementType, layout,
726                    memorySpace);
727 }
728 
729 MemRefType
730 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
731                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
732                        Attribute memorySpace) {
733 
734   // Use default layout for empty map.
735   if (!map)
736     map = AffineMap::getMultiDimIdentityMap(shape.size(),
737                                             elementType.getContext());
738 
739   // Wrap AffineMap into Attribute.
740   Attribute layout = AffineMapAttr::get(map);
741 
742   // Drop default memory space value and replace it with empty attribute.
743   memorySpace = skipDefaultMemorySpace(memorySpace);
744 
745   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
746                           elementType, layout, memorySpace);
747 }
748 
749 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
750                            AffineMap map, unsigned memorySpaceInd) {
751 
752   // Use default layout for empty map.
753   if (!map)
754     map = AffineMap::getMultiDimIdentityMap(shape.size(),
755                                             elementType.getContext());
756 
757   // Wrap AffineMap into Attribute.
758   Attribute layout = AffineMapAttr::get(map);
759 
760   // Convert deprecated integer-like memory space to Attribute.
761   Attribute memorySpace =
762       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
763 
764   return Base::get(elementType.getContext(), shape, elementType, layout,
765                    memorySpace);
766 }
767 
768 MemRefType
769 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
770                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
771                        unsigned memorySpaceInd) {
772 
773   // Use default layout for empty map.
774   if (!map)
775     map = AffineMap::getMultiDimIdentityMap(shape.size(),
776                                             elementType.getContext());
777 
778   // Wrap AffineMap into Attribute.
779   Attribute layout = AffineMapAttr::get(map);
780 
781   // Convert deprecated integer-like memory space to Attribute.
782   Attribute memorySpace =
783       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
784 
785   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
786                           elementType, layout, memorySpace);
787 }
788 
789 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
790                                  ArrayRef<int64_t> shape, Type elementType,
791                                  MemRefLayoutAttrInterface layout,
792                                  Attribute memorySpace) {
793   if (!BaseMemRefType::isValidElementType(elementType))
794     return emitError() << "invalid memref element type";
795 
796   // Negative sizes are not allowed except for `-1` that means dynamic size.
797   for (int64_t s : shape)
798     if (s < -1)
799       return emitError() << "invalid memref size";
800 
801   assert(layout && "missing layout specification");
802   if (failed(layout.verifyLayout(shape, emitError)))
803     return failure();
804 
805   if (!isSupportedMemorySpace(memorySpace))
806     return emitError() << "unsupported memory space Attribute";
807 
808   return success();
809 }
810 
811 void MemRefType::walkImmediateSubElements(
812     function_ref<void(Attribute)> walkAttrsFn,
813     function_ref<void(Type)> walkTypesFn) const {
814   walkTypesFn(getElementType());
815   if (!getLayout().isIdentity())
816     walkAttrsFn(getLayout());
817   walkAttrsFn(getMemorySpace());
818 }
819 
820 //===----------------------------------------------------------------------===//
821 // UnrankedMemRefType
822 //===----------------------------------------------------------------------===//
823 
824 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
825   return detail::getMemorySpaceAsInt(getMemorySpace());
826 }
827 
828 LogicalResult
829 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
830                            Type elementType, Attribute memorySpace) {
831   if (!BaseMemRefType::isValidElementType(elementType))
832     return emitError() << "invalid memref element type";
833 
834   if (!isSupportedMemorySpace(memorySpace))
835     return emitError() << "unsupported memory space Attribute";
836 
837   return success();
838 }
839 
840 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
841 // i.e. single term). Accumulate the AffineExpr into the existing one.
842 static void extractStridesFromTerm(AffineExpr e,
843                                    AffineExpr multiplicativeFactor,
844                                    MutableArrayRef<AffineExpr> strides,
845                                    AffineExpr &offset) {
846   if (auto dim = e.dyn_cast<AffineDimExpr>())
847     strides[dim.getPosition()] =
848         strides[dim.getPosition()] + multiplicativeFactor;
849   else
850     offset = offset + e * multiplicativeFactor;
851 }
852 
853 /// Takes a single AffineExpr `e` and populates the `strides` array with the
854 /// strides expressions for each dim position.
855 /// The convention is that the strides for dimensions d0, .. dn appear in
856 /// order to make indexing intuitive into the result.
857 static LogicalResult extractStrides(AffineExpr e,
858                                     AffineExpr multiplicativeFactor,
859                                     MutableArrayRef<AffineExpr> strides,
860                                     AffineExpr &offset) {
861   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
862   if (!bin) {
863     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
864     return success();
865   }
866 
867   if (bin.getKind() == AffineExprKind::CeilDiv ||
868       bin.getKind() == AffineExprKind::FloorDiv ||
869       bin.getKind() == AffineExprKind::Mod)
870     return failure();
871 
872   if (bin.getKind() == AffineExprKind::Mul) {
873     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
874     if (dim) {
875       strides[dim.getPosition()] =
876           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
877       return success();
878     }
879     // LHS and RHS may both contain complex expressions of dims. Try one path
880     // and if it fails try the other. This is guaranteed to succeed because
881     // only one path may have a `dim`, otherwise this is not an AffineExpr in
882     // the first place.
883     if (bin.getLHS().isSymbolicOrConstant())
884       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
885                             strides, offset);
886     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
887                           strides, offset);
888   }
889 
890   if (bin.getKind() == AffineExprKind::Add) {
891     auto res1 =
892         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
893     auto res2 =
894         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
895     return success(succeeded(res1) && succeeded(res2));
896   }
897 
898   llvm_unreachable("unexpected binary operation");
899 }
900 
901 LogicalResult mlir::getStridesAndOffset(MemRefType t,
902                                         SmallVectorImpl<AffineExpr> &strides,
903                                         AffineExpr &offset) {
904   AffineMap m = t.getLayout().getAffineMap();
905 
906   if (m.getNumResults() != 1 && !m.isIdentity())
907     return failure();
908 
909   auto zero = getAffineConstantExpr(0, t.getContext());
910   auto one = getAffineConstantExpr(1, t.getContext());
911   offset = zero;
912   strides.assign(t.getRank(), zero);
913 
914   // Canonical case for empty map.
915   if (m.isIdentity()) {
916     // 0-D corner case, offset is already 0.
917     if (t.getRank() == 0)
918       return success();
919     auto stridedExpr =
920         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
921     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
922       return success();
923     assert(false && "unexpected failure: extract strides in canonical layout");
924   }
925 
926   // Non-canonical case requires more work.
927   auto stridedExpr =
928       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
929   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
930     offset = AffineExpr();
931     strides.clear();
932     return failure();
933   }
934 
935   // Simplify results to allow folding to constants and simple checks.
936   unsigned numDims = m.getNumDims();
937   unsigned numSymbols = m.getNumSymbols();
938   offset = simplifyAffineExpr(offset, numDims, numSymbols);
939   for (auto &stride : strides)
940     stride = simplifyAffineExpr(stride, numDims, numSymbols);
941 
942   /// In practice, a strided memref must be internally non-aliasing. Test
943   /// against 0 as a proxy.
944   /// TODO: static cases can have more advanced checks.
945   /// TODO: dynamic cases would require a way to compare symbolic
946   /// expressions and would probably need an affine set context propagated
947   /// everywhere.
948   if (llvm::any_of(strides, [](AffineExpr e) {
949         return e == getAffineConstantExpr(0, e.getContext());
950       })) {
951     offset = AffineExpr();
952     strides.clear();
953     return failure();
954   }
955 
956   return success();
957 }
958 
959 LogicalResult mlir::getStridesAndOffset(MemRefType t,
960                                         SmallVectorImpl<int64_t> &strides,
961                                         int64_t &offset) {
962   AffineExpr offsetExpr;
963   SmallVector<AffineExpr, 4> strideExprs;
964   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
965     return failure();
966   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
967     offset = cst.getValue();
968   else
969     offset = ShapedType::kDynamicStrideOrOffset;
970   for (auto e : strideExprs) {
971     if (auto c = e.dyn_cast<AffineConstantExpr>())
972       strides.push_back(c.getValue());
973     else
974       strides.push_back(ShapedType::kDynamicStrideOrOffset);
975   }
976   return success();
977 }
978 
979 void UnrankedMemRefType::walkImmediateSubElements(
980     function_ref<void(Attribute)> walkAttrsFn,
981     function_ref<void(Type)> walkTypesFn) const {
982   walkTypesFn(getElementType());
983   walkAttrsFn(getMemorySpace());
984 }
985 
986 //===----------------------------------------------------------------------===//
987 /// TupleType
988 //===----------------------------------------------------------------------===//
989 
990 /// Return the elements types for this tuple.
991 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
992 
993 /// Accumulate the types contained in this tuple and tuples nested within it.
994 /// Note that this only flattens nested tuples, not any other container type,
995 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
996 /// (i32, tensor<i32>, f32, i64)
997 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
998   for (Type type : getTypes()) {
999     if (auto nestedTuple = type.dyn_cast<TupleType>())
1000       nestedTuple.getFlattenedTypes(types);
1001     else
1002       types.push_back(type);
1003   }
1004 }
1005 
1006 /// Return the number of element types.
1007 size_t TupleType::size() const { return getImpl()->size(); }
1008 
1009 void TupleType::walkImmediateSubElements(
1010     function_ref<void(Attribute)> walkAttrsFn,
1011     function_ref<void(Type)> walkTypesFn) const {
1012   for (Type type : getTypes())
1013     walkTypesFn(type);
1014 }
1015 
1016 //===----------------------------------------------------------------------===//
1017 // Type Utilities
1018 //===----------------------------------------------------------------------===//
1019 
1020 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
1021                                            int64_t offset,
1022                                            MLIRContext *context) {
1023   AffineExpr expr;
1024   unsigned nSymbols = 0;
1025 
1026   // AffineExpr for offset.
1027   // Static case.
1028   if (offset != MemRefType::getDynamicStrideOrOffset()) {
1029     auto cst = getAffineConstantExpr(offset, context);
1030     expr = cst;
1031   } else {
1032     // Dynamic case, new symbol for the offset.
1033     auto sym = getAffineSymbolExpr(nSymbols++, context);
1034     expr = sym;
1035   }
1036 
1037   // AffineExpr for strides.
1038   for (auto en : llvm::enumerate(strides)) {
1039     auto dim = en.index();
1040     auto stride = en.value();
1041     assert(stride != 0 && "Invalid stride specification");
1042     auto d = getAffineDimExpr(dim, context);
1043     AffineExpr mult;
1044     // Static case.
1045     if (stride != MemRefType::getDynamicStrideOrOffset())
1046       mult = getAffineConstantExpr(stride, context);
1047     else
1048       // Dynamic case, new symbol for each new stride.
1049       mult = getAffineSymbolExpr(nSymbols++, context);
1050     expr = expr + d * mult;
1051   }
1052 
1053   return AffineMap::get(strides.size(), nSymbols, expr);
1054 }
1055 
1056 /// Return a version of `t` with identity layout if it can be determined
1057 /// statically that the layout is the canonical contiguous strided layout.
1058 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
1059 /// `t` with simplified layout.
1060 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
1061 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
1062   AffineMap m = t.getLayout().getAffineMap();
1063 
1064   // Already in canonical form.
1065   if (m.isIdentity())
1066     return t;
1067 
1068   // Can't reduce to canonical identity form, return in canonical form.
1069   if (m.getNumResults() > 1)
1070     return t;
1071 
1072   // Corner-case for 0-D affine maps.
1073   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
1074     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
1075       if (cst.getValue() == 0)
1076         return MemRefType::Builder(t).setLayout({});
1077     return t;
1078   }
1079 
1080   // 0-D corner case for empty shape that still have an affine map. Example:
1081   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
1082   // offset needs to remain, just return t.
1083   if (t.getShape().empty())
1084     return t;
1085 
1086   // If the canonical strided layout for the sizes of `t` is equal to the
1087   // simplified layout of `t` we can just return an empty layout. Otherwise,
1088   // just simplify the existing layout.
1089   AffineExpr expr =
1090       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
1091   auto simplifiedLayoutExpr =
1092       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
1093   if (expr != simplifiedLayoutExpr)
1094     return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
1095         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
1096   return MemRefType::Builder(t).setLayout({});
1097 }
1098 
1099 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
1100                                                 ArrayRef<AffineExpr> exprs,
1101                                                 MLIRContext *context) {
1102   assert(!sizes.empty() && !exprs.empty() &&
1103          "expected non-empty sizes and exprs");
1104 
1105   // Size 0 corner case is useful for canonicalizations.
1106   if (llvm::is_contained(sizes, 0))
1107     return getAffineConstantExpr(0, context);
1108 
1109   auto maps = AffineMap::inferFromExprList(exprs);
1110   assert(!maps.empty() && "Expected one non-empty map");
1111   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
1112 
1113   AffineExpr expr;
1114   bool dynamicPoisonBit = false;
1115   int64_t runningSize = 1;
1116   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
1117     int64_t size = std::get<1>(en);
1118     // Degenerate case, no size =-> no stride
1119     if (size == 0)
1120       continue;
1121     AffineExpr dimExpr = std::get<0>(en);
1122     AffineExpr stride = dynamicPoisonBit
1123                             ? getAffineSymbolExpr(nSymbols++, context)
1124                             : getAffineConstantExpr(runningSize, context);
1125     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
1126     if (size > 0) {
1127       runningSize *= size;
1128       assert(runningSize > 0 && "integer overflow in size computation");
1129     } else {
1130       dynamicPoisonBit = true;
1131     }
1132   }
1133   return simplifyAffineExpr(expr, numDims, nSymbols);
1134 }
1135 
1136 /// Return a version of `t` with a layout that has all dynamic offset and
1137 /// strides. This is used to erase the static layout.
1138 MemRefType mlir::eraseStridedLayout(MemRefType t) {
1139   auto val = ShapedType::kDynamicStrideOrOffset;
1140   return MemRefType::Builder(t).setLayout(
1141       AffineMapAttr::get(makeStridedLinearLayoutMap(
1142           SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
1143 }
1144 
1145 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
1146                                                 MLIRContext *context) {
1147   SmallVector<AffineExpr, 4> exprs;
1148   exprs.reserve(sizes.size());
1149   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
1150     exprs.push_back(getAffineDimExpr(dim, context));
1151   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
1152 }
1153 
1154 /// Return true if the layout for `t` is compatible with strided semantics.
1155 bool mlir::isStrided(MemRefType t) {
1156   int64_t offset;
1157   SmallVector<int64_t, 4> strides;
1158   auto res = getStridesAndOffset(t, strides, offset);
1159   return succeeded(res);
1160 }
1161 
1162 /// Return the layout map in strided linear layout AffineMap form.
1163 /// Return null if the layout is not compatible with a strided layout.
1164 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
1165   int64_t offset;
1166   SmallVector<int64_t, 4> strides;
1167   if (failed(getStridesAndOffset(t, strides, offset)))
1168     return AffineMap();
1169   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1170 }
1171 
1172 /// Return the AffineExpr representation of the offset, assuming `memRefType`
1173 /// is a strided memref.
1174 static AffineExpr getOffsetExpr(MemRefType memrefType) {
1175   SmallVector<AffineExpr> strides;
1176   AffineExpr offset;
1177   if (failed(getStridesAndOffset(memrefType, strides, offset)))
1178     assert(false && "expected strided memref");
1179   return offset;
1180 }
1181 
1182 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1183 /// `offset` AffineExpr.
1184 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
1185                                                    ArrayRef<int64_t> shape,
1186                                                    Type elementType,
1187                                                    AffineExpr offset) {
1188   AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1189   AffineExpr contiguousRowMajor = canonical + offset;
1190   AffineMap contiguousRowMajorMap =
1191       AffineMap::inferFromExprList({contiguousRowMajor})[0];
1192   return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1193 }
1194 
1195 /// Helper determining if a memref is static-shape and contiguous-row-major
1196 /// layout, while still allowing for an arbitrary offset (any static or
1197 /// dynamic value).
1198 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1199   if (!memrefType.hasStaticShape())
1200     return false;
1201   AffineExpr offset = getOffsetExpr(memrefType);
1202   MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1203       memrefType.getContext(), memrefType.getShape(),
1204       memrefType.getElementType(), offset);
1205   return canonicalizeStridedLayout(memrefType) ==
1206          canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1207 }
1208