1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10 
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
13 
14 #include <numeric>
15 
16 using namespace mlir;
17 
18 Optional<SmallVector<ReassociationIndices>>
19 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
20                                         ShapedType targetType) {
21   // Make the sourceType greater rank than the targetType. If they are same
22   // rank, then its an unsupported reshape op.
23   if (sourceType.getRank() == targetType.getRank())
24     return llvm::None;
25   if (sourceType.getRank() < targetType.getRank())
26     std::swap(sourceType, targetType);
27 
28   ArrayRef<int64_t> sourceShape = sourceType.getShape();
29   ArrayRef<int64_t> targetShape = targetType.getShape();
30   unsigned sourceDim = 0;
31   SmallVector<ReassociationIndices> reassociationMap;
32   reassociationMap.reserve(targetType.getRank());
33 
34   ReassociationIndices currIndices;
35   int64_t prodOfCollapsedDims = 1;
36   while (sourceDim < sourceShape.size()) {
37     unsigned targetDim = reassociationMap.size();
38     // If we have mapped all the target dimensions stop and handle the remaining
39     // tail of size-1 dimensions explictly.
40     if (targetDim == targetType.getRank())
41       break;
42 
43     int64_t currTargetShape = targetShape[targetDim];
44     while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
45            prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
46            sourceDim < sourceShape.size()) {
47       prodOfCollapsedDims *= sourceShape[sourceDim];
48       currIndices.push_back(sourceDim++);
49     }
50 
51     // If the current expanded dimension is dynamic, then the collapsed
52     // dimensions should also be dynamic and product of all previous unprocessed
53     // dimensions of the expanded shape should be 1.
54     if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
55         (currTargetShape != ShapedType::kDynamicSize ||
56          prodOfCollapsedDims != 1))
57       return llvm::None;
58 
59     // If the collapsed dim is dynamic, the current expanded dim should also
60     // be dynamic.
61     if (currTargetShape == ShapedType::kDynamicSize &&
62         sourceShape[sourceDim] != ShapedType::kDynamicSize)
63       return llvm::None;
64 
65     // For static shapes, if the product of dimensions of the expanded shape
66     // should match the collapsed dimension shape.
67     if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
68       return llvm::None;
69 
70     currIndices.push_back(sourceDim++);
71     reassociationMap.emplace_back(ReassociationIndices{});
72     std::swap(reassociationMap.back(), currIndices);
73     prodOfCollapsedDims = 1;
74   }
75   // All the dimensions in the target must have been processed.
76   if (reassociationMap.size() != targetShape.size())
77     return llvm::None;
78   // Process any remaining entries in the source shape. They all need to be
79   // 1 or dynamic.
80   for (; sourceDim < sourceShape.size(); sourceDim++) {
81     if (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
82         sourceShape[sourceDim] != 1)
83       return llvm::None;
84     // The map is empty when the target type is a scalar.
85     if (!reassociationMap.empty())
86       reassociationMap.back().push_back(sourceDim);
87   }
88   return reassociationMap;
89 }
90 
91 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
92     ArrayRef<ReassociationIndices> producerReassociations,
93     ArrayRef<ReassociationIndices> consumerReassociations,
94     MLIRContext *context) {
95   SmallVector<ReassociationIndices> composedIndices;
96   // Make the producer the larger sized vector. If they are of same size, the
97   // resulting reshape is not a supported reshape op.
98   if (producerReassociations.size() == consumerReassociations.size())
99     return llvm::None;
100   if (producerReassociations.size() < consumerReassociations.size())
101     std::swap(producerReassociations, consumerReassociations);
102 
103   // Handle the corner case of the result being a rank 0 shaped type. Return an
104   // empty reassociation.
105   if (consumerReassociations.empty())
106     return composedIndices;
107 
108   size_t consumerDims = std::accumulate(
109       consumerReassociations.begin(), consumerReassociations.end(), 0,
110       [](size_t all, ReassociationIndicesRef indices) {
111         return all + indices.size();
112       });
113   if (producerReassociations.size() != consumerDims)
114     return llvm::None;
115 
116   for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
117     ReassociationIndices reassociations;
118     for (int64_t consumerIndex : consumerIndices) {
119       llvm::append_range(reassociations, producerReassociations[consumerIndex]);
120     }
121     composedIndices.push_back(std::move(reassociations));
122   }
123   return composedIndices;
124 }
125 
126 SmallVector<SmallVector<AffineExpr, 2>, 2>
127 mlir::convertReassociationIndicesToExprs(
128     MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
129   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
130   for (const auto &indices : reassociationIndices) {
131     SmallVector<AffineExpr, 2> reassociationMap;
132     reassociationMap.reserve(indices.size());
133     for (int64_t index : indices)
134       reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
135     reassociationMaps.push_back(std::move(reassociationMap));
136   }
137   return reassociationMaps;
138 }
139 
140 template <typename AffineExprTy>
141 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
142   unsigned pos = 0;
143   for (const auto &exprs : exprArrays) {
144     for (auto expr : exprs) {
145       expr.walk([&pos](AffineExpr e) {
146         if (auto d = e.dyn_cast<AffineExprTy>())
147           pos = std::max(pos, d.getPosition());
148       });
149     }
150   }
151   return pos;
152 }
153 
154 ArrayAttr mlir::getReassociationIndicesAttribute(
155     OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
156   SmallVector<Attribute, 4> reassociationAttr =
157       llvm::to_vector<4>(llvm::map_range(
158           reassociation, [&](const ReassociationIndices &indices) -> Attribute {
159             return b.getI64ArrayAttr(indices).cast<Attribute>();
160           }));
161   return b.getArrayAttr(reassociationAttr);
162 }
163 
164 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
165     OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
166   SmallVector<ReassociationIndices, 2> reassociationIndices;
167   for (const auto &exprs : reassociationExprs) {
168     ReassociationIndices indices;
169     indices.reserve(exprs.size());
170     for (const auto &expr : exprs)
171       indices.push_back(expr.cast<AffineDimExpr>().getPosition());
172     reassociationIndices.push_back(indices);
173   }
174   return reassociationIndices;
175 }
176 
177 SmallVector<AffineMap, 4>
178 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
179   unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
180   assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
181          "Expected symbol-less expressions");
182   SmallVector<AffineMap, 4> maps;
183   maps.reserve(reassociation.size());
184   for (const auto &exprs : reassociation) {
185     assert(!exprs.empty());
186     maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
187   }
188   return maps;
189 }
190 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
191                                 int *invalidIndex) {
192   if (reassociation.empty())
193     return true;
194   unsigned nDims = reassociation[0].getNumDims();
195   unsigned nextExpectedDim = 0;
196   for (const auto &it : llvm::enumerate(reassociation)) {
197     auto m = it.value();
198     if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
199       if (invalidIndex)
200         *invalidIndex = it.index();
201       return false;
202     }
203     for (auto e : m.getResults()) {
204       auto d = e.dyn_cast<AffineDimExpr>();
205       if (!d || d.getPosition() != nextExpectedDim++) {
206         if (invalidIndex)
207           *invalidIndex = it.index();
208         return false;
209       }
210     }
211   }
212   if (nextExpectedDim != nDims) {
213     if (invalidIndex)
214       *invalidIndex = reassociation.size() - 1;
215     return false;
216   }
217   return true;
218 }
219 
220 LogicalResult mlir::reshapeLikeShapesAreCompatible(
221     function_ref<LogicalResult(const Twine &)> emitError,
222     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
223     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
224   unsigned expandedDimStart = 0;
225   for (const auto &map : llvm::enumerate(reassociationMaps)) {
226     Optional<int64_t> dynamicShape;
227     int64_t linearizedStaticShape = 1;
228     for (const auto &dim : llvm::enumerate(
229              expandedShape.slice(expandedDimStart, map.value().size()))) {
230       if (ShapedType::isDynamic(dim.value())) {
231         if (isExpandingReshape && dynamicShape) {
232           return emitError("invalid to have a single dimension (" +
233                            Twine(map.index()) +
234                            ") expanded into multiple dynamic dims (" +
235                            Twine(expandedDimStart + dynamicShape.getValue()) +
236                            "," + Twine(expandedDimStart + dim.index()) + ")");
237         }
238         dynamicShape = dim.index();
239       } else {
240         linearizedStaticShape *= dim.value();
241       }
242     }
243     if (dynamicShape) {
244       if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
245         return emitError(
246             "expected dimension " + Twine(map.index()) +
247             " of collapsed type to be dynamic since one or more of the "
248             "corresponding dimensions in the expanded type is dynamic");
249       }
250     } else {
251       if (collapsedShape[map.index()] != linearizedStaticShape) {
252         return emitError("expected dimension " + Twine(map.index()) +
253                          " of collapsed type to be static value of " +
254                          Twine(linearizedStaticShape));
255       }
256     }
257     expandedDimStart += map.value().size();
258   }
259   return success();
260 }
261