1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10 
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
13 
14 #include <numeric>
15 
16 using namespace mlir;
17 
18 Optional<SmallVector<ReassociationIndices>>
19 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
20                                         ShapedType targetType) {
21   // Make the sourceType greater rank than the targetType. If they are same
22   // rank, then its an unsupported reshape op.
23   if (sourceType.getRank() == targetType.getRank())
24     return llvm::None;
25   if (sourceType.getRank() < targetType.getRank())
26     std::swap(sourceType, targetType);
27 
28   ArrayRef<int64_t> sourceShape = sourceType.getShape();
29   ArrayRef<int64_t> targetShape = targetType.getShape();
30   unsigned sourceDim = 0;
31   SmallVector<ReassociationIndices> reassociationMap;
32   reassociationMap.reserve(targetType.getRank());
33 
34   ReassociationIndices currIndices;
35   int64_t prodOfCollapsedDims = 1;
36   while (sourceDim < sourceShape.size()) {
37     unsigned targetDim = reassociationMap.size();
38 
39     // If all the dimensions of the targetShape are exhausted, then the
40     // remaining dims in the source shape must be all 1s. So for such cases, set
41     // 1 as the target shape. The actual reassociation indices will be handled
42     // later.
43     int64_t currTargetShape =
44         (targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
45     while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
46            prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
47            sourceDim < sourceShape.size()) {
48       prodOfCollapsedDims *= sourceShape[sourceDim];
49       currIndices.push_back(sourceDim++);
50     }
51 
52     // If the current expanded dimension is dynamic, then the collapsed
53     // dimensions should also be dynamic and product of all previous unprocessed
54     // dimensions of the expanded shape should be 1.
55     if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
56         (currTargetShape != ShapedType::kDynamicSize ||
57          prodOfCollapsedDims != 1))
58       return llvm::None;
59 
60     // If the collapsed dim is dynamic, the current expanded dim should also
61     // be dynamic.
62     if (currTargetShape == ShapedType::kDynamicSize &&
63         sourceShape[sourceDim] != ShapedType::kDynamicSize)
64       return llvm::None;
65 
66     // For static shapes, if the product of dimensions of the expanded shape
67     // should match the collapsed dimension shape.
68     if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
69       return llvm::None;
70 
71     currIndices.push_back(sourceDim++);
72     // If the reassociation is empty but the currIndices is not, this by
73     // definition is folding unit-dimensions with the result being scalar type.
74     // So only append the `currIndices` if reassociation map is not empty.
75     if (targetDim == targetShape.size()) {
76       if (!reassociationMap.empty() && !currIndices.empty())
77         reassociationMap.back().append(currIndices.begin(), currIndices.end());
78       // Break out of the loops. We should be done here.
79       break;
80     }
81     reassociationMap.emplace_back(ReassociationIndices{});
82     std::swap(reassociationMap.back(), currIndices);
83     prodOfCollapsedDims = 1;
84   }
85   // All the dimensions in the two shapes must have been processed.
86   if (reassociationMap.size() != targetShape.size() ||
87       sourceDim != sourceShape.size())
88     return llvm::None;
89   return reassociationMap;
90 }
91 
92 ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
93                                      OperationState &result) {
94   // Parse the operand.
95   OpAsmParser::OperandType src;
96   if (parser.parseOperand(src))
97     return failure();
98 
99   // Parse reassociation indices.
100   Builder &b = parser.getBuilder();
101   SmallVector<Attribute, 4> reassociation;
102   if (parser.parseLSquare())
103     return failure();
104 
105   while (true) {
106     if (succeeded(parser.parseOptionalRSquare()))
107       break;
108     if (parser.parseLSquare())
109       return failure();
110     SmallVector<int64_t> indices;
111     while (true) {
112       int64_t index;
113       if (parser.parseInteger(index))
114         return failure();
115       indices.push_back(index);
116 
117       if (succeeded(parser.parseOptionalComma()))
118         continue;
119       if (failed(parser.parseRSquare()))
120         return failure();
121       break;
122     }
123     reassociation.push_back(b.getI64ArrayAttr(indices));
124     if (succeeded(parser.parseOptionalComma()))
125       continue;
126     if (failed(parser.parseRSquare()))
127       return failure();
128     break;
129   }
130 
131   result.addAttribute(getReassociationAttrName(),
132                       b.getArrayAttr(reassociation));
133 
134   // Parse optional attributes.
135   parser.parseOptionalAttrDict(result.attributes);
136 
137   // Parse types.
138   Type srcType;
139   Type resultType;
140   if (parser.parseColon() || parser.parseType(srcType) ||
141       parser.resolveOperand(src, srcType, result.operands) ||
142       parser.parseKeyword("into") || parser.parseType(resultType))
143     return failure();
144   result.addTypes(resultType);
145   return success();
146 }
147 
148 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
149     ArrayRef<ReassociationIndices> producerReassociations,
150     ArrayRef<ReassociationIndices> consumerReassociations,
151     MLIRContext *context) {
152   SmallVector<ReassociationIndices> composedIndices;
153   // Make the producer the larger sized vector. If they are of same size, the
154   // resulting reshape is not a supported reshape op.
155   if (producerReassociations.size() == consumerReassociations.size())
156     return llvm::None;
157   if (producerReassociations.size() < consumerReassociations.size())
158     std::swap(producerReassociations, consumerReassociations);
159 
160   // Handle the corner case of the result being a rank 0 shaped type. Return an
161   // empty reassociation.
162   if (consumerReassociations.empty())
163     return composedIndices;
164 
165   size_t consumerDims = std::accumulate(
166       consumerReassociations.begin(), consumerReassociations.end(), 0,
167       [](size_t all, ReassociationIndicesRef indices) {
168         return all + indices.size();
169       });
170   if (producerReassociations.size() != consumerDims)
171     return llvm::None;
172 
173   for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
174     ReassociationIndices reassociations;
175     for (int64_t consumerIndex : consumerIndices) {
176       for (int64_t producerIndex : producerReassociations[consumerIndex])
177         reassociations.push_back(producerIndex);
178     }
179     composedIndices.push_back(std::move(reassociations));
180   }
181   return composedIndices;
182 }
183 
184 SmallVector<SmallVector<AffineExpr, 2>, 2>
185 mlir::convertReassociationIndicesToExprs(
186     MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
187   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
188   for (const auto &indices : reassociationIndices) {
189     SmallVector<AffineExpr, 2> reassociationMap;
190     reassociationMap.reserve(indices.size());
191     for (int64_t index : indices)
192       reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
193     reassociationMaps.push_back(std::move(reassociationMap));
194   }
195   return reassociationMaps;
196 }
197 
198 template <typename AffineExprTy>
199 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
200   unsigned pos = 0;
201   for (const auto &exprs : exprArrays) {
202     for (auto expr : exprs) {
203       expr.walk([&pos](AffineExpr e) {
204         if (auto d = e.dyn_cast<AffineExprTy>())
205           pos = std::max(pos, d.getPosition());
206       });
207     }
208   }
209   return pos;
210 }
211 
212 ArrayAttr mlir::getReassociationIndicesAttribute(
213     OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
214   SmallVector<Attribute, 4> reassociationAttr =
215       llvm::to_vector<4>(llvm::map_range(
216           reassociation, [&](ReassociationIndices indices) -> Attribute {
217             return b.getI64ArrayAttr(indices).cast<Attribute>();
218           }));
219   return b.getArrayAttr(reassociationAttr);
220 }
221 
222 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
223     OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
224   SmallVector<ReassociationIndices, 2> reassociationIndices;
225   for (const auto &exprs : reassociationExprs) {
226     ReassociationIndices indices;
227     indices.reserve(exprs.size());
228     for (const auto &expr : exprs)
229       indices.push_back(expr.cast<AffineDimExpr>().getPosition());
230     reassociationIndices.push_back(indices);
231   }
232   return reassociationIndices;
233 }
234 
235 SmallVector<AffineMap, 4>
236 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
237   unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
238   assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
239          "Expected symbol-less expressions");
240   SmallVector<AffineMap, 4> maps;
241   maps.reserve(reassociation.size());
242   for (const auto &exprs : reassociation) {
243     assert(!exprs.empty());
244     maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
245   }
246   return maps;
247 }
248 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
249                                 int *invalidIndex) {
250   if (reassociation.empty())
251     return true;
252   unsigned nDims = reassociation[0].getNumDims();
253   unsigned nextExpectedDim = 0;
254   for (auto it : llvm::enumerate(reassociation)) {
255     auto m = it.value();
256     if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
257       if (invalidIndex)
258         *invalidIndex = it.index();
259       return false;
260     }
261     for (auto e : m.getResults()) {
262       auto d = e.dyn_cast<AffineDimExpr>();
263       if (!d || d.getPosition() != nextExpectedDim++) {
264         if (invalidIndex)
265           *invalidIndex = it.index();
266         return false;
267       }
268     }
269   }
270   if (nextExpectedDim != nDims) {
271     if (invalidIndex)
272       *invalidIndex = reassociation.size() - 1;
273     return false;
274   }
275   return true;
276 }
277