1 //===- Traits.cpp - Common op traits shared by dialects -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Traits.h"
10 #include "mlir/IR/StandardTypes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 using namespace mlir;
15 
16 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
17                                         ArrayRef<int64_t> shape2,
18                                         SmallVectorImpl<int64_t> &resultShape) {
19   // To compute the result broadcasted shape, we compare operand shapes
20   // element-wise: starting with the trailing dimensions, and working the
21   // way backward. Two dimensions are compatible when
22   //   1. they are equal, or
23   //   2. one of them is 1
24   // The result shape has the maximum among the two inputs at every
25   // dimension index.
26 
27   resultShape.clear();
28   if (shape1.size() > shape2.size()) {
29     std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
30   } else {
31     std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
32   }
33 
34   auto i1 = shape1.rbegin(), e1 = shape1.rend();
35   auto i2 = shape2.rbegin(), e2 = shape2.rend();
36   auto iR = resultShape.rbegin();
37 
38   // Check each dimension is consistent.
39   for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
40     if (*i1 == -1 || *i2 == -1) {
41       // One or both dimensions is unknown. Follow TensorFlow behavior:
42       // - If either dimension is greater than 1, we assume that the program is
43       //   correct, and the other dimension will be broadcast to match it.
44       // - If either dimension is 1, the other dimension is the output.
45       if (*i1 > 1) {
46         *iR = *i1;
47       } else if (*i2 > 1) {
48         *iR = *i2;
49       } else if (*i1 == 1) {
50         *iR = *i2;
51       } else if (*i2 == 1) {
52         *iR = *i1;
53       } else {
54         *iR = -1;
55       }
56     } else {
57       if (*i1 == *i2 || *i2 == 1) {
58         *iR = *i1;
59       } else if (*i1 == 1) {
60         *iR = *i2;
61       } else {
62         // This dimension of the two operand types is incompatible.
63         resultShape.clear();
64         return false;
65       }
66     }
67   }
68 
69   return true;
70 }
71 
72 /// Returns the shape of the given type. Scalars will be considered as having a
73 /// shape with zero dimensions.
74 static ArrayRef<int64_t> getShape(Type type) {
75   if (auto sType = type.dyn_cast<ShapedType>())
76     return sType.getShape();
77   return {};
78 }
79 
80 /// Returns the result broadcast composition type from the two given types by
81 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
82 /// either of the input types has dynamic shape. Returns null type if the two
83 /// given types are not broadcast-compatible.
84 ///
85 /// elementType, if specified, will be used as the element type of the
86 /// broadcasted result type. Otherwise it is required that the element type of
87 /// type1 and type2 is the same and this element type will be used as the
88 /// resultant element type.
89 Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
90                                        Type elementType) {
91   // If the elementType is not specified, then the use the common element type
92   // of the inputs or fail if there is no common element type.
93   if (!elementType) {
94     elementType = getElementTypeOrSelf(type1);
95     if (elementType != getElementTypeOrSelf(type2))
96       return {};
97   }
98 
99   // If one of the types is unranked tensor, then the other type shouldn't be
100   // vector and the result should have unranked tensor type.
101   if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
102     if (type1.isa<VectorType>() || type2.isa<VectorType>())
103       return {};
104     return UnrankedTensorType::get(elementType);
105   }
106 
107   // Returns the type kind if the given type is a vector or ranked tensor type.
108   // Returns llvm::None otherwise.
109   auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
110     if (type.isa<VectorType>() || type.isa<RankedTensorType>())
111       return static_cast<StandardTypes::Kind>(type.getKind());
112     return llvm::None;
113   };
114 
115   // Make sure the composite type, if has, is consistent.
116   auto compositeKind1 = getCompositeTypeKind(type1);
117   auto compositeKind2 = getCompositeTypeKind(type2);
118   Optional<StandardTypes::Kind> resultCompositeKind;
119 
120   if (compositeKind1 && compositeKind2) {
121     // Disallow mixing vector and tensor.
122     if (compositeKind1 != compositeKind2)
123       return {};
124     resultCompositeKind = compositeKind1;
125   } else if (compositeKind1) {
126     resultCompositeKind = compositeKind1;
127   } else if (compositeKind2) {
128     resultCompositeKind = compositeKind2;
129   }
130 
131   // Get the shape of each type.
132   SmallVector<int64_t, 4> resultShape;
133   if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
134     return {};
135 
136   // Compose the final broadcasted type
137   if (resultCompositeKind == StandardTypes::Vector)
138     return VectorType::get(resultShape, elementType);
139   if (resultCompositeKind == StandardTypes::RankedTensor)
140     return RankedTensorType::get(resultShape, elementType);
141   return elementType;
142 }
143 
144 /// Returns a tuple corresponding to whether range has tensor or vector type.
145 template <typename iterator_range>
146 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
147   return std::make_tuple(
148       llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
149       llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
150 }
151 
152 static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
153                                 ArrayRef<int64_t> shape2) {
154   auto isCompatible = [](int64_t dim1, int64_t dim2) {
155     return dim1 == dim2 || dim1 == -1 || dim2 == -1;
156   };
157   if (shape1.size() != shape2.size())
158     return false;
159   for (auto p : llvm::zip(shape1, shape2))
160     if (!isCompatible(std::get<0>(p), std::get<1>(p)))
161       return false;
162   return true;
163 }
164 
165 static std::string getShapeString(ArrayRef<int64_t> shape) {
166   // TODO: should replace with printing shape more uniformly across here and
167   // when in type.
168   return std::string(
169       formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end())));
170 }
171 
172 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
173   // Ensure broadcasting only tensor or only vector types.
174   auto operandsHasTensorVectorType =
175       hasTensorOrVectorType(op->getOperandTypes());
176   auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
177   if ((std::get<0>(operandsHasTensorVectorType) ||
178        std::get<0>(resultsHasTensorVectorType)) &&
179       (std::get<1>(operandsHasTensorVectorType) ||
180        std::get<1>(resultsHasTensorVectorType)))
181     return op->emitError("cannot broadcast vector with tensor");
182 
183   auto rankedOperands = make_filter_range(
184       op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
185 
186   // If all operands are unranked, then all result shapes are possible.
187   if (rankedOperands.empty())
188     return success();
189 
190   // Compute broadcasted shape of operands (which requires that operands are
191   // broadcast compatible). The results need to be broadcast compatible with
192   // this result shape.
193   SmallVector<int64_t, 4> resultShape;
194   (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
195                                   resultShape);
196   for (auto other : make_early_inc_range(rankedOperands)) {
197     SmallVector<int64_t, 4> temp = resultShape;
198     if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
199       return op->emitOpError("operands don't have broadcast-compatible shapes");
200   }
201 
202   auto rankedResults = make_filter_range(
203       op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
204 
205   // If all of the results are unranked then no further verification.
206   if (rankedResults.empty())
207     return success();
208 
209   for (auto type : rankedResults) {
210     ArrayRef<int64_t> actualSuffix =
211         getShape(type).take_back(resultShape.size());
212     if (!areCompatibleShapes(actualSuffix, resultShape))
213       return op->emitOpError()
214              << "result type " << getShapeString(getShape(type))
215              << " not broadcast compatible with broadcasted operands's shapes "
216              << getShapeString(resultShape);
217   }
218   return success();
219 }
220