1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines generic type utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/IR/TypeUtilities.h" 14 #include "mlir/IR/Attributes.h" 15 #include "mlir/IR/StandardTypes.h" 16 #include "mlir/IR/Types.h" 17 #include "mlir/IR/Value.h" 18 19 using namespace mlir; 20 21 Type mlir::getElementTypeOrSelf(Type type) { 22 if (auto st = type.dyn_cast<ShapedType>()) 23 return st.getElementType(); 24 return type; 25 } 26 27 Type mlir::getElementTypeOrSelf(Value val) { 28 return getElementTypeOrSelf(val.getType()); 29 } 30 31 Type mlir::getElementTypeOrSelf(Attribute attr) { 32 return getElementTypeOrSelf(attr.getType()); 33 } 34 35 SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) { 36 SmallVector<Type, 10> fTypes; 37 t.getFlattenedTypes(fTypes); 38 return fTypes; 39 } 40 41 /// Return true if the specified type is an opaque type with the specified 42 /// dialect and typeData. 43 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect, 44 StringRef typeData) { 45 if (auto opaque = type.dyn_cast<mlir::OpaqueType>()) 46 return opaque.getDialectNamespace() == dialect && 47 opaque.getTypeData() == typeData; 48 return false; 49 } 50 51 /// Returns success if the given two shapes are compatible. That is, they have 52 /// the same size and each pair of the elements are equal or one of them is 53 /// dynamic. 54 LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1, 55 ArrayRef<int64_t> shape2) { 56 if (shape1.size() != shape2.size()) 57 return failure(); 58 for (auto dims : llvm::zip(shape1, shape2)) { 59 int64_t dim1 = std::get<0>(dims); 60 int64_t dim2 = std::get<1>(dims); 61 if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && 62 dim1 != dim2) 63 return failure(); 64 } 65 return success(); 66 } 67 68 /// Returns success if the given two types have compatible shape. That is, 69 /// they are both scalars (not shaped), or they are both shaped types and at 70 /// least one is unranked or they have compatible dimensions. Dimensions are 71 /// compatible if at least one is dynamic or both are equal. The element type 72 /// does not matter. 73 LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { 74 auto sType1 = type1.dyn_cast<ShapedType>(); 75 auto sType2 = type2.dyn_cast<ShapedType>(); 76 77 // Either both or neither type should be shaped. 78 if (!sType1) 79 return success(!sType2); 80 if (!sType2) 81 return failure(); 82 83 if (!sType1.hasRank() || !sType2.hasRank()) 84 return success(); 85 86 return verifyCompatibleShape(sType1.getShape(), sType2.getShape()); 87 } 88 89 OperandElementTypeIterator::OperandElementTypeIterator( 90 Operation::operand_iterator it) 91 : llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>( 92 it, &unwrap) {} 93 94 Type OperandElementTypeIterator::unwrap(Value value) { 95 return value.getType().cast<ShapedType>().getElementType(); 96 } 97 98 ResultElementTypeIterator::ResultElementTypeIterator( 99 Operation::result_iterator it) 100 : llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value)>( 101 it, &unwrap) {} 102 103 Type ResultElementTypeIterator::unwrap(Value value) { 104 return value.getType().cast<ShapedType>().getElementType(); 105 } 106