1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/FunctionInterfaces.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/TypeSwitch.h"
25
26 using namespace mlir;
27 using namespace mlir::detail;
28
29 //===----------------------------------------------------------------------===//
30 /// Tablegen Type Definitions
31 //===----------------------------------------------------------------------===//
32
33 #define GET_TYPEDEF_CLASSES
34 #include "mlir/IR/BuiltinTypes.cpp.inc"
35
36 //===----------------------------------------------------------------------===//
37 // BuiltinDialect
38 //===----------------------------------------------------------------------===//
39
registerTypes()40 void BuiltinDialect::registerTypes() {
41 addTypes<
42 #define GET_TYPEDEF_LIST
43 #include "mlir/IR/BuiltinTypes.cpp.inc"
44 >();
45 }
46
47 //===----------------------------------------------------------------------===//
48 /// ComplexType
49 //===----------------------------------------------------------------------===//
50
51 /// Verify the construction of an integer type.
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType)52 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
53 Type elementType) {
54 if (!elementType.isIntOrFloat())
55 return emitError() << "invalid element type for complex";
56 return success();
57 }
58
59 //===----------------------------------------------------------------------===//
60 // Integer Type
61 //===----------------------------------------------------------------------===//
62
63 // static constexpr must have a definition (until in C++17 and inline variable).
64 constexpr unsigned IntegerType::kMaxWidth;
65
66 /// Verify the construction of an integer type.
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned width,SignednessSemantics signedness)67 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
68 unsigned width,
69 SignednessSemantics signedness) {
70 if (width > IntegerType::kMaxWidth) {
71 return emitError() << "integer bitwidth is limited to "
72 << IntegerType::kMaxWidth << " bits";
73 }
74 return success();
75 }
76
getWidth() const77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78
getSignedness() const79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80 return getImpl()->signedness;
81 }
82
scaleElementBitwidth(unsigned scale)83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84 if (!scale)
85 return IntegerType();
86 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
87 }
88
89 //===----------------------------------------------------------------------===//
90 // Float Type
91 //===----------------------------------------------------------------------===//
92
getWidth()93 unsigned FloatType::getWidth() {
94 if (isa<Float16Type, BFloat16Type>())
95 return 16;
96 if (isa<Float32Type>())
97 return 32;
98 if (isa<Float64Type>())
99 return 64;
100 if (isa<Float80Type>())
101 return 80;
102 if (isa<Float128Type>())
103 return 128;
104 llvm_unreachable("unexpected float type");
105 }
106
107 /// Returns the floating semantics for the given type.
getFloatSemantics()108 const llvm::fltSemantics &FloatType::getFloatSemantics() {
109 if (isa<BFloat16Type>())
110 return APFloat::BFloat();
111 if (isa<Float16Type>())
112 return APFloat::IEEEhalf();
113 if (isa<Float32Type>())
114 return APFloat::IEEEsingle();
115 if (isa<Float64Type>())
116 return APFloat::IEEEdouble();
117 if (isa<Float80Type>())
118 return APFloat::x87DoubleExtended();
119 if (isa<Float128Type>())
120 return APFloat::IEEEquad();
121 llvm_unreachable("non-floating point type used");
122 }
123
scaleElementBitwidth(unsigned scale)124 FloatType FloatType::scaleElementBitwidth(unsigned scale) {
125 if (!scale)
126 return FloatType();
127 MLIRContext *ctx = getContext();
128 if (isF16() || isBF16()) {
129 if (scale == 2)
130 return FloatType::getF32(ctx);
131 if (scale == 4)
132 return FloatType::getF64(ctx);
133 }
134 if (isF32())
135 if (scale == 2)
136 return FloatType::getF64(ctx);
137 return FloatType();
138 }
139
getFPMantissaWidth()140 unsigned FloatType::getFPMantissaWidth() {
141 return APFloat::semanticsPrecision(getFloatSemantics());
142 }
143
144 //===----------------------------------------------------------------------===//
145 // FunctionType
146 //===----------------------------------------------------------------------===//
147
getNumInputs() const148 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
149
getInputs() const150 ArrayRef<Type> FunctionType::getInputs() const {
151 return getImpl()->getInputs();
152 }
153
getNumResults() const154 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
155
getResults() const156 ArrayRef<Type> FunctionType::getResults() const {
157 return getImpl()->getResults();
158 }
159
clone(TypeRange inputs,TypeRange results) const160 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
161 return get(getContext(), inputs, results);
162 }
163
164 /// Returns a new function type with the specified arguments and results
165 /// inserted.
getWithArgsAndResults(ArrayRef<unsigned> argIndices,TypeRange argTypes,ArrayRef<unsigned> resultIndices,TypeRange resultTypes)166 FunctionType FunctionType::getWithArgsAndResults(
167 ArrayRef<unsigned> argIndices, TypeRange argTypes,
168 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
169 SmallVector<Type> argStorage, resultStorage;
170 TypeRange newArgTypes = function_interface_impl::insertTypesInto(
171 getInputs(), argIndices, argTypes, argStorage);
172 TypeRange newResultTypes = function_interface_impl::insertTypesInto(
173 getResults(), resultIndices, resultTypes, resultStorage);
174 return clone(newArgTypes, newResultTypes);
175 }
176
177 /// Returns a new function type without the specified arguments and results.
178 FunctionType
getWithoutArgsAndResults(const BitVector & argIndices,const BitVector & resultIndices)179 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
180 const BitVector &resultIndices) {
181 SmallVector<Type> argStorage, resultStorage;
182 TypeRange newArgTypes = function_interface_impl::filterTypesOut(
183 getInputs(), argIndices, argStorage);
184 TypeRange newResultTypes = function_interface_impl::filterTypesOut(
185 getResults(), resultIndices, resultStorage);
186 return clone(newArgTypes, newResultTypes);
187 }
188
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const189 void FunctionType::walkImmediateSubElements(
190 function_ref<void(Attribute)> walkAttrsFn,
191 function_ref<void(Type)> walkTypesFn) const {
192 for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
193 walkTypesFn(type);
194 }
195
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const196 Type FunctionType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
197 ArrayRef<Type> replTypes) const {
198 unsigned numInputs = getNumInputs();
199 return get(getContext(), replTypes.take_front(numInputs),
200 replTypes.drop_front(numInputs));
201 }
202
203 //===----------------------------------------------------------------------===//
204 // OpaqueType
205 //===----------------------------------------------------------------------===//
206
207 /// Verify the construction of an opaque type.
verify(function_ref<InFlightDiagnostic ()> emitError,StringAttr dialect,StringRef typeData)208 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
209 StringAttr dialect, StringRef typeData) {
210 if (!Dialect::isValidNamespace(dialect.strref()))
211 return emitError() << "invalid dialect namespace '" << dialect << "'";
212
213 // Check that the dialect is actually registered.
214 MLIRContext *context = dialect.getContext();
215 if (!context->allowsUnregisteredDialects() &&
216 !context->getLoadedDialect(dialect.strref())) {
217 return emitError()
218 << "`!" << dialect << "<\"" << typeData << "\">"
219 << "` type created with unregistered dialect. If this is "
220 "intended, please call allowUnregisteredDialects() on the "
221 "MLIRContext, or use -allow-unregistered-dialect with "
222 "the MLIR opt tool used";
223 }
224
225 return success();
226 }
227
228 //===----------------------------------------------------------------------===//
229 // VectorType
230 //===----------------------------------------------------------------------===//
231
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,unsigned numScalableDims)232 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
233 ArrayRef<int64_t> shape, Type elementType,
234 unsigned numScalableDims) {
235 if (!isValidElementType(elementType))
236 return emitError()
237 << "vector elements must be int/index/float type but got "
238 << elementType;
239
240 if (any_of(shape, [](int64_t i) { return i <= 0; }))
241 return emitError()
242 << "vector types must have positive constant sizes but got "
243 << shape;
244
245 return success();
246 }
247
scaleElementBitwidth(unsigned scale)248 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
249 if (!scale)
250 return VectorType();
251 if (auto et = getElementType().dyn_cast<IntegerType>())
252 if (auto scaledEt = et.scaleElementBitwidth(scale))
253 return VectorType::get(getShape(), scaledEt, getNumScalableDims());
254 if (auto et = getElementType().dyn_cast<FloatType>())
255 if (auto scaledEt = et.scaleElementBitwidth(scale))
256 return VectorType::get(getShape(), scaledEt, getNumScalableDims());
257 return VectorType();
258 }
259
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const260 void VectorType::walkImmediateSubElements(
261 function_ref<void(Attribute)> walkAttrsFn,
262 function_ref<void(Type)> walkTypesFn) const {
263 walkTypesFn(getElementType());
264 }
265
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const266 Type VectorType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
267 ArrayRef<Type> replTypes) const {
268 return get(getShape(), replTypes.front(), getNumScalableDims());
269 }
270
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const271 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
272 Type elementType) const {
273 return VectorType::get(shape.value_or(getShape()), elementType,
274 getNumScalableDims());
275 }
276
277 //===----------------------------------------------------------------------===//
278 // TensorType
279 //===----------------------------------------------------------------------===//
280
getElementType() const281 Type TensorType::getElementType() const {
282 return llvm::TypeSwitch<TensorType, Type>(*this)
283 .Case<RankedTensorType, UnrankedTensorType>(
284 [](auto type) { return type.getElementType(); });
285 }
286
hasRank() const287 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
288
getShape() const289 ArrayRef<int64_t> TensorType::getShape() const {
290 return cast<RankedTensorType>().getShape();
291 }
292
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const293 TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
294 Type elementType) const {
295 if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
296 if (shape)
297 return RankedTensorType::get(*shape, elementType);
298 return UnrankedTensorType::get(elementType);
299 }
300
301 auto rankedTy = cast<RankedTensorType>();
302 if (!shape)
303 return RankedTensorType::get(rankedTy.getShape(), elementType,
304 rankedTy.getEncoding());
305 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
306 rankedTy.getEncoding());
307 }
308
309 // Check if "elementType" can be an element type of a tensor.
310 static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic ()> emitError,Type elementType)311 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
312 Type elementType) {
313 if (!TensorType::isValidElementType(elementType))
314 return emitError() << "invalid tensor element type: " << elementType;
315 return success();
316 }
317
318 /// Return true if the specified element type is ok in a tensor.
isValidElementType(Type type)319 bool TensorType::isValidElementType(Type type) {
320 // Note: Non standard/builtin types are allowed to exist within tensor
321 // types. Dialects are expected to verify that tensor types have a valid
322 // element type within that dialect.
323 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
324 IndexType>() ||
325 !llvm::isa<BuiltinDialect>(type.getDialect());
326 }
327
328 //===----------------------------------------------------------------------===//
329 // RankedTensorType
330 //===----------------------------------------------------------------------===//
331
332 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,Attribute encoding)333 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
334 ArrayRef<int64_t> shape, Type elementType,
335 Attribute encoding) {
336 for (int64_t s : shape)
337 if (s < -1)
338 return emitError() << "invalid tensor dimension size";
339 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
340 if (failed(v.verifyEncoding(shape, elementType, emitError)))
341 return failure();
342 return checkTensorElementType(emitError, elementType);
343 }
344
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const345 void RankedTensorType::walkImmediateSubElements(
346 function_ref<void(Attribute)> walkAttrsFn,
347 function_ref<void(Type)> walkTypesFn) const {
348 walkTypesFn(getElementType());
349 if (Attribute encoding = getEncoding())
350 walkAttrsFn(encoding);
351 }
352
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const353 Type RankedTensorType::replaceImmediateSubElements(
354 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
355 return get(getShape(), replTypes.front(),
356 replAttrs.empty() ? Attribute() : replAttrs.back());
357 }
358
359 //===----------------------------------------------------------------------===//
360 // UnrankedTensorType
361 //===----------------------------------------------------------------------===//
362
363 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType)364 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
365 Type elementType) {
366 return checkTensorElementType(emitError, elementType);
367 }
368
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const369 void UnrankedTensorType::walkImmediateSubElements(
370 function_ref<void(Attribute)> walkAttrsFn,
371 function_ref<void(Type)> walkTypesFn) const {
372 walkTypesFn(getElementType());
373 }
374
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const375 Type UnrankedTensorType::replaceImmediateSubElements(
376 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
377 return get(replTypes.front());
378 }
379
380 //===----------------------------------------------------------------------===//
381 // BaseMemRefType
382 //===----------------------------------------------------------------------===//
383
getElementType() const384 Type BaseMemRefType::getElementType() const {
385 return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
386 .Case<MemRefType, UnrankedMemRefType>(
387 [](auto type) { return type.getElementType(); });
388 }
389
hasRank() const390 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
391
getShape() const392 ArrayRef<int64_t> BaseMemRefType::getShape() const {
393 return cast<MemRefType>().getShape();
394 }
395
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const396 BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape,
397 Type elementType) const {
398 if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
399 if (!shape)
400 return UnrankedMemRefType::get(elementType, getMemorySpace());
401 MemRefType::Builder builder(*shape, elementType);
402 builder.setMemorySpace(getMemorySpace());
403 return builder;
404 }
405
406 MemRefType::Builder builder(cast<MemRefType>());
407 if (shape)
408 builder.setShape(*shape);
409 builder.setElementType(elementType);
410 return builder;
411 }
412
getMemorySpace() const413 Attribute BaseMemRefType::getMemorySpace() const {
414 if (auto rankedMemRefTy = dyn_cast<MemRefType>())
415 return rankedMemRefTy.getMemorySpace();
416 return cast<UnrankedMemRefType>().getMemorySpace();
417 }
418
getMemorySpaceAsInt() const419 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
420 if (auto rankedMemRefTy = dyn_cast<MemRefType>())
421 return rankedMemRefTy.getMemorySpaceAsInt();
422 return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
423 }
424
425 //===----------------------------------------------------------------------===//
426 // MemRefType
427 //===----------------------------------------------------------------------===//
428
429 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
430 /// `originalShape` with some `1` entries erased, return the set of indices
431 /// that specifies which of the entries of `originalShape` are dropped to obtain
432 /// `reducedShape`. The returned mask can be applied as a projection to
433 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
434 /// which dimensions must be kept when e.g. compute MemRef strides under
435 /// rank-reducing operations. Return None if reducedShape cannot be obtained
436 /// by dropping only `1` entries in `originalShape`.
437 llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> reducedShape)438 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
439 ArrayRef<int64_t> reducedShape) {
440 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
441 llvm::SmallDenseSet<unsigned> unusedDims;
442 unsigned reducedIdx = 0;
443 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
444 // Greedily insert `originalIdx` if match.
445 if (reducedIdx < reducedRank &&
446 originalShape[originalIdx] == reducedShape[reducedIdx]) {
447 reducedIdx++;
448 continue;
449 }
450
451 unusedDims.insert(originalIdx);
452 // If no match on `originalIdx`, the `originalShape` at this dimension
453 // must be 1, otherwise we bail.
454 if (originalShape[originalIdx] != 1)
455 return llvm::None;
456 }
457 // The whole reducedShape must be scanned, otherwise we bail.
458 if (reducedIdx != reducedRank)
459 return llvm::None;
460 return unusedDims;
461 }
462
463 SliceVerificationResult
isRankReducedType(ShapedType originalType,ShapedType candidateReducedType)464 mlir::isRankReducedType(ShapedType originalType,
465 ShapedType candidateReducedType) {
466 if (originalType == candidateReducedType)
467 return SliceVerificationResult::Success;
468
469 ShapedType originalShapedType = originalType.cast<ShapedType>();
470 ShapedType candidateReducedShapedType =
471 candidateReducedType.cast<ShapedType>();
472
473 // Rank and size logic is valid for all ShapedTypes.
474 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
475 ArrayRef<int64_t> candidateReducedShape =
476 candidateReducedShapedType.getShape();
477 unsigned originalRank = originalShape.size(),
478 candidateReducedRank = candidateReducedShape.size();
479 if (candidateReducedRank > originalRank)
480 return SliceVerificationResult::RankTooLarge;
481
482 auto optionalUnusedDimsMask =
483 computeRankReductionMask(originalShape, candidateReducedShape);
484
485 // Sizes cannot be matched in case empty vector is returned.
486 if (!optionalUnusedDimsMask)
487 return SliceVerificationResult::SizeMismatch;
488
489 if (originalShapedType.getElementType() !=
490 candidateReducedShapedType.getElementType())
491 return SliceVerificationResult::ElemTypeMismatch;
492
493 return SliceVerificationResult::Success;
494 }
495
isSupportedMemorySpace(Attribute memorySpace)496 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
497 // Empty attribute is allowed as default memory space.
498 if (!memorySpace)
499 return true;
500
501 // Supported built-in attributes.
502 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
503 return true;
504
505 // Allow custom dialect attributes.
506 if (!isa<BuiltinDialect>(memorySpace.getDialect()))
507 return true;
508
509 return false;
510 }
511
wrapIntegerMemorySpace(unsigned memorySpace,MLIRContext * ctx)512 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
513 MLIRContext *ctx) {
514 if (memorySpace == 0)
515 return nullptr;
516
517 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
518 }
519
skipDefaultMemorySpace(Attribute memorySpace)520 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
521 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
522 if (intMemorySpace && intMemorySpace.getValue() == 0)
523 return nullptr;
524
525 return memorySpace;
526 }
527
getMemorySpaceAsInt(Attribute memorySpace)528 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
529 if (!memorySpace)
530 return 0;
531
532 assert(memorySpace.isa<IntegerAttr>() &&
533 "Using `getMemorySpaceInteger` with non-Integer attribute");
534
535 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
536 }
537
538 MemRefType::Builder &
setMemorySpace(unsigned newMemorySpace)539 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
540 memorySpace =
541 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
542 return *this;
543 }
544
getMemorySpaceAsInt() const545 unsigned MemRefType::getMemorySpaceAsInt() const {
546 return detail::getMemorySpaceAsInt(getMemorySpace());
547 }
548
get(ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)549 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
550 MemRefLayoutAttrInterface layout,
551 Attribute memorySpace) {
552 // Use default layout for empty attribute.
553 if (!layout)
554 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
555 shape.size(), elementType.getContext()));
556
557 // Drop default memory space value and replace it with empty attribute.
558 memorySpace = skipDefaultMemorySpace(memorySpace);
559
560 return Base::get(elementType.getContext(), shape, elementType, layout,
561 memorySpace);
562 }
563
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)564 MemRefType MemRefType::getChecked(
565 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
566 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
567
568 // Use default layout for empty attribute.
569 if (!layout)
570 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
571 shape.size(), elementType.getContext()));
572
573 // Drop default memory space value and replace it with empty attribute.
574 memorySpace = skipDefaultMemorySpace(memorySpace);
575
576 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
577 elementType, layout, memorySpace);
578 }
579
get(ArrayRef<int64_t> shape,Type elementType,AffineMap map,Attribute memorySpace)580 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
581 AffineMap map, Attribute memorySpace) {
582
583 // Use default layout for empty map.
584 if (!map)
585 map = AffineMap::getMultiDimIdentityMap(shape.size(),
586 elementType.getContext());
587
588 // Wrap AffineMap into Attribute.
589 Attribute layout = AffineMapAttr::get(map);
590
591 // Drop default memory space value and replace it with empty attribute.
592 memorySpace = skipDefaultMemorySpace(memorySpace);
593
594 return Base::get(elementType.getContext(), shape, elementType, layout,
595 memorySpace);
596 }
597
598 MemRefType
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,AffineMap map,Attribute memorySpace)599 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
600 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
601 Attribute memorySpace) {
602
603 // Use default layout for empty map.
604 if (!map)
605 map = AffineMap::getMultiDimIdentityMap(shape.size(),
606 elementType.getContext());
607
608 // Wrap AffineMap into Attribute.
609 Attribute layout = AffineMapAttr::get(map);
610
611 // Drop default memory space value and replace it with empty attribute.
612 memorySpace = skipDefaultMemorySpace(memorySpace);
613
614 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
615 elementType, layout, memorySpace);
616 }
617
get(ArrayRef<int64_t> shape,Type elementType,AffineMap map,unsigned memorySpaceInd)618 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
619 AffineMap map, unsigned memorySpaceInd) {
620
621 // Use default layout for empty map.
622 if (!map)
623 map = AffineMap::getMultiDimIdentityMap(shape.size(),
624 elementType.getContext());
625
626 // Wrap AffineMap into Attribute.
627 Attribute layout = AffineMapAttr::get(map);
628
629 // Convert deprecated integer-like memory space to Attribute.
630 Attribute memorySpace =
631 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
632
633 return Base::get(elementType.getContext(), shape, elementType, layout,
634 memorySpace);
635 }
636
637 MemRefType
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,AffineMap map,unsigned memorySpaceInd)638 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
639 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
640 unsigned memorySpaceInd) {
641
642 // Use default layout for empty map.
643 if (!map)
644 map = AffineMap::getMultiDimIdentityMap(shape.size(),
645 elementType.getContext());
646
647 // Wrap AffineMap into Attribute.
648 Attribute layout = AffineMapAttr::get(map);
649
650 // Convert deprecated integer-like memory space to Attribute.
651 Attribute memorySpace =
652 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
653
654 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
655 elementType, layout, memorySpace);
656 }
657
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)658 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
659 ArrayRef<int64_t> shape, Type elementType,
660 MemRefLayoutAttrInterface layout,
661 Attribute memorySpace) {
662 if (!BaseMemRefType::isValidElementType(elementType))
663 return emitError() << "invalid memref element type";
664
665 // Negative sizes are not allowed except for `-1` that means dynamic size.
666 for (int64_t s : shape)
667 if (s < -1)
668 return emitError() << "invalid memref size";
669
670 assert(layout && "missing layout specification");
671 if (failed(layout.verifyLayout(shape, emitError)))
672 return failure();
673
674 if (!isSupportedMemorySpace(memorySpace))
675 return emitError() << "unsupported memory space Attribute";
676
677 return success();
678 }
679
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const680 void MemRefType::walkImmediateSubElements(
681 function_ref<void(Attribute)> walkAttrsFn,
682 function_ref<void(Type)> walkTypesFn) const {
683 walkTypesFn(getElementType());
684 if (!getLayout().isIdentity())
685 walkAttrsFn(getLayout());
686 walkAttrsFn(getMemorySpace());
687 }
688
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const689 Type MemRefType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
690 ArrayRef<Type> replTypes) const {
691 bool hasLayout = replAttrs.size() > 1;
692 return get(getShape(), replTypes[0],
693 hasLayout ? replAttrs[0].dyn_cast<MemRefLayoutAttrInterface>()
694 : MemRefLayoutAttrInterface(),
695 hasLayout ? replAttrs[1] : replAttrs[0]);
696 }
697
698 //===----------------------------------------------------------------------===//
699 // UnrankedMemRefType
700 //===----------------------------------------------------------------------===//
701
getMemorySpaceAsInt() const702 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
703 return detail::getMemorySpaceAsInt(getMemorySpace());
704 }
705
706 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType,Attribute memorySpace)707 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
708 Type elementType, Attribute memorySpace) {
709 if (!BaseMemRefType::isValidElementType(elementType))
710 return emitError() << "invalid memref element type";
711
712 if (!isSupportedMemorySpace(memorySpace))
713 return emitError() << "unsupported memory space Attribute";
714
715 return success();
716 }
717
718 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
719 // i.e. single term). Accumulate the AffineExpr into the existing one.
extractStridesFromTerm(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)720 static void extractStridesFromTerm(AffineExpr e,
721 AffineExpr multiplicativeFactor,
722 MutableArrayRef<AffineExpr> strides,
723 AffineExpr &offset) {
724 if (auto dim = e.dyn_cast<AffineDimExpr>())
725 strides[dim.getPosition()] =
726 strides[dim.getPosition()] + multiplicativeFactor;
727 else
728 offset = offset + e * multiplicativeFactor;
729 }
730
731 /// Takes a single AffineExpr `e` and populates the `strides` array with the
732 /// strides expressions for each dim position.
733 /// The convention is that the strides for dimensions d0, .. dn appear in
734 /// order to make indexing intuitive into the result.
extractStrides(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)735 static LogicalResult extractStrides(AffineExpr e,
736 AffineExpr multiplicativeFactor,
737 MutableArrayRef<AffineExpr> strides,
738 AffineExpr &offset) {
739 auto bin = e.dyn_cast<AffineBinaryOpExpr>();
740 if (!bin) {
741 extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
742 return success();
743 }
744
745 if (bin.getKind() == AffineExprKind::CeilDiv ||
746 bin.getKind() == AffineExprKind::FloorDiv ||
747 bin.getKind() == AffineExprKind::Mod)
748 return failure();
749
750 if (bin.getKind() == AffineExprKind::Mul) {
751 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
752 if (dim) {
753 strides[dim.getPosition()] =
754 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
755 return success();
756 }
757 // LHS and RHS may both contain complex expressions of dims. Try one path
758 // and if it fails try the other. This is guaranteed to succeed because
759 // only one path may have a `dim`, otherwise this is not an AffineExpr in
760 // the first place.
761 if (bin.getLHS().isSymbolicOrConstant())
762 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
763 strides, offset);
764 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
765 strides, offset);
766 }
767
768 if (bin.getKind() == AffineExprKind::Add) {
769 auto res1 =
770 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
771 auto res2 =
772 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
773 return success(succeeded(res1) && succeeded(res2));
774 }
775
776 llvm_unreachable("unexpected binary operation");
777 }
778
getStridesAndOffset(MemRefType t,SmallVectorImpl<AffineExpr> & strides,AffineExpr & offset)779 LogicalResult mlir::getStridesAndOffset(MemRefType t,
780 SmallVectorImpl<AffineExpr> &strides,
781 AffineExpr &offset) {
782 AffineMap m = t.getLayout().getAffineMap();
783
784 if (m.getNumResults() != 1 && !m.isIdentity())
785 return failure();
786
787 auto zero = getAffineConstantExpr(0, t.getContext());
788 auto one = getAffineConstantExpr(1, t.getContext());
789 offset = zero;
790 strides.assign(t.getRank(), zero);
791
792 // Canonical case for empty map.
793 if (m.isIdentity()) {
794 // 0-D corner case, offset is already 0.
795 if (t.getRank() == 0)
796 return success();
797 auto stridedExpr =
798 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
799 if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
800 return success();
801 assert(false && "unexpected failure: extract strides in canonical layout");
802 }
803
804 // Non-canonical case requires more work.
805 auto stridedExpr =
806 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
807 if (failed(extractStrides(stridedExpr, one, strides, offset))) {
808 offset = AffineExpr();
809 strides.clear();
810 return failure();
811 }
812
813 // Simplify results to allow folding to constants and simple checks.
814 unsigned numDims = m.getNumDims();
815 unsigned numSymbols = m.getNumSymbols();
816 offset = simplifyAffineExpr(offset, numDims, numSymbols);
817 for (auto &stride : strides)
818 stride = simplifyAffineExpr(stride, numDims, numSymbols);
819
820 /// In practice, a strided memref must be internally non-aliasing. Test
821 /// against 0 as a proxy.
822 /// TODO: static cases can have more advanced checks.
823 /// TODO: dynamic cases would require a way to compare symbolic
824 /// expressions and would probably need an affine set context propagated
825 /// everywhere.
826 if (llvm::any_of(strides, [](AffineExpr e) {
827 return e == getAffineConstantExpr(0, e.getContext());
828 })) {
829 offset = AffineExpr();
830 strides.clear();
831 return failure();
832 }
833
834 return success();
835 }
836
getStridesAndOffset(MemRefType t,SmallVectorImpl<int64_t> & strides,int64_t & offset)837 LogicalResult mlir::getStridesAndOffset(MemRefType t,
838 SmallVectorImpl<int64_t> &strides,
839 int64_t &offset) {
840 AffineExpr offsetExpr;
841 SmallVector<AffineExpr, 4> strideExprs;
842 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
843 return failure();
844 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
845 offset = cst.getValue();
846 else
847 offset = ShapedType::kDynamicStrideOrOffset;
848 for (auto e : strideExprs) {
849 if (auto c = e.dyn_cast<AffineConstantExpr>())
850 strides.push_back(c.getValue());
851 else
852 strides.push_back(ShapedType::kDynamicStrideOrOffset);
853 }
854 return success();
855 }
856
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const857 void UnrankedMemRefType::walkImmediateSubElements(
858 function_ref<void(Attribute)> walkAttrsFn,
859 function_ref<void(Type)> walkTypesFn) const {
860 walkTypesFn(getElementType());
861 walkAttrsFn(getMemorySpace());
862 }
863
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const864 Type UnrankedMemRefType::replaceImmediateSubElements(
865 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
866 return get(replTypes.front(), replAttrs.front());
867 }
868
869 //===----------------------------------------------------------------------===//
870 /// TupleType
871 //===----------------------------------------------------------------------===//
872
873 /// Return the elements types for this tuple.
getTypes() const874 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
875
876 /// Accumulate the types contained in this tuple and tuples nested within it.
877 /// Note that this only flattens nested tuples, not any other container type,
878 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
879 /// (i32, tensor<i32>, f32, i64)
getFlattenedTypes(SmallVectorImpl<Type> & types)880 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
881 for (Type type : getTypes()) {
882 if (auto nestedTuple = type.dyn_cast<TupleType>())
883 nestedTuple.getFlattenedTypes(types);
884 else
885 types.push_back(type);
886 }
887 }
888
889 /// Return the number of element types.
size() const890 size_t TupleType::size() const { return getImpl()->size(); }
891
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const892 void TupleType::walkImmediateSubElements(
893 function_ref<void(Attribute)> walkAttrsFn,
894 function_ref<void(Type)> walkTypesFn) const {
895 for (Type type : getTypes())
896 walkTypesFn(type);
897 }
898
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const899 Type TupleType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
900 ArrayRef<Type> replTypes) const {
901 return get(getContext(), replTypes);
902 }
903
904 //===----------------------------------------------------------------------===//
905 // Type Utilities
906 //===----------------------------------------------------------------------===//
907
makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,int64_t offset,MLIRContext * context)908 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
909 int64_t offset,
910 MLIRContext *context) {
911 AffineExpr expr;
912 unsigned nSymbols = 0;
913
914 // AffineExpr for offset.
915 // Static case.
916 if (offset != MemRefType::getDynamicStrideOrOffset()) {
917 auto cst = getAffineConstantExpr(offset, context);
918 expr = cst;
919 } else {
920 // Dynamic case, new symbol for the offset.
921 auto sym = getAffineSymbolExpr(nSymbols++, context);
922 expr = sym;
923 }
924
925 // AffineExpr for strides.
926 for (const auto &en : llvm::enumerate(strides)) {
927 auto dim = en.index();
928 auto stride = en.value();
929 assert(stride != 0 && "Invalid stride specification");
930 auto d = getAffineDimExpr(dim, context);
931 AffineExpr mult;
932 // Static case.
933 if (stride != MemRefType::getDynamicStrideOrOffset())
934 mult = getAffineConstantExpr(stride, context);
935 else
936 // Dynamic case, new symbol for each new stride.
937 mult = getAffineSymbolExpr(nSymbols++, context);
938 expr = expr + d * mult;
939 }
940
941 return AffineMap::get(strides.size(), nSymbols, expr);
942 }
943
944 /// Return a version of `t` with identity layout if it can be determined
945 /// statically that the layout is the canonical contiguous strided layout.
946 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
947 /// `t` with simplified layout.
948 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
canonicalizeStridedLayout(MemRefType t)949 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
950 AffineMap m = t.getLayout().getAffineMap();
951
952 // Already in canonical form.
953 if (m.isIdentity())
954 return t;
955
956 // Can't reduce to canonical identity form, return in canonical form.
957 if (m.getNumResults() > 1)
958 return t;
959
960 // Corner-case for 0-D affine maps.
961 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
962 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
963 if (cst.getValue() == 0)
964 return MemRefType::Builder(t).setLayout({});
965 return t;
966 }
967
968 // 0-D corner case for empty shape that still have an affine map. Example:
969 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
970 // offset needs to remain, just return t.
971 if (t.getShape().empty())
972 return t;
973
974 // If the canonical strided layout for the sizes of `t` is equal to the
975 // simplified layout of `t` we can just return an empty layout. Otherwise,
976 // just simplify the existing layout.
977 AffineExpr expr =
978 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
979 auto simplifiedLayoutExpr =
980 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
981 if (expr != simplifiedLayoutExpr)
982 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
983 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
984 return MemRefType::Builder(t).setLayout({});
985 }
986
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> exprs,MLIRContext * context)987 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
988 ArrayRef<AffineExpr> exprs,
989 MLIRContext *context) {
990 // Size 0 corner case is useful for canonicalizations.
991 if (sizes.empty() || llvm::is_contained(sizes, 0))
992 return getAffineConstantExpr(0, context);
993
994 assert(!exprs.empty() && "expected exprs");
995 auto maps = AffineMap::inferFromExprList(exprs);
996 assert(!maps.empty() && "Expected one non-empty map");
997 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
998
999 AffineExpr expr;
1000 bool dynamicPoisonBit = false;
1001 int64_t runningSize = 1;
1002 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
1003 int64_t size = std::get<1>(en);
1004 // Degenerate case, no size =-> no stride
1005 if (size == 0)
1006 continue;
1007 AffineExpr dimExpr = std::get<0>(en);
1008 AffineExpr stride = dynamicPoisonBit
1009 ? getAffineSymbolExpr(nSymbols++, context)
1010 : getAffineConstantExpr(runningSize, context);
1011 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
1012 if (size > 0) {
1013 runningSize *= size;
1014 assert(runningSize > 0 && "integer overflow in size computation");
1015 } else {
1016 dynamicPoisonBit = true;
1017 }
1018 }
1019 return simplifyAffineExpr(expr, numDims, nSymbols);
1020 }
1021
1022 /// Return a version of `t` with a layout that has all dynamic offset and
1023 /// strides. This is used to erase the static layout.
eraseStridedLayout(MemRefType t)1024 MemRefType mlir::eraseStridedLayout(MemRefType t) {
1025 auto val = ShapedType::kDynamicStrideOrOffset;
1026 return MemRefType::Builder(t).setLayout(
1027 AffineMapAttr::get(makeStridedLinearLayoutMap(
1028 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
1029 }
1030
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,MLIRContext * context)1031 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
1032 MLIRContext *context) {
1033 SmallVector<AffineExpr, 4> exprs;
1034 exprs.reserve(sizes.size());
1035 for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
1036 exprs.push_back(getAffineDimExpr(dim, context));
1037 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
1038 }
1039
1040 /// Return true if the layout for `t` is compatible with strided semantics.
isStrided(MemRefType t)1041 bool mlir::isStrided(MemRefType t) {
1042 int64_t offset;
1043 SmallVector<int64_t, 4> strides;
1044 auto res = getStridesAndOffset(t, strides, offset);
1045 return succeeded(res);
1046 }
1047
1048 /// Return the layout map in strided linear layout AffineMap form.
1049 /// Return null if the layout is not compatible with a strided layout.
getStridedLinearLayoutMap(MemRefType t)1050 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
1051 int64_t offset;
1052 SmallVector<int64_t, 4> strides;
1053 if (failed(getStridesAndOffset(t, strides, offset)))
1054 return AffineMap();
1055 return makeStridedLinearLayoutMap(strides, offset, t.getContext());
1056 }
1057
1058 /// Return the AffineExpr representation of the offset, assuming `memRefType`
1059 /// is a strided memref.
getOffsetExpr(MemRefType memrefType)1060 static AffineExpr getOffsetExpr(MemRefType memrefType) {
1061 SmallVector<AffineExpr> strides;
1062 AffineExpr offset;
1063 if (failed(getStridesAndOffset(memrefType, strides, offset)))
1064 assert(false && "expected strided memref");
1065 return offset;
1066 }
1067
1068 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1069 /// `offset` AffineExpr.
makeContiguousRowMajorMemRefType(MLIRContext * context,ArrayRef<int64_t> shape,Type elementType,AffineExpr offset)1070 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
1071 ArrayRef<int64_t> shape,
1072 Type elementType,
1073 AffineExpr offset) {
1074 AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1075 AffineExpr contiguousRowMajor = canonical + offset;
1076 AffineMap contiguousRowMajorMap =
1077 AffineMap::inferFromExprList({contiguousRowMajor})[0];
1078 return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1079 }
1080
1081 /// Helper determining if a memref is static-shape and contiguous-row-major
1082 /// layout, while still allowing for an arbitrary offset (any static or
1083 /// dynamic value).
isStaticShapeAndContiguousRowMajor(MemRefType memrefType)1084 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1085 if (!memrefType.hasStaticShape())
1086 return false;
1087 AffineExpr offset = getOffsetExpr(memrefType);
1088 MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1089 memrefType.getContext(), memrefType.getShape(),
1090 memrefType.getElementType(), offset);
1091 return canonicalizeStridedLayout(memrefType) ==
1092 canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1093 }
1094