1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10 
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
13 
14 #include <numeric>
15 
16 using namespace mlir;
17 
18 Optional<SmallVector<ReassociationIndices>>
getReassociationIndicesForReshape(ShapedType sourceType,ShapedType targetType)19 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
20                                         ShapedType targetType) {
21   if (sourceType.getRank() > targetType.getRank())
22     return getReassociationIndicesForCollapse(sourceType.getShape(),
23                                               targetType.getShape());
24   if (sourceType.getRank() < targetType.getRank())
25     return getReassociationIndicesForCollapse(targetType.getShape(),
26                                               sourceType.getShape());
27   return llvm::None;
28 }
29 
30 Optional<SmallVector<ReassociationIndices>>
getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,ArrayRef<int64_t> targetShape)31 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
32                                          ArrayRef<int64_t> targetShape) {
33   if (sourceShape.size() <= targetShape.size())
34     return llvm::None;
35   unsigned sourceDim = 0;
36   SmallVector<ReassociationIndices> reassociationMap;
37   reassociationMap.reserve(targetShape.size());
38 
39   ReassociationIndices currIndices;
40   int64_t prodOfCollapsedDims = 1;
41   while (sourceDim < sourceShape.size()) {
42     unsigned targetDim = reassociationMap.size();
43     // If we have mapped all the target dimensions stop and handle the remaining
44     // tail of size-1 dimensions explictly.
45     if (targetDim == targetShape.size())
46       break;
47 
48     int64_t currTargetShape = targetShape[targetDim];
49     while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
50            prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
51            sourceDim < sourceShape.size()) {
52       prodOfCollapsedDims *= sourceShape[sourceDim];
53       currIndices.push_back(sourceDim++);
54     }
55 
56     // If the current expanded dimension is dynamic, then the collapsed
57     // dimensions should also be dynamic and product of all previous unprocessed
58     // dimensions of the expanded shape should be 1.
59     if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
60         (currTargetShape != ShapedType::kDynamicSize ||
61          prodOfCollapsedDims != 1))
62       return llvm::None;
63 
64     // If the collapsed dim is dynamic, the current expanded dim should also
65     // be dynamic.
66     if (currTargetShape == ShapedType::kDynamicSize &&
67         sourceShape[sourceDim] != ShapedType::kDynamicSize)
68       return llvm::None;
69 
70     // For static shapes, if the product of dimensions of the expanded shape
71     // should match the collapsed dimension shape.
72     if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73       return llvm::None;
74 
75     currIndices.push_back(sourceDim++);
76     reassociationMap.emplace_back(ReassociationIndices{});
77     std::swap(reassociationMap.back(), currIndices);
78     prodOfCollapsedDims = 1;
79   }
80   // All the dimensions in the target must have been processed.
81   if (reassociationMap.size() != targetShape.size())
82     return llvm::None;
83   // Process any remaining entries in the source shape. They all need to be
84   // 1 or dynamic.
85   for (; sourceDim < sourceShape.size(); sourceDim++) {
86     if (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
87         sourceShape[sourceDim] != 1)
88       return llvm::None;
89     // The map is empty when the target type is a scalar.
90     if (!reassociationMap.empty())
91       reassociationMap.back().push_back(sourceDim);
92   }
93   return reassociationMap;
94 }
95 
composeReassociationIndices(ArrayRef<ReassociationIndices> producerReassociations,ArrayRef<ReassociationIndices> consumerReassociations,MLIRContext * context)96 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
97     ArrayRef<ReassociationIndices> producerReassociations,
98     ArrayRef<ReassociationIndices> consumerReassociations,
99     MLIRContext *context) {
100   SmallVector<ReassociationIndices> composedIndices;
101   // Make the producer the larger sized vector. If they are of same size, the
102   // resulting reshape is not a supported reshape op.
103   if (producerReassociations.size() == consumerReassociations.size())
104     return llvm::None;
105   if (producerReassociations.size() < consumerReassociations.size())
106     std::swap(producerReassociations, consumerReassociations);
107 
108   // Handle the corner case of the result being a rank 0 shaped type. Return an
109   // empty reassociation.
110   if (consumerReassociations.empty())
111     return composedIndices;
112 
113   size_t consumerDims = std::accumulate(
114       consumerReassociations.begin(), consumerReassociations.end(), 0,
115       [](size_t all, ReassociationIndicesRef indices) {
116         return all + indices.size();
117       });
118   if (producerReassociations.size() != consumerDims)
119     return llvm::None;
120 
121   for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
122     ReassociationIndices reassociations;
123     for (int64_t consumerIndex : consumerIndices) {
124       llvm::append_range(reassociations, producerReassociations[consumerIndex]);
125     }
126     composedIndices.push_back(std::move(reassociations));
127   }
128   return composedIndices;
129 }
130 
131 SmallVector<SmallVector<AffineExpr, 2>, 2>
convertReassociationIndicesToExprs(MLIRContext * context,ArrayRef<ReassociationIndices> reassociationIndices)132 mlir::convertReassociationIndicesToExprs(
133     MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
134   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
135   for (const auto &indices : reassociationIndices) {
136     SmallVector<AffineExpr, 2> reassociationMap;
137     reassociationMap.reserve(indices.size());
138     for (int64_t index : indices)
139       reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
140     reassociationMaps.push_back(std::move(reassociationMap));
141   }
142   return reassociationMaps;
143 }
144 
145 template <typename AffineExprTy>
getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays)146 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
147   unsigned pos = 0;
148   for (const auto &exprs : exprArrays) {
149     for (auto expr : exprs) {
150       expr.walk([&pos](AffineExpr e) {
151         if (auto d = e.dyn_cast<AffineExprTy>())
152           pos = std::max(pos, d.getPosition());
153       });
154     }
155   }
156   return pos;
157 }
158 
getReassociationIndicesAttribute(OpBuilder & b,ArrayRef<ReassociationIndices> reassociation)159 ArrayAttr mlir::getReassociationIndicesAttribute(
160     OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
161   SmallVector<Attribute, 4> reassociationAttr =
162       llvm::to_vector<4>(llvm::map_range(
163           reassociation, [&](const ReassociationIndices &indices) -> Attribute {
164             return b.getI64ArrayAttr(indices).cast<Attribute>();
165           }));
166   return b.getArrayAttr(reassociationAttr);
167 }
168 
convertReassociationMapsToIndices(OpBuilder & b,ArrayRef<ReassociationExprs> reassociationExprs)169 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
170     OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
171   SmallVector<ReassociationIndices, 2> reassociationIndices;
172   for (const auto &exprs : reassociationExprs) {
173     ReassociationIndices indices;
174     indices.reserve(exprs.size());
175     for (const auto &expr : exprs)
176       indices.push_back(expr.cast<AffineDimExpr>().getPosition());
177     reassociationIndices.push_back(indices);
178   }
179   return reassociationIndices;
180 }
181 
182 SmallVector<AffineMap, 4>
getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation)183 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
184   unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
185   assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
186          "Expected symbol-less expressions");
187   SmallVector<AffineMap, 4> maps;
188   maps.reserve(reassociation.size());
189   for (const auto &exprs : reassociation) {
190     assert(!exprs.empty());
191     maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
192   }
193   return maps;
194 }
195 
isReassociationValid(ArrayRef<AffineMap> reassociation,int * invalidIndex)196 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
197                                 int *invalidIndex) {
198   if (reassociation.empty())
199     return true;
200   unsigned nDims = reassociation[0].getNumDims();
201   unsigned nextExpectedDim = 0;
202   for (const auto &it : llvm::enumerate(reassociation)) {
203     auto m = it.value();
204     if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
205       if (invalidIndex)
206         *invalidIndex = it.index();
207       return false;
208     }
209     for (auto e : m.getResults()) {
210       auto d = e.dyn_cast<AffineDimExpr>();
211       if (!d || d.getPosition() != nextExpectedDim++) {
212         if (invalidIndex)
213           *invalidIndex = it.index();
214         return false;
215       }
216     }
217   }
218   if (nextExpectedDim != nDims) {
219     if (invalidIndex)
220       *invalidIndex = reassociation.size() - 1;
221     return false;
222   }
223   return true;
224 }
225 
reshapeLikeShapesAreCompatible(function_ref<LogicalResult (const Twine &)> emitError,ArrayRef<int64_t> collapsedShape,ArrayRef<int64_t> expandedShape,ArrayRef<ReassociationIndices> reassociationMaps,bool isExpandingReshape)226 LogicalResult mlir::reshapeLikeShapesAreCompatible(
227     function_ref<LogicalResult(const Twine &)> emitError,
228     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
229     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
230   unsigned expandedDimStart = 0;
231   for (const auto &map : llvm::enumerate(reassociationMaps)) {
232     Optional<int64_t> dynamicShape;
233     int64_t linearizedStaticShape = 1;
234     for (const auto &dim : llvm::enumerate(
235              expandedShape.slice(expandedDimStart, map.value().size()))) {
236       if (ShapedType::isDynamic(dim.value())) {
237         if (isExpandingReshape && dynamicShape) {
238           return emitError("invalid to have a single dimension (" +
239                            Twine(map.index()) +
240                            ") expanded into multiple dynamic dims (" +
241                            Twine(expandedDimStart + dynamicShape.value()) +
242                            "," + Twine(expandedDimStart + dim.index()) + ")");
243         }
244         dynamicShape = dim.index();
245       } else {
246         linearizedStaticShape *= dim.value();
247       }
248     }
249     if (dynamicShape) {
250       if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
251         return emitError(
252             "expected dimension " + Twine(map.index()) +
253             " of collapsed type to be dynamic since one or more of the "
254             "corresponding dimensions in the expanded type is dynamic");
255       }
256     } else {
257       if (collapsedShape[map.index()] != linearizedStaticShape) {
258         return emitError("expected dimension " + Twine(map.index()) +
259                          " of collapsed type to be static value of " +
260                          Twine(linearizedStaticShape));
261       }
262     }
263     expandedDimStart += map.value().size();
264   }
265   return success();
266 }
267 
hasNonIdentityLayout(Type type)268 bool mlir::hasNonIdentityLayout(Type type) {
269   if (auto memrefType = type.dyn_cast<MemRefType>())
270     return !memrefType.getLayout().isIdentity();
271   return false;
272 }
273