1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/Interfaces/VectorInterfaces.h"
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 /// Transpose a vector transfer op's `in_bounds` attribute according to given
24 /// indices.
25 static ArrayAttr
26 transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27                       const SmallVector<unsigned> &permutation) {
28   SmallVector<bool> newInBoundsValues;
29   for (unsigned pos : permutation)
30     newInBoundsValues.push_back(
31         attr.getValue()[pos].cast<BoolAttr>().getValue());
32   return builder.getBoolArrayAttr(newInBoundsValues);
33 }
34 
35 /// Lower transfer_read op with permutation into a transfer_read with a
36 /// permutation map composed of leading zeros followed by a minor identiy +
37 /// vector.transpose op.
38 /// Ex:
39 ///     vector.transfer_read ...
40 ///         permutation_map: (d0, d1, d2) -> (0, d1)
41 /// into:
42 ///     %v = vector.transfer_read ...
43 ///         permutation_map: (d0, d1, d2) -> (d1, 0)
44 ///     vector.transpose %v, [1, 0]
45 ///
46 ///     vector.transfer_read ...
47 ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
48 /// into:
49 ///     %v = vector.transfer_read ...
50 ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
51 ///     vector.transpose %v, [0, 1, 3, 2, 4]
52 /// Note that an alternative is to transform it to linalg.transpose +
53 /// vector.transfer_read to do the transpose in memory instead.
54 struct TransferReadPermutationLowering
55     : public OpRewritePattern<vector::TransferReadOp> {
56   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
57 
58   LogicalResult matchAndRewrite(vector::TransferReadOp op,
59                                 PatternRewriter &rewriter) const override {
60     // TODO: support 0-d corner case.
61     if (op.getTransferRank() == 0)
62       return failure();
63 
64     SmallVector<unsigned> permutation;
65     AffineMap map = op.getPermutationMap();
66     if (map.getNumResults() == 0)
67       return failure();
68     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
69       return failure();
70     AffineMap permutationMap =
71         map.getPermutationMap(permutation, op.getContext());
72     if (permutationMap.isIdentity())
73       return failure();
74 
75     permutationMap = map.getPermutationMap(permutation, op.getContext());
76     // Caluclate the map of the new read by applying the inverse permutation.
77     permutationMap = inversePermutation(permutationMap);
78     AffineMap newMap = permutationMap.compose(map);
79     // Apply the reverse transpose to deduce the type of the transfer_read.
80     ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
81     SmallVector<int64_t> newVectorShape(originalShape.size());
82     for (const auto &pos : llvm::enumerate(permutation)) {
83       newVectorShape[pos.value()] = originalShape[pos.index()];
84     }
85 
86     // Transpose mask operand.
87     Value newMask;
88     if (op.getMask()) {
89       // Remove unused dims from the permutation map. E.g.:
90       // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
91       // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
92       auto comp = compressUnusedDims(map);
93       // Get positions of remaining result dims.
94       // E.g.:  (d0, d1, d2) -> (d2, 0, d1, 0 d0)
95       // maskTransposeIndices = [ 2,     1,    0]
96       SmallVector<int64_t> maskTransposeIndices;
97       for (unsigned i = 0; i < comp.getNumResults(); ++i) {
98         if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
99           maskTransposeIndices.push_back(expr.getPosition());
100       }
101 
102       newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
103                                                      maskTransposeIndices);
104     }
105 
106     // Transpose in_bounds attribute.
107     ArrayAttr newInBoundsAttr =
108         op.getInBounds() ? transposeInBoundsAttr(
109                                rewriter, op.getInBounds().value(), permutation)
110                          : ArrayAttr();
111 
112     // Generate new transfer_read operation.
113     VectorType newReadType =
114         VectorType::get(newVectorShape, op.getVectorType().getElementType());
115     Value newRead = rewriter.create<vector::TransferReadOp>(
116         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
117         AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);
118 
119     // Transpose result of transfer_read.
120     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
121     rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
122                                                      transposePerm);
123     return success();
124   }
125 };
126 
127 /// Lower transfer_write op with permutation into a transfer_write with a
128 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
129 /// Ex:
130 ///     vector.transfer_write %v ...
131 ///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
132 /// into:
133 ///     %tmp = vector.transpose %v, [2, 0, 1]
134 ///     vector.transfer_write %tmp ...
135 ///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
136 ///
137 ///     vector.transfer_write %v ...
138 ///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
139 /// into:
140 ///     %tmp = vector.transpose %v, [1, 0]
141 ///     %v = vector.transfer_write %tmp ...
142 ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
143 struct TransferWritePermutationLowering
144     : public OpRewritePattern<vector::TransferWriteOp> {
145   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
146 
147   LogicalResult matchAndRewrite(vector::TransferWriteOp op,
148                                 PatternRewriter &rewriter) const override {
149     // TODO: support 0-d corner case.
150     if (op.getTransferRank() == 0)
151       return failure();
152 
153     SmallVector<unsigned> permutation;
154     AffineMap map = op.getPermutationMap();
155     if (map.isMinorIdentity())
156       return failure();
157     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
158       return failure();
159 
160     // Remove unused dims from the permutation map. E.g.:
161     // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
162     // comp = (d0, d1, d2) -> (d2, d0, d1)
163     auto comp = compressUnusedDims(map);
164     // Get positions of remaining result dims.
165     SmallVector<int64_t> indices;
166     llvm::transform(comp.getResults(), std::back_inserter(indices),
167                     [](AffineExpr expr) {
168                       return expr.dyn_cast<AffineDimExpr>().getPosition();
169                     });
170 
171     // Transpose mask operand.
172     Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
173                                        op.getLoc(), op.getMask(), indices)
174                                  : Value();
175 
176     // Transpose in_bounds attribute.
177     ArrayAttr newInBoundsAttr =
178         op.getInBounds() ? transposeInBoundsAttr(
179                                rewriter, op.getInBounds().value(), permutation)
180                          : ArrayAttr();
181 
182     // Generate new transfer_write operation.
183     Value newVec = rewriter.create<vector::TransposeOp>(
184         op.getLoc(), op.getVector(), indices);
185     auto newMap = AffineMap::getMinorIdentityMap(
186         map.getNumDims(), map.getNumResults(), rewriter.getContext());
187     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
188         op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
189         newMask, newInBoundsAttr);
190 
191     return success();
192   }
193 };
194 
195 /// Lower transfer_read op with broadcast in the leading dimensions into
196 /// transfer_read of lower rank + vector.broadcast.
197 /// Ex: vector.transfer_read ...
198 ///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
199 /// into:
200 ///     %v = vector.transfer_read ...
201 ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
202 ///     vector.broadcast %v
203 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
204   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
205 
206   LogicalResult matchAndRewrite(vector::TransferReadOp op,
207                                 PatternRewriter &rewriter) const override {
208     // TODO: support 0-d corner case.
209     if (op.getTransferRank() == 0)
210       return failure();
211 
212     AffineMap map = op.getPermutationMap();
213     unsigned numLeadingBroadcast = 0;
214     for (auto expr : map.getResults()) {
215       auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
216       if (!dimExpr || dimExpr.getValue() != 0)
217         break;
218       numLeadingBroadcast++;
219     }
220     // If there are no leading zeros in the map there is nothing to do.
221     if (numLeadingBroadcast == 0)
222       return failure();
223     VectorType originalVecType = op.getVectorType();
224     unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
225     // Calculate new map, vector type and masks without the leading zeros.
226     AffineMap newMap = AffineMap::get(
227         map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
228         op.getContext());
229     // Only remove the leading zeros if the rest of the map is a minor identity
230     // with broadasting. Otherwise we first want to permute the map.
231     if (!newMap.isMinorIdentityWithBroadcasting())
232       return failure();
233 
234     // TODO: support zero-dimension vectors natively.  See:
235     // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
236     // In the meantime, lower these to a scalar load when they pop up.
237     if (reducedShapeRank == 0) {
238       Value newRead;
239       if (op.getShapedType().isa<TensorType>()) {
240         newRead = rewriter.create<tensor::ExtractOp>(
241             op.getLoc(), op.getSource(), op.getIndices());
242       } else {
243         newRead = rewriter.create<memref::LoadOp>(
244             op.getLoc(), originalVecType.getElementType(), op.getSource(),
245             op.getIndices());
246       }
247       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
248                                                        newRead);
249       return success();
250     }
251     SmallVector<int64_t> newShape = llvm::to_vector<4>(
252         originalVecType.getShape().take_back(reducedShapeRank));
253     // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
254     if (newShape.empty())
255       return failure();
256     VectorType newReadType =
257         VectorType::get(newShape, originalVecType.getElementType());
258     ArrayAttr newInBoundsAttr =
259         op.getInBounds()
260             ? rewriter.getArrayAttr(
261                   op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
262             : ArrayAttr();
263     Value newRead = rewriter.create<vector::TransferReadOp>(
264         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
265         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
266         newInBoundsAttr);
267     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
268                                                      newRead);
269     return success();
270   }
271 };
272 
273 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
274     RewritePatternSet &patterns) {
275   patterns.add<TransferReadPermutationLowering,
276                TransferWritePermutationLowering, TransferOpReduceRank>(
277       patterns.getContext());
278 }
279