1 //===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities and common canonicalization patterns for
10 // reshape operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
16 
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/StringRef.h"
21 
22 namespace mlir {
23 
24 using ReassociationIndices = SmallVector<int64_t, 2>;
25 using ReassociationIndicesRef = ArrayRef<int64_t>;
26 using ReassociationExprs = SmallVector<AffineExpr, 2>;
27 
28 /// Attribute name for the ArrayAttr which encodes reassociation indices.
getReassociationAttrName()29 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
30 
31 /// Compose reassociation maps that are used in pair of reshape ops where one
32 /// is a producer and other is the consumer. Only valid to use this method when
33 /// both the producer and consumer are collapsing dimensions or both are
34 /// expanding dimensions.
35 ///
36 /// For example,
37 ///   producerReassociation = [[0, 1], [2], [3, 4]]
38 ///   consumerReassociation = [[0, 1], [2]]
39 ///
40 /// is folded into
41 ///
42 ///   result = [[0, 1, 2], [3, 4]].
43 Optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
44     ArrayRef<ReassociationIndices> producerReassociations,
45     ArrayRef<ReassociationIndices> consumerReassociations,
46     MLIRContext *context);
47 
48 /// Convert reassociation indices to affine expressions.
49 SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
50     MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
51 
52 /// Constructs affine maps out of Array<Array<AffineExpr>>.
53 SmallVector<AffineMap, 4>
54 getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation);
55 
56 /// Wraps a list of reassociations in an ArrayAttr.
57 ArrayAttr
58 getReassociationIndicesAttribute(OpBuilder &b,
59                                  ArrayRef<ReassociationIndices> reassociation);
60 
61 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
62 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
63     OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
64 
65 /// Return the reassociations maps to use to reshape given the source type and
66 /// the target type when possible. Return llvm::None when this computation
67 /// failed.
68 Optional<SmallVector<ReassociationIndices>>
69 getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
70 
71 /// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
72 /// possible.
73 Optional<SmallVector<ReassociationIndices>>
74 getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
75                                    ArrayRef<int64_t> targetShape);
76 
77 /// Return true if the reassociation specification is valid, false otherwise.
78 /// When false, the `invalidIndex` integer pointer is optionally filled with the
79 /// index of the offending reassociation map.
80 bool isReassociationValid(ArrayRef<AffineMap> reassociation,
81                           int *invalidIndex = nullptr);
82 
83 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
foldReshapeOp(ReshapeOpTy reshapeOp,ArrayRef<Attribute> operands)84 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
85                                   ArrayRef<Attribute> operands) {
86   // Fold producer-consumer reshape ops that where the operand type of the
87   // producer is same as the return type of the consumer.
88   auto reshapeSrcOp =
89       reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
90   if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
91     return reshapeSrcOp.getSrc();
92   // Reshape of a constant can be replaced with a new constant.
93   if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
94     return elements.reshape(
95         reshapeOp.getResult().getType().template cast<ShapedType>());
96   }
97   return nullptr;
98 }
99 
100 /// Common verifier for reshape-like types. Fills `expandedType` and
101 ///`collapsedType` with the proper `src` or `result` type.
102 template <typename Op, typename T>
verifyReshapeLikeTypes(Op op,T expandedType,T collapsedType,bool isExpansion)103 static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
104                                             T collapsedType, bool isExpansion) {
105   unsigned expandedRank = expandedType.getRank();
106   unsigned collapsedRank = collapsedType.getRank();
107   if (expandedRank < collapsedRank)
108     return op.emitOpError("expected the type ")
109            << expandedType
110            << " to have higher rank than the type = " << collapsedType;
111   if (expandedRank == 0)
112     return op.emitOpError("expected non-zero memref ranks");
113   if (expandedRank == collapsedRank)
114     return op.emitOpError("expected to collapse or expand dims");
115 
116   if (collapsedRank == 0) {
117     // If collapsed rank is 0, then expanded type must be static shaped and of
118     // sizes 1.
119     if (llvm::any_of(expandedType.getShape(),
120                      [](int64_t dim) -> bool { return dim != 1; }))
121       return op.emitOpError("invalid to reshape tensor/memref with non-unit "
122                             "extent dimensions to zero-rank tensor/memref");
123     return success();
124   }
125   if (collapsedRank != op.getReassociation().size())
126     return op.emitOpError("expected rank of the collapsed type(")
127            << collapsedRank << ") to be the number of reassociation maps("
128            << op.getReassociation().size() << ")";
129   auto maps = op.getReassociationMaps();
130   for (auto it : llvm::enumerate(maps))
131     if (it.value().getNumDims() != expandedRank)
132       return op.emitOpError("expected reassociation map #")
133              << it.index() << " of same rank as expanded memref("
134              << expandedRank << "), but got " << it.value().getNumDims();
135   int invalidIdx = 0;
136   if (!isReassociationValid(maps, &invalidIdx))
137     return op.emitOpError("expected reassociation map #")
138            << invalidIdx << " to be valid and contiguous";
139   return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
140 }
141 
142 /// Verify that shapes of the reshaped types using following rules
143 /// 1) if a dimension in the collapsed type is static, then the corresponding
144 ///    dimensions in the expanded shape should be
145 ///    a) static
146 ///    b) the product should be same as the collaped shape.
147 /// 2) if a dimension in the collaped type is dynamic, one and only one of the
148 ///    corresponding dimensions in the expanded type should be dynamic. This
149 ///    rule is only needed with reshape operations that are expanding.
150 LogicalResult reshapeLikeShapesAreCompatible(
151     function_ref<LogicalResult(const Twine &)> emitError,
152     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
153     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
154 
155 template <typename OpTy>
verifyReshapeLikeShapes(OpTy op,ShapedType collapsedType,ShapedType expandedType,bool isExpandingReshape)156 static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
157                                              ShapedType expandedType,
158                                              bool isExpandingReshape) {
159   return reshapeLikeShapesAreCompatible(
160       [&](const Twine &msg) { return op->emitOpError(msg); },
161       collapsedType.getShape(), expandedType.getShape(),
162       op.getReassociationIndices(), isExpandingReshape);
163 }
164 
165 /// Returns true iff the type is a MemRefType and has a non-identity layout.
166 bool hasNonIdentityLayout(Type type);
167 
168 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
169 /// dimensions or are both expanding dimensions.
170 template <typename ReshapeOpTy>
171 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
172   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
matchAndRewriteComposeReassociativeReshapeOps173   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
174                                 PatternRewriter &rewriter) const override {
175     auto srcReshapeOp =
176         reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
177     if (!srcReshapeOp)
178       return failure();
179 
180     ShapedType resultType = reshapeOp.getResultType();
181 
182     if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
183         hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
184         hasNonIdentityLayout(reshapeOp.getResult().getType()))
185       return failure();
186 
187     Optional<SmallVector<ReassociationIndices>> reassociationIndices =
188         composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
189                                     reshapeOp.getReassociationIndices(),
190                                     rewriter.getContext());
191     if (!reassociationIndices)
192       return failure();
193     rewriter.replaceOpWithNewOp<ReshapeOpTy>(
194         reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
195     return success();
196   }
197 };
198 
199 /// Pattern to compose
200 /// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
201 /// In that case both `srcType` and `resultType` can be expressed as a function
202 /// of `intermediateType`.
203 /// In order to demonstrate the approach, let's assume that `rank(srcType) >
204 /// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
205 /// In that case, we can iterate over every set of indices in `reassociation_2`
206 /// and try to find ids of sets of indices in `reassociation_1` that cover it
207 /// completely.
208 ///
209 /// Example:
210 ///
211 ///   %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
212 ///     : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
213 ///   %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
214 ///     : tensor<?x?x?x1xi64> into tensor<?x?xi64>
215 ///
216 /// can be canonicalized into
217 ///
218 ///   %0 = tensor.collapse_shape %arg [[0, 1], [2]]
219 ///     : tensor<?x?x?xi64> into tensor<?x?xi64>
220 ///
221 /// because [0] and [1] from `expand_shape` reassociation cover completely
222 /// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
223 /// indices, then we fail.
224 //
225 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
226 /// `reassociation_2` and produce `expand_shape`.
227 template <typename CollapseOpTy, typename ExpandOpTy>
228 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
229   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
matchAndRewriteComposeCollapseOfExpandOp230   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
231                                 PatternRewriter &rewriter) const override {
232     auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
233     if (!expandOp)
234       return failure();
235 
236     ShapedType srcType = expandOp.getSrcType();
237     ShapedType resultType = collapseOp.getResultType();
238 
239     if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
240         hasNonIdentityLayout(expandOp.getSrc().getType()) ||
241         hasNonIdentityLayout(expandOp.getResult().getType()))
242       return failure();
243 
244     int64_t srcRank = srcType.getRank();
245     int64_t resultRank = resultType.getRank();
246     if (srcType == resultType)
247       return failure();
248 
249     SmallVector<ReassociationIndices, 4> higherRankReassociation,
250         lowerRankReassociation;
251 
252     bool isResultCollapsed = srcRank > resultRank;
253     if (isResultCollapsed) {
254       higherRankReassociation = expandOp.getReassociationIndices();
255       lowerRankReassociation = collapseOp.getReassociationIndices();
256     } else {
257       higherRankReassociation = collapseOp.getReassociationIndices();
258       lowerRankReassociation = expandOp.getReassociationIndices();
259     }
260 
261     size_t higherRankIndicesID = 0;
262     SmallVector<ReassociationIndices, 4> composedReassociation;
263     for (const auto &lowerRankIndices : lowerRankReassociation) {
264       ReassociationIndices composedIndices;
265       while (higherRankIndicesID < higherRankReassociation.size()) {
266         auto rightmostIndex =
267             higherRankReassociation[higherRankIndicesID].back();
268         if (rightmostIndex > lowerRankIndices.back())
269           return failure();
270         composedIndices.push_back(higherRankIndicesID++);
271         if (rightmostIndex == lowerRankIndices.back())
272           break;
273       }
274       composedReassociation.push_back(composedIndices);
275     }
276     if (isResultCollapsed)
277       rewriter.replaceOpWithNewOp<CollapseOpTy>(
278           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
279     else
280       rewriter.replaceOpWithNewOp<ExpandOpTy>(
281           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
282     return success();
283   }
284 };
285 
286 template <typename ExpandOpTy, typename CollapseOpTy>
287 struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
288   using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
matchAndRewriteComposeExpandOfCollapseOp289   LogicalResult matchAndRewrite(ExpandOpTy expandOp,
290                                 PatternRewriter &rewriter) const override {
291     auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
292     if (!collapseOp)
293       return failure();
294 
295     ShapedType srcType = collapseOp.getSrcType();
296     ShapedType resultType = expandOp.getResultType();
297 
298     if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
299         hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
300         hasNonIdentityLayout(collapseOp.getResult().getType()))
301       return failure();
302 
303     int64_t srcRank = srcType.getRank();
304     int64_t resultRank = resultType.getRank();
305     if (srcType == resultType)
306       return failure();
307 
308     auto srcReassociation = collapseOp.getReassociationIndices();
309     auto resultReassociation = expandOp.getReassociationIndices();
310     if (srcRank > resultRank) {
311       auto composedReassociation = findCollapsingReassociation(
312           srcReassociation, resultReassociation, srcType.getShape(),
313           resultType.getShape());
314       if (!composedReassociation)
315         return failure();
316 
317       rewriter.replaceOpWithNewOp<CollapseOpTy>(
318           expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
319       return success();
320     }
321     auto composedReassociation =
322         findCollapsingReassociation(resultReassociation, srcReassociation,
323                                     resultType.getShape(), srcType.getShape());
324     if (!composedReassociation)
325       return failure();
326 
327     rewriter.replaceOpWithNewOp<ExpandOpTy>(
328         expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
329     return success();
330   }
331 
332 private:
333   // Attempts to find a way to collapse `srcShape` to `resultShape` by
334   // collapsing subshapes defined by the reassociation indices.
findCollapsingReassociationComposeExpandOfCollapseOp335   Optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
336       ArrayRef<ReassociationIndices> srcReassociation,
337       ArrayRef<ReassociationIndices> resultReassociation,
338       ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const {
339     SmallVector<ReassociationIndices, 4> composedReassociation;
340 
341     if (srcReassociation.empty())
342       return {getReassociationIndicesForCollapse(srcShape, resultShape)};
343 
344     for (auto item : llvm::zip(srcReassociation, resultReassociation)) {
345       auto &srcIndices = std::get<0>(item);
346       auto &resultIndices = std::get<1>(item);
347       auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
348       auto resultSubShape =
349           resultShape.slice(resultIndices.front(), resultIndices.size());
350 
351       if (srcSubShape.size() == resultSubShape.size()) {
352         if (srcSubShape == resultSubShape)
353           composedReassociation.push_back(srcIndices);
354         else
355           return llvm::None;
356       }
357 
358       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
359       auto subShapeReassociation =
360           getReassociationIndicesForCollapse(srcSubShape, resultSubShape);
361       if (!subShapeReassociation)
362         return llvm::None;
363 
364       // Remap the subshape indices back to the original srcShape.
365       for (auto &subshape_indices : *subShapeReassociation) {
366         ReassociationIndices shape_indices;
367         for (int64_t index : subshape_indices)
368           shape_indices.push_back(srcIndices.front() + index);
369         composedReassociation.push_back(shape_indices);
370       }
371     }
372     return {std::move(composedReassociation)};
373   }
374 };
375 
376 } // namespace mlir
377 
378 #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
379