16b436eacSAlex Zinenko //===- TypeUtilities.cpp - Helper function for type queries ---------------===//
26b436eacSAlex Zinenko //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66b436eacSAlex Zinenko //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
86b436eacSAlex Zinenko //
96b436eacSAlex Zinenko // This file defines generic type utilities.
106b436eacSAlex Zinenko //
116b436eacSAlex Zinenko //===----------------------------------------------------------------------===//
126b436eacSAlex Zinenko 
136b436eacSAlex Zinenko #include "mlir/IR/TypeUtilities.h"
1425a20b8aSTres Popp 
1525a20b8aSTres Popp #include <numeric>
1625a20b8aSTres Popp 
176b436eacSAlex Zinenko #include "mlir/IR/Attributes.h"
1809f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
196b436eacSAlex Zinenko #include "mlir/IR/Types.h"
206b436eacSAlex Zinenko #include "mlir/IR/Value.h"
216b436eacSAlex Zinenko 
226b436eacSAlex Zinenko using namespace mlir;
236b436eacSAlex Zinenko 
getElementTypeOrSelf(Type type)246b436eacSAlex Zinenko Type mlir::getElementTypeOrSelf(Type type) {
256b436eacSAlex Zinenko   if (auto st = type.dyn_cast<ShapedType>())
266b436eacSAlex Zinenko     return st.getElementType();
276b436eacSAlex Zinenko   return type;
286b436eacSAlex Zinenko }
296b436eacSAlex Zinenko 
getElementTypeOrSelf(Value val)30e62a6956SRiver Riddle Type mlir::getElementTypeOrSelf(Value val) {
312bdf33ccSRiver Riddle   return getElementTypeOrSelf(val.getType());
326b436eacSAlex Zinenko }
336b436eacSAlex Zinenko 
getElementTypeOrSelf(Attribute attr)346b436eacSAlex Zinenko Type mlir::getElementTypeOrSelf(Attribute attr) {
356b436eacSAlex Zinenko   return getElementTypeOrSelf(attr.getType());
366b436eacSAlex Zinenko }
376b436eacSAlex Zinenko 
getFlattenedTypes(TupleType t)386b436eacSAlex Zinenko SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
396b436eacSAlex Zinenko   SmallVector<Type, 10> fTypes;
406b436eacSAlex Zinenko   t.getFlattenedTypes(fTypes);
416b436eacSAlex Zinenko   return fTypes;
426b436eacSAlex Zinenko }
436b436eacSAlex Zinenko 
44a477fbafSChris Lattner /// Return true if the specified type is an opaque type with the specified
45a477fbafSChris Lattner /// dialect and typeData.
isOpaqueTypeWithName(Type type,StringRef dialect,StringRef typeData)46a477fbafSChris Lattner bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
47a477fbafSChris Lattner                                 StringRef typeData) {
48a477fbafSChris Lattner   if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
4974e6a5b2SChris Lattner     return opaque.getDialectNamespace() == dialect &&
50a477fbafSChris Lattner            opaque.getTypeData() == typeData;
51a477fbafSChris Lattner   return false;
52a477fbafSChris Lattner }
53a477fbafSChris Lattner 
542d22b1e0SSmit Hinsu /// Returns success if the given two shapes are compatible. That is, they have
552d22b1e0SSmit Hinsu /// the same size and each pair of the elements are equal or one of them is
562d22b1e0SSmit Hinsu /// dynamic.
verifyCompatibleShape(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2)572d22b1e0SSmit Hinsu LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
582d22b1e0SSmit Hinsu                                           ArrayRef<int64_t> shape2) {
592d22b1e0SSmit Hinsu   if (shape1.size() != shape2.size())
602d22b1e0SSmit Hinsu     return failure();
61eeef50b1SFangrui Song   for (auto dims : llvm::zip(shape1, shape2)) {
622d22b1e0SSmit Hinsu     int64_t dim1 = std::get<0>(dims);
632d22b1e0SSmit Hinsu     int64_t dim2 = std::get<1>(dims);
642d22b1e0SSmit Hinsu     if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
652d22b1e0SSmit Hinsu         dim1 != dim2)
662d22b1e0SSmit Hinsu       return failure();
672d22b1e0SSmit Hinsu   }
682d22b1e0SSmit Hinsu   return success();
692d22b1e0SSmit Hinsu }
702d22b1e0SSmit Hinsu 
7185b46314SSmit Hinsu /// Returns success if the given two types have compatible shape. That is,
7285b46314SSmit Hinsu /// they are both scalars (not shaped), or they are both shaped types and at
7385b46314SSmit Hinsu /// least one is unranked or they have compatible dimensions. Dimensions are
7485b46314SSmit Hinsu /// compatible if at least one is dynamic or both are equal. The element type
7585b46314SSmit Hinsu /// does not matter.
verifyCompatibleShape(Type type1,Type type2)7685b46314SSmit Hinsu LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
7785b46314SSmit Hinsu   auto sType1 = type1.dyn_cast<ShapedType>();
7885b46314SSmit Hinsu   auto sType2 = type2.dyn_cast<ShapedType>();
7985b46314SSmit Hinsu 
8085b46314SSmit Hinsu   // Either both or neither type should be shaped.
8185b46314SSmit Hinsu   if (!sType1)
8285b46314SSmit Hinsu     return success(!sType2);
8385b46314SSmit Hinsu   if (!sType2)
8485b46314SSmit Hinsu     return failure();
8585b46314SSmit Hinsu 
8685b46314SSmit Hinsu   if (!sType1.hasRank() || !sType2.hasRank())
8785b46314SSmit Hinsu     return success();
8885b46314SSmit Hinsu 
892d22b1e0SSmit Hinsu   return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
9085b46314SSmit Hinsu }
9185b46314SSmit Hinsu 
92acaf85f7SJacques Pienaar /// Returns success if the given two arrays have the same number of elements and
93acaf85f7SJacques Pienaar /// each pair wise entries have compatible shape.
verifyCompatibleShapes(TypeRange types1,TypeRange types2)943dfa8614SRiver Riddle LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) {
95acaf85f7SJacques Pienaar   if (types1.size() != types2.size())
96acaf85f7SJacques Pienaar     return failure();
973dfa8614SRiver Riddle   for (auto it : llvm::zip_first(types1, types2))
98acaf85f7SJacques Pienaar     if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
99acaf85f7SJacques Pienaar       return failure();
100acaf85f7SJacques Pienaar   return success();
101acaf85f7SJacques Pienaar }
102acaf85f7SJacques Pienaar 
verifyCompatibleDims(ArrayRef<int64_t> dims)10325a20b8aSTres Popp LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
10425a20b8aSTres Popp   if (dims.empty())
10525a20b8aSTres Popp     return success();
10625a20b8aSTres Popp   auto staticDim = std::accumulate(
10725a20b8aSTres Popp       dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
10825a20b8aSTres Popp         return ShapedType::isDynamic(dim) ? fold : dim;
10925a20b8aSTres Popp       });
11025a20b8aSTres Popp   return success(llvm::all_of(dims, [&](auto dim) {
11125a20b8aSTres Popp     return ShapedType::isDynamic(dim) || dim == staticDim;
11225a20b8aSTres Popp   }));
11325a20b8aSTres Popp }
11425a20b8aSTres Popp 
11525a20b8aSTres Popp /// Returns success if all given types have compatible shapes. That is, they are
11625a20b8aSTres Popp /// all scalars (not shaped), or they are all shaped types and any ranked shapes
11725a20b8aSTres Popp /// have compatible dimensions. Dimensions are compatible if all non-dynamic
11825a20b8aSTres Popp /// dims are equal. The element type does not matter.
verifyCompatibleShapes(TypeRange types)11925a20b8aSTres Popp LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
12025a20b8aSTres Popp   auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
12125a20b8aSTres Popp       types, [](auto type) { return type.template dyn_cast<ShapedType>(); }));
12225a20b8aSTres Popp   // Return failure if some, but not all are not shaped. Return early if none
12325a20b8aSTres Popp   // are shaped also.
12425a20b8aSTres Popp   if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
12525a20b8aSTres Popp     return success();
12625a20b8aSTres Popp   if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
12725a20b8aSTres Popp     return failure();
12825a20b8aSTres Popp 
129*71705f53SJavier Setoain   // Return failure if some, but not all, are scalable vectors.
130*71705f53SJavier Setoain   bool hasScalableVecTypes = false;
131*71705f53SJavier Setoain   bool hasNonScalableVecTypes = false;
132*71705f53SJavier Setoain   for (Type t : types) {
133*71705f53SJavier Setoain     auto vType = t.dyn_cast<VectorType>();
134*71705f53SJavier Setoain     if (vType && vType.isScalable())
135*71705f53SJavier Setoain       hasScalableVecTypes = true;
136*71705f53SJavier Setoain     else
137*71705f53SJavier Setoain       hasNonScalableVecTypes = true;
138*71705f53SJavier Setoain     if (hasScalableVecTypes && hasNonScalableVecTypes)
139*71705f53SJavier Setoain       return failure();
140*71705f53SJavier Setoain   }
141*71705f53SJavier Setoain 
14225a20b8aSTres Popp   // Remove all unranked shapes
14325a20b8aSTres Popp   auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
14425a20b8aSTres Popp       shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
14525a20b8aSTres Popp   if (shapes.empty())
14625a20b8aSTres Popp     return success();
14725a20b8aSTres Popp 
14825a20b8aSTres Popp   // All ranks should be equal
14925a20b8aSTres Popp   auto firstRank = shapes.front().getRank();
15025a20b8aSTres Popp   if (llvm::any_of(shapes,
15125a20b8aSTres Popp                    [&](auto shape) { return firstRank != shape.getRank(); }))
15225a20b8aSTres Popp     return failure();
15325a20b8aSTres Popp 
15425a20b8aSTres Popp   for (unsigned i = 0; i < firstRank; ++i) {
15525a20b8aSTres Popp     // Retrieve all ranked dimensions
15625a20b8aSTres Popp     auto dims = llvm::to_vector<8>(llvm::map_range(
15725a20b8aSTres Popp         llvm::make_filter_range(
15825a20b8aSTres Popp             shapes, [&](auto shape) { return shape.getRank() >= i; }),
15925a20b8aSTres Popp         [&](auto shape) { return shape.getDimSize(i); }));
16025a20b8aSTres Popp     if (verifyCompatibleDims(dims).failed())
16125a20b8aSTres Popp       return failure();
16225a20b8aSTres Popp   }
16325a20b8aSTres Popp 
16425a20b8aSTres Popp   return success();
16525a20b8aSTres Popp }
16625a20b8aSTres Popp 
mapElement(Value value) const1676de6131fSRiver Riddle Type OperandElementTypeIterator::mapElement(Value value) const {
1682bdf33ccSRiver Riddle   return value.getType().cast<ShapedType>().getElementType();
1696b436eacSAlex Zinenko }
1706b436eacSAlex Zinenko 
mapElement(Value value) const1716de6131fSRiver Riddle Type ResultElementTypeIterator::mapElement(Value value) const {
1722bdf33ccSRiver Riddle   return value.getType().cast<ShapedType>().getElementType();
1736b436eacSAlex Zinenko }
174