1 //===- Traits.cpp - Common op traits shared by dialects -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Traits.h" 10 #include "mlir/IR/StandardTypes.h" 11 #include "mlir/IR/TypeUtilities.h" 12 #include "llvm/Support/FormatVariadic.h" 13 14 using namespace mlir; 15 16 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, 17 ArrayRef<int64_t> shape2, 18 SmallVectorImpl<int64_t> &resultShape) { 19 // To compute the result broadcasted shape, we compare operand shapes 20 // element-wise: starting with the trailing dimensions, and working the 21 // way backward. Two dimensions are compatible when 22 // 1. they are equal, or 23 // 2. one of them is 1 24 // The result shape has the maximum among the two inputs at every 25 // dimension index. 26 27 resultShape.clear(); 28 if (shape1.size() > shape2.size()) { 29 std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape)); 30 } else { 31 std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape)); 32 } 33 34 auto i1 = shape1.rbegin(), e1 = shape1.rend(); 35 auto i2 = shape2.rbegin(), e2 = shape2.rend(); 36 auto iR = resultShape.rbegin(); 37 38 // Check each dimension is consistent. 39 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { 40 if (*i1 == -1 || *i2 == -1) { 41 // One or both dimensions is unknown. Follow TensorFlow behavior: 42 // - If either dimension is greater than 1, we assume that the program is 43 // correct, and the other dimension will be broadcast to match it. 44 // - If either dimension is 1, the other dimension is the output. 45 if (*i1 > 1) { 46 *iR = *i1; 47 } else if (*i2 > 1) { 48 *iR = *i2; 49 } else if (*i1 == 1) { 50 *iR = *i2; 51 } else if (*i2 == 1) { 52 *iR = *i1; 53 } else { 54 *iR = -1; 55 } 56 } else { 57 if (*i1 == *i2 || *i2 == 1) { 58 *iR = *i1; 59 } else if (*i1 == 1) { 60 *iR = *i2; 61 } else { 62 // This dimension of the two operand types is incompatible. 63 resultShape.clear(); 64 return false; 65 } 66 } 67 } 68 69 return true; 70 } 71 72 /// Returns the shape of the given type. Scalars will be considered as having a 73 /// shape with zero dimensions. 74 static ArrayRef<int64_t> getShape(Type type) { 75 if (auto sType = type.dyn_cast<ShapedType>()) 76 return sType.getShape(); 77 return {}; 78 } 79 80 /// Returns the result broadcast composition type from the two given types by 81 /// following NumPy broadcast semantics. Returned type may have dynamic shape if 82 /// either of the input types has dynamic shape. Returns null type if the two 83 /// given types are not broadcast-compatible. 84 /// 85 /// elementType, if specified, will be used as the element type of the 86 /// broadcasted result type. Otherwise it is required that the element type of 87 /// type1 and type2 is the same and this element type will be used as the 88 /// resultant element type. 89 Type OpTrait::util::getBroadcastedType(Type type1, Type type2, 90 Type elementType) { 91 // If the elementType is not specified, then the use the common element type 92 // of the inputs or fail if there is no common element type. 93 if (!elementType) { 94 elementType = getElementTypeOrSelf(type1); 95 if (elementType != getElementTypeOrSelf(type2)) 96 return {}; 97 } 98 99 // If one of the types is unranked tensor, then the other type shouldn't be 100 // vector and the result should have unranked tensor type. 101 if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) { 102 if (type1.isa<VectorType>() || type2.isa<VectorType>()) 103 return {}; 104 return UnrankedTensorType::get(elementType); 105 } 106 107 // Returns the type kind if the given type is a vector or ranked tensor type. 108 // Returns llvm::None otherwise. 109 auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> { 110 if (type.isa<VectorType>() || type.isa<RankedTensorType>()) 111 return static_cast<StandardTypes::Kind>(type.getKind()); 112 return llvm::None; 113 }; 114 115 // Make sure the composite type, if has, is consistent. 116 auto compositeKind1 = getCompositeTypeKind(type1); 117 auto compositeKind2 = getCompositeTypeKind(type2); 118 Optional<StandardTypes::Kind> resultCompositeKind; 119 120 if (compositeKind1 && compositeKind2) { 121 // Disallow mixing vector and tensor. 122 if (compositeKind1 != compositeKind2) 123 return {}; 124 resultCompositeKind = compositeKind1; 125 } else if (compositeKind1) { 126 resultCompositeKind = compositeKind1; 127 } else if (compositeKind2) { 128 resultCompositeKind = compositeKind2; 129 } 130 131 // Get the shape of each type. 132 SmallVector<int64_t, 4> resultShape; 133 if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) 134 return {}; 135 136 // Compose the final broadcasted type 137 if (resultCompositeKind == StandardTypes::Vector) 138 return VectorType::get(resultShape, elementType); 139 if (resultCompositeKind == StandardTypes::RankedTensor) 140 return RankedTensorType::get(resultShape, elementType); 141 return elementType; 142 } 143 144 /// Returns a tuple corresponding to whether range has tensor or vector type. 145 template <typename iterator_range> 146 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) { 147 return std::make_tuple( 148 llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }), 149 llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); })); 150 } 151 152 static bool areCompatibleShapes(ArrayRef<int64_t> shape1, 153 ArrayRef<int64_t> shape2) { 154 auto isCompatible = [](int64_t dim1, int64_t dim2) { 155 return dim1 == dim2 || dim1 == -1 || dim2 == -1; 156 }; 157 if (shape1.size() != shape2.size()) 158 return false; 159 for (auto p : llvm::zip(shape1, shape2)) 160 if (!isCompatible(std::get<0>(p), std::get<1>(p))) 161 return false; 162 return true; 163 } 164 165 static std::string getShapeString(ArrayRef<int64_t> shape) { 166 // TODO: should replace with printing shape more uniformly across here and 167 // when in type. 168 return std::string( 169 formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end()))); 170 } 171 172 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { 173 // Ensure broadcasting only tensor or only vector types. 174 auto operandsHasTensorVectorType = 175 hasTensorOrVectorType(op->getOperandTypes()); 176 auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes()); 177 if ((std::get<0>(operandsHasTensorVectorType) || 178 std::get<0>(resultsHasTensorVectorType)) && 179 (std::get<1>(operandsHasTensorVectorType) || 180 std::get<1>(resultsHasTensorVectorType))) 181 return op->emitError("cannot broadcast vector with tensor"); 182 183 auto rankedOperands = make_filter_range( 184 op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); }); 185 186 // If all operands are unranked, then all result shapes are possible. 187 if (rankedOperands.empty()) 188 return success(); 189 190 // Compute broadcasted shape of operands (which requires that operands are 191 // broadcast compatible). The results need to be broadcast compatible with 192 // this result shape. 193 SmallVector<int64_t, 4> resultShape; 194 (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {}, 195 resultShape); 196 for (auto other : make_early_inc_range(rankedOperands)) { 197 SmallVector<int64_t, 4> temp = resultShape; 198 if (!util::getBroadcastedShape(temp, getShape(other), resultShape)) 199 return op->emitOpError("operands don't have broadcast-compatible shapes"); 200 } 201 202 auto rankedResults = make_filter_range( 203 op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); }); 204 205 // If all of the results are unranked then no further verification. 206 if (rankedResults.empty()) 207 return success(); 208 209 for (auto type : rankedResults) { 210 ArrayRef<int64_t> actualSuffix = 211 getShape(type).take_back(resultShape.size()); 212 if (!areCompatibleShapes(actualSuffix, resultShape)) 213 return op->emitOpError() 214 << "result type " << getShapeString(getShape(type)) 215 << " not broadcast compatible with broadcasted operands's shapes " 216 << getShapeString(resultShape); 217 } 218 return success(); 219 } 220