1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file defines generic type utilities.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Attributes.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/IR/Value.h"
27 
28 using namespace mlir;
29 
30 Type mlir::getElementTypeOrSelf(Type type) {
31   if (auto st = type.dyn_cast<ShapedType>())
32     return st.getElementType();
33   return type;
34 }
35 
36 Type mlir::getElementTypeOrSelf(Value *val) {
37   return getElementTypeOrSelf(val->getType());
38 }
39 
40 Type mlir::getElementTypeOrSelf(Value &val) {
41   return getElementTypeOrSelf(val.getType());
42 }
43 
44 Type mlir::getElementTypeOrSelf(Attribute attr) {
45   return getElementTypeOrSelf(attr.getType());
46 }
47 
48 SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
49   SmallVector<Type, 10> fTypes;
50   t.getFlattenedTypes(fTypes);
51   return fTypes;
52 }
53 
54 /// Return true if the specified type is an opaque type with the specified
55 /// dialect and typeData.
56 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
57                                 StringRef typeData) {
58   if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
59     return opaque.getDialectNamespace().is(dialect) &&
60            opaque.getTypeData() == typeData;
61   return false;
62 }
63 
64 OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it)
65     : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {}
66 
67 Type OperandElementTypeIterator::unwrap(Value *value) {
68   return value->getType().cast<ShapedType>().getElementType();
69 }
70 
71 ResultElementTypeIterator::ResultElementTypeIterator(ResultIterator it)
72     : llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
73 
74 Type ResultElementTypeIterator::unwrap(Value *value) {
75   return value->getType().cast<ShapedType>().getElementType();
76 }
77