1 //===- Traits.cpp - Common op traits shared by dialects -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Traits.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "mlir/IR/TypeUtilities.h" 12 #include "llvm/Support/FormatVariadic.h" 13 14 using namespace mlir; 15 16 bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1, 17 ArrayRef<int64_t> shape2) { 18 // Two dimensions are compatible when 19 // 1. they are defined and equal, or 20 // 2. one of them is 1 21 return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)), 22 [](auto dimensions) { 23 auto dim1 = std::get<0>(dimensions); 24 auto dim2 = std::get<1>(dimensions); 25 if (dim1 == 1 || dim2 == 1) 26 return true; 27 if (dim1 == dim2 && !ShapedType::isDynamic(dim1)) 28 return true; 29 return false; 30 }); 31 } 32 33 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, 34 ArrayRef<int64_t> shape2, 35 SmallVectorImpl<int64_t> &resultShape) { 36 // To compute the result broadcasted shape, we compare operand shapes 37 // element-wise: starting with the trailing dimensions, and working the 38 // way backward. Two dimensions are compatible when 39 // 1. they are equal, or 40 // 2. one of them is 1 41 // The result shape has the maximum among the two inputs at every 42 // dimension index. 43 44 resultShape.clear(); 45 if (shape1.size() > shape2.size()) { 46 std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape)); 47 } else { 48 std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape)); 49 } 50 51 auto i1 = shape1.rbegin(), e1 = shape1.rend(); 52 auto i2 = shape2.rbegin(), e2 = shape2.rend(); 53 auto iR = resultShape.rbegin(); 54 55 // Check each dimension is consistent. 56 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { 57 if (*i1 == -1 || *i2 == -1) { 58 // One or both dimensions is unknown. Follow TensorFlow behavior: 59 // - If either dimension is greater than 1, we assume that the program is 60 // correct, and the other dimension will be broadcast to match it. 61 // - If either dimension is 1, the other dimension is the output. 62 if (*i1 > 1) { 63 *iR = *i1; 64 } else if (*i2 > 1) { 65 *iR = *i2; 66 } else if (*i1 == 1) { 67 *iR = *i2; 68 } else if (*i2 == 1) { 69 *iR = *i1; 70 } else { 71 *iR = -1; 72 } 73 } else { 74 if (*i1 == *i2 || *i2 == 1) { 75 *iR = *i1; 76 } else if (*i1 == 1) { 77 *iR = *i2; 78 } else { 79 // This dimension of the two operand types is incompatible. 80 resultShape.clear(); 81 return false; 82 } 83 } 84 } 85 86 return true; 87 } 88 89 /// Returns the shape of the given type. Scalars will be considered as having a 90 /// shape with zero dimensions. 91 static ArrayRef<int64_t> getShape(Type type) { 92 if (auto sType = type.dyn_cast<ShapedType>()) 93 return sType.getShape(); 94 return {}; 95 } 96 97 /// Returns the result broadcast composition type from the two given types by 98 /// following NumPy broadcast semantics. Returned type may have dynamic shape if 99 /// either of the input types has dynamic shape. Returns null type if the two 100 /// given types are not broadcast-compatible. 101 /// 102 /// elementType, if specified, will be used as the element type of the 103 /// broadcasted result type. Otherwise it is required that the element type of 104 /// type1 and type2 is the same and this element type will be used as the 105 /// resultant element type. 106 Type OpTrait::util::getBroadcastedType(Type type1, Type type2, 107 Type elementType) { 108 // If the elementType is not specified, then the use the common element type 109 // of the inputs or fail if there is no common element type. 110 if (!elementType) { 111 elementType = getElementTypeOrSelf(type1); 112 if (elementType != getElementTypeOrSelf(type2)) 113 return {}; 114 } 115 116 // If one of the types is unranked tensor, then the other type shouldn't be 117 // vector and the result should have unranked tensor type. 118 if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) { 119 if (type1.isa<VectorType>() || type2.isa<VectorType>()) 120 return {}; 121 return UnrankedTensorType::get(elementType); 122 } 123 124 // Returns the type kind if the given type is a vector or ranked tensor type. 125 // Returns llvm::None otherwise. 126 auto getCompositeTypeKind = [](Type type) -> Optional<TypeID> { 127 if (type.isa<VectorType, RankedTensorType>()) 128 return type.getTypeID(); 129 return llvm::None; 130 }; 131 132 // Make sure the composite type, if has, is consistent. 133 Optional<TypeID> compositeKind1 = getCompositeTypeKind(type1); 134 Optional<TypeID> compositeKind2 = getCompositeTypeKind(type2); 135 Optional<TypeID> resultCompositeKind; 136 137 if (compositeKind1 && compositeKind2) { 138 // Disallow mixing vector and tensor. 139 if (compositeKind1 != compositeKind2) 140 return {}; 141 resultCompositeKind = compositeKind1; 142 } else if (compositeKind1) { 143 resultCompositeKind = compositeKind1; 144 } else if (compositeKind2) { 145 resultCompositeKind = compositeKind2; 146 } 147 148 // Get the shape of each type. 149 SmallVector<int64_t, 4> resultShape; 150 if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) 151 return {}; 152 153 // Compose the final broadcasted type 154 if (resultCompositeKind == VectorType::getTypeID()) 155 return VectorType::get(resultShape, elementType); 156 if (resultCompositeKind == RankedTensorType::getTypeID()) 157 return RankedTensorType::get(resultShape, elementType); 158 return elementType; 159 } 160 161 /// Returns a tuple corresponding to whether range has tensor or vector type. 162 template <typename iterator_range> 163 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) { 164 return std::make_tuple( 165 llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }), 166 llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); })); 167 } 168 169 static bool areCompatibleShapes(ArrayRef<int64_t> shape1, 170 ArrayRef<int64_t> shape2) { 171 auto isCompatible = [](int64_t dim1, int64_t dim2) { 172 return dim1 == dim2 || dim1 == -1 || dim2 == -1; 173 }; 174 if (shape1.size() != shape2.size()) 175 return false; 176 for (auto p : llvm::zip(shape1, shape2)) 177 if (!isCompatible(std::get<0>(p), std::get<1>(p))) 178 return false; 179 return true; 180 } 181 182 static std::string getShapeString(ArrayRef<int64_t> shape) { 183 // TODO: should replace with printing shape more uniformly across here and 184 // when in type. 185 return std::string( 186 formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end()))); 187 } 188 189 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { 190 // Ensure broadcasting only tensor or only vector types. 191 auto operandsHasTensorVectorType = 192 hasTensorOrVectorType(op->getOperandTypes()); 193 auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes()); 194 if ((std::get<0>(operandsHasTensorVectorType) || 195 std::get<0>(resultsHasTensorVectorType)) && 196 (std::get<1>(operandsHasTensorVectorType) || 197 std::get<1>(resultsHasTensorVectorType))) 198 return op->emitError("cannot broadcast vector with tensor"); 199 200 auto rankedOperands = make_filter_range( 201 op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); }); 202 203 // If all operands are unranked, then all result shapes are possible. 204 if (rankedOperands.empty()) 205 return success(); 206 207 // Compute broadcasted shape of operands (which requires that operands are 208 // broadcast compatible). The results need to be broadcast compatible with 209 // this result shape. 210 SmallVector<int64_t, 4> resultShape; 211 (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {}, 212 resultShape); 213 for (auto other : make_early_inc_range(rankedOperands)) { 214 SmallVector<int64_t, 4> temp = resultShape; 215 if (!util::getBroadcastedShape(temp, getShape(other), resultShape)) 216 return op->emitOpError("operands don't have broadcast-compatible shapes"); 217 } 218 219 auto rankedResults = make_filter_range( 220 op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); }); 221 222 // If all of the results are unranked then no further verification. 223 if (rankedResults.empty()) 224 return success(); 225 226 for (auto type : rankedResults) { 227 ArrayRef<int64_t> actualSuffix = 228 getShape(type).take_back(resultShape.size()); 229 if (!areCompatibleShapes(actualSuffix, resultShape)) 230 return op->emitOpError() 231 << "result type " << getShapeString(getShape(type)) 232 << " not broadcast compatible with broadcasted operands's shapes " 233 << getShapeString(resultShape); 234 } 235 return success(); 236 } 237