1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines generic type utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/IR/TypeUtilities.h"
14 
15 #include <numeric>
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/IR/Value.h"
21 
22 using namespace mlir;
23 
getElementTypeOrSelf(Type type)24 Type mlir::getElementTypeOrSelf(Type type) {
25   if (auto st = type.dyn_cast<ShapedType>())
26     return st.getElementType();
27   return type;
28 }
29 
getElementTypeOrSelf(Value val)30 Type mlir::getElementTypeOrSelf(Value val) {
31   return getElementTypeOrSelf(val.getType());
32 }
33 
getElementTypeOrSelf(Attribute attr)34 Type mlir::getElementTypeOrSelf(Attribute attr) {
35   return getElementTypeOrSelf(attr.getType());
36 }
37 
getFlattenedTypes(TupleType t)38 SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
39   SmallVector<Type, 10> fTypes;
40   t.getFlattenedTypes(fTypes);
41   return fTypes;
42 }
43 
44 /// Return true if the specified type is an opaque type with the specified
45 /// dialect and typeData.
isOpaqueTypeWithName(Type type,StringRef dialect,StringRef typeData)46 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
47                                 StringRef typeData) {
48   if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
49     return opaque.getDialectNamespace() == dialect &&
50            opaque.getTypeData() == typeData;
51   return false;
52 }
53 
54 /// Returns success if the given two shapes are compatible. That is, they have
55 /// the same size and each pair of the elements are equal or one of them is
56 /// dynamic.
verifyCompatibleShape(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2)57 LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
58                                           ArrayRef<int64_t> shape2) {
59   if (shape1.size() != shape2.size())
60     return failure();
61   for (auto dims : llvm::zip(shape1, shape2)) {
62     int64_t dim1 = std::get<0>(dims);
63     int64_t dim2 = std::get<1>(dims);
64     if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
65         dim1 != dim2)
66       return failure();
67   }
68   return success();
69 }
70 
71 /// Returns success if the given two types have compatible shape. That is,
72 /// they are both scalars (not shaped), or they are both shaped types and at
73 /// least one is unranked or they have compatible dimensions. Dimensions are
74 /// compatible if at least one is dynamic or both are equal. The element type
75 /// does not matter.
verifyCompatibleShape(Type type1,Type type2)76 LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
77   auto sType1 = type1.dyn_cast<ShapedType>();
78   auto sType2 = type2.dyn_cast<ShapedType>();
79 
80   // Either both or neither type should be shaped.
81   if (!sType1)
82     return success(!sType2);
83   if (!sType2)
84     return failure();
85 
86   if (!sType1.hasRank() || !sType2.hasRank())
87     return success();
88 
89   return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
90 }
91 
92 /// Returns success if the given two arrays have the same number of elements and
93 /// each pair wise entries have compatible shape.
verifyCompatibleShapes(TypeRange types1,TypeRange types2)94 LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) {
95   if (types1.size() != types2.size())
96     return failure();
97   for (auto it : llvm::zip_first(types1, types2))
98     if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
99       return failure();
100   return success();
101 }
102 
verifyCompatibleDims(ArrayRef<int64_t> dims)103 LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
104   if (dims.empty())
105     return success();
106   auto staticDim = std::accumulate(
107       dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
108         return ShapedType::isDynamic(dim) ? fold : dim;
109       });
110   return success(llvm::all_of(dims, [&](auto dim) {
111     return ShapedType::isDynamic(dim) || dim == staticDim;
112   }));
113 }
114 
115 /// Returns success if all given types have compatible shapes. That is, they are
116 /// all scalars (not shaped), or they are all shaped types and any ranked shapes
117 /// have compatible dimensions. Dimensions are compatible if all non-dynamic
118 /// dims are equal. The element type does not matter.
verifyCompatibleShapes(TypeRange types)119 LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
120   auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
121       types, [](auto type) { return type.template dyn_cast<ShapedType>(); }));
122   // Return failure if some, but not all are not shaped. Return early if none
123   // are shaped also.
124   if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
125     return success();
126   if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
127     return failure();
128 
129   // Return failure if some, but not all, are scalable vectors.
130   bool hasScalableVecTypes = false;
131   bool hasNonScalableVecTypes = false;
132   for (Type t : types) {
133     auto vType = t.dyn_cast<VectorType>();
134     if (vType && vType.isScalable())
135       hasScalableVecTypes = true;
136     else
137       hasNonScalableVecTypes = true;
138     if (hasScalableVecTypes && hasNonScalableVecTypes)
139       return failure();
140   }
141 
142   // Remove all unranked shapes
143   auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
144       shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
145   if (shapes.empty())
146     return success();
147 
148   // All ranks should be equal
149   auto firstRank = shapes.front().getRank();
150   if (llvm::any_of(shapes,
151                    [&](auto shape) { return firstRank != shape.getRank(); }))
152     return failure();
153 
154   for (unsigned i = 0; i < firstRank; ++i) {
155     // Retrieve all ranked dimensions
156     auto dims = llvm::to_vector<8>(llvm::map_range(
157         llvm::make_filter_range(
158             shapes, [&](auto shape) { return shape.getRank() >= i; }),
159         [&](auto shape) { return shape.getDimSize(i); }));
160     if (verifyCompatibleDims(dims).failed())
161       return failure();
162   }
163 
164   return success();
165 }
166 
mapElement(Value value) const167 Type OperandElementTypeIterator::mapElement(Value value) const {
168   return value.getType().cast<ShapedType>().getElementType();
169 }
170 
mapElement(Value value) const171 Type ResultElementTypeIterator::mapElement(Value value) const {
172   return value.getType().cast<ShapedType>().getElementType();
173 }
174