1 //===- Traits.cpp - Common op traits shared by dialects -------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Traits.h" 10 #include "mlir/IR/StandardTypes.h" 11 #include "llvm/Support/FormatVariadic.h" 12 13 using namespace mlir; 14 15 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, 16 ArrayRef<int64_t> shape2, 17 SmallVectorImpl<int64_t> &resultShape) { 18 // To compute the result broadcasted shape, we compare operand shapes 19 // element-wise: starting with the trailing dimensions, and working the 20 // way backward. Two dimensions are compatible when 21 // 1. they are equal, or 22 // 2. one of them is 1 23 // The result shape has the maximum among the two inputs at every 24 // dimension index. 25 26 resultShape.clear(); 27 if (shape1.size() > shape2.size()) { 28 std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape)); 29 } else { 30 std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape)); 31 } 32 33 auto i1 = shape1.rbegin(), e1 = shape1.rend(); 34 auto i2 = shape2.rbegin(), e2 = shape2.rend(); 35 auto iR = resultShape.rbegin(); 36 37 // Check each dimension is consistent. 38 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { 39 if (*i1 == -1 || *i2 == -1) { 40 // One or both dimensions is unknown. Follow TensorFlow behavior: 41 // - If either dimension is greater than 1, we assume that the program is 42 // correct, and the other dimension will be broadcast to match it. 43 // - If either dimension is 1, the other dimension is the output. 44 if (*i1 > 1) { 45 *iR = *i1; 46 } else if (*i2 > 1) { 47 *iR = *i2; 48 } else if (*i1 == 1) { 49 *iR = *i2; 50 } else if (*i2 == 1) { 51 *iR = *i1; 52 } else { 53 *iR = -1; 54 } 55 } else { 56 if (*i1 == *i2 || *i2 == 1) { 57 *iR = *i1; 58 } else if (*i1 == 1) { 59 *iR = *i2; 60 } else { 61 // This dimension of the two operand types is incompatible. 62 resultShape.clear(); 63 return false; 64 } 65 } 66 } 67 68 return true; 69 } 70 71 /// Returns the shape of the given type. Scalars will be considered as having a 72 /// shape with zero dimensions. 73 static ArrayRef<int64_t> getShape(Type type) { 74 if (auto sType = type.dyn_cast<ShapedType>()) 75 return sType.getShape(); 76 return {}; 77 } 78 79 /// Returns the result broadcast composition type from the two given types by 80 /// following NumPy broadcast semantics. Returned type may have dynamic shape if 81 /// either of the input types has dynamic shape. Returns null type if the two 82 /// given types are not broadcast-compatible. 83 Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { 84 // Returns the scalar type out of the given type. 85 auto getScalarType = [](Type type) -> Type { 86 if (auto shapedType = type.dyn_cast<ShapedType>()) 87 return shapedType.getElementType(); 88 return type; 89 }; 90 91 // Make sure underlying scalar type is the same. 92 auto scalarType = getScalarType(type1); 93 if (scalarType != getScalarType(type2)) 94 return {}; 95 96 // If one of the types is unranked tensor, then the other type shouldn't be 97 // vector and the result should have unranked tensor type. 98 if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) { 99 if (type1.isa<VectorType>() || type2.isa<VectorType>()) 100 return {}; 101 return UnrankedTensorType::get(scalarType); 102 } 103 104 // Returns the type kind if the given type is a vector or ranked tensor type. 105 // Returns llvm::None otherwise. 106 auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> { 107 if (type.isa<VectorType>() || type.isa<RankedTensorType>()) 108 return static_cast<StandardTypes::Kind>(type.getKind()); 109 return llvm::None; 110 }; 111 112 // Make sure the composite type, if has, is consistent. 113 auto compositeKind1 = getCompositeTypeKind(type1); 114 auto compositeKind2 = getCompositeTypeKind(type2); 115 Optional<StandardTypes::Kind> resultCompositeKind; 116 117 if (compositeKind1 && compositeKind2) { 118 // Disallow mixing vector and tensor. 119 if (compositeKind1 != compositeKind2) 120 return {}; 121 resultCompositeKind = compositeKind1; 122 } else if (compositeKind1) { 123 resultCompositeKind = compositeKind1; 124 } else if (compositeKind2) { 125 resultCompositeKind = compositeKind2; 126 } 127 128 // Get the shape of each type. 129 SmallVector<int64_t, 4> resultShape; 130 if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) 131 return {}; 132 133 // Compose the final broadcasted type 134 if (resultCompositeKind == StandardTypes::Vector) 135 return VectorType::get(resultShape, scalarType); 136 if (resultCompositeKind == StandardTypes::RankedTensor) 137 return RankedTensorType::get(resultShape, scalarType); 138 return scalarType; 139 } 140 141 /// Returns true if the given types has both vector types and tensor types. 142 static bool hasBothVectorAndTensorType(ArrayRef<Type> types) { 143 return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) && 144 llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }); 145 } 146 147 static bool areCompatibleShapes(ArrayRef<int64_t> shape1, 148 ArrayRef<int64_t> shape2) { 149 auto isCompatible = [](int64_t dim1, int64_t dim2) { 150 return dim1 == dim2 || dim1 == -1 || dim2 == -1; 151 }; 152 if (shape1.size() != shape2.size()) 153 return false; 154 for (auto p : llvm::zip(shape1, shape2)) 155 if (!isCompatible(std::get<0>(p), std::get<1>(p))) 156 return false; 157 return true; 158 } 159 160 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { 161 assert(op->getNumOperands() == 2 && 162 "only support broadcast check on two operands"); 163 assert(op->getNumResults() == 1 && 164 "only support broadcast check on one result"); 165 166 auto type1 = op->getOperand(0).getType(); 167 auto type2 = op->getOperand(1).getType(); 168 auto retType = op->getResult(0).getType(); 169 170 // We forbid broadcasting vector and tensor. 171 if (hasBothVectorAndTensorType({type1, type2, retType})) 172 return op->emitError("cannot broadcast vector with tensor"); 173 174 if (retType.isa<UnrankedTensorType>()) 175 return success(); 176 177 bool isUnranked1 = type1.isa<UnrankedTensorType>(); 178 bool isUnranked2 = type2.isa<UnrankedTensorType>(); 179 180 // If both operands are unranked, then all result shapes are possible. 181 if (isUnranked1 && isUnranked2) 182 return success(); 183 184 // If one of the operands is unranked, then the known dimensions in the result 185 // should be compatible with the other shaped operand. 186 if (isUnranked1 || isUnranked2) { 187 // Result should have higher rank than the shaped operand's rank and then 188 // the result's trailing dimensions should be compatible with the operand 189 // shape. 190 ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2); 191 ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size()); 192 if (!areCompatibleShapes(actualSuffix, shape)) 193 return op->emitOpError() 194 << "result type " << retType 195 << " has shape incompatible with a ranked operand type"; 196 return success(); 197 } 198 199 // If both operands are shaped, then the computed broadcasted shape should be 200 // compatible with the result shape. 201 SmallVector<int64_t, 4> resultShape; 202 if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) 203 return op->emitOpError("operands don't have broadcast-compatible shapes"); 204 205 if (!areCompatibleShapes(resultShape, getShape(retType))) 206 return op->emitOpError() << "result type " << retType 207 << " does not have shape compatible with the one " 208 "computed from the operand types"; 209 210 return success(); 211 } 212