109f7a55fSRiver Riddle //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
209f7a55fSRiver Riddle //
309f7a55fSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409f7a55fSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
509f7a55fSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609f7a55fSRiver Riddle //
709f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
809f7a55fSRiver Riddle
909f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
1009f7a55fSRiver Riddle #include "TypeDetail.h"
1109f7a55fSRiver Riddle #include "mlir/IR/AffineExpr.h"
1209f7a55fSRiver Riddle #include "mlir/IR/AffineMap.h"
13f3bf5c05SVladislav Vinogradov #include "mlir/IR/BuiltinAttributes.h"
14f3bf5c05SVladislav Vinogradov #include "mlir/IR/BuiltinDialect.h"
1509f7a55fSRiver Riddle #include "mlir/IR/Diagnostics.h"
1609f7a55fSRiver Riddle #include "mlir/IR/Dialect.h"
177ceffae1SRiver Riddle #include "mlir/IR/FunctionInterfaces.h"
18ee090870SMehdi Amini #include "mlir/IR/OpImplementation.h"
1923c9e8bcSAart Bik #include "mlir/IR/TensorEncoding.h"
2009f7a55fSRiver Riddle #include "llvm/ADT/APFloat.h"
2109f7a55fSRiver Riddle #include "llvm/ADT/BitVector.h"
22c7cae0e4SRiver Riddle #include "llvm/ADT/Sequence.h"
2309f7a55fSRiver Riddle #include "llvm/ADT/Twine.h"
240d01dfbcSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
2509f7a55fSRiver Riddle
2609f7a55fSRiver Riddle using namespace mlir;
2709f7a55fSRiver Riddle using namespace mlir::detail;
2809f7a55fSRiver Riddle
2909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
3095019de8SRiver Riddle /// Tablegen Type Definitions
3195019de8SRiver Riddle //===----------------------------------------------------------------------===//
3295019de8SRiver Riddle
3395019de8SRiver Riddle #define GET_TYPEDEF_CLASSES
3495019de8SRiver Riddle #include "mlir/IR/BuiltinTypes.cpp.inc"
3595019de8SRiver Riddle
3695019de8SRiver Riddle //===----------------------------------------------------------------------===//
3731bb8efdSRiver Riddle // BuiltinDialect
3831bb8efdSRiver Riddle //===----------------------------------------------------------------------===//
3931bb8efdSRiver Riddle
registerTypes()4031bb8efdSRiver Riddle void BuiltinDialect::registerTypes() {
4131bb8efdSRiver Riddle addTypes<
4231bb8efdSRiver Riddle #define GET_TYPEDEF_LIST
4331bb8efdSRiver Riddle #include "mlir/IR/BuiltinTypes.cpp.inc"
4431bb8efdSRiver Riddle >();
4531bb8efdSRiver Riddle }
4631bb8efdSRiver Riddle
4731bb8efdSRiver Riddle //===----------------------------------------------------------------------===//
4809f7a55fSRiver Riddle /// ComplexType
4909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
5009f7a55fSRiver Riddle
5109f7a55fSRiver Riddle /// Verify the construction of an integer type.
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType)5206e25d56SRiver Riddle LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
5309f7a55fSRiver Riddle Type elementType) {
5409f7a55fSRiver Riddle if (!elementType.isIntOrFloat())
5506e25d56SRiver Riddle return emitError() << "invalid element type for complex";
5609f7a55fSRiver Riddle return success();
5709f7a55fSRiver Riddle }
5809f7a55fSRiver Riddle
5909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
6009f7a55fSRiver Riddle // Integer Type
6109f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
6209f7a55fSRiver Riddle
6309f7a55fSRiver Riddle // static constexpr must have a definition (until in C++17 and inline variable).
6409f7a55fSRiver Riddle constexpr unsigned IntegerType::kMaxWidth;
6509f7a55fSRiver Riddle
6609f7a55fSRiver Riddle /// Verify the construction of an integer type.
verify(function_ref<InFlightDiagnostic ()> emitError,unsigned width,SignednessSemantics signedness)6706e25d56SRiver Riddle LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
6806e25d56SRiver Riddle unsigned width,
6909f7a55fSRiver Riddle SignednessSemantics signedness) {
7009f7a55fSRiver Riddle if (width > IntegerType::kMaxWidth) {
7106e25d56SRiver Riddle return emitError() << "integer bitwidth is limited to "
7209f7a55fSRiver Riddle << IntegerType::kMaxWidth << " bits";
7309f7a55fSRiver Riddle }
7409f7a55fSRiver Riddle return success();
7509f7a55fSRiver Riddle }
7609f7a55fSRiver Riddle
getWidth() const7709f7a55fSRiver Riddle unsigned IntegerType::getWidth() const { return getImpl()->width; }
7809f7a55fSRiver Riddle
getSignedness() const7909f7a55fSRiver Riddle IntegerType::SignednessSemantics IntegerType::getSignedness() const {
8009f7a55fSRiver Riddle return getImpl()->signedness;
8109f7a55fSRiver Riddle }
8209f7a55fSRiver Riddle
scaleElementBitwidth(unsigned scale)837310501fSNicolas Vasilache IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
847310501fSNicolas Vasilache if (!scale)
857310501fSNicolas Vasilache return IntegerType();
861b97cdf8SRiver Riddle return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
877310501fSNicolas Vasilache }
887310501fSNicolas Vasilache
8909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
9009f7a55fSRiver Riddle // Float Type
9109f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
9209f7a55fSRiver Riddle
getWidth()9309f7a55fSRiver Riddle unsigned FloatType::getWidth() {
9409f7a55fSRiver Riddle if (isa<Float16Type, BFloat16Type>())
9509f7a55fSRiver Riddle return 16;
9609f7a55fSRiver Riddle if (isa<Float32Type>())
9709f7a55fSRiver Riddle return 32;
9809f7a55fSRiver Riddle if (isa<Float64Type>())
9909f7a55fSRiver Riddle return 64;
100cf0173deSValentin Clement if (isa<Float80Type>())
101cf0173deSValentin Clement return 80;
102cf0173deSValentin Clement if (isa<Float128Type>())
103cf0173deSValentin Clement return 128;
10409f7a55fSRiver Riddle llvm_unreachable("unexpected float type");
10509f7a55fSRiver Riddle }
10609f7a55fSRiver Riddle
10709f7a55fSRiver Riddle /// Returns the floating semantics for the given type.
getFloatSemantics()10809f7a55fSRiver Riddle const llvm::fltSemantics &FloatType::getFloatSemantics() {
10909f7a55fSRiver Riddle if (isa<BFloat16Type>())
11009f7a55fSRiver Riddle return APFloat::BFloat();
11109f7a55fSRiver Riddle if (isa<Float16Type>())
11209f7a55fSRiver Riddle return APFloat::IEEEhalf();
11309f7a55fSRiver Riddle if (isa<Float32Type>())
11409f7a55fSRiver Riddle return APFloat::IEEEsingle();
11509f7a55fSRiver Riddle if (isa<Float64Type>())
11609f7a55fSRiver Riddle return APFloat::IEEEdouble();
117cf0173deSValentin Clement if (isa<Float80Type>())
118cf0173deSValentin Clement return APFloat::x87DoubleExtended();
119cf0173deSValentin Clement if (isa<Float128Type>())
120cf0173deSValentin Clement return APFloat::IEEEquad();
12109f7a55fSRiver Riddle llvm_unreachable("non-floating point type used");
12209f7a55fSRiver Riddle }
12309f7a55fSRiver Riddle
scaleElementBitwidth(unsigned scale)1247310501fSNicolas Vasilache FloatType FloatType::scaleElementBitwidth(unsigned scale) {
1257310501fSNicolas Vasilache if (!scale)
1267310501fSNicolas Vasilache return FloatType();
1277310501fSNicolas Vasilache MLIRContext *ctx = getContext();
1287310501fSNicolas Vasilache if (isF16() || isBF16()) {
1297310501fSNicolas Vasilache if (scale == 2)
1307310501fSNicolas Vasilache return FloatType::getF32(ctx);
1317310501fSNicolas Vasilache if (scale == 4)
1327310501fSNicolas Vasilache return FloatType::getF64(ctx);
1337310501fSNicolas Vasilache }
1347310501fSNicolas Vasilache if (isF32())
1357310501fSNicolas Vasilache if (scale == 2)
1367310501fSNicolas Vasilache return FloatType::getF64(ctx);
1377310501fSNicolas Vasilache return FloatType();
1387310501fSNicolas Vasilache }
1397310501fSNicolas Vasilache
getFPMantissaWidth()1401b2a1f84SWilliam S. Moses unsigned FloatType::getFPMantissaWidth() {
1411b2a1f84SWilliam S. Moses return APFloat::semanticsPrecision(getFloatSemantics());
1421b2a1f84SWilliam S. Moses }
1431b2a1f84SWilliam S. Moses
14409f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
14509f7a55fSRiver Riddle // FunctionType
14609f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
14709f7a55fSRiver Riddle
getNumInputs() const14809f7a55fSRiver Riddle unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
14909f7a55fSRiver Riddle
getInputs() const15009f7a55fSRiver Riddle ArrayRef<Type> FunctionType::getInputs() const {
15109f7a55fSRiver Riddle return getImpl()->getInputs();
15209f7a55fSRiver Riddle }
15309f7a55fSRiver Riddle
getNumResults() const15409f7a55fSRiver Riddle unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
15509f7a55fSRiver Riddle
getResults() const15609f7a55fSRiver Riddle ArrayRef<Type> FunctionType::getResults() const {
15709f7a55fSRiver Riddle return getImpl()->getResults();
15809f7a55fSRiver Riddle }
15909f7a55fSRiver Riddle
clone(TypeRange inputs,TypeRange results) const1607ceffae1SRiver Riddle FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
1617ceffae1SRiver Riddle return get(getContext(), inputs, results);
16209f7a55fSRiver Riddle }
16309f7a55fSRiver Riddle
1648066f22cSFabian Schuiki /// Returns a new function type with the specified arguments and results
1658066f22cSFabian Schuiki /// inserted.
getWithArgsAndResults(ArrayRef<unsigned> argIndices,TypeRange argTypes,ArrayRef<unsigned> resultIndices,TypeRange resultTypes)1668066f22cSFabian Schuiki FunctionType FunctionType::getWithArgsAndResults(
1678066f22cSFabian Schuiki ArrayRef<unsigned> argIndices, TypeRange argTypes,
1688066f22cSFabian Schuiki ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
1697ceffae1SRiver Riddle SmallVector<Type> argStorage, resultStorage;
1707ceffae1SRiver Riddle TypeRange newArgTypes = function_interface_impl::insertTypesInto(
1717ceffae1SRiver Riddle getInputs(), argIndices, argTypes, argStorage);
1727ceffae1SRiver Riddle TypeRange newResultTypes = function_interface_impl::insertTypesInto(
1737ceffae1SRiver Riddle getResults(), resultIndices, resultTypes, resultStorage);
1747ceffae1SRiver Riddle return clone(newArgTypes, newResultTypes);
1758066f22cSFabian Schuiki }
1768066f22cSFabian Schuiki
17709f7a55fSRiver Riddle /// Returns a new function type without the specified arguments and results.
17809f7a55fSRiver Riddle FunctionType
getWithoutArgsAndResults(const BitVector & argIndices,const BitVector & resultIndices)179d10d49dcSRiver Riddle FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
180d10d49dcSRiver Riddle const BitVector &resultIndices) {
1817ceffae1SRiver Riddle SmallVector<Type> argStorage, resultStorage;
1827ceffae1SRiver Riddle TypeRange newArgTypes = function_interface_impl::filterTypesOut(
1837ceffae1SRiver Riddle getInputs(), argIndices, argStorage);
1847ceffae1SRiver Riddle TypeRange newResultTypes = function_interface_impl::filterTypesOut(
1857ceffae1SRiver Riddle getResults(), resultIndices, resultStorage);
1867ceffae1SRiver Riddle return clone(newArgTypes, newResultTypes);
18709f7a55fSRiver Riddle }
18809f7a55fSRiver Riddle
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const189c42dd5dbSRiver Riddle void FunctionType::walkImmediateSubElements(
190c42dd5dbSRiver Riddle function_ref<void(Attribute)> walkAttrsFn,
191c42dd5dbSRiver Riddle function_ref<void(Type)> walkTypesFn) const {
192c42dd5dbSRiver Riddle for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
193c42dd5dbSRiver Riddle walkTypesFn(type);
194c42dd5dbSRiver Riddle }
195c42dd5dbSRiver Riddle
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const196*01eedbc7SRiver Riddle Type FunctionType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
197*01eedbc7SRiver Riddle ArrayRef<Type> replTypes) const {
198*01eedbc7SRiver Riddle unsigned numInputs = getNumInputs();
199*01eedbc7SRiver Riddle return get(getContext(), replTypes.take_front(numInputs),
200*01eedbc7SRiver Riddle replTypes.drop_front(numInputs));
201*01eedbc7SRiver Riddle }
202*01eedbc7SRiver Riddle
20309f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
20409f7a55fSRiver Riddle // OpaqueType
20509f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
20609f7a55fSRiver Riddle
20709f7a55fSRiver Riddle /// Verify the construction of an opaque type.
verify(function_ref<InFlightDiagnostic ()> emitError,StringAttr dialect,StringRef typeData)20806e25d56SRiver Riddle LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
209195730a6SRiver Riddle StringAttr dialect, StringRef typeData) {
21009f7a55fSRiver Riddle if (!Dialect::isValidNamespace(dialect.strref()))
21106e25d56SRiver Riddle return emitError() << "invalid dialect namespace '" << dialect << "'";
212109305e1SRiver Riddle
213109305e1SRiver Riddle // Check that the dialect is actually registered.
214109305e1SRiver Riddle MLIRContext *context = dialect.getContext();
215109305e1SRiver Riddle if (!context->allowsUnregisteredDialects() &&
216109305e1SRiver Riddle !context->getLoadedDialect(dialect.strref())) {
217109305e1SRiver Riddle return emitError()
218109305e1SRiver Riddle << "`!" << dialect << "<\"" << typeData << "\">"
219109305e1SRiver Riddle << "` type created with unregistered dialect. If this is "
220109305e1SRiver Riddle "intended, please call allowUnregisteredDialects() on the "
221109305e1SRiver Riddle "MLIRContext, or use -allow-unregistered-dialect with "
2220f9e6451SMehdi Amini "the MLIR opt tool used";
223109305e1SRiver Riddle }
224109305e1SRiver Riddle
22509f7a55fSRiver Riddle return success();
22609f7a55fSRiver Riddle }
22709f7a55fSRiver Riddle
22809f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
22909f7a55fSRiver Riddle // VectorType
23009f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
23109f7a55fSRiver Riddle
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,unsigned numScalableDims)23206e25d56SRiver Riddle LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
233a4830d14SJavier Setoain ArrayRef<int64_t> shape, Type elementType,
234a4830d14SJavier Setoain unsigned numScalableDims) {
23509f7a55fSRiver Riddle if (!isValidElementType(elementType))
236b096ac90SGeoffrey Martin-Noble return emitError()
237b096ac90SGeoffrey Martin-Noble << "vector elements must be int/index/float type but got "
238b096ac90SGeoffrey Martin-Noble << elementType;
23909f7a55fSRiver Riddle
24009f7a55fSRiver Riddle if (any_of(shape, [](int64_t i) { return i <= 0; }))
241b096ac90SGeoffrey Martin-Noble return emitError()
242b096ac90SGeoffrey Martin-Noble << "vector types must have positive constant sizes but got "
243b096ac90SGeoffrey Martin-Noble << shape;
24409f7a55fSRiver Riddle
24509f7a55fSRiver Riddle return success();
24609f7a55fSRiver Riddle }
24709f7a55fSRiver Riddle
scaleElementBitwidth(unsigned scale)2487310501fSNicolas Vasilache VectorType VectorType::scaleElementBitwidth(unsigned scale) {
2497310501fSNicolas Vasilache if (!scale)
2507310501fSNicolas Vasilache return VectorType();
2517310501fSNicolas Vasilache if (auto et = getElementType().dyn_cast<IntegerType>())
2527310501fSNicolas Vasilache if (auto scaledEt = et.scaleElementBitwidth(scale))
253a4830d14SJavier Setoain return VectorType::get(getShape(), scaledEt, getNumScalableDims());
2547310501fSNicolas Vasilache if (auto et = getElementType().dyn_cast<FloatType>())
2557310501fSNicolas Vasilache if (auto scaledEt = et.scaleElementBitwidth(scale))
256a4830d14SJavier Setoain return VectorType::get(getShape(), scaledEt, getNumScalableDims());
2577310501fSNicolas Vasilache return VectorType();
2587310501fSNicolas Vasilache }
2597310501fSNicolas Vasilache
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const260c42dd5dbSRiver Riddle void VectorType::walkImmediateSubElements(
261c42dd5dbSRiver Riddle function_ref<void(Attribute)> walkAttrsFn,
262c42dd5dbSRiver Riddle function_ref<void(Type)> walkTypesFn) const {
263c42dd5dbSRiver Riddle walkTypesFn(getElementType());
264c42dd5dbSRiver Riddle }
265c42dd5dbSRiver Riddle
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const266*01eedbc7SRiver Riddle Type VectorType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
267*01eedbc7SRiver Riddle ArrayRef<Type> replTypes) const {
268*01eedbc7SRiver Riddle return get(getShape(), replTypes.front(), getNumScalableDims());
269*01eedbc7SRiver Riddle }
270*01eedbc7SRiver Riddle
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const271676bfb2aSRiver Riddle VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
272159898d5SAdrian Kuegel Type elementType) const {
27330c67587SKazu Hirata return VectorType::get(shape.value_or(getShape()), elementType,
274676bfb2aSRiver Riddle getNumScalableDims());
275676bfb2aSRiver Riddle }
276676bfb2aSRiver Riddle
27709f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
27809f7a55fSRiver Riddle // TensorType
27909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
28009f7a55fSRiver Riddle
getElementType() const281676bfb2aSRiver Riddle Type TensorType::getElementType() const {
282676bfb2aSRiver Riddle return llvm::TypeSwitch<TensorType, Type>(*this)
283676bfb2aSRiver Riddle .Case<RankedTensorType, UnrankedTensorType>(
284676bfb2aSRiver Riddle [](auto type) { return type.getElementType(); });
285676bfb2aSRiver Riddle }
286676bfb2aSRiver Riddle
hasRank() const287676bfb2aSRiver Riddle bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
288676bfb2aSRiver Riddle
getShape() const289676bfb2aSRiver Riddle ArrayRef<int64_t> TensorType::getShape() const {
290676bfb2aSRiver Riddle return cast<RankedTensorType>().getShape();
291676bfb2aSRiver Riddle }
292676bfb2aSRiver Riddle
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const293676bfb2aSRiver Riddle TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
294676bfb2aSRiver Riddle Type elementType) const {
295676bfb2aSRiver Riddle if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
296676bfb2aSRiver Riddle if (shape)
297676bfb2aSRiver Riddle return RankedTensorType::get(*shape, elementType);
298676bfb2aSRiver Riddle return UnrankedTensorType::get(elementType);
299676bfb2aSRiver Riddle }
300676bfb2aSRiver Riddle
301676bfb2aSRiver Riddle auto rankedTy = cast<RankedTensorType>();
302676bfb2aSRiver Riddle if (!shape)
303676bfb2aSRiver Riddle return RankedTensorType::get(rankedTy.getShape(), elementType,
304676bfb2aSRiver Riddle rankedTy.getEncoding());
30530c67587SKazu Hirata return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
30630c67587SKazu Hirata rankedTy.getEncoding());
307676bfb2aSRiver Riddle }
308676bfb2aSRiver Riddle
30906e25d56SRiver Riddle // Check if "elementType" can be an element type of a tensor.
31006e25d56SRiver Riddle static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic ()> emitError,Type elementType)31106e25d56SRiver Riddle checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
31209f7a55fSRiver Riddle Type elementType) {
31309f7a55fSRiver Riddle if (!TensorType::isValidElementType(elementType))
31406e25d56SRiver Riddle return emitError() << "invalid tensor element type: " << elementType;
31509f7a55fSRiver Riddle return success();
31609f7a55fSRiver Riddle }
31709f7a55fSRiver Riddle
31809f7a55fSRiver Riddle /// Return true if the specified element type is ok in a tensor.
isValidElementType(Type type)31909f7a55fSRiver Riddle bool TensorType::isValidElementType(Type type) {
32009f7a55fSRiver Riddle // Note: Non standard/builtin types are allowed to exist within tensor
32109f7a55fSRiver Riddle // types. Dialects are expected to verify that tensor types have a valid
32209f7a55fSRiver Riddle // element type within that dialect.
32309f7a55fSRiver Riddle return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
32409f7a55fSRiver Riddle IndexType>() ||
325f8479d9dSRiver Riddle !llvm::isa<BuiltinDialect>(type.getDialect());
32609f7a55fSRiver Riddle }
32709f7a55fSRiver Riddle
32809f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
32909f7a55fSRiver Riddle // RankedTensorType
33009f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
33109f7a55fSRiver Riddle
33206e25d56SRiver Riddle LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,Attribute encoding)33306e25d56SRiver Riddle RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
3347714b405SAart Bik ArrayRef<int64_t> shape, Type elementType,
3357714b405SAart Bik Attribute encoding) {
3360d01dfbcSRiver Riddle for (int64_t s : shape)
33709f7a55fSRiver Riddle if (s < -1)
33806e25d56SRiver Riddle return emitError() << "invalid tensor dimension size";
33923c9e8bcSAart Bik if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
34023c9e8bcSAart Bik if (failed(v.verifyEncoding(shape, elementType, emitError)))
34123c9e8bcSAart Bik return failure();
34206e25d56SRiver Riddle return checkTensorElementType(emitError, elementType);
34309f7a55fSRiver Riddle }
34409f7a55fSRiver Riddle
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const345c42dd5dbSRiver Riddle void RankedTensorType::walkImmediateSubElements(
346c42dd5dbSRiver Riddle function_ref<void(Attribute)> walkAttrsFn,
347c42dd5dbSRiver Riddle function_ref<void(Type)> walkTypesFn) const {
348c42dd5dbSRiver Riddle walkTypesFn(getElementType());
349eb6c63cbSVladislav Vinogradov if (Attribute encoding = getEncoding())
350eb6c63cbSVladislav Vinogradov walkAttrsFn(encoding);
351c42dd5dbSRiver Riddle }
352c42dd5dbSRiver Riddle
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const353*01eedbc7SRiver Riddle Type RankedTensorType::replaceImmediateSubElements(
354*01eedbc7SRiver Riddle ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
355*01eedbc7SRiver Riddle return get(getShape(), replTypes.front(),
356*01eedbc7SRiver Riddle replAttrs.empty() ? Attribute() : replAttrs.back());
357*01eedbc7SRiver Riddle }
358*01eedbc7SRiver Riddle
35909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
36009f7a55fSRiver Riddle // UnrankedTensorType
36109f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
36209f7a55fSRiver Riddle
36309f7a55fSRiver Riddle LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType)36406e25d56SRiver Riddle UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
36509f7a55fSRiver Riddle Type elementType) {
36606e25d56SRiver Riddle return checkTensorElementType(emitError, elementType);
36709f7a55fSRiver Riddle }
36809f7a55fSRiver Riddle
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const369c42dd5dbSRiver Riddle void UnrankedTensorType::walkImmediateSubElements(
370c42dd5dbSRiver Riddle function_ref<void(Attribute)> walkAttrsFn,
371c42dd5dbSRiver Riddle function_ref<void(Type)> walkTypesFn) const {
372c42dd5dbSRiver Riddle walkTypesFn(getElementType());
373c42dd5dbSRiver Riddle }
374c42dd5dbSRiver Riddle
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const375*01eedbc7SRiver Riddle Type UnrankedTensorType::replaceImmediateSubElements(
376*01eedbc7SRiver Riddle ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
377*01eedbc7SRiver Riddle return get(replTypes.front());
378*01eedbc7SRiver Riddle }
379*01eedbc7SRiver Riddle
38009f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
38109f7a55fSRiver Riddle // BaseMemRefType
38209f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
38309f7a55fSRiver Riddle
getElementType() const384676bfb2aSRiver Riddle Type BaseMemRefType::getElementType() const {
385676bfb2aSRiver Riddle return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
386676bfb2aSRiver Riddle .Case<MemRefType, UnrankedMemRefType>(
387676bfb2aSRiver Riddle [](auto type) { return type.getElementType(); });
388676bfb2aSRiver Riddle }
389676bfb2aSRiver Riddle
hasRank() const390676bfb2aSRiver Riddle bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
391676bfb2aSRiver Riddle
getShape() const392676bfb2aSRiver Riddle ArrayRef<int64_t> BaseMemRefType::getShape() const {
393676bfb2aSRiver Riddle return cast<MemRefType>().getShape();
394676bfb2aSRiver Riddle }
395676bfb2aSRiver Riddle
cloneWith(Optional<ArrayRef<int64_t>> shape,Type elementType) const396676bfb2aSRiver Riddle BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape,
397676bfb2aSRiver Riddle Type elementType) const {
398676bfb2aSRiver Riddle if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
399676bfb2aSRiver Riddle if (!shape)
400676bfb2aSRiver Riddle return UnrankedMemRefType::get(elementType, getMemorySpace());
401676bfb2aSRiver Riddle MemRefType::Builder builder(*shape, elementType);
402676bfb2aSRiver Riddle builder.setMemorySpace(getMemorySpace());
403676bfb2aSRiver Riddle return builder;
404676bfb2aSRiver Riddle }
405676bfb2aSRiver Riddle
406676bfb2aSRiver Riddle MemRefType::Builder builder(cast<MemRefType>());
407676bfb2aSRiver Riddle if (shape)
408676bfb2aSRiver Riddle builder.setShape(*shape);
409676bfb2aSRiver Riddle builder.setElementType(elementType);
410676bfb2aSRiver Riddle return builder;
411676bfb2aSRiver Riddle }
412676bfb2aSRiver Riddle
getMemorySpace() const413f3bf5c05SVladislav Vinogradov Attribute BaseMemRefType::getMemorySpace() const {
414f3bf5c05SVladislav Vinogradov if (auto rankedMemRefTy = dyn_cast<MemRefType>())
415f3bf5c05SVladislav Vinogradov return rankedMemRefTy.getMemorySpace();
416f3bf5c05SVladislav Vinogradov return cast<UnrankedMemRefType>().getMemorySpace();
417f3bf5c05SVladislav Vinogradov }
418f3bf5c05SVladislav Vinogradov
getMemorySpaceAsInt() const41937eca08eSVladislav Vinogradov unsigned BaseMemRefType::getMemorySpaceAsInt() const {
4200d01dfbcSRiver Riddle if (auto rankedMemRefTy = dyn_cast<MemRefType>())
4210d01dfbcSRiver Riddle return rankedMemRefTy.getMemorySpaceAsInt();
4220d01dfbcSRiver Riddle return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
42309f7a55fSRiver Riddle }
42409f7a55fSRiver Riddle
42509f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
42609f7a55fSRiver Riddle // MemRefType
42709f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
42809f7a55fSRiver Riddle
4290fb4a201SAlex Zinenko /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
4300fb4a201SAlex Zinenko /// `originalShape` with some `1` entries erased, return the set of indices
4310fb4a201SAlex Zinenko /// that specifies which of the entries of `originalShape` are dropped to obtain
4320fb4a201SAlex Zinenko /// `reducedShape`. The returned mask can be applied as a projection to
4330fb4a201SAlex Zinenko /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
4340fb4a201SAlex Zinenko /// which dimensions must be kept when e.g. compute MemRef strides under
4350fb4a201SAlex Zinenko /// rank-reducing operations. Return None if reducedShape cannot be obtained
4360fb4a201SAlex Zinenko /// by dropping only `1` entries in `originalShape`.
4370fb4a201SAlex Zinenko llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> reducedShape)4380fb4a201SAlex Zinenko mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
4390fb4a201SAlex Zinenko ArrayRef<int64_t> reducedShape) {
4400fb4a201SAlex Zinenko size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
4410fb4a201SAlex Zinenko llvm::SmallDenseSet<unsigned> unusedDims;
4420fb4a201SAlex Zinenko unsigned reducedIdx = 0;
4430fb4a201SAlex Zinenko for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
444a08b750cSNicolas Vasilache // Greedily insert `originalIdx` if match.
4450fb4a201SAlex Zinenko if (reducedIdx < reducedRank &&
4460fb4a201SAlex Zinenko originalShape[originalIdx] == reducedShape[reducedIdx]) {
4470fb4a201SAlex Zinenko reducedIdx++;
4480fb4a201SAlex Zinenko continue;
4490fb4a201SAlex Zinenko }
4500fb4a201SAlex Zinenko
4510fb4a201SAlex Zinenko unusedDims.insert(originalIdx);
4520fb4a201SAlex Zinenko // If no match on `originalIdx`, the `originalShape` at this dimension
4530fb4a201SAlex Zinenko // must be 1, otherwise we bail.
4540fb4a201SAlex Zinenko if (originalShape[originalIdx] != 1)
4550fb4a201SAlex Zinenko return llvm::None;
4560fb4a201SAlex Zinenko }
4570fb4a201SAlex Zinenko // The whole reducedShape must be scanned, otherwise we bail.
4580fb4a201SAlex Zinenko if (reducedIdx != reducedRank)
4590fb4a201SAlex Zinenko return llvm::None;
4600fb4a201SAlex Zinenko return unusedDims;
4610fb4a201SAlex Zinenko }
4620fb4a201SAlex Zinenko
463a08b750cSNicolas Vasilache SliceVerificationResult
isRankReducedType(ShapedType originalType,ShapedType candidateReducedType)464a08b750cSNicolas Vasilache mlir::isRankReducedType(ShapedType originalType,
465a08b750cSNicolas Vasilache ShapedType candidateReducedType) {
466a08b750cSNicolas Vasilache if (originalType == candidateReducedType)
467a08b750cSNicolas Vasilache return SliceVerificationResult::Success;
468a08b750cSNicolas Vasilache
469a08b750cSNicolas Vasilache ShapedType originalShapedType = originalType.cast<ShapedType>();
470a08b750cSNicolas Vasilache ShapedType candidateReducedShapedType =
471a08b750cSNicolas Vasilache candidateReducedType.cast<ShapedType>();
472a08b750cSNicolas Vasilache
473a08b750cSNicolas Vasilache // Rank and size logic is valid for all ShapedTypes.
474a08b750cSNicolas Vasilache ArrayRef<int64_t> originalShape = originalShapedType.getShape();
475a08b750cSNicolas Vasilache ArrayRef<int64_t> candidateReducedShape =
476a08b750cSNicolas Vasilache candidateReducedShapedType.getShape();
477a08b750cSNicolas Vasilache unsigned originalRank = originalShape.size(),
478a08b750cSNicolas Vasilache candidateReducedRank = candidateReducedShape.size();
479a08b750cSNicolas Vasilache if (candidateReducedRank > originalRank)
480a08b750cSNicolas Vasilache return SliceVerificationResult::RankTooLarge;
481a08b750cSNicolas Vasilache
482a08b750cSNicolas Vasilache auto optionalUnusedDimsMask =
483a08b750cSNicolas Vasilache computeRankReductionMask(originalShape, candidateReducedShape);
484a08b750cSNicolas Vasilache
485a08b750cSNicolas Vasilache // Sizes cannot be matched in case empty vector is returned.
486037f0995SKazu Hirata if (!optionalUnusedDimsMask)
487a08b750cSNicolas Vasilache return SliceVerificationResult::SizeMismatch;
488a08b750cSNicolas Vasilache
489a08b750cSNicolas Vasilache if (originalShapedType.getElementType() !=
490a08b750cSNicolas Vasilache candidateReducedShapedType.getElementType())
491a08b750cSNicolas Vasilache return SliceVerificationResult::ElemTypeMismatch;
492a08b750cSNicolas Vasilache
493a08b750cSNicolas Vasilache return SliceVerificationResult::Success;
494a08b750cSNicolas Vasilache }
495a08b750cSNicolas Vasilache
isSupportedMemorySpace(Attribute memorySpace)496f3bf5c05SVladislav Vinogradov bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
497f3bf5c05SVladislav Vinogradov // Empty attribute is allowed as default memory space.
498f3bf5c05SVladislav Vinogradov if (!memorySpace)
499f3bf5c05SVladislav Vinogradov return true;
500f3bf5c05SVladislav Vinogradov
501f3bf5c05SVladislav Vinogradov // Supported built-in attributes.
502f3bf5c05SVladislav Vinogradov if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
503f3bf5c05SVladislav Vinogradov return true;
504f3bf5c05SVladislav Vinogradov
505f3bf5c05SVladislav Vinogradov // Allow custom dialect attributes.
506ee090870SMehdi Amini if (!isa<BuiltinDialect>(memorySpace.getDialect()))
507f3bf5c05SVladislav Vinogradov return true;
508f3bf5c05SVladislav Vinogradov
509f3bf5c05SVladislav Vinogradov return false;
510f3bf5c05SVladislav Vinogradov }
511f3bf5c05SVladislav Vinogradov
wrapIntegerMemorySpace(unsigned memorySpace,MLIRContext * ctx)512f3bf5c05SVladislav Vinogradov Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
513f3bf5c05SVladislav Vinogradov MLIRContext *ctx) {
514f3bf5c05SVladislav Vinogradov if (memorySpace == 0)
515f3bf5c05SVladislav Vinogradov return nullptr;
516f3bf5c05SVladislav Vinogradov
517f3bf5c05SVladislav Vinogradov return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
518f3bf5c05SVladislav Vinogradov }
519f3bf5c05SVladislav Vinogradov
skipDefaultMemorySpace(Attribute memorySpace)520f3bf5c05SVladislav Vinogradov Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
521f3bf5c05SVladislav Vinogradov IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
522f3bf5c05SVladislav Vinogradov if (intMemorySpace && intMemorySpace.getValue() == 0)
523f3bf5c05SVladislav Vinogradov return nullptr;
524f3bf5c05SVladislav Vinogradov
525f3bf5c05SVladislav Vinogradov return memorySpace;
526f3bf5c05SVladislav Vinogradov }
527f3bf5c05SVladislav Vinogradov
getMemorySpaceAsInt(Attribute memorySpace)528f3bf5c05SVladislav Vinogradov unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
529f3bf5c05SVladislav Vinogradov if (!memorySpace)
530f3bf5c05SVladislav Vinogradov return 0;
531f3bf5c05SVladislav Vinogradov
532f3bf5c05SVladislav Vinogradov assert(memorySpace.isa<IntegerAttr>() &&
533f3bf5c05SVladislav Vinogradov "Using `getMemorySpaceInteger` with non-Integer attribute");
534f3bf5c05SVladislav Vinogradov
535f3bf5c05SVladislav Vinogradov return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
536f3bf5c05SVladislav Vinogradov }
537f3bf5c05SVladislav Vinogradov
538f3bf5c05SVladislav Vinogradov MemRefType::Builder &
setMemorySpace(unsigned newMemorySpace)539f3bf5c05SVladislav Vinogradov MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) {
540f3bf5c05SVladislav Vinogradov memorySpace =
541f3bf5c05SVladislav Vinogradov wrapIntegerMemorySpace(newMemorySpace, elementType.getContext());
542f3bf5c05SVladislav Vinogradov return *this;
543f3bf5c05SVladislav Vinogradov }
544f3bf5c05SVladislav Vinogradov
getMemorySpaceAsInt() const545f3bf5c05SVladislav Vinogradov unsigned MemRefType::getMemorySpaceAsInt() const {
546f3bf5c05SVladislav Vinogradov return detail::getMemorySpaceAsInt(getMemorySpace());
547f3bf5c05SVladislav Vinogradov }
548f3bf5c05SVladislav Vinogradov
get(ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)549e41ebbecSVladislav Vinogradov MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
550e41ebbecSVladislav Vinogradov MemRefLayoutAttrInterface layout,
551e41ebbecSVladislav Vinogradov Attribute memorySpace) {
552e41ebbecSVladislav Vinogradov // Use default layout for empty attribute.
553e41ebbecSVladislav Vinogradov if (!layout)
554e41ebbecSVladislav Vinogradov layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
555e41ebbecSVladislav Vinogradov shape.size(), elementType.getContext()));
556e41ebbecSVladislav Vinogradov
557e41ebbecSVladislav Vinogradov // Drop default memory space value and replace it with empty attribute.
558e41ebbecSVladislav Vinogradov memorySpace = skipDefaultMemorySpace(memorySpace);
559e41ebbecSVladislav Vinogradov
560e41ebbecSVladislav Vinogradov return Base::get(elementType.getContext(), shape, elementType, layout,
561e41ebbecSVladislav Vinogradov memorySpace);
562e41ebbecSVladislav Vinogradov }
563e41ebbecSVladislav Vinogradov
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)564e41ebbecSVladislav Vinogradov MemRefType MemRefType::getChecked(
565e41ebbecSVladislav Vinogradov function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
566e41ebbecSVladislav Vinogradov Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
567e41ebbecSVladislav Vinogradov
568e41ebbecSVladislav Vinogradov // Use default layout for empty attribute.
569e41ebbecSVladislav Vinogradov if (!layout)
570e41ebbecSVladislav Vinogradov layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
571e41ebbecSVladislav Vinogradov shape.size(), elementType.getContext()));
572e41ebbecSVladislav Vinogradov
573e41ebbecSVladislav Vinogradov // Drop default memory space value and replace it with empty attribute.
574e41ebbecSVladislav Vinogradov memorySpace = skipDefaultMemorySpace(memorySpace);
575e41ebbecSVladislav Vinogradov
576e41ebbecSVladislav Vinogradov return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
577e41ebbecSVladislav Vinogradov elementType, layout, memorySpace);
578e41ebbecSVladislav Vinogradov }
579e41ebbecSVladislav Vinogradov
get(ArrayRef<int64_t> shape,Type elementType,AffineMap map,Attribute memorySpace)580e41ebbecSVladislav Vinogradov MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
581e41ebbecSVladislav Vinogradov AffineMap map, Attribute memorySpace) {
582e41ebbecSVladislav Vinogradov
583e41ebbecSVladislav Vinogradov // Use default layout for empty map.
584e41ebbecSVladislav Vinogradov if (!map)
585e41ebbecSVladislav Vinogradov map = AffineMap::getMultiDimIdentityMap(shape.size(),
586e41ebbecSVladislav Vinogradov elementType.getContext());
587e41ebbecSVladislav Vinogradov
588e41ebbecSVladislav Vinogradov // Wrap AffineMap into Attribute.
589e41ebbecSVladislav Vinogradov Attribute layout = AffineMapAttr::get(map);
590e41ebbecSVladislav Vinogradov
591e41ebbecSVladislav Vinogradov // Drop default memory space value and replace it with empty attribute.
592e41ebbecSVladislav Vinogradov memorySpace = skipDefaultMemorySpace(memorySpace);
593e41ebbecSVladislav Vinogradov
594e41ebbecSVladislav Vinogradov return Base::get(elementType.getContext(), shape, elementType, layout,
595e41ebbecSVladislav Vinogradov memorySpace);
596e41ebbecSVladislav Vinogradov }
597e41ebbecSVladislav Vinogradov
598e41ebbecSVladislav Vinogradov MemRefType
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,AffineMap map,Attribute memorySpace)599e41ebbecSVladislav Vinogradov MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
600e41ebbecSVladislav Vinogradov ArrayRef<int64_t> shape, Type elementType, AffineMap map,
601e41ebbecSVladislav Vinogradov Attribute memorySpace) {
602e41ebbecSVladislav Vinogradov
603e41ebbecSVladislav Vinogradov // Use default layout for empty map.
604e41ebbecSVladislav Vinogradov if (!map)
605e41ebbecSVladislav Vinogradov map = AffineMap::getMultiDimIdentityMap(shape.size(),
606e41ebbecSVladislav Vinogradov elementType.getContext());
607e41ebbecSVladislav Vinogradov
608e41ebbecSVladislav Vinogradov // Wrap AffineMap into Attribute.
609e41ebbecSVladislav Vinogradov Attribute layout = AffineMapAttr::get(map);
610e41ebbecSVladislav Vinogradov
611e41ebbecSVladislav Vinogradov // Drop default memory space value and replace it with empty attribute.
612e41ebbecSVladislav Vinogradov memorySpace = skipDefaultMemorySpace(memorySpace);
613e41ebbecSVladislav Vinogradov
614e41ebbecSVladislav Vinogradov return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
615e41ebbecSVladislav Vinogradov elementType, layout, memorySpace);
616e41ebbecSVladislav Vinogradov }
617e41ebbecSVladislav Vinogradov
get(ArrayRef<int64_t> shape,Type elementType,AffineMap map,unsigned memorySpaceInd)618e41ebbecSVladislav Vinogradov MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
619e41ebbecSVladislav Vinogradov AffineMap map, unsigned memorySpaceInd) {
620e41ebbecSVladislav Vinogradov
621e41ebbecSVladislav Vinogradov // Use default layout for empty map.
622e41ebbecSVladislav Vinogradov if (!map)
623e41ebbecSVladislav Vinogradov map = AffineMap::getMultiDimIdentityMap(shape.size(),
624e41ebbecSVladislav Vinogradov elementType.getContext());
625e41ebbecSVladislav Vinogradov
626e41ebbecSVladislav Vinogradov // Wrap AffineMap into Attribute.
627e41ebbecSVladislav Vinogradov Attribute layout = AffineMapAttr::get(map);
628e41ebbecSVladislav Vinogradov
629e41ebbecSVladislav Vinogradov // Convert deprecated integer-like memory space to Attribute.
630e41ebbecSVladislav Vinogradov Attribute memorySpace =
631e41ebbecSVladislav Vinogradov wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
632e41ebbecSVladislav Vinogradov
633e41ebbecSVladislav Vinogradov return Base::get(elementType.getContext(), shape, elementType, layout,
634e41ebbecSVladislav Vinogradov memorySpace);
635e41ebbecSVladislav Vinogradov }
636e41ebbecSVladislav Vinogradov
637e41ebbecSVladislav Vinogradov MemRefType
getChecked(function_ref<InFlightDiagnostic ()> emitErrorFn,ArrayRef<int64_t> shape,Type elementType,AffineMap map,unsigned memorySpaceInd)638e41ebbecSVladislav Vinogradov MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
639e41ebbecSVladislav Vinogradov ArrayRef<int64_t> shape, Type elementType, AffineMap map,
640e41ebbecSVladislav Vinogradov unsigned memorySpaceInd) {
641e41ebbecSVladislav Vinogradov
642e41ebbecSVladislav Vinogradov // Use default layout for empty map.
643e41ebbecSVladislav Vinogradov if (!map)
644e41ebbecSVladislav Vinogradov map = AffineMap::getMultiDimIdentityMap(shape.size(),
645e41ebbecSVladislav Vinogradov elementType.getContext());
646e41ebbecSVladislav Vinogradov
647e41ebbecSVladislav Vinogradov // Wrap AffineMap into Attribute.
648e41ebbecSVladislav Vinogradov Attribute layout = AffineMapAttr::get(map);
649e41ebbecSVladislav Vinogradov
650e41ebbecSVladislav Vinogradov // Convert deprecated integer-like memory space to Attribute.
651e41ebbecSVladislav Vinogradov Attribute memorySpace =
652e41ebbecSVladislav Vinogradov wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
653e41ebbecSVladislav Vinogradov
654e41ebbecSVladislav Vinogradov return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
655e41ebbecSVladislav Vinogradov elementType, layout, memorySpace);
656e41ebbecSVladislav Vinogradov }
657e41ebbecSVladislav Vinogradov
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<int64_t> shape,Type elementType,MemRefLayoutAttrInterface layout,Attribute memorySpace)6580d01dfbcSRiver Riddle LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
65906e25d56SRiver Riddle ArrayRef<int64_t> shape, Type elementType,
660e41ebbecSVladislav Vinogradov MemRefLayoutAttrInterface layout,
661f3bf5c05SVladislav Vinogradov Attribute memorySpace) {
66209f7a55fSRiver Riddle if (!BaseMemRefType::isValidElementType(elementType))
6630d01dfbcSRiver Riddle return emitError() << "invalid memref element type";
66409f7a55fSRiver Riddle
66509f7a55fSRiver Riddle // Negative sizes are not allowed except for `-1` that means dynamic size.
6660d01dfbcSRiver Riddle for (int64_t s : shape)
66709f7a55fSRiver Riddle if (s < -1)
6680d01dfbcSRiver Riddle return emitError() << "invalid memref size";
66909f7a55fSRiver Riddle
670e41ebbecSVladislav Vinogradov assert(layout && "missing layout specification");
671e41ebbecSVladislav Vinogradov if (failed(layout.verifyLayout(shape, emitError)))
672e41ebbecSVladislav Vinogradov return failure();
673f3bf5c05SVladislav Vinogradov
674e41ebbecSVladislav Vinogradov if (!isSupportedMemorySpace(memorySpace))
675f3bf5c05SVladislav Vinogradov return emitError() << "unsupported memory space Attribute";
676f3bf5c05SVladislav Vinogradov
6770d01dfbcSRiver Riddle return success();
67809f7a55fSRiver Riddle }
67909f7a55fSRiver Riddle
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const680c42dd5dbSRiver Riddle void MemRefType::walkImmediateSubElements(
681c42dd5dbSRiver Riddle function_ref<void(Attribute)> walkAttrsFn,
682c42dd5dbSRiver Riddle function_ref<void(Type)> walkTypesFn) const {
683c42dd5dbSRiver Riddle walkTypesFn(getElementType());
684e41ebbecSVladislav Vinogradov if (!getLayout().isIdentity())
685e41ebbecSVladislav Vinogradov walkAttrsFn(getLayout());
686c42dd5dbSRiver Riddle walkAttrsFn(getMemorySpace());
687c42dd5dbSRiver Riddle }
688c42dd5dbSRiver Riddle
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const689*01eedbc7SRiver Riddle Type MemRefType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
690*01eedbc7SRiver Riddle ArrayRef<Type> replTypes) const {
691*01eedbc7SRiver Riddle bool hasLayout = replAttrs.size() > 1;
692*01eedbc7SRiver Riddle return get(getShape(), replTypes[0],
693*01eedbc7SRiver Riddle hasLayout ? replAttrs[0].dyn_cast<MemRefLayoutAttrInterface>()
694*01eedbc7SRiver Riddle : MemRefLayoutAttrInterface(),
695*01eedbc7SRiver Riddle hasLayout ? replAttrs[1] : replAttrs[0]);
696*01eedbc7SRiver Riddle }
697*01eedbc7SRiver Riddle
69809f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
69909f7a55fSRiver Riddle // UnrankedMemRefType
70009f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
70109f7a55fSRiver Riddle
getMemorySpaceAsInt() const702f3bf5c05SVladislav Vinogradov unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
703f3bf5c05SVladislav Vinogradov return detail::getMemorySpaceAsInt(getMemorySpace());
704f3bf5c05SVladislav Vinogradov }
705f3bf5c05SVladislav Vinogradov
70609f7a55fSRiver Riddle LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType,Attribute memorySpace)70706e25d56SRiver Riddle UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
708f3bf5c05SVladislav Vinogradov Type elementType, Attribute memorySpace) {
70909f7a55fSRiver Riddle if (!BaseMemRefType::isValidElementType(elementType))
71006e25d56SRiver Riddle return emitError() << "invalid memref element type";
711f3bf5c05SVladislav Vinogradov
712f3bf5c05SVladislav Vinogradov if (!isSupportedMemorySpace(memorySpace))
713f3bf5c05SVladislav Vinogradov return emitError() << "unsupported memory space Attribute";
714f3bf5c05SVladislav Vinogradov
71509f7a55fSRiver Riddle return success();
71609f7a55fSRiver Riddle }
71709f7a55fSRiver Riddle
71809f7a55fSRiver Riddle // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
71909f7a55fSRiver Riddle // i.e. single term). Accumulate the AffineExpr into the existing one.
extractStridesFromTerm(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)72009f7a55fSRiver Riddle static void extractStridesFromTerm(AffineExpr e,
72109f7a55fSRiver Riddle AffineExpr multiplicativeFactor,
72209f7a55fSRiver Riddle MutableArrayRef<AffineExpr> strides,
72309f7a55fSRiver Riddle AffineExpr &offset) {
72409f7a55fSRiver Riddle if (auto dim = e.dyn_cast<AffineDimExpr>())
72509f7a55fSRiver Riddle strides[dim.getPosition()] =
72609f7a55fSRiver Riddle strides[dim.getPosition()] + multiplicativeFactor;
72709f7a55fSRiver Riddle else
72809f7a55fSRiver Riddle offset = offset + e * multiplicativeFactor;
72909f7a55fSRiver Riddle }
73009f7a55fSRiver Riddle
73109f7a55fSRiver Riddle /// Takes a single AffineExpr `e` and populates the `strides` array with the
73209f7a55fSRiver Riddle /// strides expressions for each dim position.
73309f7a55fSRiver Riddle /// The convention is that the strides for dimensions d0, .. dn appear in
73409f7a55fSRiver Riddle /// order to make indexing intuitive into the result.
extractStrides(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)73509f7a55fSRiver Riddle static LogicalResult extractStrides(AffineExpr e,
73609f7a55fSRiver Riddle AffineExpr multiplicativeFactor,
73709f7a55fSRiver Riddle MutableArrayRef<AffineExpr> strides,
73809f7a55fSRiver Riddle AffineExpr &offset) {
73909f7a55fSRiver Riddle auto bin = e.dyn_cast<AffineBinaryOpExpr>();
74009f7a55fSRiver Riddle if (!bin) {
74109f7a55fSRiver Riddle extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
74209f7a55fSRiver Riddle return success();
74309f7a55fSRiver Riddle }
74409f7a55fSRiver Riddle
74509f7a55fSRiver Riddle if (bin.getKind() == AffineExprKind::CeilDiv ||
74609f7a55fSRiver Riddle bin.getKind() == AffineExprKind::FloorDiv ||
74709f7a55fSRiver Riddle bin.getKind() == AffineExprKind::Mod)
74809f7a55fSRiver Riddle return failure();
74909f7a55fSRiver Riddle
75009f7a55fSRiver Riddle if (bin.getKind() == AffineExprKind::Mul) {
75109f7a55fSRiver Riddle auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
75209f7a55fSRiver Riddle if (dim) {
75309f7a55fSRiver Riddle strides[dim.getPosition()] =
75409f7a55fSRiver Riddle strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
75509f7a55fSRiver Riddle return success();
75609f7a55fSRiver Riddle }
75709f7a55fSRiver Riddle // LHS and RHS may both contain complex expressions of dims. Try one path
75809f7a55fSRiver Riddle // and if it fails try the other. This is guaranteed to succeed because
75909f7a55fSRiver Riddle // only one path may have a `dim`, otherwise this is not an AffineExpr in
76009f7a55fSRiver Riddle // the first place.
76109f7a55fSRiver Riddle if (bin.getLHS().isSymbolicOrConstant())
76209f7a55fSRiver Riddle return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
76309f7a55fSRiver Riddle strides, offset);
76409f7a55fSRiver Riddle return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
76509f7a55fSRiver Riddle strides, offset);
76609f7a55fSRiver Riddle }
76709f7a55fSRiver Riddle
76809f7a55fSRiver Riddle if (bin.getKind() == AffineExprKind::Add) {
76909f7a55fSRiver Riddle auto res1 =
77009f7a55fSRiver Riddle extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
77109f7a55fSRiver Riddle auto res2 =
77209f7a55fSRiver Riddle extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
77309f7a55fSRiver Riddle return success(succeeded(res1) && succeeded(res2));
77409f7a55fSRiver Riddle }
77509f7a55fSRiver Riddle
77609f7a55fSRiver Riddle llvm_unreachable("unexpected binary operation");
77709f7a55fSRiver Riddle }
77809f7a55fSRiver Riddle
getStridesAndOffset(MemRefType t,SmallVectorImpl<AffineExpr> & strides,AffineExpr & offset)77909f7a55fSRiver Riddle LogicalResult mlir::getStridesAndOffset(MemRefType t,
78009f7a55fSRiver Riddle SmallVectorImpl<AffineExpr> &strides,
78109f7a55fSRiver Riddle AffineExpr &offset) {
782e41ebbecSVladislav Vinogradov AffineMap m = t.getLayout().getAffineMap();
78370b6f16eSVladislav Vinogradov
784e41ebbecSVladislav Vinogradov if (m.getNumResults() != 1 && !m.isIdentity())
785fab634b4SNicolas Vasilache return failure();
786fab634b4SNicolas Vasilache
78709f7a55fSRiver Riddle auto zero = getAffineConstantExpr(0, t.getContext());
78809f7a55fSRiver Riddle auto one = getAffineConstantExpr(1, t.getContext());
78909f7a55fSRiver Riddle offset = zero;
79009f7a55fSRiver Riddle strides.assign(t.getRank(), zero);
79109f7a55fSRiver Riddle
79209f7a55fSRiver Riddle // Canonical case for empty map.
793e41ebbecSVladislav Vinogradov if (m.isIdentity()) {
79409f7a55fSRiver Riddle // 0-D corner case, offset is already 0.
79509f7a55fSRiver Riddle if (t.getRank() == 0)
79609f7a55fSRiver Riddle return success();
79709f7a55fSRiver Riddle auto stridedExpr =
79809f7a55fSRiver Riddle makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
79909f7a55fSRiver Riddle if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
80009f7a55fSRiver Riddle return success();
80109f7a55fSRiver Riddle assert(false && "unexpected failure: extract strides in canonical layout");
80209f7a55fSRiver Riddle }
80309f7a55fSRiver Riddle
80409f7a55fSRiver Riddle // Non-canonical case requires more work.
80509f7a55fSRiver Riddle auto stridedExpr =
80609f7a55fSRiver Riddle simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
80709f7a55fSRiver Riddle if (failed(extractStrides(stridedExpr, one, strides, offset))) {
80809f7a55fSRiver Riddle offset = AffineExpr();
80909f7a55fSRiver Riddle strides.clear();
81009f7a55fSRiver Riddle return failure();
81109f7a55fSRiver Riddle }
81209f7a55fSRiver Riddle
81309f7a55fSRiver Riddle // Simplify results to allow folding to constants and simple checks.
81409f7a55fSRiver Riddle unsigned numDims = m.getNumDims();
81509f7a55fSRiver Riddle unsigned numSymbols = m.getNumSymbols();
81609f7a55fSRiver Riddle offset = simplifyAffineExpr(offset, numDims, numSymbols);
81709f7a55fSRiver Riddle for (auto &stride : strides)
81809f7a55fSRiver Riddle stride = simplifyAffineExpr(stride, numDims, numSymbols);
81909f7a55fSRiver Riddle
82009f7a55fSRiver Riddle /// In practice, a strided memref must be internally non-aliasing. Test
82109f7a55fSRiver Riddle /// against 0 as a proxy.
82209f7a55fSRiver Riddle /// TODO: static cases can have more advanced checks.
82309f7a55fSRiver Riddle /// TODO: dynamic cases would require a way to compare symbolic
82409f7a55fSRiver Riddle /// expressions and would probably need an affine set context propagated
82509f7a55fSRiver Riddle /// everywhere.
82609f7a55fSRiver Riddle if (llvm::any_of(strides, [](AffineExpr e) {
82709f7a55fSRiver Riddle return e == getAffineConstantExpr(0, e.getContext());
82809f7a55fSRiver Riddle })) {
82909f7a55fSRiver Riddle offset = AffineExpr();
83009f7a55fSRiver Riddle strides.clear();
83109f7a55fSRiver Riddle return failure();
83209f7a55fSRiver Riddle }
83309f7a55fSRiver Riddle
83409f7a55fSRiver Riddle return success();
83509f7a55fSRiver Riddle }
83609f7a55fSRiver Riddle
getStridesAndOffset(MemRefType t,SmallVectorImpl<int64_t> & strides,int64_t & offset)83709f7a55fSRiver Riddle LogicalResult mlir::getStridesAndOffset(MemRefType t,
83809f7a55fSRiver Riddle SmallVectorImpl<int64_t> &strides,
83909f7a55fSRiver Riddle int64_t &offset) {
84009f7a55fSRiver Riddle AffineExpr offsetExpr;
84109f7a55fSRiver Riddle SmallVector<AffineExpr, 4> strideExprs;
84209f7a55fSRiver Riddle if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
84309f7a55fSRiver Riddle return failure();
84409f7a55fSRiver Riddle if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
84509f7a55fSRiver Riddle offset = cst.getValue();
84609f7a55fSRiver Riddle else
84709f7a55fSRiver Riddle offset = ShapedType::kDynamicStrideOrOffset;
84809f7a55fSRiver Riddle for (auto e : strideExprs) {
84909f7a55fSRiver Riddle if (auto c = e.dyn_cast<AffineConstantExpr>())
85009f7a55fSRiver Riddle strides.push_back(c.getValue());
85109f7a55fSRiver Riddle else
85209f7a55fSRiver Riddle strides.push_back(ShapedType::kDynamicStrideOrOffset);
85309f7a55fSRiver Riddle }
85409f7a55fSRiver Riddle return success();
85509f7a55fSRiver Riddle }
85609f7a55fSRiver Riddle
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const857c42dd5dbSRiver Riddle void UnrankedMemRefType::walkImmediateSubElements(
858c42dd5dbSRiver Riddle function_ref<void(Attribute)> walkAttrsFn,
859c42dd5dbSRiver Riddle function_ref<void(Type)> walkTypesFn) const {
860c42dd5dbSRiver Riddle walkTypesFn(getElementType());
861c42dd5dbSRiver Riddle walkAttrsFn(getMemorySpace());
862c42dd5dbSRiver Riddle }
863c42dd5dbSRiver Riddle
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const864*01eedbc7SRiver Riddle Type UnrankedMemRefType::replaceImmediateSubElements(
865*01eedbc7SRiver Riddle ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
866*01eedbc7SRiver Riddle return get(replTypes.front(), replAttrs.front());
867*01eedbc7SRiver Riddle }
868*01eedbc7SRiver Riddle
86909f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
87009f7a55fSRiver Riddle /// TupleType
87109f7a55fSRiver Riddle //===----------------------------------------------------------------------===//
87209f7a55fSRiver Riddle
87309f7a55fSRiver Riddle /// Return the elements types for this tuple.
getTypes() const87409f7a55fSRiver Riddle ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
87509f7a55fSRiver Riddle
87609f7a55fSRiver Riddle /// Accumulate the types contained in this tuple and tuples nested within it.
87709f7a55fSRiver Riddle /// Note that this only flattens nested tuples, not any other container type,
87809f7a55fSRiver Riddle /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
87909f7a55fSRiver Riddle /// (i32, tensor<i32>, f32, i64)
getFlattenedTypes(SmallVectorImpl<Type> & types)88009f7a55fSRiver Riddle void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
88109f7a55fSRiver Riddle for (Type type : getTypes()) {
88209f7a55fSRiver Riddle if (auto nestedTuple = type.dyn_cast<TupleType>())
88309f7a55fSRiver Riddle nestedTuple.getFlattenedTypes(types);
88409f7a55fSRiver Riddle else
88509f7a55fSRiver Riddle types.push_back(type);
88609f7a55fSRiver Riddle }
88709f7a55fSRiver Riddle }
88809f7a55fSRiver Riddle
88909f7a55fSRiver Riddle /// Return the number of element types.
size() const89009f7a55fSRiver Riddle size_t TupleType::size() const { return getImpl()->size(); }
89109f7a55fSRiver Riddle
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const892c42dd5dbSRiver Riddle void TupleType::walkImmediateSubElements(
893c42dd5dbSRiver Riddle function_ref<void(Attribute)> walkAttrsFn,
894c42dd5dbSRiver Riddle function_ref<void(Type)> walkTypesFn) const {
895c42dd5dbSRiver Riddle for (Type type : getTypes())
896c42dd5dbSRiver Riddle walkTypesFn(type);
897c42dd5dbSRiver Riddle }
898c42dd5dbSRiver Riddle
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const899*01eedbc7SRiver Riddle Type TupleType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
900*01eedbc7SRiver Riddle ArrayRef<Type> replTypes) const {
901*01eedbc7SRiver Riddle return get(getContext(), replTypes);
902*01eedbc7SRiver Riddle }
903*01eedbc7SRiver Riddle
904d79642b3SRiver Riddle //===----------------------------------------------------------------------===//
905d79642b3SRiver Riddle // Type Utilities
906d79642b3SRiver Riddle //===----------------------------------------------------------------------===//
907d79642b3SRiver Riddle
makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,int64_t offset,MLIRContext * context)90809f7a55fSRiver Riddle AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
90909f7a55fSRiver Riddle int64_t offset,
91009f7a55fSRiver Riddle MLIRContext *context) {
91109f7a55fSRiver Riddle AffineExpr expr;
91209f7a55fSRiver Riddle unsigned nSymbols = 0;
91309f7a55fSRiver Riddle
91409f7a55fSRiver Riddle // AffineExpr for offset.
91509f7a55fSRiver Riddle // Static case.
91609f7a55fSRiver Riddle if (offset != MemRefType::getDynamicStrideOrOffset()) {
91709f7a55fSRiver Riddle auto cst = getAffineConstantExpr(offset, context);
91809f7a55fSRiver Riddle expr = cst;
91909f7a55fSRiver Riddle } else {
92009f7a55fSRiver Riddle // Dynamic case, new symbol for the offset.
92109f7a55fSRiver Riddle auto sym = getAffineSymbolExpr(nSymbols++, context);
92209f7a55fSRiver Riddle expr = sym;
92309f7a55fSRiver Riddle }
92409f7a55fSRiver Riddle
92509f7a55fSRiver Riddle // AffineExpr for strides.
926e4853be2SMehdi Amini for (const auto &en : llvm::enumerate(strides)) {
92709f7a55fSRiver Riddle auto dim = en.index();
92809f7a55fSRiver Riddle auto stride = en.value();
92909f7a55fSRiver Riddle assert(stride != 0 && "Invalid stride specification");
93009f7a55fSRiver Riddle auto d = getAffineDimExpr(dim, context);
93109f7a55fSRiver Riddle AffineExpr mult;
93209f7a55fSRiver Riddle // Static case.
93309f7a55fSRiver Riddle if (stride != MemRefType::getDynamicStrideOrOffset())
93409f7a55fSRiver Riddle mult = getAffineConstantExpr(stride, context);
93509f7a55fSRiver Riddle else
93609f7a55fSRiver Riddle // Dynamic case, new symbol for each new stride.
93709f7a55fSRiver Riddle mult = getAffineSymbolExpr(nSymbols++, context);
93809f7a55fSRiver Riddle expr = expr + d * mult;
93909f7a55fSRiver Riddle }
94009f7a55fSRiver Riddle
94109f7a55fSRiver Riddle return AffineMap::get(strides.size(), nSymbols, expr);
94209f7a55fSRiver Riddle }
94309f7a55fSRiver Riddle
94409f7a55fSRiver Riddle /// Return a version of `t` with identity layout if it can be determined
94509f7a55fSRiver Riddle /// statically that the layout is the canonical contiguous strided layout.
94609f7a55fSRiver Riddle /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
94709f7a55fSRiver Riddle /// `t` with simplified layout.
94809f7a55fSRiver Riddle /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
canonicalizeStridedLayout(MemRefType t)94909f7a55fSRiver Riddle MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
950e41ebbecSVladislav Vinogradov AffineMap m = t.getLayout().getAffineMap();
951e41ebbecSVladislav Vinogradov
95209f7a55fSRiver Riddle // Already in canonical form.
953e41ebbecSVladislav Vinogradov if (m.isIdentity())
95409f7a55fSRiver Riddle return t;
95509f7a55fSRiver Riddle
95609f7a55fSRiver Riddle // Can't reduce to canonical identity form, return in canonical form.
957e41ebbecSVladislav Vinogradov if (m.getNumResults() > 1)
95809f7a55fSRiver Riddle return t;
95909f7a55fSRiver Riddle
9607e6fe5c4SNicolas Vasilache // Corner-case for 0-D affine maps.
9617e6fe5c4SNicolas Vasilache if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
9627e6fe5c4SNicolas Vasilache if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
9637e6fe5c4SNicolas Vasilache if (cst.getValue() == 0)
964e41ebbecSVladislav Vinogradov return MemRefType::Builder(t).setLayout({});
9657e6fe5c4SNicolas Vasilache return t;
9667e6fe5c4SNicolas Vasilache }
9677e6fe5c4SNicolas Vasilache
968f4ac9f03SNicolas Vasilache // 0-D corner case for empty shape that still have an affine map. Example:
969f4ac9f03SNicolas Vasilache // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
970f4ac9f03SNicolas Vasilache // offset needs to remain, just return t.
971f4ac9f03SNicolas Vasilache if (t.getShape().empty())
972f4ac9f03SNicolas Vasilache return t;
973f4ac9f03SNicolas Vasilache
97409f7a55fSRiver Riddle // If the canonical strided layout for the sizes of `t` is equal to the
97509f7a55fSRiver Riddle // simplified layout of `t` we can just return an empty layout. Otherwise,
97609f7a55fSRiver Riddle // just simplify the existing layout.
97709f7a55fSRiver Riddle AffineExpr expr =
97809f7a55fSRiver Riddle makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
97909f7a55fSRiver Riddle auto simplifiedLayoutExpr =
98009f7a55fSRiver Riddle simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
98109f7a55fSRiver Riddle if (expr != simplifiedLayoutExpr)
982e41ebbecSVladislav Vinogradov return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
983e41ebbecSVladislav Vinogradov m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
984e41ebbecSVladislav Vinogradov return MemRefType::Builder(t).setLayout({});
98509f7a55fSRiver Riddle }
98609f7a55fSRiver Riddle
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> exprs,MLIRContext * context)98709f7a55fSRiver Riddle AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
98809f7a55fSRiver Riddle ArrayRef<AffineExpr> exprs,
98909f7a55fSRiver Riddle MLIRContext *context) {
99009f7a55fSRiver Riddle // Size 0 corner case is useful for canonicalizations.
991dccb7331SBenjamin Kramer if (sizes.empty() || llvm::is_contained(sizes, 0))
99209f7a55fSRiver Riddle return getAffineConstantExpr(0, context);
99309f7a55fSRiver Riddle
994dccb7331SBenjamin Kramer assert(!exprs.empty() && "expected exprs");
99509f7a55fSRiver Riddle auto maps = AffineMap::inferFromExprList(exprs);
99609f7a55fSRiver Riddle assert(!maps.empty() && "Expected one non-empty map");
99709f7a55fSRiver Riddle unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
99809f7a55fSRiver Riddle
99909f7a55fSRiver Riddle AffineExpr expr;
100009f7a55fSRiver Riddle bool dynamicPoisonBit = false;
100109f7a55fSRiver Riddle int64_t runningSize = 1;
100209f7a55fSRiver Riddle for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
100309f7a55fSRiver Riddle int64_t size = std::get<1>(en);
100409f7a55fSRiver Riddle // Degenerate case, no size =-> no stride
100509f7a55fSRiver Riddle if (size == 0)
100609f7a55fSRiver Riddle continue;
100709f7a55fSRiver Riddle AffineExpr dimExpr = std::get<0>(en);
100809f7a55fSRiver Riddle AffineExpr stride = dynamicPoisonBit
100909f7a55fSRiver Riddle ? getAffineSymbolExpr(nSymbols++, context)
101009f7a55fSRiver Riddle : getAffineConstantExpr(runningSize, context);
101109f7a55fSRiver Riddle expr = expr ? expr + dimExpr * stride : dimExpr * stride;
10125f022ad6SAart Bik if (size > 0) {
101309f7a55fSRiver Riddle runningSize *= size;
10145f022ad6SAart Bik assert(runningSize > 0 && "integer overflow in size computation");
10155f022ad6SAart Bik } else {
101609f7a55fSRiver Riddle dynamicPoisonBit = true;
101709f7a55fSRiver Riddle }
10185f022ad6SAart Bik }
101909f7a55fSRiver Riddle return simplifyAffineExpr(expr, numDims, nSymbols);
102009f7a55fSRiver Riddle }
102109f7a55fSRiver Riddle
102209f7a55fSRiver Riddle /// Return a version of `t` with a layout that has all dynamic offset and
102309f7a55fSRiver Riddle /// strides. This is used to erase the static layout.
eraseStridedLayout(MemRefType t)102409f7a55fSRiver Riddle MemRefType mlir::eraseStridedLayout(MemRefType t) {
102509f7a55fSRiver Riddle auto val = ShapedType::kDynamicStrideOrOffset;
1026e41ebbecSVladislav Vinogradov return MemRefType::Builder(t).setLayout(
1027e41ebbecSVladislav Vinogradov AffineMapAttr::get(makeStridedLinearLayoutMap(
1028e41ebbecSVladislav Vinogradov SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())));
102909f7a55fSRiver Riddle }
103009f7a55fSRiver Riddle
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,MLIRContext * context)103109f7a55fSRiver Riddle AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
103209f7a55fSRiver Riddle MLIRContext *context) {
103309f7a55fSRiver Riddle SmallVector<AffineExpr, 4> exprs;
103409f7a55fSRiver Riddle exprs.reserve(sizes.size());
103509f7a55fSRiver Riddle for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
103609f7a55fSRiver Riddle exprs.push_back(getAffineDimExpr(dim, context));
103709f7a55fSRiver Riddle return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
103809f7a55fSRiver Riddle }
103909f7a55fSRiver Riddle
104009f7a55fSRiver Riddle /// Return true if the layout for `t` is compatible with strided semantics.
isStrided(MemRefType t)104109f7a55fSRiver Riddle bool mlir::isStrided(MemRefType t) {
104209f7a55fSRiver Riddle int64_t offset;
10435bc4f884SNicolas Vasilache SmallVector<int64_t, 4> strides;
10445bc4f884SNicolas Vasilache auto res = getStridesAndOffset(t, strides, offset);
104509f7a55fSRiver Riddle return succeeded(res);
104609f7a55fSRiver Riddle }
10475bc4f884SNicolas Vasilache
10485bc4f884SNicolas Vasilache /// Return the layout map in strided linear layout AffineMap form.
10495bc4f884SNicolas Vasilache /// Return null if the layout is not compatible with a strided layout.
getStridedLinearLayoutMap(MemRefType t)10505bc4f884SNicolas Vasilache AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
10515bc4f884SNicolas Vasilache int64_t offset;
10525bc4f884SNicolas Vasilache SmallVector<int64_t, 4> strides;
10535bc4f884SNicolas Vasilache if (failed(getStridesAndOffset(t, strides, offset)))
10545bc4f884SNicolas Vasilache return AffineMap();
10555bc4f884SNicolas Vasilache return makeStridedLinearLayoutMap(strides, offset, t.getContext());
10565bc4f884SNicolas Vasilache }
1057aba437ceSBenoit Jacob
1058aba437ceSBenoit Jacob /// Return the AffineExpr representation of the offset, assuming `memRefType`
1059aba437ceSBenoit Jacob /// is a strided memref.
getOffsetExpr(MemRefType memrefType)1060aba437ceSBenoit Jacob static AffineExpr getOffsetExpr(MemRefType memrefType) {
1061aba437ceSBenoit Jacob SmallVector<AffineExpr> strides;
1062aba437ceSBenoit Jacob AffineExpr offset;
1063aba437ceSBenoit Jacob if (failed(getStridesAndOffset(memrefType, strides, offset)))
1064aba437ceSBenoit Jacob assert(false && "expected strided memref");
1065aba437ceSBenoit Jacob return offset;
1066aba437ceSBenoit Jacob }
1067aba437ceSBenoit Jacob
1068aba437ceSBenoit Jacob /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1069aba437ceSBenoit Jacob /// `offset` AffineExpr.
makeContiguousRowMajorMemRefType(MLIRContext * context,ArrayRef<int64_t> shape,Type elementType,AffineExpr offset)1070aba437ceSBenoit Jacob static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
1071aba437ceSBenoit Jacob ArrayRef<int64_t> shape,
1072aba437ceSBenoit Jacob Type elementType,
1073aba437ceSBenoit Jacob AffineExpr offset) {
1074aba437ceSBenoit Jacob AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1075aba437ceSBenoit Jacob AffineExpr contiguousRowMajor = canonical + offset;
1076aba437ceSBenoit Jacob AffineMap contiguousRowMajorMap =
1077aba437ceSBenoit Jacob AffineMap::inferFromExprList({contiguousRowMajor})[0];
1078aba437ceSBenoit Jacob return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1079aba437ceSBenoit Jacob }
1080aba437ceSBenoit Jacob
1081aba437ceSBenoit Jacob /// Helper determining if a memref is static-shape and contiguous-row-major
1082aba437ceSBenoit Jacob /// layout, while still allowing for an arbitrary offset (any static or
1083aba437ceSBenoit Jacob /// dynamic value).
isStaticShapeAndContiguousRowMajor(MemRefType memrefType)1084aba437ceSBenoit Jacob bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1085aba437ceSBenoit Jacob if (!memrefType.hasStaticShape())
1086aba437ceSBenoit Jacob return false;
1087aba437ceSBenoit Jacob AffineExpr offset = getOffsetExpr(memrefType);
1088aba437ceSBenoit Jacob MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1089aba437ceSBenoit Jacob memrefType.getContext(), memrefType.getShape(),
1090aba437ceSBenoit Jacob memrefType.getElementType(), offset);
1091aba437ceSBenoit Jacob return canonicalizeStridedLayout(memrefType) ==
1092aba437ceSBenoit Jacob canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1093aba437ceSBenoit Jacob }
1094