1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/FunctionInterfaces.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 /// Tablegen Type Definitions
31 //===----------------------------------------------------------------------===//
32 
33 #define GET_TYPEDEF_CLASSES
34 #include "mlir/IR/BuiltinTypes.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // BuiltinDialect
38 //===----------------------------------------------------------------------===//
39 
registerTypes()40 void BuiltinDialect::registerTypes() {
41   addTypes<
42 #define GET_TYPEDEF_LIST
43 #include "mlir/IR/BuiltinTypes.cpp.inc"
44       >();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 /// ComplexType
49 //===----------------------------------------------------------------------===//
50 
51 /// Verify the construction of an integer type.
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType)52 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
53                                   Type elementType) {
54   if (!elementType.isIntOrFloat())
55     return emitError() << "invalid element type for complex";
56   return success();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // Integer Type
61 //===----------------------------------------------------------------------===//
62 
63 // static constexpr must have a definition (until in C++17 and inline variable).
64 constexpr unsigned IntegerType::kMaxWidth;
65 
66 /// Verify the construction of an integer type.
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned width,SignednessSemantics signedness)67 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
68                                   unsigned width,
69                                   SignednessSemantics signedness) {
70   if (width > IntegerType::kMaxWidth) {
71     return emitError() << "integer bitwidth is limited to "
72                        << IntegerType::kMaxWidth << " bits";
73   }
74   return success();
75 }
76 
getWidth() const77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78 
getSignedness() const79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80   return getImpl()->signedness;
81 }
82 
scaleElementBitwidth(unsigned scale)83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84   if (!scale)
85     return IntegerType();
86   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Float Type
91 //===----------------------------------------------------------------------===//
92 
getWidth()93 unsigned FloatType::getWidth() {
94   if (isa<Float16Type, BFloat16Type>())
95     return 16;
96   if (isa<Float32Type>())
97     return 32;
98   if (isa<Float64Type>())
99     return 64;
100   if (isa<Float80Type>())
101     return 80;
102   if (isa<Float128Type>())
103     return 128;
104   llvm_unreachable("unexpected float type");
105 }
106 
107 /// Returns the floating semantics for the given type.
getFloatSemantics()108 const llvm::fltSemantics &FloatType::getFloatSemantics() {
109   if (isa<BFloat16Type>())
110     return APFloat::BFloat();
111   if (isa<Float16Type>())
112     return APFloat::IEEEhalf();
113   if (isa<Float32Type>())
114     return APFloat::IEEEsingle();
115   if (isa<Float64Type>())
116     return APFloat::IEEEdouble();
117   if (isa<Float80Type>())
118     return APFloat::x87DoubleExtended();
119   if (isa<Float128Type>())
120     return APFloat::IEEEquad();
121   llvm_unreachable("non-floating point type used");
122 }
123 
scaleElementBitwidth(unsigned scale)124 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
125   if (!scale)
126     return FloatType();
127   MLIRContext *ctx = getContext();
128   if (isF16() || isBF16()) {
129     if (scale == 2)
130       return FloatType::getF32(ctx);
131     if (scale == 4)
132       return FloatType::getF64(ctx);
133   }
134   if (isF32())
135     if (scale == 2)
136       return FloatType::getF64(ctx);
137   return FloatType();
138 }
139 
getFPMantissaWidth()140 unsigned FloatType::getFPMantissaWidth() {
141   return APFloat::semanticsPrecision(getFloatSemantics());
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // FunctionType
146 //===----------------------------------------------------------------------===//
147 
getNumInputs() const148 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
149 
getInputs() const150 ArrayRef<Type> FunctionType::getInputs() const {
151   return getImpl()->getInputs();
152 }
153 
getNumResults() const154 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
155 
getResults() const156 ArrayRef<Type> FunctionType::getResults() const {
157   return getImpl()->getResults();
158 }
159 
clone(TypeRange inputs,TypeRange results) const160 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
161   return get(getContext(), inputs, results);
162 }
163 
164 /// Returns a new function type with the specified arguments and results
165 /// inserted.
getWithArgsAndResults(ArrayRef<unsigned> argIndices,TypeRange argTypes,ArrayRef<unsigned> resultIndices,TypeRange resultTypes)166 FunctionType FunctionType::getWithArgsAndResults(
167     ArrayRef<unsigned> argIndices, TypeRange argTypes,
168     ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
169   SmallVector<Type> argStorage, resultStorage;
170   TypeRange newArgTypes = function_interface_impl::insertTypesInto(
171       getInputs(), argIndices, argTypes, argStorage);
172   TypeRange newResultTypes = function_interface_impl::insertTypesInto(
173       getResults(), resultIndices, resultTypes, resultStorage);
174   return clone(newArgTypes, newResultTypes);
175 }
176 
177 /// Returns a new function type without the specified arguments and results.
178 FunctionType
getWithoutArgsAndResults(const BitVector & argIndices,const BitVector & resultIndices)179 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
180                                        const BitVector &resultIndices) {
181   SmallVector<Type> argStorage, resultStorage;
182   TypeRange newArgTypes = function_interface_impl::filterTypesOut(
183       getInputs(), argIndices, argStorage);
184   TypeRange newResultTypes = function_interface_impl::filterTypesOut(
185       getResults(), resultIndices, resultStorage);
186   return clone(newArgTypes, newResultTypes);
187 }
188 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const189 void FunctionType::walkImmediateSubElements(
190     function_ref<void(Attribute)> walkAttrsFn,
191     function_ref<void(Type)> walkTypesFn) const {
192   for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
193     walkTypesFn(type);
194 }
195 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const196 Type FunctionType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
197                                                ArrayRef<Type> replTypes) const {
198   unsigned numInputs = getNumInputs();
199   return get(getContext(), replTypes.take_front(numInputs),
200              replTypes.drop_front(numInputs));
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // OpaqueType
205 //===----------------------------------------------------------------------===//
206 
207 /// Verify the construction of an opaque type.
verify(function_ref<InFlightDiagnostic ()> emitError,StringAttr dialect,StringRef typeData)208 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
209                                  StringAttr dialect, StringRef typeData) {
210   if (!Dialect::isValidNamespace(dialect.strref()))
211     return emitError() << "invalid dialect namespace '" << dialect << "'";
212 
213   // Check that the dialect is actually registered.
214   MLIRContext *context = dialect.getContext();
215   if (!context->allowsUnregisteredDialects() &&
216       !context->getLoadedDialect(dialect.strref())) {
217     return emitError()
218            << "`!" << dialect << "<\"" << typeData << "\">"
219            << "` type created with unregistered dialect. If this is "
220               "intended, please call allowUnregisteredDialects() on the "
221               "MLIRContext, or use -allow-unregistered-dialect with "
222               "the MLIR opt tool used";
223   }
224 
225   return success();
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // VectorType
230 //===----------------------------------------------------------------------===//
231 
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,unsigned numScalableDims)232 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
233                                  ArrayRef<int64_t> shape, Type elementType,
234                                  unsigned numScalableDims) {
235   if (!isValidElementType(elementType))
236     return emitError()
237            << "vector elements must be int/index/float type but got "
238            << elementType;
239 
240   if (any_of(shape, [](int64_t i) { return i <= 0; }))
241     return emitError()
242            << "vector types must have positive constant sizes but got "
243            << shape;
244 
245   return success();
246 }
247 
scaleElementBitwidth(unsigned scale)248 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
249   if (!scale)
250     return VectorType();
251   if (auto et = getElementType().dyn_cast<IntegerType>())
252     if (auto scaledEt = et.scaleElementBitwidth(scale))
253       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
254   if (auto et = getElementType().dyn_cast<FloatType>())
255     if (auto scaledEt = et.scaleElementBitwidth(scale))
256       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
257   return VectorType();
258 }
259 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const260 void VectorType::walkImmediateSubElements(
261     function_ref<void(Attribute)> walkAttrsFn,
262     function_ref<void(Type)> walkTypesFn) const {
263   walkTypesFn(getElementType());
264 }
265 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const266 Type VectorType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
267                                              ArrayRef<Type> replTypes) const {
268   return get(getShape(), replTypes.front(), getNumScalableDims());
269 }
270 
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const271 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
272                                  Type elementType) const {
273   return VectorType::get(shape.value_or(getShape()), elementType,
274                          getNumScalableDims());
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // TensorType
279 //===----------------------------------------------------------------------===//
280 
getElementType() const281 Type TensorType::getElementType() const {
282   return llvm::TypeSwitch<TensorType, Type>(*this)
283       .Case<RankedTensorType, UnrankedTensorType>(
284           [](auto type) { return type.getElementType(); });
285 }
286 
hasRank() const287 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
288 
getShape() const289 ArrayRef<int64_t> TensorType::getShape() const {
290   return cast<RankedTensorType>().getShape();
291 }
292 
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const293 TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
294                                  Type elementType) const {
295   if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
296     if (shape)
297       return RankedTensorType::get(*shape, elementType);
298     return UnrankedTensorType::get(elementType);
299   }
300 
301   auto rankedTy = cast<RankedTensorType>();
302   if (!shape)
303     return RankedTensorType::get(rankedTy.getShape(), elementType,
304                                  rankedTy.getEncoding());
305   return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
306                                rankedTy.getEncoding());
307 }
308 
309 // Check if "elementType" can be an element type of a tensor.
310 static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic ()> emitError,Type elementType)311 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
312                        Type elementType) {
313   if (!TensorType::isValidElementType(elementType))
314     return emitError() << "invalid tensor element type: " << elementType;
315   return success();
316 }
317 
318 /// Return true if the specified element type is ok in a tensor.
isValidElementType(Type type)319 bool TensorType::isValidElementType(Type type) {
320   // Note: Non standard/builtin types are allowed to exist within tensor
321   // types. Dialects are expected to verify that tensor types have a valid
322   // element type within that dialect.
323   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
324                   IndexType>() ||
325          !llvm::isa<BuiltinDialect>(type.getDialect());
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // RankedTensorType
330 //===----------------------------------------------------------------------===//
331 
332 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,Attribute encoding)333 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
334                          ArrayRef<int64_t> shape, Type elementType,
335                          Attribute encoding) {
336   for (int64_t s : shape)
337     if (s < -1)
338       return emitError() << "invalid tensor dimension size";
339   if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
340     if (failed(v.verifyEncoding(shape, elementType, emitError)))
341       return failure();
342   return checkTensorElementType(emitError, elementType);
343 }
344 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const345 void RankedTensorType::walkImmediateSubElements(
346     function_ref<void(Attribute)> walkAttrsFn,
347     function_ref<void(Type)> walkTypesFn) const {
348   walkTypesFn(getElementType());
349   if (Attribute encoding = getEncoding())
350     walkAttrsFn(encoding);
351 }
352 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const353 Type RankedTensorType::replaceImmediateSubElements(
354     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
355   return get(getShape(), replTypes.front(),
356              replAttrs.empty() ? Attribute() : replAttrs.back());
357 }
358 
359 //===----------------------------------------------------------------------===//
360 // UnrankedTensorType
361 //===----------------------------------------------------------------------===//
362 
363 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType)364 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
365                            Type elementType) {
366   return checkTensorElementType(emitError, elementType);
367 }
368 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const369 void UnrankedTensorType::walkImmediateSubElements(
370     function_ref<void(Attribute)> walkAttrsFn,
371     function_ref<void(Type)> walkTypesFn) const {
372   walkTypesFn(getElementType());
373 }
374 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const375 Type UnrankedTensorType::replaceImmediateSubElements(
376     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
377   return get(replTypes.front());
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // BaseMemRefType
382 //===----------------------------------------------------------------------===//
383 
getElementType() const384 Type BaseMemRefType::getElementType() const {
385   return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
386       .Case<MemRefType, UnrankedMemRefType>(
387           [](auto type) { return type.getElementType(); });
388 }
389 
hasRank() const390 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
391 
getShape() const392 ArrayRef<int64_t> BaseMemRefType::getShape() const {
393   return cast<MemRefType>().getShape();
394 }
395 
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const396 BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape,
397                                          Type elementType) const {
398   if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
399     if (!shape)
400       return UnrankedMemRefType::get(elementType, getMemorySpace());
401     MemRefType::Builder builder(*shape, elementType);
402     builder.setMemorySpace(getMemorySpace());
403     return builder;
404   }
405 
406   MemRefType::Builder builder(cast<MemRefType>());
407   if (shape)
408     builder.setShape(*shape);
409   builder.setElementType(elementType);
410   return builder;
411 }
412 
getMemorySpace() const413 Attribute BaseMemRefType::getMemorySpace() const {
414   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
415     return rankedMemRefTy.getMemorySpace();
416   return cast<UnrankedMemRefType>().getMemorySpace();
417 }
418 
getMemorySpaceAsInt() const419 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
420   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
421     return rankedMemRefTy.getMemorySpaceAsInt();
422   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // MemRefType
427 //===----------------------------------------------------------------------===//
428 
429 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
430 /// `originalShape` with some `1` entries erased, return the set of indices
431 /// that specifies which of the entries of `originalShape` are dropped to obtain
432 /// `reducedShape`. The returned mask can be applied as a projection to
433 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
434 /// which dimensions must be kept when e.g. compute MemRef strides under
435 /// rank-reducing operations. Return None if reducedShape cannot be obtained
436 /// by dropping only `1` entries in `originalShape`.
437 llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> reducedShape)438 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
439                                ArrayRef<int64_t> reducedShape) {
440   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
441   llvm::SmallDenseSet<unsigned> unusedDims;
442   unsigned reducedIdx = 0;
443   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
444     // Greedily insert `originalIdx` if match.
445     if (reducedIdx < reducedRank &&
446         originalShape[originalIdx] == reducedShape[reducedIdx]) {
447       reducedIdx++;
448       continue;
449     }
450 
451     unusedDims.insert(originalIdx);
452     // If no match on `originalIdx`, the `originalShape` at this dimension
453     // must be 1, otherwise we bail.
454     if (originalShape[originalIdx] != 1)
455       return llvm::None;
456   }
457   // The whole reducedShape must be scanned, otherwise we bail.
458   if (reducedIdx != reducedRank)
459     return llvm::None;
460   return unusedDims;
461 }
462 
463 SliceVerificationResult
isRankReducedType(ShapedType originalType,ShapedType candidateReducedType)464 mlir::isRankReducedType(ShapedType originalType,
465                         ShapedType candidateReducedType) {
466   if (originalType == candidateReducedType)
467     return SliceVerificationResult::Success;
468 
469   ShapedType originalShapedType = originalType.cast<ShapedType>();
470   ShapedType candidateReducedShapedType =
471       candidateReducedType.cast<ShapedType>();
472 
473   // Rank and size logic is valid for all ShapedTypes.
474   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
475   ArrayRef<int64_t> candidateReducedShape =
476       candidateReducedShapedType.getShape();
477   unsigned originalRank = originalShape.size(),
478            candidateReducedRank = candidateReducedShape.size();
479   if (candidateReducedRank > originalRank)
480     return SliceVerificationResult::RankTooLarge;
481 
482   auto optionalUnusedDimsMask =
483       computeRankReductionMask(originalShape, candidateReducedShape);
484 
485   // Sizes cannot be matched in case empty vector is returned.
486   if (!optionalUnusedDimsMask)
487     return SliceVerificationResult::SizeMismatch;
488 
489   if (originalShapedType.getElementType() !=
490       candidateReducedShapedType.getElementType())
491     return SliceVerificationResult::ElemTypeMismatch;
492 
493   return SliceVerificationResult::Success;
494 }
495 
isSupportedMemorySpace(Attribute memorySpace)496 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
497   // Empty attribute is allowed as default memory space.
498   if (!memorySpace)
499     return true;
500 
501   // Supported built-in attributes.
502   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
503     return true;
504 
505   // Allow custom dialect attributes.
506   if (!isa<BuiltinDialect>(memorySpace.getDialect()))
507     return true;
508 
509   return false;
510 }
511 
wrapIntegerMemorySpace(unsigned memorySpace,MLIRContext * ctx)512 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
513                                                MLIRContext *ctx) {
514   if (memorySpace == 0)
515     return nullptr;
516 
517   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
518 }
519 
skipDefaultMemorySpace(Attribute memorySpace)520 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
521   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
522   if (intMemorySpace && intMemorySpace.getValue() == 0)
523     return nullptr;
524 
525   return memorySpace;
526 }
527 
getMemorySpaceAsInt(Attribute memorySpace)528 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
529   if (!memorySpace)
530     return 0;
531 
532   assert(memorySpace.isa<IntegerAttr>() &&
533          "Using `getMemorySpaceInteger` with non-Integer attribute");
534 
535   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
536 }
537 
538 MemRefType::Builder &
setMemorySpace(unsigned newMemorySpace)539 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
540   memorySpace =
541       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
542   return *this;
543 }
544 
getMemorySpaceAsInt() const545 unsigned MemRefType::getMemorySpaceAsInt() const {
546   return detail::getMemorySpaceAsInt(getMemorySpace());
547 }
548 
get(ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)549 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
550                            MemRefLayoutAttrInterface layout,
551                            Attribute memorySpace) {
552   // Use default layout for empty attribute.
553   if (!layout)
554     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
555         shape.size(), elementType.getContext()));
556 
557   // Drop default memory space value and replace it with empty attribute.
558   memorySpace = skipDefaultMemorySpace(memorySpace);
559 
560   return Base::get(elementType.getContext(), shape, elementType, layout,
561                    memorySpace);
562 }
563 
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)564 MemRefType MemRefType::getChecked(
565     function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
566     Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
567 
568   // Use default layout for empty attribute.
569   if (!layout)
570     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
571         shape.size(), elementType.getContext()));
572 
573   // Drop default memory space value and replace it with empty attribute.
574   memorySpace = skipDefaultMemorySpace(memorySpace);
575 
576   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
577                           elementType, layout, memorySpace);
578 }
579 
get(ArrayRef<int64_t> shape,Type elementType,AffineMap map,Attribute memorySpace)580 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
581                            AffineMap map, Attribute memorySpace) {
582 
583   // Use default layout for empty map.
584   if (!map)
585     map = AffineMap::getMultiDimIdentityMap(shape.size(),
586                                             elementType.getContext());
587 
588   // Wrap AffineMap into Attribute.
589   Attribute layout = AffineMapAttr::get(map);
590 
591   // Drop default memory space value and replace it with empty attribute.
592   memorySpace = skipDefaultMemorySpace(memorySpace);
593 
594   return Base::get(elementType.getContext(), shape, elementType, layout,
595                    memorySpace);
596 }
597 
598 MemRefType
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,AffineMap map,Attribute memorySpace)599 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
600                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
601                        Attribute memorySpace) {
602 
603   // Use default layout for empty map.
604   if (!map)
605     map = AffineMap::getMultiDimIdentityMap(shape.size(),
606                                             elementType.getContext());
607 
608   // Wrap AffineMap into Attribute.
609   Attribute layout = AffineMapAttr::get(map);
610 
611   // Drop default memory space value and replace it with empty attribute.
612   memorySpace = skipDefaultMemorySpace(memorySpace);
613 
614   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
615                           elementType, layout, memorySpace);
616 }
617 
get(ArrayRef<int64_t> shape,Type elementType,AffineMap map,unsigned memorySpaceInd)618 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
619                            AffineMap map, unsigned memorySpaceInd) {
620 
621   // Use default layout for empty map.
622   if (!map)
623     map = AffineMap::getMultiDimIdentityMap(shape.size(),
624                                             elementType.getContext());
625 
626   // Wrap AffineMap into Attribute.
627   Attribute layout = AffineMapAttr::get(map);
628 
629   // Convert deprecated integer-like memory space to Attribute.
630   Attribute memorySpace =
631       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
632 
633   return Base::get(elementType.getContext(), shape, elementType, layout,
634                    memorySpace);
635 }
636 
637 MemRefType
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,AffineMap map,unsigned memorySpaceInd)638 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
639                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
640                        unsigned memorySpaceInd) {
641 
642   // Use default layout for empty map.
643   if (!map)
644     map = AffineMap::getMultiDimIdentityMap(shape.size(),
645                                             elementType.getContext());
646 
647   // Wrap AffineMap into Attribute.
648   Attribute layout = AffineMapAttr::get(map);
649 
650   // Convert deprecated integer-like memory space to Attribute.
651   Attribute memorySpace =
652       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
653 
654   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
655                           elementType, layout, memorySpace);
656 }
657 
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)658 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
659                                  ArrayRef<int64_t> shape, Type elementType,
660                                  MemRefLayoutAttrInterface layout,
661                                  Attribute memorySpace) {
662   if (!BaseMemRefType::isValidElementType(elementType))
663     return emitError() << "invalid memref element type";
664 
665   // Negative sizes are not allowed except for `-1` that means dynamic size.
666   for (int64_t s : shape)
667     if (s < -1)
668       return emitError() << "invalid memref size";
669 
670   assert(layout && "missing layout specification");
671   if (failed(layout.verifyLayout(shape, emitError)))
672     return failure();
673 
674   if (!isSupportedMemorySpace(memorySpace))
675     return emitError() << "unsupported memory space Attribute";
676 
677   return success();
678 }
679 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const680 void MemRefType::walkImmediateSubElements(
681     function_ref<void(Attribute)> walkAttrsFn,
682     function_ref<void(Type)> walkTypesFn) const {
683   walkTypesFn(getElementType());
684   if (!getLayout().isIdentity())
685     walkAttrsFn(getLayout());
686   walkAttrsFn(getMemorySpace());
687 }
688 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const689 Type MemRefType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
690                                              ArrayRef<Type> replTypes) const {
691   bool hasLayout = replAttrs.size() > 1;
692   return get(getShape(), replTypes[0],
693              hasLayout ? replAttrs[0].dyn_cast<MemRefLayoutAttrInterface>()
694                        : MemRefLayoutAttrInterface(),
695              hasLayout ? replAttrs[1] : replAttrs[0]);
696 }
697 
698 //===----------------------------------------------------------------------===//
699 // UnrankedMemRefType
700 //===----------------------------------------------------------------------===//
701 
getMemorySpaceAsInt() const702 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
703   return detail::getMemorySpaceAsInt(getMemorySpace());
704 }
705 
706 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType,Attribute memorySpace)707 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
708                            Type elementType, Attribute memorySpace) {
709   if (!BaseMemRefType::isValidElementType(elementType))
710     return emitError() << "invalid memref element type";
711 
712   if (!isSupportedMemorySpace(memorySpace))
713     return emitError() << "unsupported memory space Attribute";
714 
715   return success();
716 }
717 
718 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
719 // i.e. single term). Accumulate the AffineExpr into the existing one.
extractStridesFromTerm(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)720 static void extractStridesFromTerm(AffineExpr e,
721                                    AffineExpr multiplicativeFactor,
722                                    MutableArrayRef<AffineExpr> strides,
723                                    AffineExpr &offset) {
724   if (auto dim = e.dyn_cast<AffineDimExpr>())
725     strides[dim.getPosition()] =
726         strides[dim.getPosition()] + multiplicativeFactor;
727   else
728     offset = offset + e * multiplicativeFactor;
729 }
730 
731 /// Takes a single AffineExpr `e` and populates the `strides` array with the
732 /// strides expressions for each dim position.
733 /// The convention is that the strides for dimensions d0, .. dn appear in
734 /// order to make indexing intuitive into the result.
extractStrides(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)735 static LogicalResult extractStrides(AffineExpr e,
736                                     AffineExpr multiplicativeFactor,
737                                     MutableArrayRef<AffineExpr> strides,
738                                     AffineExpr &offset) {
739   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
740   if (!bin) {
741     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
742     return success();
743   }
744 
745   if (bin.getKind() == AffineExprKind::CeilDiv ||
746       bin.getKind() == AffineExprKind::FloorDiv ||
747       bin.getKind() == AffineExprKind::Mod)
748     return failure();
749 
750   if (bin.getKind() == AffineExprKind::Mul) {
751     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
752     if (dim) {
753       strides[dim.getPosition()] =
754           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
755       return success();
756     }
757     // LHS and RHS may both contain complex expressions of dims. Try one path
758     // and if it fails try the other. This is guaranteed to succeed because
759     // only one path may have a `dim`, otherwise this is not an AffineExpr in
760     // the first place.
761     if (bin.getLHS().isSymbolicOrConstant())
762       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
763                             strides, offset);
764     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
765                           strides, offset);
766   }
767 
768   if (bin.getKind() == AffineExprKind::Add) {
769     auto res1 =
770         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
771     auto res2 =
772         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
773     return success(succeeded(res1) && succeeded(res2));
774   }
775 
776   llvm_unreachable("unexpected binary operation");
777 }
778 
getStridesAndOffset(MemRefType t,SmallVectorImpl<AffineExpr> & strides,AffineExpr & offset)779 LogicalResult mlir::getStridesAndOffset(MemRefType t,
780                                         SmallVectorImpl<AffineExpr> &strides,
781                                         AffineExpr &offset) {
782   AffineMap m = t.getLayout().getAffineMap();
783 
784   if (m.getNumResults() != 1 && !m.isIdentity())
785     return failure();
786 
787   auto zero = getAffineConstantExpr(0, t.getContext());
788   auto one = getAffineConstantExpr(1, t.getContext());
789   offset = zero;
790   strides.assign(t.getRank(), zero);
791 
792   // Canonical case for empty map.
793   if (m.isIdentity()) {
794     // 0-D corner case, offset is already 0.
795     if (t.getRank() == 0)
796       return success();
797     auto stridedExpr =
798         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
799     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
800       return success();
801     assert(false && "unexpected failure: extract strides in canonical layout");
802   }
803 
804   // Non-canonical case requires more work.
805   auto stridedExpr =
806       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
807   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
808     offset = AffineExpr();
809     strides.clear();
810     return failure();
811   }
812 
813   // Simplify results to allow folding to constants and simple checks.
814   unsigned numDims = m.getNumDims();
815   unsigned numSymbols = m.getNumSymbols();
816   offset = simplifyAffineExpr(offset, numDims, numSymbols);
817   for (auto &stride : strides)
818     stride = simplifyAffineExpr(stride, numDims, numSymbols);
819 
820   /// In practice, a strided memref must be internally non-aliasing. Test
821   /// against 0 as a proxy.
822   /// TODO: static cases can have more advanced checks.
823   /// TODO: dynamic cases would require a way to compare symbolic
824   /// expressions and would probably need an affine set context propagated
825   /// everywhere.
826   if (llvm::any_of(strides, [](AffineExpr e) {
827         return e == getAffineConstantExpr(0, e.getContext());
828       })) {
829     offset = AffineExpr();
830     strides.clear();
831     return failure();
832   }
833 
834   return success();
835 }
836 
getStridesAndOffset(MemRefType t,SmallVectorImpl<int64_t> & strides,int64_t & offset)837 LogicalResult mlir::getStridesAndOffset(MemRefType t,
838                                         SmallVectorImpl<int64_t> &strides,
839                                         int64_t &offset) {
840   AffineExpr offsetExpr;
841   SmallVector<AffineExpr, 4> strideExprs;
842   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
843     return failure();
844   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
845     offset = cst.getValue();
846   else
847     offset = ShapedType::kDynamicStrideOrOffset;
848   for (auto e : strideExprs) {
849     if (auto c = e.dyn_cast<AffineConstantExpr>())
850       strides.push_back(c.getValue());
851     else
852       strides.push_back(ShapedType::kDynamicStrideOrOffset);
853   }
854   return success();
855 }
856 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const857 void UnrankedMemRefType::walkImmediateSubElements(
858     function_ref<void(Attribute)> walkAttrsFn,
859     function_ref<void(Type)> walkTypesFn) const {
860   walkTypesFn(getElementType());
861   walkAttrsFn(getMemorySpace());
862 }
863 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const864 Type UnrankedMemRefType::replaceImmediateSubElements(
865     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
866   return get(replTypes.front(), replAttrs.front());
867 }
868 
869 //===----------------------------------------------------------------------===//
870 /// TupleType
871 //===----------------------------------------------------------------------===//
872 
873 /// Return the elements types for this tuple.
getTypes() const874 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
875 
876 /// Accumulate the types contained in this tuple and tuples nested within it.
877 /// Note that this only flattens nested tuples, not any other container type,
878 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
879 /// (i32, tensor<i32>, f32, i64)
getFlattenedTypes(SmallVectorImpl<Type> & types)880 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
881   for (Type type : getTypes()) {
882     if (auto nestedTuple = type.dyn_cast<TupleType>())
883       nestedTuple.getFlattenedTypes(types);
884     else
885       types.push_back(type);
886   }
887 }
888 
889 /// Return the number of element types.
size() const890 size_t TupleType::size() const { return getImpl()->size(); }
891 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const892 void TupleType::walkImmediateSubElements(
893     function_ref<void(Attribute)> walkAttrsFn,
894     function_ref<void(Type)> walkTypesFn) const {
895   for (Type type : getTypes())
896     walkTypesFn(type);
897 }
898 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const899 Type TupleType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
900                                             ArrayRef<Type> replTypes) const {
901   return get(getContext(), replTypes);
902 }
903 
904 //===----------------------------------------------------------------------===//
905 // Type Utilities
906 //===----------------------------------------------------------------------===//
907 
makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,int64_t offset,MLIRContext * context)908 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
909                                            int64_t offset,
910                                            MLIRContext *context) {
911   AffineExpr expr;
912   unsigned nSymbols = 0;
913 
914   // AffineExpr for offset.
915   // Static case.
916   if (offset != MemRefType::getDynamicStrideOrOffset()) {
917     auto cst = getAffineConstantExpr(offset, context);
918     expr = cst;
919   } else {
920     // Dynamic case, new symbol for the offset.
921     auto sym = getAffineSymbolExpr(nSymbols++, context);
922     expr = sym;
923   }
924 
925   // AffineExpr for strides.
926   for (const auto &en : llvm::enumerate(strides)) {
927     auto dim = en.index();
928     auto stride = en.value();
929     assert(stride != 0 && "Invalid stride specification");
930     auto d = getAffineDimExpr(dim, context);
931     AffineExpr mult;
932     // Static case.
933     if (stride != MemRefType::getDynamicStrideOrOffset())
934       mult = getAffineConstantExpr(stride, context);
935     else
936       // Dynamic case, new symbol for each new stride.
937       mult = getAffineSymbolExpr(nSymbols++, context);
938     expr = expr + d * mult;
939   }
940 
941   return AffineMap::get(strides.size(), nSymbols, expr);
942 }
943 
944 /// Return a version of `t` with identity layout if it can be determined
945 /// statically that the layout is the canonical contiguous strided layout.
946 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
947 /// `t` with simplified layout.
948 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
canonicalizeStridedLayout(MemRefType t)949 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
950   AffineMap m = t.getLayout().getAffineMap();
951 
952   // Already in canonical form.
953   if (m.isIdentity())
954     return t;
955 
956   // Can't reduce to canonical identity form, return in canonical form.
957   if (m.getNumResults() > 1)
958     return t;
959 
960   // Corner-case for 0-D affine maps.
961   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
962     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
963       if (cst.getValue() == 0)
964         return MemRefType::Builder(t).setLayout({});
965     return t;
966   }
967 
968   // 0-D corner case for empty shape that still have an affine map. Example:
969   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
970   // offset needs to remain, just return t.
971   if (t.getShape().empty())
972     return t;
973 
974   // If the canonical strided layout for the sizes of `t` is equal to the
975   // simplified layout of `t` we can just return an empty layout. Otherwise,
976   // just simplify the existing layout.
977   AffineExpr expr =
978       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
979   auto simplifiedLayoutExpr =
980       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
981   if (expr != simplifiedLayoutExpr)
982     return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
983         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
984   return MemRefType::Builder(t).setLayout({});
985 }
986 
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> exprs,MLIRContext * context)987 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
988                                                 ArrayRef<AffineExpr> exprs,
989                                                 MLIRContext *context) {
990   // Size 0 corner case is useful for canonicalizations.
991   if (sizes.empty() || llvm::is_contained(sizes, 0))
992     return getAffineConstantExpr(0, context);
993 
994   assert(!exprs.empty() && "expected exprs");
995   auto maps = AffineMap::inferFromExprList(exprs);
996   assert(!maps.empty() && "Expected one non-empty map");
997   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
998 
999   AffineExpr expr;
1000   bool dynamicPoisonBit = false;
1001   int64_t runningSize = 1;
1002   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
1003     int64_t size = std::get<1>(en);
1004     // Degenerate case, no size =-> no stride
1005     if (size == 0)
1006       continue;
1007     AffineExpr dimExpr = std::get<0>(en);
1008     AffineExpr stride = dynamicPoisonBit
1009                             ? getAffineSymbolExpr(nSymbols++, context)
1010                             : getAffineConstantExpr(runningSize, context);
1011     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
1012     if (size > 0) {
1013       runningSize *= size;
1014       assert(runningSize > 0 && "integer overflow in size computation");
1015     } else {
1016       dynamicPoisonBit = true;
1017     }
1018   }
1019   return simplifyAffineExpr(expr, numDims, nSymbols);
1020 }
1021 
1022 /// Return a version of `t` with a layout that has all dynamic offset and
1023 /// strides. This is used to erase the static layout.
eraseStridedLayout(MemRefType t)1024 MemRefType mlir::eraseStridedLayout(MemRefType t) {
1025   auto val = ShapedType::kDynamicStrideOrOffset;
1026   return MemRefType::Builder(t).setLayout(
1027       AffineMapAttr::get(makeStridedLinearLayoutMap(
1028           SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
1029 }
1030 
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,MLIRContext * context)1031 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
1032                                                 MLIRContext *context) {
1033   SmallVector<AffineExpr, 4> exprs;
1034   exprs.reserve(sizes.size());
1035   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
1036     exprs.push_back(getAffineDimExpr(dim, context));
1037   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
1038 }
1039 
1040 /// Return true if the layout for `t` is compatible with strided semantics.
isStrided(MemRefType t)1041 bool mlir::isStrided(MemRefType t) {
1042   int64_t offset;
1043   SmallVector<int64_t, 4> strides;
1044   auto res = getStridesAndOffset(t, strides, offset);
1045   return succeeded(res);
1046 }
1047 
1048 /// Return the layout map in strided linear layout AffineMap form.
1049 /// Return null if the layout is not compatible with a strided layout.
getStridedLinearLayoutMap(MemRefType t)1050 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
1051   int64_t offset;
1052   SmallVector<int64_t, 4> strides;
1053   if (failed(getStridesAndOffset(t, strides, offset)))
1054     return AffineMap();
1055   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1056 }
1057 
1058 /// Return the AffineExpr representation of the offset, assuming `memRefType`
1059 /// is a strided memref.
getOffsetExpr(MemRefType memrefType)1060 static AffineExpr getOffsetExpr(MemRefType memrefType) {
1061   SmallVector<AffineExpr> strides;
1062   AffineExpr offset;
1063   if (failed(getStridesAndOffset(memrefType, strides, offset)))
1064     assert(false && "expected strided memref");
1065   return offset;
1066 }
1067 
1068 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1069 /// `offset` AffineExpr.
makeContiguousRowMajorMemRefType(MLIRContext * context,ArrayRef<int64_t> shape,Type elementType,AffineExpr offset)1070 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
1071                                                    ArrayRef<int64_t> shape,
1072                                                    Type elementType,
1073                                                    AffineExpr offset) {
1074   AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1075   AffineExpr contiguousRowMajor = canonical + offset;
1076   AffineMap contiguousRowMajorMap =
1077       AffineMap::inferFromExprList({contiguousRowMajor})[0];
1078   return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1079 }
1080 
1081 /// Helper determining if a memref is static-shape and contiguous-row-major
1082 /// layout, while still allowing for an arbitrary offset (any static or
1083 /// dynamic value).
isStaticShapeAndContiguousRowMajor(MemRefType memrefType)1084 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1085   if (!memrefType.hasStaticShape())
1086     return false;
1087   AffineExpr offset = getOffsetExpr(memrefType);
1088   MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1089       memrefType.getContext(), memrefType.getShape(),
1090       memrefType.getElementType(), offset);
1091   return canonicalizeStridedLayout(memrefType) ==
1092          canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1093 }
1094