1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file defines generic type utilities.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Attributes.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/IR/Value.h"
27 
28 using namespace mlir;
29 
30 Type mlir::getElementTypeOrSelf(Type type) {
31   if (auto st = type.dyn_cast<ShapedType>())
32     return st.getElementType();
33   return type;
34 }
35 
36 Type mlir::getElementTypeOrSelf(Value *val) {
37   return getElementTypeOrSelf(val->getType());
38 }
39 
40 Type mlir::getElementTypeOrSelf(Value &val) {
41   return getElementTypeOrSelf(val.getType());
42 }
43 
44 Type mlir::getElementTypeOrSelf(Attribute attr) {
45   return getElementTypeOrSelf(attr.getType());
46 }
47 
48 SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
49   SmallVector<Type, 10> fTypes;
50   t.getFlattenedTypes(fTypes);
51   return fTypes;
52 }
53 
54 /// Return true if the specified type is an opaque type with the specified
55 /// dialect and typeData.
56 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
57                                 StringRef typeData) {
58   if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
59     return opaque.getDialectNamespace().is(dialect) &&
60            opaque.getTypeData() == typeData;
61   return false;
62 }
63 
64 /// Returns success if the given two types have compatible shape. That is,
65 /// they are both scalars (not shaped), or they are both shaped types and at
66 /// least one is unranked or they have compatible dimensions. Dimensions are
67 /// compatible if at least one is dynamic or both are equal. The element type
68 /// does not matter.
69 LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
70   auto sType1 = type1.dyn_cast<ShapedType>();
71   auto sType2 = type2.dyn_cast<ShapedType>();
72 
73   // Either both or neither type should be shaped.
74   if (!sType1)
75     return success(!sType2);
76   if (!sType2)
77     return failure();
78 
79   if (!sType1.hasRank() || !sType2.hasRank())
80     return success();
81 
82   if (sType1.getRank() != sType2.getRank())
83     return failure();
84 
85   for (const auto &dims : llvm::zip(sType1.getShape(), sType2.getShape())) {
86     int64_t dim1 = std::get<0>(dims);
87     int64_t dim2 = std::get<1>(dims);
88     if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
89         dim1 != dim2)
90       return failure();
91   }
92   return success();
93 }
94 
95 OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it)
96     : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {}
97 
98 Type OperandElementTypeIterator::unwrap(Value *value) {
99   return value->getType().cast<ShapedType>().getElementType();
100 }
101 
102 ResultElementTypeIterator::ResultElementTypeIterator(ResultIterator it)
103     : llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
104 
105 Type ResultElementTypeIterator::unwrap(Value *value) {
106   return value->getType().cast<ShapedType>().getElementType();
107 }
108