1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines generic type utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/IR/TypeUtilities.h" 14 15 #include <numeric> 16 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/IR/Value.h" 21 22 using namespace mlir; 23 24 Type mlir::getElementTypeOrSelf(Type type) { 25 if (auto st = type.dyn_cast<ShapedType>()) 26 return st.getElementType(); 27 return type; 28 } 29 30 Type mlir::getElementTypeOrSelf(Value val) { 31 return getElementTypeOrSelf(val.getType()); 32 } 33 34 Type mlir::getElementTypeOrSelf(Attribute attr) { 35 return getElementTypeOrSelf(attr.getType()); 36 } 37 38 SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) { 39 SmallVector<Type, 10> fTypes; 40 t.getFlattenedTypes(fTypes); 41 return fTypes; 42 } 43 44 /// Return true if the specified type is an opaque type with the specified 45 /// dialect and typeData. 46 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect, 47 StringRef typeData) { 48 if (auto opaque = type.dyn_cast<mlir::OpaqueType>()) 49 return opaque.getDialectNamespace() == dialect && 50 opaque.getTypeData() == typeData; 51 return false; 52 } 53 54 /// Returns success if the given two shapes are compatible. That is, they have 55 /// the same size and each pair of the elements are equal or one of them is 56 /// dynamic. 57 LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1, 58 ArrayRef<int64_t> shape2) { 59 if (shape1.size() != shape2.size()) 60 return failure(); 61 for (auto dims : llvm::zip(shape1, shape2)) { 62 int64_t dim1 = std::get<0>(dims); 63 int64_t dim2 = std::get<1>(dims); 64 if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && 65 dim1 != dim2) 66 return failure(); 67 } 68 return success(); 69 } 70 71 /// Returns success if the given two types have compatible shape. That is, 72 /// they are both scalars (not shaped), or they are both shaped types and at 73 /// least one is unranked or they have compatible dimensions. Dimensions are 74 /// compatible if at least one is dynamic or both are equal. The element type 75 /// does not matter. 76 LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { 77 auto sType1 = type1.dyn_cast<ShapedType>(); 78 auto sType2 = type2.dyn_cast<ShapedType>(); 79 80 // Either both or neither type should be shaped. 81 if (!sType1) 82 return success(!sType2); 83 if (!sType2) 84 return failure(); 85 86 if (!sType1.hasRank() || !sType2.hasRank()) 87 return success(); 88 89 return verifyCompatibleShape(sType1.getShape(), sType2.getShape()); 90 } 91 92 /// Returns success if the given two arrays have the same number of elements and 93 /// each pair wise entries have compatible shape. 94 LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) { 95 if (types1.size() != types2.size()) 96 return failure(); 97 for (auto it : llvm::zip_first(types1, types2)) 98 if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it)))) 99 return failure(); 100 return success(); 101 } 102 103 LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) { 104 if (dims.empty()) 105 return success(); 106 auto staticDim = std::accumulate( 107 dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) { 108 return ShapedType::isDynamic(dim) ? fold : dim; 109 }); 110 return success(llvm::all_of(dims, [&](auto dim) { 111 return ShapedType::isDynamic(dim) || dim == staticDim; 112 })); 113 } 114 115 /// Returns success if all given types have compatible shapes. That is, they are 116 /// all scalars (not shaped), or they are all shaped types and any ranked shapes 117 /// have compatible dimensions. Dimensions are compatible if all non-dynamic 118 /// dims are equal. The element type does not matter. 119 LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { 120 auto shapedTypes = llvm::to_vector<8>(llvm::map_range( 121 types, [](auto type) { return type.template dyn_cast<ShapedType>(); })); 122 // Return failure if some, but not all are not shaped. Return early if none 123 // are shaped also. 124 if (llvm::none_of(shapedTypes, [](auto t) { return t; })) 125 return success(); 126 if (!llvm::all_of(shapedTypes, [](auto t) { return t; })) 127 return failure(); 128 129 // Remove all unranked shapes 130 auto shapes = llvm::to_vector<8>(llvm::make_filter_range( 131 shapedTypes, [](auto shapedType) { return shapedType.hasRank(); })); 132 if (shapes.empty()) 133 return success(); 134 135 // All ranks should be equal 136 auto firstRank = shapes.front().getRank(); 137 if (llvm::any_of(shapes, 138 [&](auto shape) { return firstRank != shape.getRank(); })) 139 return failure(); 140 141 for (unsigned i = 0; i < firstRank; ++i) { 142 // Retrieve all ranked dimensions 143 auto dims = llvm::to_vector<8>(llvm::map_range( 144 llvm::make_filter_range( 145 shapes, [&](auto shape) { return shape.getRank() >= i; }), 146 [&](auto shape) { return shape.getDimSize(i); })); 147 if (verifyCompatibleDims(dims).failed()) 148 return failure(); 149 } 150 151 return success(); 152 } 153 154 OperandElementTypeIterator::OperandElementTypeIterator( 155 Operation::operand_iterator it) 156 : llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>( 157 it, &unwrap) {} 158 159 Type OperandElementTypeIterator::unwrap(Value value) { 160 return value.getType().cast<ShapedType>().getElementType(); 161 } 162 163 ResultElementTypeIterator::ResultElementTypeIterator( 164 Operation::result_iterator it) 165 : llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value)>( 166 it, &unwrap) {} 167 168 Type ResultElementTypeIterator::unwrap(Value value) { 169 return value.getType().cast<ShapedType>().getElementType(); 170 } 171