1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10 
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
13 
14 #include <numeric>
15 
16 using namespace mlir;
17 
18 Optional<SmallVector<ReassociationIndices>>
19 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
20                                         ShapedType targetType) {
21   // Make the sourceType greater rank than the targetType. If they are same
22   // rank, then its an unsupported reshape op.
23   if (sourceType.getRank() == targetType.getRank())
24     return llvm::None;
25   if (sourceType.getRank() < targetType.getRank())
26     std::swap(sourceType, targetType);
27 
28   ArrayRef<int64_t> sourceShape = sourceType.getShape();
29   ArrayRef<int64_t> targetShape = targetType.getShape();
30   unsigned sourceDim = 0;
31   SmallVector<ReassociationIndices> reassociationMap;
32   reassociationMap.reserve(targetType.getRank());
33 
34   ReassociationIndices currIndices;
35   int64_t prodOfCollapsedDims = 1;
36   while (sourceDim < sourceShape.size()) {
37     unsigned targetDim = reassociationMap.size();
38 
39     // If all the dimensions of the targetShape are exhausted, then the
40     // remaining dims in the source shape must be all 1s. So for such cases, set
41     // 1 as the target shape. The actual reassociation indices will be handled
42     // later.
43     int64_t currTargetShape =
44         (targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
45     while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
46            prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
47            sourceDim < sourceShape.size()) {
48       prodOfCollapsedDims *= sourceShape[sourceDim];
49       currIndices.push_back(sourceDim++);
50     }
51 
52     // If the current expanded dimension is dynamic, then the collapsed
53     // dimensions should also be dynamic and product of all previous unprocessed
54     // dimensions of the expanded shape should be 1.
55     if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
56         (currTargetShape != ShapedType::kDynamicSize ||
57          prodOfCollapsedDims != 1))
58       return llvm::None;
59 
60     // If the collapsed dim is dynamic, the current expanded dim should also
61     // be dynamic.
62     if (currTargetShape == ShapedType::kDynamicSize &&
63         sourceShape[sourceDim] != ShapedType::kDynamicSize)
64       return llvm::None;
65 
66     // For static shapes, if the product of dimensions of the expanded shape
67     // should match the collapsed dimension shape.
68     if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
69       return llvm::None;
70 
71     currIndices.push_back(sourceDim++);
72     // If the reassociation is empty but the currIndices is not, this by
73     // definition is folding unit-dimensions with the result being scalar type.
74     // So only append the `currIndices` if reassociation map is not empty.
75     if (targetDim == targetShape.size()) {
76       while (sourceDim < sourceShape.size())
77         currIndices.push_back(sourceDim++);
78       if (!reassociationMap.empty() && !currIndices.empty())
79         reassociationMap.back().append(currIndices.begin(), currIndices.end());
80       // Break out of the loops. We should be done here.
81       break;
82     }
83     reassociationMap.emplace_back(ReassociationIndices{});
84     std::swap(reassociationMap.back(), currIndices);
85     prodOfCollapsedDims = 1;
86   }
87   // All the dimensions in the two shapes must have been processed.
88   if (reassociationMap.size() != targetShape.size() ||
89       sourceDim != sourceShape.size())
90     return llvm::None;
91   return reassociationMap;
92 }
93 
94 ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
95                                      OperationState &result) {
96   // Parse the operand.
97   OpAsmParser::OperandType src;
98   if (parser.parseOperand(src))
99     return failure();
100 
101   // Parse reassociation indices.
102   Builder &b = parser.getBuilder();
103   SmallVector<Attribute, 4> reassociation;
104   if (parser.parseLSquare())
105     return failure();
106 
107   while (true) {
108     if (succeeded(parser.parseOptionalRSquare()))
109       break;
110     if (parser.parseLSquare())
111       return failure();
112     SmallVector<int64_t> indices;
113     while (true) {
114       int64_t index;
115       if (parser.parseInteger(index))
116         return failure();
117       indices.push_back(index);
118 
119       if (succeeded(parser.parseOptionalComma()))
120         continue;
121       if (failed(parser.parseRSquare()))
122         return failure();
123       break;
124     }
125     reassociation.push_back(b.getI64ArrayAttr(indices));
126     if (succeeded(parser.parseOptionalComma()))
127       continue;
128     if (failed(parser.parseRSquare()))
129       return failure();
130     break;
131   }
132 
133   result.addAttribute(getReassociationAttrName(),
134                       b.getArrayAttr(reassociation));
135 
136   // Parse optional attributes.
137   parser.parseOptionalAttrDict(result.attributes);
138 
139   // Parse types.
140   Type srcType;
141   Type resultType;
142   if (parser.parseColon() || parser.parseType(srcType) ||
143       parser.resolveOperand(src, srcType, result.operands) ||
144       parser.parseKeyword("into") || parser.parseType(resultType))
145     return failure();
146   result.addTypes(resultType);
147   return success();
148 }
149 
150 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
151     ArrayRef<ReassociationIndices> producerReassociations,
152     ArrayRef<ReassociationIndices> consumerReassociations,
153     MLIRContext *context) {
154   SmallVector<ReassociationIndices> composedIndices;
155   // Make the producer the larger sized vector. If they are of same size, the
156   // resulting reshape is not a supported reshape op.
157   if (producerReassociations.size() == consumerReassociations.size())
158     return llvm::None;
159   if (producerReassociations.size() < consumerReassociations.size())
160     std::swap(producerReassociations, consumerReassociations);
161 
162   // Handle the corner case of the result being a rank 0 shaped type. Return an
163   // empty reassociation.
164   if (consumerReassociations.empty())
165     return composedIndices;
166 
167   size_t consumerDims = std::accumulate(
168       consumerReassociations.begin(), consumerReassociations.end(), 0,
169       [](size_t all, ReassociationIndicesRef indices) {
170         return all + indices.size();
171       });
172   if (producerReassociations.size() != consumerDims)
173     return llvm::None;
174 
175   for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
176     ReassociationIndices reassociations;
177     for (int64_t consumerIndex : consumerIndices) {
178       for (int64_t producerIndex : producerReassociations[consumerIndex])
179         reassociations.push_back(producerIndex);
180     }
181     composedIndices.push_back(std::move(reassociations));
182   }
183   return composedIndices;
184 }
185 
186 SmallVector<SmallVector<AffineExpr, 2>, 2>
187 mlir::convertReassociationIndicesToExprs(
188     MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
189   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
190   for (const auto &indices : reassociationIndices) {
191     SmallVector<AffineExpr, 2> reassociationMap;
192     reassociationMap.reserve(indices.size());
193     for (int64_t index : indices)
194       reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
195     reassociationMaps.push_back(std::move(reassociationMap));
196   }
197   return reassociationMaps;
198 }
199 
200 template <typename AffineExprTy>
201 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
202   unsigned pos = 0;
203   for (const auto &exprs : exprArrays) {
204     for (auto expr : exprs) {
205       expr.walk([&pos](AffineExpr e) {
206         if (auto d = e.dyn_cast<AffineExprTy>())
207           pos = std::max(pos, d.getPosition());
208       });
209     }
210   }
211   return pos;
212 }
213 
214 ArrayAttr mlir::getReassociationIndicesAttribute(
215     OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
216   SmallVector<Attribute, 4> reassociationAttr =
217       llvm::to_vector<4>(llvm::map_range(
218           reassociation, [&](ReassociationIndices indices) -> Attribute {
219             return b.getI64ArrayAttr(indices).cast<Attribute>();
220           }));
221   return b.getArrayAttr(reassociationAttr);
222 }
223 
224 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
225     OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
226   SmallVector<ReassociationIndices, 2> reassociationIndices;
227   for (const auto &exprs : reassociationExprs) {
228     ReassociationIndices indices;
229     indices.reserve(exprs.size());
230     for (const auto &expr : exprs)
231       indices.push_back(expr.cast<AffineDimExpr>().getPosition());
232     reassociationIndices.push_back(indices);
233   }
234   return reassociationIndices;
235 }
236 
237 SmallVector<AffineMap, 4>
238 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
239   unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
240   assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
241          "Expected symbol-less expressions");
242   SmallVector<AffineMap, 4> maps;
243   maps.reserve(reassociation.size());
244   for (const auto &exprs : reassociation) {
245     assert(!exprs.empty());
246     maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
247   }
248   return maps;
249 }
250 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
251                                 int *invalidIndex) {
252   if (reassociation.empty())
253     return true;
254   unsigned nDims = reassociation[0].getNumDims();
255   unsigned nextExpectedDim = 0;
256   for (auto it : llvm::enumerate(reassociation)) {
257     auto m = it.value();
258     if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
259       if (invalidIndex)
260         *invalidIndex = it.index();
261       return false;
262     }
263     for (auto e : m.getResults()) {
264       auto d = e.dyn_cast<AffineDimExpr>();
265       if (!d || d.getPosition() != nextExpectedDim++) {
266         if (invalidIndex)
267           *invalidIndex = it.index();
268         return false;
269       }
270     }
271   }
272   if (nextExpectedDim != nDims) {
273     if (invalidIndex)
274       *invalidIndex = reassociation.size() - 1;
275     return false;
276   }
277   return true;
278 }
279