1 //===- Traits.cpp - Common op traits shared by dialects -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Traits.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "mlir/IR/TypeUtilities.h" 12 #include "llvm/Support/FormatVariadic.h" 13 14 using namespace mlir; 15 16 bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1, 17 ArrayRef<int64_t> shape2) { 18 SmallVector<SmallVector<int64_t, 6>, 2> extents; 19 extents.emplace_back(shape1.begin(), shape1.end()); 20 extents.emplace_back(shape2.begin(), shape2.end()); 21 return staticallyKnownBroadcastable(extents); 22 } 23 24 bool OpTrait::util::staticallyKnownBroadcastable( 25 ArrayRef<SmallVector<int64_t, 6>> shapes) { 26 assert(!shapes.empty() && "Expected at least one shape"); 27 size_t maxRank = shapes[0].size(); 28 for (size_t i = 1; i != shapes.size(); ++i) 29 maxRank = std::max(maxRank, shapes[i].size()); 30 31 // We look backwards through every column of `shapes`. 32 for (size_t i = 0; i != maxRank; ++i) { 33 bool seenDynamic = false; 34 Optional<int64_t> nonOneDim; 35 for (ArrayRef<int64_t> extent : shapes) { 36 int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1]; 37 38 if (dim == 1) 39 continue; 40 41 // Dimensions are compatible when 42 //. 1. One is dynamic, the rest are 1 43 if (ShapedType::isDynamic(dim)) { 44 if (seenDynamic || nonOneDim) 45 return false; 46 seenDynamic = true; 47 } 48 49 // 2. All are 1 or a specific constant. 50 if (nonOneDim && dim != *nonOneDim) 51 return false; 52 53 nonOneDim = dim; 54 } 55 } 56 return true; 57 } 58 59 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, 60 ArrayRef<int64_t> shape2, 61 SmallVectorImpl<int64_t> &resultShape) { 62 // To compute the result broadcasted shape, we compare operand shapes 63 // element-wise: starting with the trailing dimensions, and working the 64 // way backward. Two dimensions are compatible when 65 // 1. they are equal, or 66 // 2. one of them is 1 67 // The result shape has the maximum among the two inputs at every 68 // dimension index. 69 70 resultShape.clear(); 71 if (shape1.size() > shape2.size()) { 72 std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape)); 73 } else { 74 std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape)); 75 } 76 77 auto i1 = shape1.rbegin(), e1 = shape1.rend(); 78 auto i2 = shape2.rbegin(), e2 = shape2.rend(); 79 auto iR = resultShape.rbegin(); 80 81 // Check each dimension is consistent. 82 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { 83 if (*i1 == -1 || *i2 == -1) { 84 // One or both dimensions is unknown. Follow TensorFlow behavior: 85 // - If either dimension is greater than 1, we assume that the program is 86 // correct, and the other dimension will be broadcast to match it. 87 // - If either dimension is 1, the other dimension is the output. 88 if (*i1 > 1) { 89 *iR = *i1; 90 } else if (*i2 > 1) { 91 *iR = *i2; 92 } else if (*i1 == 1) { 93 *iR = *i2; 94 } else if (*i2 == 1) { 95 *iR = *i1; 96 } else { 97 *iR = -1; 98 } 99 } else { 100 if (*i1 == *i2 || *i2 == 1) { 101 *iR = *i1; 102 } else if (*i1 == 1) { 103 *iR = *i2; 104 } else { 105 // This dimension of the two operand types is incompatible. 106 resultShape.clear(); 107 return false; 108 } 109 } 110 } 111 112 return true; 113 } 114 115 /// Returns the shape of the given type. Scalars will be considered as having a 116 /// shape with zero dimensions. 117 static ArrayRef<int64_t> getShape(Type type) { 118 if (auto sType = type.dyn_cast<ShapedType>()) 119 return sType.getShape(); 120 return {}; 121 } 122 123 /// Returns the result broadcast composition type from the two given types by 124 /// following NumPy broadcast semantics. Returned type may have dynamic shape if 125 /// either of the input types has dynamic shape. Returns null type if the two 126 /// given types are not broadcast-compatible. 127 /// 128 /// elementType, if specified, will be used as the element type of the 129 /// broadcasted result type. Otherwise it is required that the element type of 130 /// type1 and type2 is the same and this element type will be used as the 131 /// resultant element type. 132 Type OpTrait::util::getBroadcastedType(Type type1, Type type2, 133 Type elementType) { 134 // If the elementType is not specified, then the use the common element type 135 // of the inputs or fail if there is no common element type. 136 if (!elementType) { 137 elementType = getElementTypeOrSelf(type1); 138 if (elementType != getElementTypeOrSelf(type2)) 139 return {}; 140 } 141 142 // If one of the types is unranked tensor, then the other type shouldn't be 143 // vector and the result should have unranked tensor type. 144 if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) { 145 if (type1.isa<VectorType>() || type2.isa<VectorType>()) 146 return {}; 147 return UnrankedTensorType::get(elementType); 148 } 149 150 // Returns the type kind if the given type is a vector or ranked tensor type. 151 // Returns llvm::None otherwise. 152 auto getCompositeTypeKind = [](Type type) -> Optional<TypeID> { 153 if (type.isa<VectorType, RankedTensorType>()) 154 return type.getTypeID(); 155 return llvm::None; 156 }; 157 158 // Make sure the composite type, if has, is consistent. 159 Optional<TypeID> compositeKind1 = getCompositeTypeKind(type1); 160 Optional<TypeID> compositeKind2 = getCompositeTypeKind(type2); 161 Optional<TypeID> resultCompositeKind; 162 163 if (compositeKind1 && compositeKind2) { 164 // Disallow mixing vector and tensor. 165 if (compositeKind1 != compositeKind2) 166 return {}; 167 resultCompositeKind = compositeKind1; 168 } else if (compositeKind1) { 169 resultCompositeKind = compositeKind1; 170 } else if (compositeKind2) { 171 resultCompositeKind = compositeKind2; 172 } 173 174 // Get the shape of each type. 175 SmallVector<int64_t, 4> resultShape; 176 if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) 177 return {}; 178 179 // Compose the final broadcasted type 180 if (resultCompositeKind == VectorType::getTypeID()) 181 return VectorType::get(resultShape, elementType); 182 if (resultCompositeKind == RankedTensorType::getTypeID()) 183 return RankedTensorType::get(resultShape, elementType); 184 return elementType; 185 } 186 187 /// Returns a tuple corresponding to whether range has tensor or vector type. 188 template <typename iterator_range> 189 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) { 190 return std::make_tuple( 191 llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }), 192 llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); })); 193 } 194 195 static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred, 196 ArrayRef<int64_t> existing) { 197 auto isCompatible = [](int64_t dim1, int64_t dim2) { 198 // If the inferred and existing dim is the same, or one of them is unknown 199 // then it is compatible, else if the inferred dim is 1 then it is also 200 // compatible. But if the existing dim is 1 and the inferred is greater than 201 // 1 then flag. 202 return dim1 == dim2 || dim1 == -1 || dim2 == -1 || dim1 == 1; 203 }; 204 if (inferred.size() != existing.size()) 205 return false; 206 for (auto p : llvm::zip(inferred, existing)) 207 if (!isCompatible(std::get<0>(p), std::get<1>(p))) 208 return false; 209 return true; 210 } 211 212 static std::string getShapeString(ArrayRef<int64_t> shape) { 213 // TODO: should replace with printing shape more uniformly across here and 214 // when in type. 215 std::string ret; 216 llvm::raw_string_ostream ss(ret); 217 ss << '\''; 218 llvm::interleave( 219 shape, ss, 220 [&](int64_t dim) { 221 if (ShapedType::isDynamic(dim)) 222 ss << '?'; 223 else 224 ss << dim; 225 }, 226 "x"); 227 ss << '\''; 228 return ss.str(); 229 } 230 231 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { 232 // Ensure broadcasting only tensor or only vector types. 233 auto operandsHasTensorVectorType = 234 hasTensorOrVectorType(op->getOperandTypes()); 235 auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes()); 236 if ((std::get<0>(operandsHasTensorVectorType) || 237 std::get<0>(resultsHasTensorVectorType)) && 238 (std::get<1>(operandsHasTensorVectorType) || 239 std::get<1>(resultsHasTensorVectorType))) 240 return op->emitError("cannot broadcast vector with tensor"); 241 242 auto rankedOperands = make_filter_range( 243 op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); }); 244 245 // If all operands are unranked, then all result shapes are possible. 246 if (rankedOperands.empty()) 247 return success(); 248 249 // Compute broadcasted shape of operands (which requires that operands are 250 // broadcast compatible). The results need to be broadcast compatible with 251 // this result shape. 252 SmallVector<int64_t, 4> resultShape; 253 (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {}, 254 resultShape); 255 for (auto other : make_early_inc_range(rankedOperands)) { 256 SmallVector<int64_t, 4> temp = resultShape; 257 if (!util::getBroadcastedShape(temp, getShape(other), resultShape)) 258 return op->emitOpError("operands don't have broadcast-compatible shapes"); 259 } 260 261 auto rankedResults = make_filter_range( 262 op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); }); 263 264 // If all of the results are unranked then no further verification. 265 if (rankedResults.empty()) 266 return success(); 267 268 for (auto type : rankedResults) { 269 ArrayRef<int64_t> actualSuffix = 270 getShape(type).take_back(resultShape.size()); 271 if (!isCompatibleInferredReturnShape(resultShape, actualSuffix)) 272 return op->emitOpError() 273 << "result type " << getShapeString(getShape(type)) 274 << " not broadcast compatible with broadcasted operands's shapes " 275 << getShapeString(resultShape); 276 } 277 return success(); 278 } 279