16b436eacSAlex Zinenko //===- TypeUtilities.cpp - Helper function for type queries ---------------===//
26b436eacSAlex Zinenko //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66b436eacSAlex Zinenko //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
86b436eacSAlex Zinenko //
96b436eacSAlex Zinenko // This file defines generic type utilities.
106b436eacSAlex Zinenko //
116b436eacSAlex Zinenko //===----------------------------------------------------------------------===//
126b436eacSAlex Zinenko
136b436eacSAlex Zinenko #include "mlir/IR/TypeUtilities.h"
1425a20b8aSTres Popp
1525a20b8aSTres Popp #include <numeric>
1625a20b8aSTres Popp
176b436eacSAlex Zinenko #include "mlir/IR/Attributes.h"
1809f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
196b436eacSAlex Zinenko #include "mlir/IR/Types.h"
206b436eacSAlex Zinenko #include "mlir/IR/Value.h"
216b436eacSAlex Zinenko
226b436eacSAlex Zinenko using namespace mlir;
236b436eacSAlex Zinenko
getElementTypeOrSelf(Type type)246b436eacSAlex Zinenko Type mlir::getElementTypeOrSelf(Type type) {
256b436eacSAlex Zinenko if (auto st = type.dyn_cast<ShapedType>())
266b436eacSAlex Zinenko return st.getElementType();
276b436eacSAlex Zinenko return type;
286b436eacSAlex Zinenko }
296b436eacSAlex Zinenko
getElementTypeOrSelf(Value val)30e62a6956SRiver Riddle Type mlir::getElementTypeOrSelf(Value val) {
312bdf33ccSRiver Riddle return getElementTypeOrSelf(val.getType());
326b436eacSAlex Zinenko }
336b436eacSAlex Zinenko
getElementTypeOrSelf(Attribute attr)346b436eacSAlex Zinenko Type mlir::getElementTypeOrSelf(Attribute attr) {
356b436eacSAlex Zinenko return getElementTypeOrSelf(attr.getType());
366b436eacSAlex Zinenko }
376b436eacSAlex Zinenko
getFlattenedTypes(TupleType t)386b436eacSAlex Zinenko SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
396b436eacSAlex Zinenko SmallVector<Type, 10> fTypes;
406b436eacSAlex Zinenko t.getFlattenedTypes(fTypes);
416b436eacSAlex Zinenko return fTypes;
426b436eacSAlex Zinenko }
436b436eacSAlex Zinenko
44a477fbafSChris Lattner /// Return true if the specified type is an opaque type with the specified
45a477fbafSChris Lattner /// dialect and typeData.
isOpaqueTypeWithName(Type type,StringRef dialect,StringRef typeData)46a477fbafSChris Lattner bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
47a477fbafSChris Lattner StringRef typeData) {
48a477fbafSChris Lattner if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
4974e6a5b2SChris Lattner return opaque.getDialectNamespace() == dialect &&
50a477fbafSChris Lattner opaque.getTypeData() == typeData;
51a477fbafSChris Lattner return false;
52a477fbafSChris Lattner }
53a477fbafSChris Lattner
542d22b1e0SSmit Hinsu /// Returns success if the given two shapes are compatible. That is, they have
552d22b1e0SSmit Hinsu /// the same size and each pair of the elements are equal or one of them is
562d22b1e0SSmit Hinsu /// dynamic.
verifyCompatibleShape(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2)572d22b1e0SSmit Hinsu LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
582d22b1e0SSmit Hinsu ArrayRef<int64_t> shape2) {
592d22b1e0SSmit Hinsu if (shape1.size() != shape2.size())
602d22b1e0SSmit Hinsu return failure();
61eeef50b1SFangrui Song for (auto dims : llvm::zip(shape1, shape2)) {
622d22b1e0SSmit Hinsu int64_t dim1 = std::get<0>(dims);
632d22b1e0SSmit Hinsu int64_t dim2 = std::get<1>(dims);
642d22b1e0SSmit Hinsu if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
652d22b1e0SSmit Hinsu dim1 != dim2)
662d22b1e0SSmit Hinsu return failure();
672d22b1e0SSmit Hinsu }
682d22b1e0SSmit Hinsu return success();
692d22b1e0SSmit Hinsu }
702d22b1e0SSmit Hinsu
7185b46314SSmit Hinsu /// Returns success if the given two types have compatible shape. That is,
7285b46314SSmit Hinsu /// they are both scalars (not shaped), or they are both shaped types and at
7385b46314SSmit Hinsu /// least one is unranked or they have compatible dimensions. Dimensions are
7485b46314SSmit Hinsu /// compatible if at least one is dynamic or both are equal. The element type
7585b46314SSmit Hinsu /// does not matter.
verifyCompatibleShape(Type type1,Type type2)7685b46314SSmit Hinsu LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
7785b46314SSmit Hinsu auto sType1 = type1.dyn_cast<ShapedType>();
7885b46314SSmit Hinsu auto sType2 = type2.dyn_cast<ShapedType>();
7985b46314SSmit Hinsu
8085b46314SSmit Hinsu // Either both or neither type should be shaped.
8185b46314SSmit Hinsu if (!sType1)
8285b46314SSmit Hinsu return success(!sType2);
8385b46314SSmit Hinsu if (!sType2)
8485b46314SSmit Hinsu return failure();
8585b46314SSmit Hinsu
8685b46314SSmit Hinsu if (!sType1.hasRank() || !sType2.hasRank())
8785b46314SSmit Hinsu return success();
8885b46314SSmit Hinsu
892d22b1e0SSmit Hinsu return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
9085b46314SSmit Hinsu }
9185b46314SSmit Hinsu
92acaf85f7SJacques Pienaar /// Returns success if the given two arrays have the same number of elements and
93acaf85f7SJacques Pienaar /// each pair wise entries have compatible shape.
verifyCompatibleShapes(TypeRange types1,TypeRange types2)943dfa8614SRiver Riddle LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) {
95acaf85f7SJacques Pienaar if (types1.size() != types2.size())
96acaf85f7SJacques Pienaar return failure();
973dfa8614SRiver Riddle for (auto it : llvm::zip_first(types1, types2))
98acaf85f7SJacques Pienaar if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
99acaf85f7SJacques Pienaar return failure();
100acaf85f7SJacques Pienaar return success();
101acaf85f7SJacques Pienaar }
102acaf85f7SJacques Pienaar
verifyCompatibleDims(ArrayRef<int64_t> dims)10325a20b8aSTres Popp LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
10425a20b8aSTres Popp if (dims.empty())
10525a20b8aSTres Popp return success();
10625a20b8aSTres Popp auto staticDim = std::accumulate(
10725a20b8aSTres Popp dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
10825a20b8aSTres Popp return ShapedType::isDynamic(dim) ? fold : dim;
10925a20b8aSTres Popp });
11025a20b8aSTres Popp return success(llvm::all_of(dims, [&](auto dim) {
11125a20b8aSTres Popp return ShapedType::isDynamic(dim) || dim == staticDim;
11225a20b8aSTres Popp }));
11325a20b8aSTres Popp }
11425a20b8aSTres Popp
11525a20b8aSTres Popp /// Returns success if all given types have compatible shapes. That is, they are
11625a20b8aSTres Popp /// all scalars (not shaped), or they are all shaped types and any ranked shapes
11725a20b8aSTres Popp /// have compatible dimensions. Dimensions are compatible if all non-dynamic
11825a20b8aSTres Popp /// dims are equal. The element type does not matter.
verifyCompatibleShapes(TypeRange types)11925a20b8aSTres Popp LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
12025a20b8aSTres Popp auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
12125a20b8aSTres Popp types, [](auto type) { return type.template dyn_cast<ShapedType>(); }));
12225a20b8aSTres Popp // Return failure if some, but not all are not shaped. Return early if none
12325a20b8aSTres Popp // are shaped also.
12425a20b8aSTres Popp if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
12525a20b8aSTres Popp return success();
12625a20b8aSTres Popp if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
12725a20b8aSTres Popp return failure();
12825a20b8aSTres Popp
129*71705f53SJavier Setoain // Return failure if some, but not all, are scalable vectors.
130*71705f53SJavier Setoain bool hasScalableVecTypes = false;
131*71705f53SJavier Setoain bool hasNonScalableVecTypes = false;
132*71705f53SJavier Setoain for (Type t : types) {
133*71705f53SJavier Setoain auto vType = t.dyn_cast<VectorType>();
134*71705f53SJavier Setoain if (vType && vType.isScalable())
135*71705f53SJavier Setoain hasScalableVecTypes = true;
136*71705f53SJavier Setoain else
137*71705f53SJavier Setoain hasNonScalableVecTypes = true;
138*71705f53SJavier Setoain if (hasScalableVecTypes && hasNonScalableVecTypes)
139*71705f53SJavier Setoain return failure();
140*71705f53SJavier Setoain }
141*71705f53SJavier Setoain
14225a20b8aSTres Popp // Remove all unranked shapes
14325a20b8aSTres Popp auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
14425a20b8aSTres Popp shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
14525a20b8aSTres Popp if (shapes.empty())
14625a20b8aSTres Popp return success();
14725a20b8aSTres Popp
14825a20b8aSTres Popp // All ranks should be equal
14925a20b8aSTres Popp auto firstRank = shapes.front().getRank();
15025a20b8aSTres Popp if (llvm::any_of(shapes,
15125a20b8aSTres Popp [&](auto shape) { return firstRank != shape.getRank(); }))
15225a20b8aSTres Popp return failure();
15325a20b8aSTres Popp
15425a20b8aSTres Popp for (unsigned i = 0; i < firstRank; ++i) {
15525a20b8aSTres Popp // Retrieve all ranked dimensions
15625a20b8aSTres Popp auto dims = llvm::to_vector<8>(llvm::map_range(
15725a20b8aSTres Popp llvm::make_filter_range(
15825a20b8aSTres Popp shapes, [&](auto shape) { return shape.getRank() >= i; }),
15925a20b8aSTres Popp [&](auto shape) { return shape.getDimSize(i); }));
16025a20b8aSTres Popp if (verifyCompatibleDims(dims).failed())
16125a20b8aSTres Popp return failure();
16225a20b8aSTres Popp }
16325a20b8aSTres Popp
16425a20b8aSTres Popp return success();
16525a20b8aSTres Popp }
16625a20b8aSTres Popp
mapElement(Value value) const1676de6131fSRiver Riddle Type OperandElementTypeIterator::mapElement(Value value) const {
1682bdf33ccSRiver Riddle return value.getType().cast<ShapedType>().getElementType();
1696b436eacSAlex Zinenko }
1706b436eacSAlex Zinenko
mapElement(Value value) const1716de6131fSRiver Riddle Type ResultElementTypeIterator::mapElement(Value value) const {
1722bdf33ccSRiver Riddle return value.getType().cast<ShapedType>().getElementType();
1736b436eacSAlex Zinenko }
174