109f7a55fSRiver Riddle //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
209f7a55fSRiver Riddle //
309f7a55fSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409f7a55fSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
509f7a55fSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609f7a55fSRiver Riddle //
709f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
809f7a55fSRiver Riddle 
909f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
1009f7a55fSRiver Riddle #include "TypeDetail.h"
1109f7a55fSRiver Riddle #include "mlir/IR/AffineExpr.h"
1209f7a55fSRiver Riddle #include "mlir/IR/AffineMap.h"
13f3bf5c05SVladislav Vinogradov #include "mlir/IR/BuiltinAttributes.h"
14f3bf5c05SVladislav Vinogradov #include "mlir/IR/BuiltinDialect.h"
1509f7a55fSRiver Riddle #include "mlir/IR/Diagnostics.h"
1609f7a55fSRiver Riddle #include "mlir/IR/Dialect.h"
177ceffae1SRiver Riddle #include "mlir/IR/FunctionInterfaces.h"
18ee090870SMehdi Amini #include "mlir/IR/OpImplementation.h"
1923c9e8bcSAart Bik #include "mlir/IR/TensorEncoding.h"
2009f7a55fSRiver Riddle #include "llvm/ADT/APFloat.h"
2109f7a55fSRiver Riddle #include "llvm/ADT/BitVector.h"
22c7cae0e4SRiver Riddle #include "llvm/ADT/Sequence.h"
2309f7a55fSRiver Riddle #include "llvm/ADT/Twine.h"
240d01dfbcSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
2509f7a55fSRiver Riddle 
2609f7a55fSRiver Riddle using namespace mlir;
2709f7a55fSRiver Riddle using namespace mlir::detail;
2809f7a55fSRiver Riddle 
2909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
3095019de8SRiver Riddle /// Tablegen Type Definitions
3195019de8SRiver Riddle //===----------------------------------------------------------------------===//
3295019de8SRiver Riddle 
3395019de8SRiver Riddle #define GET_TYPEDEF_CLASSES
3495019de8SRiver Riddle #include "mlir/IR/BuiltinTypes.cpp.inc"
3595019de8SRiver Riddle 
3695019de8SRiver Riddle //===----------------------------------------------------------------------===//
3731bb8efdSRiver Riddle // BuiltinDialect
3831bb8efdSRiver Riddle //===----------------------------------------------------------------------===//
3931bb8efdSRiver Riddle 
registerTypes()4031bb8efdSRiver Riddle void BuiltinDialect::registerTypes() {
4131bb8efdSRiver Riddle   addTypes<
4231bb8efdSRiver Riddle #define GET_TYPEDEF_LIST
4331bb8efdSRiver Riddle #include "mlir/IR/BuiltinTypes.cpp.inc"
4431bb8efdSRiver Riddle       >();
4531bb8efdSRiver Riddle }
4631bb8efdSRiver Riddle 
4731bb8efdSRiver Riddle //===----------------------------------------------------------------------===//
4809f7a55fSRiver Riddle /// ComplexType
4909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
5009f7a55fSRiver Riddle 
5109f7a55fSRiver Riddle /// Verify the construction of an integer type.
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType)5206e25d56SRiver Riddle LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
5309f7a55fSRiver Riddle                                   Type elementType) {
5409f7a55fSRiver Riddle   if (!elementType.isIntOrFloat())
5506e25d56SRiver Riddle     return emitError() << "invalid element type for complex";
5609f7a55fSRiver Riddle   return success();
5709f7a55fSRiver Riddle }
5809f7a55fSRiver Riddle 
5909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
6009f7a55fSRiver Riddle // Integer Type
6109f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
6209f7a55fSRiver Riddle 
6309f7a55fSRiver Riddle // static constexpr must have a definition (until in C++17 and inline variable).
6409f7a55fSRiver Riddle constexpr unsigned IntegerType::kMaxWidth;
6509f7a55fSRiver Riddle 
6609f7a55fSRiver Riddle /// Verify the construction of an integer type.
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned width,SignednessSemantics signedness)6706e25d56SRiver Riddle LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
6806e25d56SRiver Riddle                                   unsigned width,
6909f7a55fSRiver Riddle                                   SignednessSemantics signedness) {
7009f7a55fSRiver Riddle   if (width > IntegerType::kMaxWidth) {
7106e25d56SRiver Riddle     return emitError() << "integer bitwidth is limited to "
7209f7a55fSRiver Riddle                        << IntegerType::kMaxWidth << " bits";
7309f7a55fSRiver Riddle   }
7409f7a55fSRiver Riddle   return success();
7509f7a55fSRiver Riddle }
7609f7a55fSRiver Riddle 
getWidth() const7709f7a55fSRiver Riddle unsigned IntegerType::getWidth() const { return getImpl()->width; }
7809f7a55fSRiver Riddle 
getSignedness() const7909f7a55fSRiver Riddle IntegerType::SignednessSemantics IntegerType::getSignedness() const {
8009f7a55fSRiver Riddle   return getImpl()->signedness;
8109f7a55fSRiver Riddle }
8209f7a55fSRiver Riddle 
scaleElementBitwidth(unsigned scale)837310501fSNicolas Vasilache IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
847310501fSNicolas Vasilache   if (!scale)
857310501fSNicolas Vasilache     return IntegerType();
861b97cdf8SRiver Riddle   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
877310501fSNicolas Vasilache }
887310501fSNicolas Vasilache 
8909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
9009f7a55fSRiver Riddle // Float Type
9109f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
9209f7a55fSRiver Riddle 
getWidth()9309f7a55fSRiver Riddle unsigned FloatType::getWidth() {
9409f7a55fSRiver Riddle   if (isa<Float16Type, BFloat16Type>())
9509f7a55fSRiver Riddle     return 16;
9609f7a55fSRiver Riddle   if (isa<Float32Type>())
9709f7a55fSRiver Riddle     return 32;
9809f7a55fSRiver Riddle   if (isa<Float64Type>())
9909f7a55fSRiver Riddle     return 64;
100cf0173deSValentin Clement   if (isa<Float80Type>())
101cf0173deSValentin Clement     return 80;
102cf0173deSValentin Clement   if (isa<Float128Type>())
103cf0173deSValentin Clement     return 128;
10409f7a55fSRiver Riddle   llvm_unreachable("unexpected float type");
10509f7a55fSRiver Riddle }
10609f7a55fSRiver Riddle 
10709f7a55fSRiver Riddle /// Returns the floating semantics for the given type.
getFloatSemantics()10809f7a55fSRiver Riddle const llvm::fltSemantics &FloatType::getFloatSemantics() {
10909f7a55fSRiver Riddle   if (isa<BFloat16Type>())
11009f7a55fSRiver Riddle     return APFloat::BFloat();
11109f7a55fSRiver Riddle   if (isa<Float16Type>())
11209f7a55fSRiver Riddle     return APFloat::IEEEhalf();
11309f7a55fSRiver Riddle   if (isa<Float32Type>())
11409f7a55fSRiver Riddle     return APFloat::IEEEsingle();
11509f7a55fSRiver Riddle   if (isa<Float64Type>())
11609f7a55fSRiver Riddle     return APFloat::IEEEdouble();
117cf0173deSValentin Clement   if (isa<Float80Type>())
118cf0173deSValentin Clement     return APFloat::x87DoubleExtended();
119cf0173deSValentin Clement   if (isa<Float128Type>())
120cf0173deSValentin Clement     return APFloat::IEEEquad();
12109f7a55fSRiver Riddle   llvm_unreachable("non-floating point type used");
12209f7a55fSRiver Riddle }
12309f7a55fSRiver Riddle 
scaleElementBitwidth(unsigned scale)1247310501fSNicolas Vasilache FloatType FloatType::scaleElementBitwidth(unsigned scale) {
1257310501fSNicolas Vasilache   if (!scale)
1267310501fSNicolas Vasilache     return FloatType();
1277310501fSNicolas Vasilache   MLIRContext *ctx = getContext();
1287310501fSNicolas Vasilache   if (isF16() || isBF16()) {
1297310501fSNicolas Vasilache     if (scale == 2)
1307310501fSNicolas Vasilache       return FloatType::getF32(ctx);
1317310501fSNicolas Vasilache     if (scale == 4)
1327310501fSNicolas Vasilache       return FloatType::getF64(ctx);
1337310501fSNicolas Vasilache   }
1347310501fSNicolas Vasilache   if (isF32())
1357310501fSNicolas Vasilache     if (scale == 2)
1367310501fSNicolas Vasilache       return FloatType::getF64(ctx);
1377310501fSNicolas Vasilache   return FloatType();
1387310501fSNicolas Vasilache }
1397310501fSNicolas Vasilache 
getFPMantissaWidth()1401b2a1f84SWilliam S. Moses unsigned FloatType::getFPMantissaWidth() {
1411b2a1f84SWilliam S. Moses   return APFloat::semanticsPrecision(getFloatSemantics());
1421b2a1f84SWilliam S. Moses }
1431b2a1f84SWilliam S. Moses 
14409f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
14509f7a55fSRiver Riddle // FunctionType
14609f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
14709f7a55fSRiver Riddle 
getNumInputs() const14809f7a55fSRiver Riddle unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
14909f7a55fSRiver Riddle 
getInputs() const15009f7a55fSRiver Riddle ArrayRef<Type> FunctionType::getInputs() const {
15109f7a55fSRiver Riddle   return getImpl()->getInputs();
15209f7a55fSRiver Riddle }
15309f7a55fSRiver Riddle 
getNumResults() const15409f7a55fSRiver Riddle unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
15509f7a55fSRiver Riddle 
getResults() const15609f7a55fSRiver Riddle ArrayRef<Type> FunctionType::getResults() const {
15709f7a55fSRiver Riddle   return getImpl()->getResults();
15809f7a55fSRiver Riddle }
15909f7a55fSRiver Riddle 
clone(TypeRange inputs,TypeRange results) const1607ceffae1SRiver Riddle FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
1617ceffae1SRiver Riddle   return get(getContext(), inputs, results);
16209f7a55fSRiver Riddle }
16309f7a55fSRiver Riddle 
1648066f22cSFabian Schuiki /// Returns a new function type with the specified arguments and results
1658066f22cSFabian Schuiki /// inserted.
getWithArgsAndResults(ArrayRef<unsigned> argIndices,TypeRange argTypes,ArrayRef<unsigned> resultIndices,TypeRange resultTypes)1668066f22cSFabian Schuiki FunctionType FunctionType::getWithArgsAndResults(
1678066f22cSFabian Schuiki     ArrayRef<unsigned> argIndices, TypeRange argTypes,
1688066f22cSFabian Schuiki     ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
1697ceffae1SRiver Riddle   SmallVector<Type> argStorage, resultStorage;
1707ceffae1SRiver Riddle   TypeRange newArgTypes = function_interface_impl::insertTypesInto(
1717ceffae1SRiver Riddle       getInputs(), argIndices, argTypes, argStorage);
1727ceffae1SRiver Riddle   TypeRange newResultTypes = function_interface_impl::insertTypesInto(
1737ceffae1SRiver Riddle       getResults(), resultIndices, resultTypes, resultStorage);
1747ceffae1SRiver Riddle   return clone(newArgTypes, newResultTypes);
1758066f22cSFabian Schuiki }
1768066f22cSFabian Schuiki 
17709f7a55fSRiver Riddle /// Returns a new function type without the specified arguments and results.
17809f7a55fSRiver Riddle FunctionType
getWithoutArgsAndResults(const BitVector & argIndices,const BitVector & resultIndices)179d10d49dcSRiver Riddle FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
180d10d49dcSRiver Riddle                                        const BitVector &resultIndices) {
1817ceffae1SRiver Riddle   SmallVector<Type> argStorage, resultStorage;
1827ceffae1SRiver Riddle   TypeRange newArgTypes = function_interface_impl::filterTypesOut(
1837ceffae1SRiver Riddle       getInputs(), argIndices, argStorage);
1847ceffae1SRiver Riddle   TypeRange newResultTypes = function_interface_impl::filterTypesOut(
1857ceffae1SRiver Riddle       getResults(), resultIndices, resultStorage);
1867ceffae1SRiver Riddle   return clone(newArgTypes, newResultTypes);
18709f7a55fSRiver Riddle }
18809f7a55fSRiver Riddle 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const189c42dd5dbSRiver Riddle void FunctionType::walkImmediateSubElements(
190c42dd5dbSRiver Riddle     function_ref<void(Attribute)> walkAttrsFn,
191c42dd5dbSRiver Riddle     function_ref<void(Type)> walkTypesFn) const {
192c42dd5dbSRiver Riddle   for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
193c42dd5dbSRiver Riddle     walkTypesFn(type);
194c42dd5dbSRiver Riddle }
195c42dd5dbSRiver Riddle 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const196*01eedbc7SRiver Riddle Type FunctionType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
197*01eedbc7SRiver Riddle                                                ArrayRef<Type> replTypes) const {
198*01eedbc7SRiver Riddle   unsigned numInputs = getNumInputs();
199*01eedbc7SRiver Riddle   return get(getContext(), replTypes.take_front(numInputs),
200*01eedbc7SRiver Riddle              replTypes.drop_front(numInputs));
201*01eedbc7SRiver Riddle }
202*01eedbc7SRiver Riddle 
20309f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
20409f7a55fSRiver Riddle // OpaqueType
20509f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
20609f7a55fSRiver Riddle 
20709f7a55fSRiver Riddle /// Verify the construction of an opaque type.
verify(function_ref<InFlightDiagnostic ()> emitError,StringAttr dialect,StringRef typeData)20806e25d56SRiver Riddle LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
209195730a6SRiver Riddle                                  StringAttr dialect, StringRef typeData) {
21009f7a55fSRiver Riddle   if (!Dialect::isValidNamespace(dialect.strref()))
21106e25d56SRiver Riddle     return emitError() << "invalid dialect namespace '" << dialect << "'";
212109305e1SRiver Riddle 
213109305e1SRiver Riddle   // Check that the dialect is actually registered.
214109305e1SRiver Riddle   MLIRContext *context = dialect.getContext();
215109305e1SRiver Riddle   if (!context->allowsUnregisteredDialects() &&
216109305e1SRiver Riddle       !context->getLoadedDialect(dialect.strref())) {
217109305e1SRiver Riddle     return emitError()
218109305e1SRiver Riddle            << "`!" << dialect << "<\"" << typeData << "\">"
219109305e1SRiver Riddle            << "` type created with unregistered dialect. If this is "
220109305e1SRiver Riddle               "intended, please call allowUnregisteredDialects() on the "
221109305e1SRiver Riddle               "MLIRContext, or use -allow-unregistered-dialect with "
2220f9e6451SMehdi Amini               "the MLIR opt tool used";
223109305e1SRiver Riddle   }
224109305e1SRiver Riddle 
22509f7a55fSRiver Riddle   return success();
22609f7a55fSRiver Riddle }
22709f7a55fSRiver Riddle 
22809f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
22909f7a55fSRiver Riddle // VectorType
23009f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
23109f7a55fSRiver Riddle 
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,unsigned numScalableDims)23206e25d56SRiver Riddle LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
233a4830d14SJavier Setoain                                  ArrayRef<int64_t> shape, Type elementType,
234a4830d14SJavier Setoain                                  unsigned numScalableDims) {
23509f7a55fSRiver Riddle   if (!isValidElementType(elementType))
236b096ac90SGeoffrey Martin-Noble     return emitError()
237b096ac90SGeoffrey Martin-Noble            << "vector elements must be int/index/float type but got "
238b096ac90SGeoffrey Martin-Noble            << elementType;
23909f7a55fSRiver Riddle 
24009f7a55fSRiver Riddle   if (any_of(shape, [](int64_t i) { return i <= 0; }))
241b096ac90SGeoffrey Martin-Noble     return emitError()
242b096ac90SGeoffrey Martin-Noble            << "vector types must have positive constant sizes but got "
243b096ac90SGeoffrey Martin-Noble            << shape;
24409f7a55fSRiver Riddle 
24509f7a55fSRiver Riddle   return success();
24609f7a55fSRiver Riddle }
24709f7a55fSRiver Riddle 
scaleElementBitwidth(unsigned scale)2487310501fSNicolas Vasilache VectorType VectorType::scaleElementBitwidth(unsigned scale) {
2497310501fSNicolas Vasilache   if (!scale)
2507310501fSNicolas Vasilache     return VectorType();
2517310501fSNicolas Vasilache   if (auto et = getElementType().dyn_cast<IntegerType>())
2527310501fSNicolas Vasilache     if (auto scaledEt = et.scaleElementBitwidth(scale))
253a4830d14SJavier Setoain       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
2547310501fSNicolas Vasilache   if (auto et = getElementType().dyn_cast<FloatType>())
2557310501fSNicolas Vasilache     if (auto scaledEt = et.scaleElementBitwidth(scale))
256a4830d14SJavier Setoain       return VectorType::get(getShape(), scaledEt, getNumScalableDims());
2577310501fSNicolas Vasilache   return VectorType();
2587310501fSNicolas Vasilache }
2597310501fSNicolas Vasilache 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const260c42dd5dbSRiver Riddle void VectorType::walkImmediateSubElements(
261c42dd5dbSRiver Riddle     function_ref<void(Attribute)> walkAttrsFn,
262c42dd5dbSRiver Riddle     function_ref<void(Type)> walkTypesFn) const {
263c42dd5dbSRiver Riddle   walkTypesFn(getElementType());
264c42dd5dbSRiver Riddle }
265c42dd5dbSRiver Riddle 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const266*01eedbc7SRiver Riddle Type VectorType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
267*01eedbc7SRiver Riddle                                              ArrayRef<Type> replTypes) const {
268*01eedbc7SRiver Riddle   return get(getShape(), replTypes.front(), getNumScalableDims());
269*01eedbc7SRiver Riddle }
270*01eedbc7SRiver Riddle 
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const271676bfb2aSRiver Riddle VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
272159898d5SAdrian Kuegel                                  Type elementType) const {
27330c67587SKazu Hirata   return VectorType::get(shape.value_or(getShape()), elementType,
274676bfb2aSRiver Riddle                          getNumScalableDims());
275676bfb2aSRiver Riddle }
276676bfb2aSRiver Riddle 
27709f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
27809f7a55fSRiver Riddle // TensorType
27909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
28009f7a55fSRiver Riddle 
getElementType() const281676bfb2aSRiver Riddle Type TensorType::getElementType() const {
282676bfb2aSRiver Riddle   return llvm::TypeSwitch<TensorType, Type>(*this)
283676bfb2aSRiver Riddle       .Case<RankedTensorType, UnrankedTensorType>(
284676bfb2aSRiver Riddle           [](auto type) { return type.getElementType(); });
285676bfb2aSRiver Riddle }
286676bfb2aSRiver Riddle 
hasRank() const287676bfb2aSRiver Riddle bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
288676bfb2aSRiver Riddle 
getShape() const289676bfb2aSRiver Riddle ArrayRef<int64_t> TensorType::getShape() const {
290676bfb2aSRiver Riddle   return cast<RankedTensorType>().getShape();
291676bfb2aSRiver Riddle }
292676bfb2aSRiver Riddle 
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const293676bfb2aSRiver Riddle TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
294676bfb2aSRiver Riddle                                  Type elementType) const {
295676bfb2aSRiver Riddle   if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
296676bfb2aSRiver Riddle     if (shape)
297676bfb2aSRiver Riddle       return RankedTensorType::get(*shape, elementType);
298676bfb2aSRiver Riddle     return UnrankedTensorType::get(elementType);
299676bfb2aSRiver Riddle   }
300676bfb2aSRiver Riddle 
301676bfb2aSRiver Riddle   auto rankedTy = cast<RankedTensorType>();
302676bfb2aSRiver Riddle   if (!shape)
303676bfb2aSRiver Riddle     return RankedTensorType::get(rankedTy.getShape(), elementType,
304676bfb2aSRiver Riddle                                  rankedTy.getEncoding());
30530c67587SKazu Hirata   return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
30630c67587SKazu Hirata                                rankedTy.getEncoding());
307676bfb2aSRiver Riddle }
308676bfb2aSRiver Riddle 
30906e25d56SRiver Riddle // Check if "elementType" can be an element type of a tensor.
31006e25d56SRiver Riddle static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic ()> emitError,Type elementType)31106e25d56SRiver Riddle checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
31209f7a55fSRiver Riddle                        Type elementType) {
31309f7a55fSRiver Riddle   if (!TensorType::isValidElementType(elementType))
31406e25d56SRiver Riddle     return emitError() << "invalid tensor element type: " << elementType;
31509f7a55fSRiver Riddle   return success();
31609f7a55fSRiver Riddle }
31709f7a55fSRiver Riddle 
31809f7a55fSRiver Riddle /// Return true if the specified element type is ok in a tensor.
isValidElementType(Type type)31909f7a55fSRiver Riddle bool TensorType::isValidElementType(Type type) {
32009f7a55fSRiver Riddle   // Note: Non standard/builtin types are allowed to exist within tensor
32109f7a55fSRiver Riddle   // types. Dialects are expected to verify that tensor types have a valid
32209f7a55fSRiver Riddle   // element type within that dialect.
32309f7a55fSRiver Riddle   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
32409f7a55fSRiver Riddle                   IndexType>() ||
325f8479d9dSRiver Riddle          !llvm::isa<BuiltinDialect>(type.getDialect());
32609f7a55fSRiver Riddle }
32709f7a55fSRiver Riddle 
32809f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
32909f7a55fSRiver Riddle // RankedTensorType
33009f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
33109f7a55fSRiver Riddle 
33206e25d56SRiver Riddle LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,Attribute encoding)33306e25d56SRiver Riddle RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
3347714b405SAart Bik                          ArrayRef<int64_t> shape, Type elementType,
3357714b405SAart Bik                          Attribute encoding) {
3360d01dfbcSRiver Riddle   for (int64_t s : shape)
33709f7a55fSRiver Riddle     if (s < -1)
33806e25d56SRiver Riddle       return emitError() << "invalid tensor dimension size";
33923c9e8bcSAart Bik   if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
34023c9e8bcSAart Bik     if (failed(v.verifyEncoding(shape, elementType, emitError)))
34123c9e8bcSAart Bik       return failure();
34206e25d56SRiver Riddle   return checkTensorElementType(emitError, elementType);
34309f7a55fSRiver Riddle }
34409f7a55fSRiver Riddle 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const345c42dd5dbSRiver Riddle void RankedTensorType::walkImmediateSubElements(
346c42dd5dbSRiver Riddle     function_ref<void(Attribute)> walkAttrsFn,
347c42dd5dbSRiver Riddle     function_ref<void(Type)> walkTypesFn) const {
348c42dd5dbSRiver Riddle   walkTypesFn(getElementType());
349eb6c63cbSVladislav Vinogradov   if (Attribute encoding = getEncoding())
350eb6c63cbSVladislav Vinogradov     walkAttrsFn(encoding);
351c42dd5dbSRiver Riddle }
352c42dd5dbSRiver Riddle 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const353*01eedbc7SRiver Riddle Type RankedTensorType::replaceImmediateSubElements(
354*01eedbc7SRiver Riddle     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
355*01eedbc7SRiver Riddle   return get(getShape(), replTypes.front(),
356*01eedbc7SRiver Riddle              replAttrs.empty() ? Attribute() : replAttrs.back());
357*01eedbc7SRiver Riddle }
358*01eedbc7SRiver Riddle 
35909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
36009f7a55fSRiver Riddle // UnrankedTensorType
36109f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
36209f7a55fSRiver Riddle 
36309f7a55fSRiver Riddle LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType)36406e25d56SRiver Riddle UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
36509f7a55fSRiver Riddle                            Type elementType) {
36606e25d56SRiver Riddle   return checkTensorElementType(emitError, elementType);
36709f7a55fSRiver Riddle }
36809f7a55fSRiver Riddle 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const369c42dd5dbSRiver Riddle void UnrankedTensorType::walkImmediateSubElements(
370c42dd5dbSRiver Riddle     function_ref<void(Attribute)> walkAttrsFn,
371c42dd5dbSRiver Riddle     function_ref<void(Type)> walkTypesFn) const {
372c42dd5dbSRiver Riddle   walkTypesFn(getElementType());
373c42dd5dbSRiver Riddle }
374c42dd5dbSRiver Riddle 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const375*01eedbc7SRiver Riddle Type UnrankedTensorType::replaceImmediateSubElements(
376*01eedbc7SRiver Riddle     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
377*01eedbc7SRiver Riddle   return get(replTypes.front());
378*01eedbc7SRiver Riddle }
379*01eedbc7SRiver Riddle 
38009f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
38109f7a55fSRiver Riddle // BaseMemRefType
38209f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
38309f7a55fSRiver Riddle 
getElementType() const384676bfb2aSRiver Riddle Type BaseMemRefType::getElementType() const {
385676bfb2aSRiver Riddle   return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
386676bfb2aSRiver Riddle       .Case<MemRefType, UnrankedMemRefType>(
387676bfb2aSRiver Riddle           [](auto type) { return type.getElementType(); });
388676bfb2aSRiver Riddle }
389676bfb2aSRiver Riddle 
hasRank() const390676bfb2aSRiver Riddle bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
391676bfb2aSRiver Riddle 
getShape() const392676bfb2aSRiver Riddle ArrayRef<int64_t> BaseMemRefType::getShape() const {
393676bfb2aSRiver Riddle   return cast<MemRefType>().getShape();
394676bfb2aSRiver Riddle }
395676bfb2aSRiver Riddle 
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const396676bfb2aSRiver Riddle BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape,
397676bfb2aSRiver Riddle                                          Type elementType) const {
398676bfb2aSRiver Riddle   if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
399676bfb2aSRiver Riddle     if (!shape)
400676bfb2aSRiver Riddle       return UnrankedMemRefType::get(elementType, getMemorySpace());
401676bfb2aSRiver Riddle     MemRefType::Builder builder(*shape, elementType);
402676bfb2aSRiver Riddle     builder.setMemorySpace(getMemorySpace());
403676bfb2aSRiver Riddle     return builder;
404676bfb2aSRiver Riddle   }
405676bfb2aSRiver Riddle 
406676bfb2aSRiver Riddle   MemRefType::Builder builder(cast<MemRefType>());
407676bfb2aSRiver Riddle   if (shape)
408676bfb2aSRiver Riddle     builder.setShape(*shape);
409676bfb2aSRiver Riddle   builder.setElementType(elementType);
410676bfb2aSRiver Riddle   return builder;
411676bfb2aSRiver Riddle }
412676bfb2aSRiver Riddle 
getMemorySpace() const413f3bf5c05SVladislav Vinogradov Attribute BaseMemRefType::getMemorySpace() const {
414f3bf5c05SVladislav Vinogradov   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
415f3bf5c05SVladislav Vinogradov     return rankedMemRefTy.getMemorySpace();
416f3bf5c05SVladislav Vinogradov   return cast<UnrankedMemRefType>().getMemorySpace();
417f3bf5c05SVladislav Vinogradov }
418f3bf5c05SVladislav Vinogradov 
getMemorySpaceAsInt() const41937eca08eSVladislav Vinogradov unsigned BaseMemRefType::getMemorySpaceAsInt() const {
4200d01dfbcSRiver Riddle   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
4210d01dfbcSRiver Riddle     return rankedMemRefTy.getMemorySpaceAsInt();
4220d01dfbcSRiver Riddle   return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
42309f7a55fSRiver Riddle }
42409f7a55fSRiver Riddle 
42509f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
42609f7a55fSRiver Riddle // MemRefType
42709f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
42809f7a55fSRiver Riddle 
4290fb4a201SAlex Zinenko /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
4300fb4a201SAlex Zinenko /// `originalShape` with some `1` entries erased, return the set of indices
4310fb4a201SAlex Zinenko /// that specifies which of the entries of `originalShape` are dropped to obtain
4320fb4a201SAlex Zinenko /// `reducedShape`. The returned mask can be applied as a projection to
4330fb4a201SAlex Zinenko /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
4340fb4a201SAlex Zinenko /// which dimensions must be kept when e.g. compute MemRef strides under
4350fb4a201SAlex Zinenko /// rank-reducing operations. Return None if reducedShape cannot be obtained
4360fb4a201SAlex Zinenko /// by dropping only `1` entries in `originalShape`.
4370fb4a201SAlex Zinenko llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> reducedShape)4380fb4a201SAlex Zinenko mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
4390fb4a201SAlex Zinenko                                ArrayRef<int64_t> reducedShape) {
4400fb4a201SAlex Zinenko   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
4410fb4a201SAlex Zinenko   llvm::SmallDenseSet<unsigned> unusedDims;
4420fb4a201SAlex Zinenko   unsigned reducedIdx = 0;
4430fb4a201SAlex Zinenko   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
444a08b750cSNicolas Vasilache     // Greedily insert `originalIdx` if match.
4450fb4a201SAlex Zinenko     if (reducedIdx < reducedRank &&
4460fb4a201SAlex Zinenko         originalShape[originalIdx] == reducedShape[reducedIdx]) {
4470fb4a201SAlex Zinenko       reducedIdx++;
4480fb4a201SAlex Zinenko       continue;
4490fb4a201SAlex Zinenko     }
4500fb4a201SAlex Zinenko 
4510fb4a201SAlex Zinenko     unusedDims.insert(originalIdx);
4520fb4a201SAlex Zinenko     // If no match on `originalIdx`, the `originalShape` at this dimension
4530fb4a201SAlex Zinenko     // must be 1, otherwise we bail.
4540fb4a201SAlex Zinenko     if (originalShape[originalIdx] != 1)
4550fb4a201SAlex Zinenko       return llvm::None;
4560fb4a201SAlex Zinenko   }
4570fb4a201SAlex Zinenko   // The whole reducedShape must be scanned, otherwise we bail.
4580fb4a201SAlex Zinenko   if (reducedIdx != reducedRank)
4590fb4a201SAlex Zinenko     return llvm::None;
4600fb4a201SAlex Zinenko   return unusedDims;
4610fb4a201SAlex Zinenko }
4620fb4a201SAlex Zinenko 
463a08b750cSNicolas Vasilache SliceVerificationResult
isRankReducedType(ShapedType originalType,ShapedType candidateReducedType)464a08b750cSNicolas Vasilache mlir::isRankReducedType(ShapedType originalType,
465a08b750cSNicolas Vasilache                         ShapedType candidateReducedType) {
466a08b750cSNicolas Vasilache   if (originalType == candidateReducedType)
467a08b750cSNicolas Vasilache     return SliceVerificationResult::Success;
468a08b750cSNicolas Vasilache 
469a08b750cSNicolas Vasilache   ShapedType originalShapedType = originalType.cast<ShapedType>();
470a08b750cSNicolas Vasilache   ShapedType candidateReducedShapedType =
471a08b750cSNicolas Vasilache       candidateReducedType.cast<ShapedType>();
472a08b750cSNicolas Vasilache 
473a08b750cSNicolas Vasilache   // Rank and size logic is valid for all ShapedTypes.
474a08b750cSNicolas Vasilache   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
475a08b750cSNicolas Vasilache   ArrayRef<int64_t> candidateReducedShape =
476a08b750cSNicolas Vasilache       candidateReducedShapedType.getShape();
477a08b750cSNicolas Vasilache   unsigned originalRank = originalShape.size(),
478a08b750cSNicolas Vasilache            candidateReducedRank = candidateReducedShape.size();
479a08b750cSNicolas Vasilache   if (candidateReducedRank > originalRank)
480a08b750cSNicolas Vasilache     return SliceVerificationResult::RankTooLarge;
481a08b750cSNicolas Vasilache 
482a08b750cSNicolas Vasilache   auto optionalUnusedDimsMask =
483a08b750cSNicolas Vasilache       computeRankReductionMask(originalShape, candidateReducedShape);
484a08b750cSNicolas Vasilache 
485a08b750cSNicolas Vasilache   // Sizes cannot be matched in case empty vector is returned.
486037f0995SKazu Hirata   if (!optionalUnusedDimsMask)
487a08b750cSNicolas Vasilache     return SliceVerificationResult::SizeMismatch;
488a08b750cSNicolas Vasilache 
489a08b750cSNicolas Vasilache   if (originalShapedType.getElementType() !=
490a08b750cSNicolas Vasilache       candidateReducedShapedType.getElementType())
491a08b750cSNicolas Vasilache     return SliceVerificationResult::ElemTypeMismatch;
492a08b750cSNicolas Vasilache 
493a08b750cSNicolas Vasilache   return SliceVerificationResult::Success;
494a08b750cSNicolas Vasilache }
495a08b750cSNicolas Vasilache 
isSupportedMemorySpace(Attribute memorySpace)496f3bf5c05SVladislav Vinogradov bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
497f3bf5c05SVladislav Vinogradov   // Empty attribute is allowed as default memory space.
498f3bf5c05SVladislav Vinogradov   if (!memorySpace)
499f3bf5c05SVladislav Vinogradov     return true;
500f3bf5c05SVladislav Vinogradov 
501f3bf5c05SVladislav Vinogradov   // Supported built-in attributes.
502f3bf5c05SVladislav Vinogradov   if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
503f3bf5c05SVladislav Vinogradov     return true;
504f3bf5c05SVladislav Vinogradov 
505f3bf5c05SVladislav Vinogradov   // Allow custom dialect attributes.
506ee090870SMehdi Amini   if (!isa<BuiltinDialect>(memorySpace.getDialect()))
507f3bf5c05SVladislav Vinogradov     return true;
508f3bf5c05SVladislav Vinogradov 
509f3bf5c05SVladislav Vinogradov   return false;
510f3bf5c05SVladislav Vinogradov }
511f3bf5c05SVladislav Vinogradov 
wrapIntegerMemorySpace(unsigned memorySpace,MLIRContext * ctx)512f3bf5c05SVladislav Vinogradov Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
513f3bf5c05SVladislav Vinogradov                                                MLIRContext *ctx) {
514f3bf5c05SVladislav Vinogradov   if (memorySpace == 0)
515f3bf5c05SVladislav Vinogradov     return nullptr;
516f3bf5c05SVladislav Vinogradov 
517f3bf5c05SVladislav Vinogradov   return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
518f3bf5c05SVladislav Vinogradov }
519f3bf5c05SVladislav Vinogradov 
skipDefaultMemorySpace(Attribute memorySpace)520f3bf5c05SVladislav Vinogradov Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
521f3bf5c05SVladislav Vinogradov   IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
522f3bf5c05SVladislav Vinogradov   if (intMemorySpace && intMemorySpace.getValue() == 0)
523f3bf5c05SVladislav Vinogradov     return nullptr;
524f3bf5c05SVladislav Vinogradov 
525f3bf5c05SVladislav Vinogradov   return memorySpace;
526f3bf5c05SVladislav Vinogradov }
527f3bf5c05SVladislav Vinogradov 
getMemorySpaceAsInt(Attribute memorySpace)528f3bf5c05SVladislav Vinogradov unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
529f3bf5c05SVladislav Vinogradov   if (!memorySpace)
530f3bf5c05SVladislav Vinogradov     return 0;
531f3bf5c05SVladislav Vinogradov 
532f3bf5c05SVladislav Vinogradov   assert(memorySpace.isa<IntegerAttr>() &&
533f3bf5c05SVladislav Vinogradov          "Using `getMemorySpaceInteger` with non-Integer attribute");
534f3bf5c05SVladislav Vinogradov 
535f3bf5c05SVladislav Vinogradov   return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
536f3bf5c05SVladislav Vinogradov }
537f3bf5c05SVladislav Vinogradov 
538f3bf5c05SVladislav Vinogradov MemRefType::Builder &
setMemorySpace(unsigned newMemorySpace)539f3bf5c05SVladislav Vinogradov MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
540f3bf5c05SVladislav Vinogradov   memorySpace =
541f3bf5c05SVladislav Vinogradov       wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
542f3bf5c05SVladislav Vinogradov   return *this;
543f3bf5c05SVladislav Vinogradov }
544f3bf5c05SVladislav Vinogradov 
getMemorySpaceAsInt() const545f3bf5c05SVladislav Vinogradov unsigned MemRefType::getMemorySpaceAsInt() const {
546f3bf5c05SVladislav Vinogradov   return detail::getMemorySpaceAsInt(getMemorySpace());
547f3bf5c05SVladislav Vinogradov }
548f3bf5c05SVladislav Vinogradov 
get(ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)549e41ebbecSVladislav Vinogradov MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
550e41ebbecSVladislav Vinogradov                            MemRefLayoutAttrInterface layout,
551e41ebbecSVladislav Vinogradov                            Attribute memorySpace) {
552e41ebbecSVladislav Vinogradov   // Use default layout for empty attribute.
553e41ebbecSVladislav Vinogradov   if (!layout)
554e41ebbecSVladislav Vinogradov     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
555e41ebbecSVladislav Vinogradov         shape.size(), elementType.getContext()));
556e41ebbecSVladislav Vinogradov 
557e41ebbecSVladislav Vinogradov   // Drop default memory space value and replace it with empty attribute.
558e41ebbecSVladislav Vinogradov   memorySpace = skipDefaultMemorySpace(memorySpace);
559e41ebbecSVladislav Vinogradov 
560e41ebbecSVladislav Vinogradov   return Base::get(elementType.getContext(), shape, elementType, layout,
561e41ebbecSVladislav Vinogradov                    memorySpace);
562e41ebbecSVladislav Vinogradov }
563e41ebbecSVladislav Vinogradov 
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)564e41ebbecSVladislav Vinogradov MemRefType MemRefType::getChecked(
565e41ebbecSVladislav Vinogradov     function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
566e41ebbecSVladislav Vinogradov     Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
567e41ebbecSVladislav Vinogradov 
568e41ebbecSVladislav Vinogradov   // Use default layout for empty attribute.
569e41ebbecSVladislav Vinogradov   if (!layout)
570e41ebbecSVladislav Vinogradov     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
571e41ebbecSVladislav Vinogradov         shape.size(), elementType.getContext()));
572e41ebbecSVladislav Vinogradov 
573e41ebbecSVladislav Vinogradov   // Drop default memory space value and replace it with empty attribute.
574e41ebbecSVladislav Vinogradov   memorySpace = skipDefaultMemorySpace(memorySpace);
575e41ebbecSVladislav Vinogradov 
576e41ebbecSVladislav Vinogradov   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
577e41ebbecSVladislav Vinogradov                           elementType, layout, memorySpace);
578e41ebbecSVladislav Vinogradov }
579e41ebbecSVladislav Vinogradov 
get(ArrayRef<int64_t> shape,Type elementType,AffineMap map,Attribute memorySpace)580e41ebbecSVladislav Vinogradov MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
581e41ebbecSVladislav Vinogradov                            AffineMap map, Attribute memorySpace) {
582e41ebbecSVladislav Vinogradov 
583e41ebbecSVladislav Vinogradov   // Use default layout for empty map.
584e41ebbecSVladislav Vinogradov   if (!map)
585e41ebbecSVladislav Vinogradov     map = AffineMap::getMultiDimIdentityMap(shape.size(),
586e41ebbecSVladislav Vinogradov                                             elementType.getContext());
587e41ebbecSVladislav Vinogradov 
588e41ebbecSVladislav Vinogradov   // Wrap AffineMap into Attribute.
589e41ebbecSVladislav Vinogradov   Attribute layout = AffineMapAttr::get(map);
590e41ebbecSVladislav Vinogradov 
591e41ebbecSVladislav Vinogradov   // Drop default memory space value and replace it with empty attribute.
592e41ebbecSVladislav Vinogradov   memorySpace = skipDefaultMemorySpace(memorySpace);
593e41ebbecSVladislav Vinogradov 
594e41ebbecSVladislav Vinogradov   return Base::get(elementType.getContext(), shape, elementType, layout,
595e41ebbecSVladislav Vinogradov                    memorySpace);
596e41ebbecSVladislav Vinogradov }
597e41ebbecSVladislav Vinogradov 
598e41ebbecSVladislav Vinogradov MemRefType
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,AffineMap map,Attribute memorySpace)599e41ebbecSVladislav Vinogradov MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
600e41ebbecSVladislav Vinogradov                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
601e41ebbecSVladislav Vinogradov                        Attribute memorySpace) {
602e41ebbecSVladislav Vinogradov 
603e41ebbecSVladislav Vinogradov   // Use default layout for empty map.
604e41ebbecSVladislav Vinogradov   if (!map)
605e41ebbecSVladislav Vinogradov     map = AffineMap::getMultiDimIdentityMap(shape.size(),
606e41ebbecSVladislav Vinogradov                                             elementType.getContext());
607e41ebbecSVladislav Vinogradov 
608e41ebbecSVladislav Vinogradov   // Wrap AffineMap into Attribute.
609e41ebbecSVladislav Vinogradov   Attribute layout = AffineMapAttr::get(map);
610e41ebbecSVladislav Vinogradov 
611e41ebbecSVladislav Vinogradov   // Drop default memory space value and replace it with empty attribute.
612e41ebbecSVladislav Vinogradov   memorySpace = skipDefaultMemorySpace(memorySpace);
613e41ebbecSVladislav Vinogradov 
614e41ebbecSVladislav Vinogradov   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
615e41ebbecSVladislav Vinogradov                           elementType, layout, memorySpace);
616e41ebbecSVladislav Vinogradov }
617e41ebbecSVladislav Vinogradov 
get(ArrayRef<int64_t> shape,Type elementType,AffineMap map,unsigned memorySpaceInd)618e41ebbecSVladislav Vinogradov MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
619e41ebbecSVladislav Vinogradov                            AffineMap map, unsigned memorySpaceInd) {
620e41ebbecSVladislav Vinogradov 
621e41ebbecSVladislav Vinogradov   // Use default layout for empty map.
622e41ebbecSVladislav Vinogradov   if (!map)
623e41ebbecSVladislav Vinogradov     map = AffineMap::getMultiDimIdentityMap(shape.size(),
624e41ebbecSVladislav Vinogradov                                             elementType.getContext());
625e41ebbecSVladislav Vinogradov 
626e41ebbecSVladislav Vinogradov   // Wrap AffineMap into Attribute.
627e41ebbecSVladislav Vinogradov   Attribute layout = AffineMapAttr::get(map);
628e41ebbecSVladislav Vinogradov 
629e41ebbecSVladislav Vinogradov   // Convert deprecated integer-like memory space to Attribute.
630e41ebbecSVladislav Vinogradov   Attribute memorySpace =
631e41ebbecSVladislav Vinogradov       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
632e41ebbecSVladislav Vinogradov 
633e41ebbecSVladislav Vinogradov   return Base::get(elementType.getContext(), shape, elementType, layout,
634e41ebbecSVladislav Vinogradov                    memorySpace);
635e41ebbecSVladislav Vinogradov }
636e41ebbecSVladislav Vinogradov 
637e41ebbecSVladislav Vinogradov MemRefType
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,AffineMap map,unsigned memorySpaceInd)638e41ebbecSVladislav Vinogradov MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
639e41ebbecSVladislav Vinogradov                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
640e41ebbecSVladislav Vinogradov                        unsigned memorySpaceInd) {
641e41ebbecSVladislav Vinogradov 
642e41ebbecSVladislav Vinogradov   // Use default layout for empty map.
643e41ebbecSVladislav Vinogradov   if (!map)
644e41ebbecSVladislav Vinogradov     map = AffineMap::getMultiDimIdentityMap(shape.size(),
645e41ebbecSVladislav Vinogradov                                             elementType.getContext());
646e41ebbecSVladislav Vinogradov 
647e41ebbecSVladislav Vinogradov   // Wrap AffineMap into Attribute.
648e41ebbecSVladislav Vinogradov   Attribute layout = AffineMapAttr::get(map);
649e41ebbecSVladislav Vinogradov 
650e41ebbecSVladislav Vinogradov   // Convert deprecated integer-like memory space to Attribute.
651e41ebbecSVladislav Vinogradov   Attribute memorySpace =
652e41ebbecSVladislav Vinogradov       wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
653e41ebbecSVladislav Vinogradov 
654e41ebbecSVladislav Vinogradov   return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
655e41ebbecSVladislav Vinogradov                           elementType, layout, memorySpace);
656e41ebbecSVladislav Vinogradov }
657e41ebbecSVladislav Vinogradov 
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)6580d01dfbcSRiver Riddle LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
65906e25d56SRiver Riddle                                  ArrayRef<int64_t> shape, Type elementType,
660e41ebbecSVladislav Vinogradov                                  MemRefLayoutAttrInterface layout,
661f3bf5c05SVladislav Vinogradov                                  Attribute memorySpace) {
66209f7a55fSRiver Riddle   if (!BaseMemRefType::isValidElementType(elementType))
6630d01dfbcSRiver Riddle     return emitError() << "invalid memref element type";
66409f7a55fSRiver Riddle 
66509f7a55fSRiver Riddle   // Negative sizes are not allowed except for `-1` that means dynamic size.
6660d01dfbcSRiver Riddle   for (int64_t s : shape)
66709f7a55fSRiver Riddle     if (s < -1)
6680d01dfbcSRiver Riddle       return emitError() << "invalid memref size";
66909f7a55fSRiver Riddle 
670e41ebbecSVladislav Vinogradov   assert(layout && "missing layout specification");
671e41ebbecSVladislav Vinogradov   if (failed(layout.verifyLayout(shape, emitError)))
672e41ebbecSVladislav Vinogradov     return failure();
673f3bf5c05SVladislav Vinogradov 
674e41ebbecSVladislav Vinogradov   if (!isSupportedMemorySpace(memorySpace))
675f3bf5c05SVladislav Vinogradov     return emitError() << "unsupported memory space Attribute";
676f3bf5c05SVladislav Vinogradov 
6770d01dfbcSRiver Riddle   return success();
67809f7a55fSRiver Riddle }
67909f7a55fSRiver Riddle 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const680c42dd5dbSRiver Riddle void MemRefType::walkImmediateSubElements(
681c42dd5dbSRiver Riddle     function_ref<void(Attribute)> walkAttrsFn,
682c42dd5dbSRiver Riddle     function_ref<void(Type)> walkTypesFn) const {
683c42dd5dbSRiver Riddle   walkTypesFn(getElementType());
684e41ebbecSVladislav Vinogradov   if (!getLayout().isIdentity())
685e41ebbecSVladislav Vinogradov     walkAttrsFn(getLayout());
686c42dd5dbSRiver Riddle   walkAttrsFn(getMemorySpace());
687c42dd5dbSRiver Riddle }
688c42dd5dbSRiver Riddle 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const689*01eedbc7SRiver Riddle Type MemRefType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
690*01eedbc7SRiver Riddle                                              ArrayRef<Type> replTypes) const {
691*01eedbc7SRiver Riddle   bool hasLayout = replAttrs.size() > 1;
692*01eedbc7SRiver Riddle   return get(getShape(), replTypes[0],
693*01eedbc7SRiver Riddle              hasLayout ? replAttrs[0].dyn_cast<MemRefLayoutAttrInterface>()
694*01eedbc7SRiver Riddle                        : MemRefLayoutAttrInterface(),
695*01eedbc7SRiver Riddle              hasLayout ? replAttrs[1] : replAttrs[0]);
696*01eedbc7SRiver Riddle }
697*01eedbc7SRiver Riddle 
69809f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
69909f7a55fSRiver Riddle // UnrankedMemRefType
70009f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
70109f7a55fSRiver Riddle 
getMemorySpaceAsInt() const702f3bf5c05SVladislav Vinogradov unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
703f3bf5c05SVladislav Vinogradov   return detail::getMemorySpaceAsInt(getMemorySpace());
704f3bf5c05SVladislav Vinogradov }
705f3bf5c05SVladislav Vinogradov 
70609f7a55fSRiver Riddle LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType,Attribute memorySpace)70706e25d56SRiver Riddle UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
708f3bf5c05SVladislav Vinogradov                            Type elementType, Attribute memorySpace) {
70909f7a55fSRiver Riddle   if (!BaseMemRefType::isValidElementType(elementType))
71006e25d56SRiver Riddle     return emitError() << "invalid memref element type";
711f3bf5c05SVladislav Vinogradov 
712f3bf5c05SVladislav Vinogradov   if (!isSupportedMemorySpace(memorySpace))
713f3bf5c05SVladislav Vinogradov     return emitError() << "unsupported memory space Attribute";
714f3bf5c05SVladislav Vinogradov 
71509f7a55fSRiver Riddle   return success();
71609f7a55fSRiver Riddle }
71709f7a55fSRiver Riddle 
71809f7a55fSRiver Riddle // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
71909f7a55fSRiver Riddle // i.e. single term). Accumulate the AffineExpr into the existing one.
extractStridesFromTerm(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)72009f7a55fSRiver Riddle static void extractStridesFromTerm(AffineExpr e,
72109f7a55fSRiver Riddle                                    AffineExpr multiplicativeFactor,
72209f7a55fSRiver Riddle                                    MutableArrayRef<AffineExpr> strides,
72309f7a55fSRiver Riddle                                    AffineExpr &offset) {
72409f7a55fSRiver Riddle   if (auto dim = e.dyn_cast<AffineDimExpr>())
72509f7a55fSRiver Riddle     strides[dim.getPosition()] =
72609f7a55fSRiver Riddle         strides[dim.getPosition()] + multiplicativeFactor;
72709f7a55fSRiver Riddle   else
72809f7a55fSRiver Riddle     offset = offset + e * multiplicativeFactor;
72909f7a55fSRiver Riddle }
73009f7a55fSRiver Riddle 
73109f7a55fSRiver Riddle /// Takes a single AffineExpr `e` and populates the `strides` array with the
73209f7a55fSRiver Riddle /// strides expressions for each dim position.
73309f7a55fSRiver Riddle /// The convention is that the strides for dimensions d0, .. dn appear in
73409f7a55fSRiver Riddle /// order to make indexing intuitive into the result.
extractStrides(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)73509f7a55fSRiver Riddle static LogicalResult extractStrides(AffineExpr e,
73609f7a55fSRiver Riddle                                     AffineExpr multiplicativeFactor,
73709f7a55fSRiver Riddle                                     MutableArrayRef<AffineExpr> strides,
73809f7a55fSRiver Riddle                                     AffineExpr &offset) {
73909f7a55fSRiver Riddle   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
74009f7a55fSRiver Riddle   if (!bin) {
74109f7a55fSRiver Riddle     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
74209f7a55fSRiver Riddle     return success();
74309f7a55fSRiver Riddle   }
74409f7a55fSRiver Riddle 
74509f7a55fSRiver Riddle   if (bin.getKind() == AffineExprKind::CeilDiv ||
74609f7a55fSRiver Riddle       bin.getKind() == AffineExprKind::FloorDiv ||
74709f7a55fSRiver Riddle       bin.getKind() == AffineExprKind::Mod)
74809f7a55fSRiver Riddle     return failure();
74909f7a55fSRiver Riddle 
75009f7a55fSRiver Riddle   if (bin.getKind() == AffineExprKind::Mul) {
75109f7a55fSRiver Riddle     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
75209f7a55fSRiver Riddle     if (dim) {
75309f7a55fSRiver Riddle       strides[dim.getPosition()] =
75409f7a55fSRiver Riddle           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
75509f7a55fSRiver Riddle       return success();
75609f7a55fSRiver Riddle     }
75709f7a55fSRiver Riddle     // LHS and RHS may both contain complex expressions of dims. Try one path
75809f7a55fSRiver Riddle     // and if it fails try the other. This is guaranteed to succeed because
75909f7a55fSRiver Riddle     // only one path may have a `dim`, otherwise this is not an AffineExpr in
76009f7a55fSRiver Riddle     // the first place.
76109f7a55fSRiver Riddle     if (bin.getLHS().isSymbolicOrConstant())
76209f7a55fSRiver Riddle       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
76309f7a55fSRiver Riddle                             strides, offset);
76409f7a55fSRiver Riddle     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
76509f7a55fSRiver Riddle                           strides, offset);
76609f7a55fSRiver Riddle   }
76709f7a55fSRiver Riddle 
76809f7a55fSRiver Riddle   if (bin.getKind() == AffineExprKind::Add) {
76909f7a55fSRiver Riddle     auto res1 =
77009f7a55fSRiver Riddle         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
77109f7a55fSRiver Riddle     auto res2 =
77209f7a55fSRiver Riddle         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
77309f7a55fSRiver Riddle     return success(succeeded(res1) && succeeded(res2));
77409f7a55fSRiver Riddle   }
77509f7a55fSRiver Riddle 
77609f7a55fSRiver Riddle   llvm_unreachable("unexpected binary operation");
77709f7a55fSRiver Riddle }
77809f7a55fSRiver Riddle 
getStridesAndOffset(MemRefType t,SmallVectorImpl<AffineExpr> & strides,AffineExpr & offset)77909f7a55fSRiver Riddle LogicalResult mlir::getStridesAndOffset(MemRefType t,
78009f7a55fSRiver Riddle                                         SmallVectorImpl<AffineExpr> &strides,
78109f7a55fSRiver Riddle                                         AffineExpr &offset) {
782e41ebbecSVladislav Vinogradov   AffineMap m = t.getLayout().getAffineMap();
78370b6f16eSVladislav Vinogradov 
784e41ebbecSVladislav Vinogradov   if (m.getNumResults() != 1 && !m.isIdentity())
785fab634b4SNicolas Vasilache     return failure();
786fab634b4SNicolas Vasilache 
78709f7a55fSRiver Riddle   auto zero = getAffineConstantExpr(0, t.getContext());
78809f7a55fSRiver Riddle   auto one = getAffineConstantExpr(1, t.getContext());
78909f7a55fSRiver Riddle   offset = zero;
79009f7a55fSRiver Riddle   strides.assign(t.getRank(), zero);
79109f7a55fSRiver Riddle 
79209f7a55fSRiver Riddle   // Canonical case for empty map.
793e41ebbecSVladislav Vinogradov   if (m.isIdentity()) {
79409f7a55fSRiver Riddle     // 0-D corner case, offset is already 0.
79509f7a55fSRiver Riddle     if (t.getRank() == 0)
79609f7a55fSRiver Riddle       return success();
79709f7a55fSRiver Riddle     auto stridedExpr =
79809f7a55fSRiver Riddle         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
79909f7a55fSRiver Riddle     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
80009f7a55fSRiver Riddle       return success();
80109f7a55fSRiver Riddle     assert(false && "unexpected failure: extract strides in canonical layout");
80209f7a55fSRiver Riddle   }
80309f7a55fSRiver Riddle 
80409f7a55fSRiver Riddle   // Non-canonical case requires more work.
80509f7a55fSRiver Riddle   auto stridedExpr =
80609f7a55fSRiver Riddle       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
80709f7a55fSRiver Riddle   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
80809f7a55fSRiver Riddle     offset = AffineExpr();
80909f7a55fSRiver Riddle     strides.clear();
81009f7a55fSRiver Riddle     return failure();
81109f7a55fSRiver Riddle   }
81209f7a55fSRiver Riddle 
81309f7a55fSRiver Riddle   // Simplify results to allow folding to constants and simple checks.
81409f7a55fSRiver Riddle   unsigned numDims = m.getNumDims();
81509f7a55fSRiver Riddle   unsigned numSymbols = m.getNumSymbols();
81609f7a55fSRiver Riddle   offset = simplifyAffineExpr(offset, numDims, numSymbols);
81709f7a55fSRiver Riddle   for (auto &stride : strides)
81809f7a55fSRiver Riddle     stride = simplifyAffineExpr(stride, numDims, numSymbols);
81909f7a55fSRiver Riddle 
82009f7a55fSRiver Riddle   /// In practice, a strided memref must be internally non-aliasing. Test
82109f7a55fSRiver Riddle   /// against 0 as a proxy.
82209f7a55fSRiver Riddle   /// TODO: static cases can have more advanced checks.
82309f7a55fSRiver Riddle   /// TODO: dynamic cases would require a way to compare symbolic
82409f7a55fSRiver Riddle   /// expressions and would probably need an affine set context propagated
82509f7a55fSRiver Riddle   /// everywhere.
82609f7a55fSRiver Riddle   if (llvm::any_of(strides, [](AffineExpr e) {
82709f7a55fSRiver Riddle         return e == getAffineConstantExpr(0, e.getContext());
82809f7a55fSRiver Riddle       })) {
82909f7a55fSRiver Riddle     offset = AffineExpr();
83009f7a55fSRiver Riddle     strides.clear();
83109f7a55fSRiver Riddle     return failure();
83209f7a55fSRiver Riddle   }
83309f7a55fSRiver Riddle 
83409f7a55fSRiver Riddle   return success();
83509f7a55fSRiver Riddle }
83609f7a55fSRiver Riddle 
getStridesAndOffset(MemRefType t,SmallVectorImpl<int64_t> & strides,int64_t & offset)83709f7a55fSRiver Riddle LogicalResult mlir::getStridesAndOffset(MemRefType t,
83809f7a55fSRiver Riddle                                         SmallVectorImpl<int64_t> &strides,
83909f7a55fSRiver Riddle                                         int64_t &offset) {
84009f7a55fSRiver Riddle   AffineExpr offsetExpr;
84109f7a55fSRiver Riddle   SmallVector<AffineExpr, 4> strideExprs;
84209f7a55fSRiver Riddle   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
84309f7a55fSRiver Riddle     return failure();
84409f7a55fSRiver Riddle   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
84509f7a55fSRiver Riddle     offset = cst.getValue();
84609f7a55fSRiver Riddle   else
84709f7a55fSRiver Riddle     offset = ShapedType::kDynamicStrideOrOffset;
84809f7a55fSRiver Riddle   for (auto e : strideExprs) {
84909f7a55fSRiver Riddle     if (auto c = e.dyn_cast<AffineConstantExpr>())
85009f7a55fSRiver Riddle       strides.push_back(c.getValue());
85109f7a55fSRiver Riddle     else
85209f7a55fSRiver Riddle       strides.push_back(ShapedType::kDynamicStrideOrOffset);
85309f7a55fSRiver Riddle   }
85409f7a55fSRiver Riddle   return success();
85509f7a55fSRiver Riddle }
85609f7a55fSRiver Riddle 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const857c42dd5dbSRiver Riddle void UnrankedMemRefType::walkImmediateSubElements(
858c42dd5dbSRiver Riddle     function_ref<void(Attribute)> walkAttrsFn,
859c42dd5dbSRiver Riddle     function_ref<void(Type)> walkTypesFn) const {
860c42dd5dbSRiver Riddle   walkTypesFn(getElementType());
861c42dd5dbSRiver Riddle   walkAttrsFn(getMemorySpace());
862c42dd5dbSRiver Riddle }
863c42dd5dbSRiver Riddle 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const864*01eedbc7SRiver Riddle Type UnrankedMemRefType::replaceImmediateSubElements(
865*01eedbc7SRiver Riddle     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
866*01eedbc7SRiver Riddle   return get(replTypes.front(), replAttrs.front());
867*01eedbc7SRiver Riddle }
868*01eedbc7SRiver Riddle 
86909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
87009f7a55fSRiver Riddle /// TupleType
87109f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
87209f7a55fSRiver Riddle 
87309f7a55fSRiver Riddle /// Return the elements types for this tuple.
getTypes() const87409f7a55fSRiver Riddle ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
87509f7a55fSRiver Riddle 
87609f7a55fSRiver Riddle /// Accumulate the types contained in this tuple and tuples nested within it.
87709f7a55fSRiver Riddle /// Note that this only flattens nested tuples, not any other container type,
87809f7a55fSRiver Riddle /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
87909f7a55fSRiver Riddle /// (i32, tensor<i32>, f32, i64)
getFlattenedTypes(SmallVectorImpl<Type> & types)88009f7a55fSRiver Riddle void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
88109f7a55fSRiver Riddle   for (Type type : getTypes()) {
88209f7a55fSRiver Riddle     if (auto nestedTuple = type.dyn_cast<TupleType>())
88309f7a55fSRiver Riddle       nestedTuple.getFlattenedTypes(types);
88409f7a55fSRiver Riddle     else
88509f7a55fSRiver Riddle       types.push_back(type);
88609f7a55fSRiver Riddle   }
88709f7a55fSRiver Riddle }
88809f7a55fSRiver Riddle 
88909f7a55fSRiver Riddle /// Return the number of element types.
size() const89009f7a55fSRiver Riddle size_t TupleType::size() const { return getImpl()->size(); }
89109f7a55fSRiver Riddle 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const892c42dd5dbSRiver Riddle void TupleType::walkImmediateSubElements(
893c42dd5dbSRiver Riddle     function_ref<void(Attribute)> walkAttrsFn,
894c42dd5dbSRiver Riddle     function_ref<void(Type)> walkTypesFn) const {
895c42dd5dbSRiver Riddle   for (Type type : getTypes())
896c42dd5dbSRiver Riddle     walkTypesFn(type);
897c42dd5dbSRiver Riddle }
898c42dd5dbSRiver Riddle 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const899*01eedbc7SRiver Riddle Type TupleType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
900*01eedbc7SRiver Riddle                                             ArrayRef<Type> replTypes) const {
901*01eedbc7SRiver Riddle   return get(getContext(), replTypes);
902*01eedbc7SRiver Riddle }
903*01eedbc7SRiver Riddle 
904d79642b3SRiver Riddle //===----------------------------------------------------------------------===//
905d79642b3SRiver Riddle // Type Utilities
906d79642b3SRiver Riddle //===----------------------------------------------------------------------===//
907d79642b3SRiver Riddle 
makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,int64_t offset,MLIRContext * context)90809f7a55fSRiver Riddle AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
90909f7a55fSRiver Riddle                                            int64_t offset,
91009f7a55fSRiver Riddle                                            MLIRContext *context) {
91109f7a55fSRiver Riddle   AffineExpr expr;
91209f7a55fSRiver Riddle   unsigned nSymbols = 0;
91309f7a55fSRiver Riddle 
91409f7a55fSRiver Riddle   // AffineExpr for offset.
91509f7a55fSRiver Riddle   // Static case.
91609f7a55fSRiver Riddle   if (offset != MemRefType::getDynamicStrideOrOffset()) {
91709f7a55fSRiver Riddle     auto cst = getAffineConstantExpr(offset, context);
91809f7a55fSRiver Riddle     expr = cst;
91909f7a55fSRiver Riddle   } else {
92009f7a55fSRiver Riddle     // Dynamic case, new symbol for the offset.
92109f7a55fSRiver Riddle     auto sym = getAffineSymbolExpr(nSymbols++, context);
92209f7a55fSRiver Riddle     expr = sym;
92309f7a55fSRiver Riddle   }
92409f7a55fSRiver Riddle 
92509f7a55fSRiver Riddle   // AffineExpr for strides.
926e4853be2SMehdi Amini   for (const auto &en : llvm::enumerate(strides)) {
92709f7a55fSRiver Riddle     auto dim = en.index();
92809f7a55fSRiver Riddle     auto stride = en.value();
92909f7a55fSRiver Riddle     assert(stride != 0 && "Invalid stride specification");
93009f7a55fSRiver Riddle     auto d = getAffineDimExpr(dim, context);
93109f7a55fSRiver Riddle     AffineExpr mult;
93209f7a55fSRiver Riddle     // Static case.
93309f7a55fSRiver Riddle     if (stride != MemRefType::getDynamicStrideOrOffset())
93409f7a55fSRiver Riddle       mult = getAffineConstantExpr(stride, context);
93509f7a55fSRiver Riddle     else
93609f7a55fSRiver Riddle       // Dynamic case, new symbol for each new stride.
93709f7a55fSRiver Riddle       mult = getAffineSymbolExpr(nSymbols++, context);
93809f7a55fSRiver Riddle     expr = expr + d * mult;
93909f7a55fSRiver Riddle   }
94009f7a55fSRiver Riddle 
94109f7a55fSRiver Riddle   return AffineMap::get(strides.size(), nSymbols, expr);
94209f7a55fSRiver Riddle }
94309f7a55fSRiver Riddle 
94409f7a55fSRiver Riddle /// Return a version of `t` with identity layout if it can be determined
94509f7a55fSRiver Riddle /// statically that the layout is the canonical contiguous strided layout.
94609f7a55fSRiver Riddle /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
94709f7a55fSRiver Riddle /// `t` with simplified layout.
94809f7a55fSRiver Riddle /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
canonicalizeStridedLayout(MemRefType t)94909f7a55fSRiver Riddle MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
950e41ebbecSVladislav Vinogradov   AffineMap m = t.getLayout().getAffineMap();
951e41ebbecSVladislav Vinogradov 
95209f7a55fSRiver Riddle   // Already in canonical form.
953e41ebbecSVladislav Vinogradov   if (m.isIdentity())
95409f7a55fSRiver Riddle     return t;
95509f7a55fSRiver Riddle 
95609f7a55fSRiver Riddle   // Can't reduce to canonical identity form, return in canonical form.
957e41ebbecSVladislav Vinogradov   if (m.getNumResults() > 1)
95809f7a55fSRiver Riddle     return t;
95909f7a55fSRiver Riddle 
9607e6fe5c4SNicolas Vasilache   // Corner-case for 0-D affine maps.
9617e6fe5c4SNicolas Vasilache   if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
9627e6fe5c4SNicolas Vasilache     if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
9637e6fe5c4SNicolas Vasilache       if (cst.getValue() == 0)
964e41ebbecSVladislav Vinogradov         return MemRefType::Builder(t).setLayout({});
9657e6fe5c4SNicolas Vasilache     return t;
9667e6fe5c4SNicolas Vasilache   }
9677e6fe5c4SNicolas Vasilache 
968f4ac9f03SNicolas Vasilache   // 0-D corner case for empty shape that still have an affine map. Example:
969f4ac9f03SNicolas Vasilache   // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
970f4ac9f03SNicolas Vasilache   // offset needs to remain, just return t.
971f4ac9f03SNicolas Vasilache   if (t.getShape().empty())
972f4ac9f03SNicolas Vasilache     return t;
973f4ac9f03SNicolas Vasilache 
97409f7a55fSRiver Riddle   // If the canonical strided layout for the sizes of `t` is equal to the
97509f7a55fSRiver Riddle   // simplified layout of `t` we can just return an empty layout. Otherwise,
97609f7a55fSRiver Riddle   // just simplify the existing layout.
97709f7a55fSRiver Riddle   AffineExpr expr =
97809f7a55fSRiver Riddle       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
97909f7a55fSRiver Riddle   auto simplifiedLayoutExpr =
98009f7a55fSRiver Riddle       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
98109f7a55fSRiver Riddle   if (expr != simplifiedLayoutExpr)
982e41ebbecSVladislav Vinogradov     return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
983e41ebbecSVladislav Vinogradov         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
984e41ebbecSVladislav Vinogradov   return MemRefType::Builder(t).setLayout({});
98509f7a55fSRiver Riddle }
98609f7a55fSRiver Riddle 
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> exprs,MLIRContext * context)98709f7a55fSRiver Riddle AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
98809f7a55fSRiver Riddle                                                 ArrayRef<AffineExpr> exprs,
98909f7a55fSRiver Riddle                                                 MLIRContext *context) {
99009f7a55fSRiver Riddle   // Size 0 corner case is useful for canonicalizations.
991dccb7331SBenjamin Kramer   if (sizes.empty() || llvm::is_contained(sizes, 0))
99209f7a55fSRiver Riddle     return getAffineConstantExpr(0, context);
99309f7a55fSRiver Riddle 
994dccb7331SBenjamin Kramer   assert(!exprs.empty() && "expected exprs");
99509f7a55fSRiver Riddle   auto maps = AffineMap::inferFromExprList(exprs);
99609f7a55fSRiver Riddle   assert(!maps.empty() && "Expected one non-empty map");
99709f7a55fSRiver Riddle   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
99809f7a55fSRiver Riddle 
99909f7a55fSRiver Riddle   AffineExpr expr;
100009f7a55fSRiver Riddle   bool dynamicPoisonBit = false;
100109f7a55fSRiver Riddle   int64_t runningSize = 1;
100209f7a55fSRiver Riddle   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
100309f7a55fSRiver Riddle     int64_t size = std::get<1>(en);
100409f7a55fSRiver Riddle     // Degenerate case, no size =-> no stride
100509f7a55fSRiver Riddle     if (size == 0)
100609f7a55fSRiver Riddle       continue;
100709f7a55fSRiver Riddle     AffineExpr dimExpr = std::get<0>(en);
100809f7a55fSRiver Riddle     AffineExpr stride = dynamicPoisonBit
100909f7a55fSRiver Riddle                             ? getAffineSymbolExpr(nSymbols++, context)
101009f7a55fSRiver Riddle                             : getAffineConstantExpr(runningSize, context);
101109f7a55fSRiver Riddle     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
10125f022ad6SAart Bik     if (size > 0) {
101309f7a55fSRiver Riddle       runningSize *= size;
10145f022ad6SAart Bik       assert(runningSize > 0 && "integer overflow in size computation");
10155f022ad6SAart Bik     } else {
101609f7a55fSRiver Riddle       dynamicPoisonBit = true;
101709f7a55fSRiver Riddle     }
10185f022ad6SAart Bik   }
101909f7a55fSRiver Riddle   return simplifyAffineExpr(expr, numDims, nSymbols);
102009f7a55fSRiver Riddle }
102109f7a55fSRiver Riddle 
102209f7a55fSRiver Riddle /// Return a version of `t` with a layout that has all dynamic offset and
102309f7a55fSRiver Riddle /// strides. This is used to erase the static layout.
eraseStridedLayout(MemRefType t)102409f7a55fSRiver Riddle MemRefType mlir::eraseStridedLayout(MemRefType t) {
102509f7a55fSRiver Riddle   auto val = ShapedType::kDynamicStrideOrOffset;
1026e41ebbecSVladislav Vinogradov   return MemRefType::Builder(t).setLayout(
1027e41ebbecSVladislav Vinogradov       AffineMapAttr::get(makeStridedLinearLayoutMap(
1028e41ebbecSVladislav Vinogradov           SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
102909f7a55fSRiver Riddle }
103009f7a55fSRiver Riddle 
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,MLIRContext * context)103109f7a55fSRiver Riddle AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
103209f7a55fSRiver Riddle                                                 MLIRContext *context) {
103309f7a55fSRiver Riddle   SmallVector<AffineExpr, 4> exprs;
103409f7a55fSRiver Riddle   exprs.reserve(sizes.size());
103509f7a55fSRiver Riddle   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
103609f7a55fSRiver Riddle     exprs.push_back(getAffineDimExpr(dim, context));
103709f7a55fSRiver Riddle   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
103809f7a55fSRiver Riddle }
103909f7a55fSRiver Riddle 
104009f7a55fSRiver Riddle /// Return true if the layout for `t` is compatible with strided semantics.
isStrided(MemRefType t)104109f7a55fSRiver Riddle bool mlir::isStrided(MemRefType t) {
104209f7a55fSRiver Riddle   int64_t offset;
10435bc4f884SNicolas Vasilache   SmallVector<int64_t, 4> strides;
10445bc4f884SNicolas Vasilache   auto res = getStridesAndOffset(t, strides, offset);
104509f7a55fSRiver Riddle   return succeeded(res);
104609f7a55fSRiver Riddle }
10475bc4f884SNicolas Vasilache 
10485bc4f884SNicolas Vasilache /// Return the layout map in strided linear layout AffineMap form.
10495bc4f884SNicolas Vasilache /// Return null if the layout is not compatible with a strided layout.
getStridedLinearLayoutMap(MemRefType t)10505bc4f884SNicolas Vasilache AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
10515bc4f884SNicolas Vasilache   int64_t offset;
10525bc4f884SNicolas Vasilache   SmallVector<int64_t, 4> strides;
10535bc4f884SNicolas Vasilache   if (failed(getStridesAndOffset(t, strides, offset)))
10545bc4f884SNicolas Vasilache     return AffineMap();
10555bc4f884SNicolas Vasilache   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
10565bc4f884SNicolas Vasilache }
1057aba437ceSBenoit Jacob 
1058aba437ceSBenoit Jacob /// Return the AffineExpr representation of the offset, assuming `memRefType`
1059aba437ceSBenoit Jacob /// is a strided memref.
getOffsetExpr(MemRefType memrefType)1060aba437ceSBenoit Jacob static AffineExpr getOffsetExpr(MemRefType memrefType) {
1061aba437ceSBenoit Jacob   SmallVector<AffineExpr> strides;
1062aba437ceSBenoit Jacob   AffineExpr offset;
1063aba437ceSBenoit Jacob   if (failed(getStridesAndOffset(memrefType, strides, offset)))
1064aba437ceSBenoit Jacob     assert(false && "expected strided memref");
1065aba437ceSBenoit Jacob   return offset;
1066aba437ceSBenoit Jacob }
1067aba437ceSBenoit Jacob 
1068aba437ceSBenoit Jacob /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1069aba437ceSBenoit Jacob /// `offset` AffineExpr.
makeContiguousRowMajorMemRefType(MLIRContext * context,ArrayRef<int64_t> shape,Type elementType,AffineExpr offset)1070aba437ceSBenoit Jacob static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
1071aba437ceSBenoit Jacob                                                    ArrayRef<int64_t> shape,
1072aba437ceSBenoit Jacob                                                    Type elementType,
1073aba437ceSBenoit Jacob                                                    AffineExpr offset) {
1074aba437ceSBenoit Jacob   AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1075aba437ceSBenoit Jacob   AffineExpr contiguousRowMajor = canonical + offset;
1076aba437ceSBenoit Jacob   AffineMap contiguousRowMajorMap =
1077aba437ceSBenoit Jacob       AffineMap::inferFromExprList({contiguousRowMajor})[0];
1078aba437ceSBenoit Jacob   return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1079aba437ceSBenoit Jacob }
1080aba437ceSBenoit Jacob 
1081aba437ceSBenoit Jacob /// Helper determining if a memref is static-shape and contiguous-row-major
1082aba437ceSBenoit Jacob /// layout, while still allowing for an arbitrary offset (any static or
1083aba437ceSBenoit Jacob /// dynamic value).
isStaticShapeAndContiguousRowMajor(MemRefType memrefType)1084aba437ceSBenoit Jacob bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1085aba437ceSBenoit Jacob   if (!memrefType.hasStaticShape())
1086aba437ceSBenoit Jacob     return false;
1087aba437ceSBenoit Jacob   AffineExpr offset = getOffsetExpr(memrefType);
1088aba437ceSBenoit Jacob   MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1089aba437ceSBenoit Jacob       memrefType.getContext(), memrefType.getShape(),
1090aba437ceSBenoit Jacob       memrefType.getElementType(), offset);
1091aba437ceSBenoit Jacob   return canonicalizeStridedLayout(memrefType) ==
1092aba437ceSBenoit Jacob          canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1093aba437ceSBenoit Jacob }
1094