1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/Interfaces/VectorInterfaces.h"
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 /// Transpose a vector transfer op's `in_bounds` attribute according to given
24 /// indices.
25 static ArrayAttr
26 transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27                       const SmallVector<unsigned> &permutation) {
28   SmallVector<bool> newInBoundsValues;
29   for (unsigned pos : permutation)
30     newInBoundsValues.push_back(
31         attr.getValue()[pos].cast<BoolAttr>().getValue());
32   return builder.getBoolArrayAttr(newInBoundsValues);
33 }
34 
35 /// Lower transfer_read op with permutation into a transfer_read with a
36 /// permutation map composed of leading zeros followed by a minor identiy +
37 /// vector.transpose op.
38 /// Ex:
39 ///     vector.transfer_read ...
40 ///         permutation_map: (d0, d1, d2) -> (0, d1)
41 /// into:
42 ///     %v = vector.transfer_read ...
43 ///         permutation_map: (d0, d1, d2) -> (d1, 0)
44 ///     vector.transpose %v, [1, 0]
45 ///
46 ///     vector.transfer_read ...
47 ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
48 /// into:
49 ///     %v = vector.transfer_read ...
50 ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
51 ///     vector.transpose %v, [0, 1, 3, 2, 4]
52 /// Note that an alternative is to transform it to linalg.transpose +
53 /// vector.transfer_read to do the transpose in memory instead.
54 struct TransferReadPermutationLowering
55     : public OpRewritePattern<vector::TransferReadOp> {
56   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
57 
58   LogicalResult matchAndRewrite(vector::TransferReadOp op,
59                                 PatternRewriter &rewriter) const override {
60     // TODO: support 0-d corner case.
61     if (op.getTransferRank() == 0)
62       return failure();
63 
64     SmallVector<unsigned> permutation;
65     AffineMap map = op.getPermutationMap();
66     if (map.getNumResults() == 0)
67       return failure();
68     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
69       return failure();
70     AffineMap permutationMap =
71         map.getPermutationMap(permutation, op.getContext());
72     if (permutationMap.isIdentity())
73       return failure();
74 
75     permutationMap = map.getPermutationMap(permutation, op.getContext());
76     // Caluclate the map of the new read by applying the inverse permutation.
77     permutationMap = inversePermutation(permutationMap);
78     AffineMap newMap = permutationMap.compose(map);
79     // Apply the reverse transpose to deduce the type of the transfer_read.
80     ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
81     SmallVector<int64_t> newVectorShape(originalShape.size());
82     for (const auto &pos : llvm::enumerate(permutation)) {
83       newVectorShape[pos.value()] = originalShape[pos.index()];
84     }
85 
86     // Transpose mask operand.
87     Value newMask;
88     if (op.getMask()) {
89       // Remove unused dims from the permutation map. E.g.:
90       // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
91       // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
92       auto comp = compressUnusedDims(map);
93       // Get positions of remaining result dims.
94       // E.g.:  (d0, d1, d2) -> (d2, 0, d1, 0 d0)
95       // maskTransposeIndices = [ 2,     1,    0]
96       SmallVector<int64_t> maskTransposeIndices;
97       for (unsigned i = 0; i < comp.getNumResults(); ++i) {
98         if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
99           maskTransposeIndices.push_back(expr.getPosition());
100       }
101 
102       newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
103                                                      maskTransposeIndices);
104     }
105 
106     // Transpose in_bounds attribute.
107     ArrayAttr newInBoundsAttr =
108         op.getInBounds()
109             ? transposeInBoundsAttr(rewriter, op.getInBounds().getValue(),
110                                     permutation)
111             : ArrayAttr();
112 
113     // Generate new transfer_read operation.
114     VectorType newReadType =
115         VectorType::get(newVectorShape, op.getVectorType().getElementType());
116     Value newRead = rewriter.create<vector::TransferReadOp>(
117         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
118         AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);
119 
120     // Transpose result of transfer_read.
121     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
122     rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
123                                                      transposePerm);
124     return success();
125   }
126 };
127 
128 /// Lower transfer_write op with permutation into a transfer_write with a
129 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
130 /// Ex:
131 ///     vector.transfer_write %v ...
132 ///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
133 /// into:
134 ///     %tmp = vector.transpose %v, [2, 0, 1]
135 ///     vector.transfer_write %tmp ...
136 ///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
137 ///
138 ///     vector.transfer_write %v ...
139 ///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
140 /// into:
141 ///     %tmp = vector.transpose %v, [1, 0]
142 ///     %v = vector.transfer_write %tmp ...
143 ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
144 struct TransferWritePermutationLowering
145     : public OpRewritePattern<vector::TransferWriteOp> {
146   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
147 
148   LogicalResult matchAndRewrite(vector::TransferWriteOp op,
149                                 PatternRewriter &rewriter) const override {
150     // TODO: support 0-d corner case.
151     if (op.getTransferRank() == 0)
152       return failure();
153 
154     SmallVector<unsigned> permutation;
155     AffineMap map = op.getPermutationMap();
156     if (map.isMinorIdentity())
157       return failure();
158     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
159       return failure();
160 
161     // Remove unused dims from the permutation map. E.g.:
162     // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
163     // comp = (d0, d1, d2) -> (d2, d0, d1)
164     auto comp = compressUnusedDims(map);
165     // Get positions of remaining result dims.
166     SmallVector<int64_t> indices;
167     llvm::transform(comp.getResults(), std::back_inserter(indices),
168                     [](AffineExpr expr) {
169                       return expr.dyn_cast<AffineDimExpr>().getPosition();
170                     });
171 
172     // Transpose mask operand.
173     Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
174                                        op.getLoc(), op.getMask(), indices)
175                                  : Value();
176 
177     // Transpose in_bounds attribute.
178     ArrayAttr newInBoundsAttr =
179         op.getInBounds()
180             ? transposeInBoundsAttr(rewriter, op.getInBounds().getValue(),
181                                     permutation)
182             : ArrayAttr();
183 
184     // Generate new transfer_write operation.
185     Value newVec = rewriter.create<vector::TransposeOp>(
186         op.getLoc(), op.getVector(), indices);
187     auto newMap = AffineMap::getMinorIdentityMap(
188         map.getNumDims(), map.getNumResults(), rewriter.getContext());
189     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
190         op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
191         newMask, newInBoundsAttr);
192 
193     return success();
194   }
195 };
196 
197 /// Lower transfer_read op with broadcast in the leading dimensions into
198 /// transfer_read of lower rank + vector.broadcast.
199 /// Ex: vector.transfer_read ...
200 ///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
201 /// into:
202 ///     %v = vector.transfer_read ...
203 ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
204 ///     vector.broadcast %v
205 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
206   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
207 
208   LogicalResult matchAndRewrite(vector::TransferReadOp op,
209                                 PatternRewriter &rewriter) const override {
210     // TODO: support 0-d corner case.
211     if (op.getTransferRank() == 0)
212       return failure();
213 
214     AffineMap map = op.getPermutationMap();
215     unsigned numLeadingBroadcast = 0;
216     for (auto expr : map.getResults()) {
217       auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
218       if (!dimExpr || dimExpr.getValue() != 0)
219         break;
220       numLeadingBroadcast++;
221     }
222     // If there are no leading zeros in the map there is nothing to do.
223     if (numLeadingBroadcast == 0)
224       return failure();
225     VectorType originalVecType = op.getVectorType();
226     unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
227     // Calculate new map, vector type and masks without the leading zeros.
228     AffineMap newMap = AffineMap::get(
229         map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
230         op.getContext());
231     // Only remove the leading zeros if the rest of the map is a minor identity
232     // with broadasting. Otherwise we first want to permute the map.
233     if (!newMap.isMinorIdentityWithBroadcasting())
234       return failure();
235 
236     // TODO: support zero-dimension vectors natively.  See:
237     // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
238     // In the meantime, lower these to a scalar load when they pop up.
239     if (reducedShapeRank == 0) {
240       Value newRead;
241       if (op.getShapedType().isa<TensorType>()) {
242         newRead = rewriter.create<tensor::ExtractOp>(
243             op.getLoc(), op.getSource(), op.getIndices());
244       } else {
245         newRead = rewriter.create<memref::LoadOp>(
246             op.getLoc(), originalVecType.getElementType(), op.getSource(),
247             op.getIndices());
248       }
249       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
250                                                        newRead);
251       return success();
252     }
253     SmallVector<int64_t> newShape = llvm::to_vector<4>(
254         originalVecType.getShape().take_back(reducedShapeRank));
255     // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
256     if (newShape.empty())
257       return failure();
258     VectorType newReadType =
259         VectorType::get(newShape, originalVecType.getElementType());
260     ArrayAttr newInBoundsAttr =
261         op.getInBounds()
262             ? rewriter.getArrayAttr(
263                   op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
264             : ArrayAttr();
265     Value newRead = rewriter.create<vector::TransferReadOp>(
266         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
267         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
268         newInBoundsAttr);
269     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
270                                                      newRead);
271     return success();
272   }
273 };
274 
275 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
276     RewritePatternSet &patterns) {
277   patterns.add<TransferReadPermutationLowering,
278                TransferWritePermutationLowering, TransferOpReduceRank>(
279       patterns.getContext());
280 }
281