1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
13
14 #include <numeric>
15
16 using namespace mlir;
17
18 Optional<SmallVector<ReassociationIndices>>
getReassociationIndicesForReshape(ShapedType sourceType,ShapedType targetType)19 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
20 ShapedType targetType) {
21 if (sourceType.getRank() > targetType.getRank())
22 return getReassociationIndicesForCollapse(sourceType.getShape(),
23 targetType.getShape());
24 if (sourceType.getRank() < targetType.getRank())
25 return getReassociationIndicesForCollapse(targetType.getShape(),
26 sourceType.getShape());
27 return llvm::None;
28 }
29
30 Optional<SmallVector<ReassociationIndices>>
getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,ArrayRef<int64_t> targetShape)31 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
32 ArrayRef<int64_t> targetShape) {
33 if (sourceShape.size() <= targetShape.size())
34 return llvm::None;
35 unsigned sourceDim = 0;
36 SmallVector<ReassociationIndices> reassociationMap;
37 reassociationMap.reserve(targetShape.size());
38
39 ReassociationIndices currIndices;
40 int64_t prodOfCollapsedDims = 1;
41 while (sourceDim < sourceShape.size()) {
42 unsigned targetDim = reassociationMap.size();
43 // If we have mapped all the target dimensions stop and handle the remaining
44 // tail of size-1 dimensions explictly.
45 if (targetDim == targetShape.size())
46 break;
47
48 int64_t currTargetShape = targetShape[targetDim];
49 while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
50 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
51 sourceDim < sourceShape.size()) {
52 prodOfCollapsedDims *= sourceShape[sourceDim];
53 currIndices.push_back(sourceDim++);
54 }
55
56 // If the current expanded dimension is dynamic, then the collapsed
57 // dimensions should also be dynamic and product of all previous unprocessed
58 // dimensions of the expanded shape should be 1.
59 if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
60 (currTargetShape != ShapedType::kDynamicSize ||
61 prodOfCollapsedDims != 1))
62 return llvm::None;
63
64 // If the collapsed dim is dynamic, the current expanded dim should also
65 // be dynamic.
66 if (currTargetShape == ShapedType::kDynamicSize &&
67 sourceShape[sourceDim] != ShapedType::kDynamicSize)
68 return llvm::None;
69
70 // For static shapes, if the product of dimensions of the expanded shape
71 // should match the collapsed dimension shape.
72 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73 return llvm::None;
74
75 currIndices.push_back(sourceDim++);
76 reassociationMap.emplace_back(ReassociationIndices{});
77 std::swap(reassociationMap.back(), currIndices);
78 prodOfCollapsedDims = 1;
79 }
80 // All the dimensions in the target must have been processed.
81 if (reassociationMap.size() != targetShape.size())
82 return llvm::None;
83 // Process any remaining entries in the source shape. They all need to be
84 // 1 or dynamic.
85 for (; sourceDim < sourceShape.size(); sourceDim++) {
86 if (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
87 sourceShape[sourceDim] != 1)
88 return llvm::None;
89 // The map is empty when the target type is a scalar.
90 if (!reassociationMap.empty())
91 reassociationMap.back().push_back(sourceDim);
92 }
93 return reassociationMap;
94 }
95
composeReassociationIndices(ArrayRef<ReassociationIndices> producerReassociations,ArrayRef<ReassociationIndices> consumerReassociations,MLIRContext * context)96 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
97 ArrayRef<ReassociationIndices> producerReassociations,
98 ArrayRef<ReassociationIndices> consumerReassociations,
99 MLIRContext *context) {
100 SmallVector<ReassociationIndices> composedIndices;
101 // Make the producer the larger sized vector. If they are of same size, the
102 // resulting reshape is not a supported reshape op.
103 if (producerReassociations.size() == consumerReassociations.size())
104 return llvm::None;
105 if (producerReassociations.size() < consumerReassociations.size())
106 std::swap(producerReassociations, consumerReassociations);
107
108 // Handle the corner case of the result being a rank 0 shaped type. Return an
109 // empty reassociation.
110 if (consumerReassociations.empty())
111 return composedIndices;
112
113 size_t consumerDims = std::accumulate(
114 consumerReassociations.begin(), consumerReassociations.end(), 0,
115 [](size_t all, ReassociationIndicesRef indices) {
116 return all + indices.size();
117 });
118 if (producerReassociations.size() != consumerDims)
119 return llvm::None;
120
121 for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
122 ReassociationIndices reassociations;
123 for (int64_t consumerIndex : consumerIndices) {
124 llvm::append_range(reassociations, producerReassociations[consumerIndex]);
125 }
126 composedIndices.push_back(std::move(reassociations));
127 }
128 return composedIndices;
129 }
130
131 SmallVector<SmallVector<AffineExpr, 2>, 2>
convertReassociationIndicesToExprs(MLIRContext * context,ArrayRef<ReassociationIndices> reassociationIndices)132 mlir::convertReassociationIndicesToExprs(
133 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
134 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
135 for (const auto &indices : reassociationIndices) {
136 SmallVector<AffineExpr, 2> reassociationMap;
137 reassociationMap.reserve(indices.size());
138 for (int64_t index : indices)
139 reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
140 reassociationMaps.push_back(std::move(reassociationMap));
141 }
142 return reassociationMaps;
143 }
144
145 template <typename AffineExprTy>
getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays)146 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
147 unsigned pos = 0;
148 for (const auto &exprs : exprArrays) {
149 for (auto expr : exprs) {
150 expr.walk([&pos](AffineExpr e) {
151 if (auto d = e.dyn_cast<AffineExprTy>())
152 pos = std::max(pos, d.getPosition());
153 });
154 }
155 }
156 return pos;
157 }
158
getReassociationIndicesAttribute(OpBuilder & b,ArrayRef<ReassociationIndices> reassociation)159 ArrayAttr mlir::getReassociationIndicesAttribute(
160 OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
161 SmallVector<Attribute, 4> reassociationAttr =
162 llvm::to_vector<4>(llvm::map_range(
163 reassociation, [&](const ReassociationIndices &indices) -> Attribute {
164 return b.getI64ArrayAttr(indices).cast<Attribute>();
165 }));
166 return b.getArrayAttr(reassociationAttr);
167 }
168
convertReassociationMapsToIndices(OpBuilder & b,ArrayRef<ReassociationExprs> reassociationExprs)169 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
170 OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
171 SmallVector<ReassociationIndices, 2> reassociationIndices;
172 for (const auto &exprs : reassociationExprs) {
173 ReassociationIndices indices;
174 indices.reserve(exprs.size());
175 for (const auto &expr : exprs)
176 indices.push_back(expr.cast<AffineDimExpr>().getPosition());
177 reassociationIndices.push_back(indices);
178 }
179 return reassociationIndices;
180 }
181
182 SmallVector<AffineMap, 4>
getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation)183 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
184 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
185 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
186 "Expected symbol-less expressions");
187 SmallVector<AffineMap, 4> maps;
188 maps.reserve(reassociation.size());
189 for (const auto &exprs : reassociation) {
190 assert(!exprs.empty());
191 maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
192 }
193 return maps;
194 }
195
isReassociationValid(ArrayRef<AffineMap> reassociation,int * invalidIndex)196 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
197 int *invalidIndex) {
198 if (reassociation.empty())
199 return true;
200 unsigned nDims = reassociation[0].getNumDims();
201 unsigned nextExpectedDim = 0;
202 for (const auto &it : llvm::enumerate(reassociation)) {
203 auto m = it.value();
204 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
205 if (invalidIndex)
206 *invalidIndex = it.index();
207 return false;
208 }
209 for (auto e : m.getResults()) {
210 auto d = e.dyn_cast<AffineDimExpr>();
211 if (!d || d.getPosition() != nextExpectedDim++) {
212 if (invalidIndex)
213 *invalidIndex = it.index();
214 return false;
215 }
216 }
217 }
218 if (nextExpectedDim != nDims) {
219 if (invalidIndex)
220 *invalidIndex = reassociation.size() - 1;
221 return false;
222 }
223 return true;
224 }
225
reshapeLikeShapesAreCompatible(function_ref<LogicalResult (const Twine &)> emitError,ArrayRef<int64_t> collapsedShape,ArrayRef<int64_t> expandedShape,ArrayRef<ReassociationIndices> reassociationMaps,bool isExpandingReshape)226 LogicalResult mlir::reshapeLikeShapesAreCompatible(
227 function_ref<LogicalResult(const Twine &)> emitError,
228 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
229 ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
230 unsigned expandedDimStart = 0;
231 for (const auto &map : llvm::enumerate(reassociationMaps)) {
232 Optional<int64_t> dynamicShape;
233 int64_t linearizedStaticShape = 1;
234 for (const auto &dim : llvm::enumerate(
235 expandedShape.slice(expandedDimStart, map.value().size()))) {
236 if (ShapedType::isDynamic(dim.value())) {
237 if (isExpandingReshape && dynamicShape) {
238 return emitError("invalid to have a single dimension (" +
239 Twine(map.index()) +
240 ") expanded into multiple dynamic dims (" +
241 Twine(expandedDimStart + dynamicShape.value()) +
242 "," + Twine(expandedDimStart + dim.index()) + ")");
243 }
244 dynamicShape = dim.index();
245 } else {
246 linearizedStaticShape *= dim.value();
247 }
248 }
249 if (dynamicShape) {
250 if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
251 return emitError(
252 "expected dimension " + Twine(map.index()) +
253 " of collapsed type to be dynamic since one or more of the "
254 "corresponding dimensions in the expanded type is dynamic");
255 }
256 } else {
257 if (collapsedShape[map.index()] != linearizedStaticShape) {
258 return emitError("expected dimension " + Twine(map.index()) +
259 " of collapsed type to be static value of " +
260 Twine(linearizedStaticShape));
261 }
262 }
263 expandedDimStart += map.value().size();
264 }
265 return success();
266 }
267
hasNonIdentityLayout(Type type)268 bool mlir::hasNonIdentityLayout(Type type) {
269 if (auto memrefType = type.dyn_cast<MemRefType>())
270 return !memrefType.getLayout().isIdentity();
271 return false;
272 }
273