1 //===- Traits.cpp - Common op traits shared by dialects -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Traits.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 using namespace mlir;
15 
staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2)16 bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
17                                                  ArrayRef<int64_t> shape2) {
18   SmallVector<SmallVector<int64_t, 6>, 2> extents;
19   extents.emplace_back(shape1.begin(), shape1.end());
20   extents.emplace_back(shape2.begin(), shape2.end());
21   return staticallyKnownBroadcastable(extents);
22 }
23 
staticallyKnownBroadcastable(ArrayRef<SmallVector<int64_t,6>> shapes)24 bool OpTrait::util::staticallyKnownBroadcastable(
25     ArrayRef<SmallVector<int64_t, 6>> shapes) {
26   assert(!shapes.empty() && "Expected at least one shape");
27   size_t maxRank = shapes[0].size();
28   for (size_t i = 1; i != shapes.size(); ++i)
29     maxRank = std::max(maxRank, shapes[i].size());
30 
31   // We look backwards through every column of `shapes`.
32   for (size_t i = 0; i != maxRank; ++i) {
33     bool seenDynamic = false;
34     Optional<int64_t> nonOneDim;
35     for (ArrayRef<int64_t> extent : shapes) {
36       int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
37 
38       if (dim == 1)
39         continue;
40 
41       // Dimensions are compatible when
42       //.  1. One is dynamic, the rest are 1
43       if (ShapedType::isDynamic(dim)) {
44         if (seenDynamic || nonOneDim)
45           return false;
46         seenDynamic = true;
47       }
48 
49       //   2. All are 1 or a specific constant.
50       if (nonOneDim && dim != *nonOneDim)
51         return false;
52 
53       nonOneDim = dim;
54     }
55   }
56   return true;
57 }
58 
getBroadcastedShape(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2,SmallVectorImpl<int64_t> & resultShape)59 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
60                                         ArrayRef<int64_t> shape2,
61                                         SmallVectorImpl<int64_t> &resultShape) {
62   // To compute the result broadcasted shape, we compare operand shapes
63   // element-wise: starting with the trailing dimensions, and working the
64   // way backward. Two dimensions are compatible when
65   //   1. they are equal, or
66   //   2. one of them is 1
67   // The result shape has the maximum among the two inputs at every
68   // dimension index.
69 
70   resultShape.clear();
71   if (shape1.size() > shape2.size()) {
72     std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
73   } else {
74     std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
75   }
76 
77   auto i1 = shape1.rbegin(), e1 = shape1.rend();
78   auto i2 = shape2.rbegin(), e2 = shape2.rend();
79   auto iR = resultShape.rbegin();
80 
81   // Check each dimension is consistent.
82   for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
83     if (*i1 == -1 || *i2 == -1) {
84       // One or both dimensions is unknown. Follow TensorFlow behavior:
85       // - If either dimension is greater than 1, we assume that the program is
86       //   correct, and the other dimension will be broadcast to match it.
87       // - If either dimension is 1, the other dimension is the output.
88       if (*i1 > 1) {
89         *iR = *i1;
90       } else if (*i2 > 1) {
91         *iR = *i2;
92       } else if (*i1 == 1) {
93         *iR = *i2;
94       } else if (*i2 == 1) {
95         *iR = *i1;
96       } else {
97         *iR = -1;
98       }
99     } else {
100       if (*i1 == *i2 || *i2 == 1) {
101         *iR = *i1;
102       } else if (*i1 == 1) {
103         *iR = *i2;
104       } else {
105         // This dimension of the two operand types is incompatible.
106         resultShape.clear();
107         return false;
108       }
109     }
110   }
111 
112   return true;
113 }
114 
115 /// Returns the shape of the given type. Scalars will be considered as having a
116 /// shape with zero dimensions.
getShape(Type type)117 static ArrayRef<int64_t> getShape(Type type) {
118   if (auto sType = type.dyn_cast<ShapedType>())
119     return sType.getShape();
120   return {};
121 }
122 
123 /// Returns the result broadcast composition type from the two given types by
124 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
125 /// either of the input types has dynamic shape. Returns null type if the two
126 /// given types are not broadcast-compatible.
127 ///
128 /// elementType, if specified, will be used as the element type of the
129 /// broadcasted result type. Otherwise it is required that the element type of
130 /// type1 and type2 is the same and this element type will be used as the
131 /// resultant element type.
getBroadcastedType(Type type1,Type type2,Type elementType)132 Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
133                                        Type elementType) {
134   // If the elementType is not specified, then the use the common element type
135   // of the inputs or fail if there is no common element type.
136   if (!elementType) {
137     elementType = getElementTypeOrSelf(type1);
138     if (elementType != getElementTypeOrSelf(type2))
139       return {};
140   }
141 
142   // If one of the types is unranked tensor, then the other type shouldn't be
143   // vector and the result should have unranked tensor type.
144   if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
145     if (type1.isa<VectorType>() || type2.isa<VectorType>())
146       return {};
147     return UnrankedTensorType::get(elementType);
148   }
149 
150   // Returns the type kind if the given type is a vector or ranked tensor type.
151   // Returns llvm::None otherwise.
152   auto getCompositeTypeKind = [](Type type) -> Optional<TypeID> {
153     if (type.isa<VectorType, RankedTensorType>())
154       return type.getTypeID();
155     return llvm::None;
156   };
157 
158   // Make sure the composite type, if has, is consistent.
159   Optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
160   Optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
161   Optional<TypeID> resultCompositeKind;
162 
163   if (compositeKind1 && compositeKind2) {
164     // Disallow mixing vector and tensor.
165     if (compositeKind1 != compositeKind2)
166       return {};
167     resultCompositeKind = compositeKind1;
168   } else if (compositeKind1) {
169     resultCompositeKind = compositeKind1;
170   } else if (compositeKind2) {
171     resultCompositeKind = compositeKind2;
172   }
173 
174   // Get the shape of each type.
175   SmallVector<int64_t, 4> resultShape;
176   if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
177     return {};
178 
179   // Compose the final broadcasted type
180   if (resultCompositeKind == VectorType::getTypeID())
181     return VectorType::get(resultShape, elementType);
182   if (resultCompositeKind == RankedTensorType::getTypeID())
183     return RankedTensorType::get(resultShape, elementType);
184   return elementType;
185 }
186 
187 /// Returns a tuple corresponding to whether range has tensor or vector type.
188 template <typename iterator_range>
hasTensorOrVectorType(iterator_range types)189 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
190   return std::make_tuple(
191       llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
192       llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
193 }
194 
isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,ArrayRef<int64_t> existing)195 static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
196                                             ArrayRef<int64_t> existing) {
197   auto isCompatible = [](int64_t dim1, int64_t dim2) {
198     // If the inferred and existing dim is the same, or one of them is unknown
199     // then it is compatible, else if the inferred dim is 1 then it is also
200     // compatible. But if the existing dim is 1 and the inferred is greater than
201     // 1 then flag.
202     return dim1 == dim2 || dim1 == -1 || dim2 == -1 || dim1 == 1;
203   };
204   if (inferred.size() != existing.size())
205     return false;
206   for (auto p : llvm::zip(inferred, existing))
207     if (!isCompatible(std::get<0>(p), std::get<1>(p)))
208       return false;
209   return true;
210 }
211 
getShapeString(ArrayRef<int64_t> shape)212 static std::string getShapeString(ArrayRef<int64_t> shape) {
213   // TODO: should replace with printing shape more uniformly across here and
214   // when in type.
215   std::string ret;
216   llvm::raw_string_ostream ss(ret);
217   ss << '\'';
218   llvm::interleave(
219       shape, ss,
220       [&](int64_t dim) {
221         if (ShapedType::isDynamic(dim))
222           ss << '?';
223         else
224           ss << dim;
225       },
226       "x");
227   ss << '\'';
228   return ss.str();
229 }
230 
verifyCompatibleOperandBroadcast(Operation * op)231 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
232   // Ensure broadcasting only tensor or only vector types.
233   auto operandsHasTensorVectorType =
234       hasTensorOrVectorType(op->getOperandTypes());
235   auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
236   if ((std::get<0>(operandsHasTensorVectorType) ||
237        std::get<0>(resultsHasTensorVectorType)) &&
238       (std::get<1>(operandsHasTensorVectorType) ||
239        std::get<1>(resultsHasTensorVectorType)))
240     return op->emitError("cannot broadcast vector with tensor");
241 
242   auto rankedOperands = make_filter_range(
243       op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
244 
245   // If all operands are unranked, then all result shapes are possible.
246   if (rankedOperands.empty())
247     return success();
248 
249   // Compute broadcasted shape of operands (which requires that operands are
250   // broadcast compatible). The results need to be broadcast compatible with
251   // this result shape.
252   SmallVector<int64_t, 4> resultShape;
253   (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
254                                   resultShape);
255   for (auto other : make_early_inc_range(rankedOperands)) {
256     SmallVector<int64_t, 4> temp = resultShape;
257     if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
258       return op->emitOpError("operands don't have broadcast-compatible shapes");
259   }
260 
261   auto rankedResults = make_filter_range(
262       op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
263 
264   // If all of the results are unranked then no further verification.
265   if (rankedResults.empty())
266     return success();
267 
268   for (auto type : rankedResults) {
269     ArrayRef<int64_t> actualSuffix =
270         getShape(type).take_back(resultShape.size());
271     if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
272       return op->emitOpError()
273              << "result type " << getShapeString(getShape(type))
274              << " not broadcast compatible with broadcasted operands's shapes "
275              << getShapeString(resultShape);
276   }
277   return success();
278 }
279