1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10 
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
13 
14 #include <numeric>
15 
16 using namespace mlir;
17 
18 Optional<SmallVector<ReassociationIndices>>
19 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
20                                         ShapedType targetType) {
21   // Make the sourceType greater rank than the targetType. If they are same
22   // rank, then its an unsupported reshape op.
23   if (sourceType.getRank() == targetType.getRank())
24     return llvm::None;
25   if (sourceType.getRank() < targetType.getRank())
26     std::swap(sourceType, targetType);
27 
28   ArrayRef<int64_t> sourceShape = sourceType.getShape();
29   ArrayRef<int64_t> targetShape = targetType.getShape();
30   unsigned sourceDim = 0;
31   SmallVector<ReassociationIndices> reassociationMap;
32   reassociationMap.reserve(targetType.getRank());
33 
34   ReassociationIndices currIndices;
35   int64_t prodOfCollapsedDims = 1;
36   while (sourceDim < sourceShape.size()) {
37     unsigned targetDim = reassociationMap.size();
38     // If we have mapped all the target dimensions stop and handle the remaining
39     // tail of size-1 dimensions explictly.
40     if (targetDim == targetType.getRank())
41       break;
42 
43     int64_t currTargetShape = targetShape[targetDim];
44     while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
45            prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
46            sourceDim < sourceShape.size()) {
47       prodOfCollapsedDims *= sourceShape[sourceDim];
48       currIndices.push_back(sourceDim++);
49     }
50 
51     // If the current expanded dimension is dynamic, then the collapsed
52     // dimensions should also be dynamic and product of all previous unprocessed
53     // dimensions of the expanded shape should be 1.
54     if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
55         (currTargetShape != ShapedType::kDynamicSize ||
56          prodOfCollapsedDims != 1))
57       return llvm::None;
58 
59     // If the collapsed dim is dynamic, the current expanded dim should also
60     // be dynamic.
61     if (currTargetShape == ShapedType::kDynamicSize &&
62         sourceShape[sourceDim] != ShapedType::kDynamicSize)
63       return llvm::None;
64 
65     // For static shapes, if the product of dimensions of the expanded shape
66     // should match the collapsed dimension shape.
67     if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
68       return llvm::None;
69 
70     currIndices.push_back(sourceDim++);
71     reassociationMap.emplace_back(ReassociationIndices{});
72     std::swap(reassociationMap.back(), currIndices);
73     prodOfCollapsedDims = 1;
74   }
75   // All the dimensions in the target must have been processed.
76   if (reassociationMap.size() != targetShape.size())
77     return llvm::None;
78   // Process any remaining entries in the source shape. They all need to be
79   // 1 or dynamic.
80   for (; sourceDim < sourceShape.size(); sourceDim++) {
81     if (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
82         sourceShape[sourceDim] != 1)
83       return llvm::None;
84     // The map is empty when the target type is a scalar.
85     if (!reassociationMap.empty())
86       reassociationMap.back().push_back(sourceDim);
87   }
88   return reassociationMap;
89 }
90 
91 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
92     ArrayRef<ReassociationIndices> producerReassociations,
93     ArrayRef<ReassociationIndices> consumerReassociations,
94     MLIRContext *context) {
95   SmallVector<ReassociationIndices> composedIndices;
96   // Make the producer the larger sized vector. If they are of same size, the
97   // resulting reshape is not a supported reshape op.
98   if (producerReassociations.size() == consumerReassociations.size())
99     return llvm::None;
100   if (producerReassociations.size() < consumerReassociations.size())
101     std::swap(producerReassociations, consumerReassociations);
102 
103   // Handle the corner case of the result being a rank 0 shaped type. Return an
104   // empty reassociation.
105   if (consumerReassociations.empty())
106     return composedIndices;
107 
108   size_t consumerDims = std::accumulate(
109       consumerReassociations.begin(), consumerReassociations.end(), 0,
110       [](size_t all, ReassociationIndicesRef indices) {
111         return all + indices.size();
112       });
113   if (producerReassociations.size() != consumerDims)
114     return llvm::None;
115 
116   for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
117     ReassociationIndices reassociations;
118     for (int64_t consumerIndex : consumerIndices) {
119       for (int64_t producerIndex : producerReassociations[consumerIndex])
120         reassociations.push_back(producerIndex);
121     }
122     composedIndices.push_back(std::move(reassociations));
123   }
124   return composedIndices;
125 }
126 
127 SmallVector<SmallVector<AffineExpr, 2>, 2>
128 mlir::convertReassociationIndicesToExprs(
129     MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
130   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
131   for (const auto &indices : reassociationIndices) {
132     SmallVector<AffineExpr, 2> reassociationMap;
133     reassociationMap.reserve(indices.size());
134     for (int64_t index : indices)
135       reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
136     reassociationMaps.push_back(std::move(reassociationMap));
137   }
138   return reassociationMaps;
139 }
140 
141 template <typename AffineExprTy>
142 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
143   unsigned pos = 0;
144   for (const auto &exprs : exprArrays) {
145     for (auto expr : exprs) {
146       expr.walk([&pos](AffineExpr e) {
147         if (auto d = e.dyn_cast<AffineExprTy>())
148           pos = std::max(pos, d.getPosition());
149       });
150     }
151   }
152   return pos;
153 }
154 
155 ArrayAttr mlir::getReassociationIndicesAttribute(
156     OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
157   SmallVector<Attribute, 4> reassociationAttr =
158       llvm::to_vector<4>(llvm::map_range(
159           reassociation, [&](const ReassociationIndices &indices) -> Attribute {
160             return b.getI64ArrayAttr(indices).cast<Attribute>();
161           }));
162   return b.getArrayAttr(reassociationAttr);
163 }
164 
165 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
166     OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
167   SmallVector<ReassociationIndices, 2> reassociationIndices;
168   for (const auto &exprs : reassociationExprs) {
169     ReassociationIndices indices;
170     indices.reserve(exprs.size());
171     for (const auto &expr : exprs)
172       indices.push_back(expr.cast<AffineDimExpr>().getPosition());
173     reassociationIndices.push_back(indices);
174   }
175   return reassociationIndices;
176 }
177 
178 SmallVector<AffineMap, 4>
179 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
180   unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
181   assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
182          "Expected symbol-less expressions");
183   SmallVector<AffineMap, 4> maps;
184   maps.reserve(reassociation.size());
185   for (const auto &exprs : reassociation) {
186     assert(!exprs.empty());
187     maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
188   }
189   return maps;
190 }
191 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
192                                 int *invalidIndex) {
193   if (reassociation.empty())
194     return true;
195   unsigned nDims = reassociation[0].getNumDims();
196   unsigned nextExpectedDim = 0;
197   for (const auto &it : llvm::enumerate(reassociation)) {
198     auto m = it.value();
199     if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
200       if (invalidIndex)
201         *invalidIndex = it.index();
202       return false;
203     }
204     for (auto e : m.getResults()) {
205       auto d = e.dyn_cast<AffineDimExpr>();
206       if (!d || d.getPosition() != nextExpectedDim++) {
207         if (invalidIndex)
208           *invalidIndex = it.index();
209         return false;
210       }
211     }
212   }
213   if (nextExpectedDim != nDims) {
214     if (invalidIndex)
215       *invalidIndex = reassociation.size() - 1;
216     return false;
217   }
218   return true;
219 }
220 
221 LogicalResult mlir::reshapeLikeShapesAreCompatible(
222     function_ref<LogicalResult(const Twine &)> emitError,
223     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
224     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
225   unsigned expandedDimStart = 0;
226   for (const auto &map : llvm::enumerate(reassociationMaps)) {
227     Optional<int64_t> dynamicShape;
228     int64_t linearizedStaticShape = 1;
229     for (const auto &dim : llvm::enumerate(
230              expandedShape.slice(expandedDimStart, map.value().size()))) {
231       if (ShapedType::isDynamic(dim.value())) {
232         if (isExpandingReshape && dynamicShape) {
233           return emitError("invalid to have a single dimension (" +
234                            Twine(map.index()) +
235                            ") expanded into multiple dynamic dims (" +
236                            Twine(expandedDimStart + dynamicShape.getValue()) +
237                            "," + Twine(expandedDimStart + dim.index()) + ")");
238         }
239         dynamicShape = dim.index();
240       } else {
241         linearizedStaticShape *= dim.value();
242       }
243     }
244     if (dynamicShape) {
245       if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
246         return emitError(
247             "expected dimension " + Twine(map.index()) +
248             " of collapsed type to be dynamic since one or more of the "
249             "corresponding dimensions in the expanded type is dynamic");
250       }
251     } else {
252       if (collapsedShape[map.index()] != linearizedStaticShape) {
253         return emitError("expected dimension " + Twine(map.index()) +
254                          " of collapsed type to be static value of " +
255                          Twine(linearizedStaticShape));
256       }
257     }
258     expandedDimStart += map.value().size();
259   }
260   return success();
261 }
262