1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrite patterns for the permutation_map attribute of
10 // vector.transfer operations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/Interfaces/VectorInterfaces.h"
19
20 using namespace mlir;
21 using namespace mlir::vector;
22
23 /// Transpose a vector transfer op's `in_bounds` attribute according to given
24 /// indices.
25 static ArrayAttr
transposeInBoundsAttr(OpBuilder & builder,ArrayAttr attr,const SmallVector<unsigned> & permutation)26 transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27 const SmallVector<unsigned> &permutation) {
28 SmallVector<bool> newInBoundsValues;
29 for (unsigned pos : permutation)
30 newInBoundsValues.push_back(
31 attr.getValue()[pos].cast<BoolAttr>().getValue());
32 return builder.getBoolArrayAttr(newInBoundsValues);
33 }
34
35 /// Lower transfer_read op with permutation into a transfer_read with a
36 /// permutation map composed of leading zeros followed by a minor identiy +
37 /// vector.transpose op.
38 /// Ex:
39 /// vector.transfer_read ...
40 /// permutation_map: (d0, d1, d2) -> (0, d1)
41 /// into:
42 /// %v = vector.transfer_read ...
43 /// permutation_map: (d0, d1, d2) -> (d1, 0)
44 /// vector.transpose %v, [1, 0]
45 ///
46 /// vector.transfer_read ...
47 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
48 /// into:
49 /// %v = vector.transfer_read ...
50 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
51 /// vector.transpose %v, [0, 1, 3, 2, 4]
52 /// Note that an alternative is to transform it to linalg.transpose +
53 /// vector.transfer_read to do the transpose in memory instead.
54 struct TransferReadPermutationLowering
55 : public OpRewritePattern<vector::TransferReadOp> {
56 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
57
matchAndRewriteTransferReadPermutationLowering58 LogicalResult matchAndRewrite(vector::TransferReadOp op,
59 PatternRewriter &rewriter) const override {
60 // TODO: support 0-d corner case.
61 if (op.getTransferRank() == 0)
62 return failure();
63
64 SmallVector<unsigned> permutation;
65 AffineMap map = op.getPermutationMap();
66 if (map.getNumResults() == 0)
67 return failure();
68 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
69 return failure();
70 AffineMap permutationMap =
71 map.getPermutationMap(permutation, op.getContext());
72 if (permutationMap.isIdentity())
73 return failure();
74
75 permutationMap = map.getPermutationMap(permutation, op.getContext());
76 // Caluclate the map of the new read by applying the inverse permutation.
77 permutationMap = inversePermutation(permutationMap);
78 AffineMap newMap = permutationMap.compose(map);
79 // Apply the reverse transpose to deduce the type of the transfer_read.
80 ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
81 SmallVector<int64_t> newVectorShape(originalShape.size());
82 for (const auto &pos : llvm::enumerate(permutation)) {
83 newVectorShape[pos.value()] = originalShape[pos.index()];
84 }
85
86 // Transpose mask operand.
87 Value newMask;
88 if (op.getMask()) {
89 // Remove unused dims from the permutation map. E.g.:
90 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
91 // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
92 auto comp = compressUnusedDims(map);
93 // Get positions of remaining result dims.
94 // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
95 // maskTransposeIndices = [ 2, 1, 0]
96 SmallVector<int64_t> maskTransposeIndices;
97 for (unsigned i = 0; i < comp.getNumResults(); ++i) {
98 if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
99 maskTransposeIndices.push_back(expr.getPosition());
100 }
101
102 newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
103 maskTransposeIndices);
104 }
105
106 // Transpose in_bounds attribute.
107 ArrayAttr newInBoundsAttr =
108 op.getInBounds() ? transposeInBoundsAttr(
109 rewriter, op.getInBounds().value(), permutation)
110 : ArrayAttr();
111
112 // Generate new transfer_read operation.
113 VectorType newReadType =
114 VectorType::get(newVectorShape, op.getVectorType().getElementType());
115 Value newRead = rewriter.create<vector::TransferReadOp>(
116 op.getLoc(), newReadType, op.getSource(), op.getIndices(),
117 AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);
118
119 // Transpose result of transfer_read.
120 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
121 rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
122 transposePerm);
123 return success();
124 }
125 };
126
127 /// Lower transfer_write op with permutation into a transfer_write with a
128 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
129 /// Ex:
130 /// vector.transfer_write %v ...
131 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
132 /// into:
133 /// %tmp = vector.transpose %v, [2, 0, 1]
134 /// vector.transfer_write %tmp ...
135 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
136 ///
137 /// vector.transfer_write %v ...
138 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
139 /// into:
140 /// %tmp = vector.transpose %v, [1, 0]
141 /// %v = vector.transfer_write %tmp ...
142 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
143 struct TransferWritePermutationLowering
144 : public OpRewritePattern<vector::TransferWriteOp> {
145 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
146
matchAndRewriteTransferWritePermutationLowering147 LogicalResult matchAndRewrite(vector::TransferWriteOp op,
148 PatternRewriter &rewriter) const override {
149 // TODO: support 0-d corner case.
150 if (op.getTransferRank() == 0)
151 return failure();
152
153 SmallVector<unsigned> permutation;
154 AffineMap map = op.getPermutationMap();
155 if (map.isMinorIdentity())
156 return failure();
157 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
158 return failure();
159
160 // Remove unused dims from the permutation map. E.g.:
161 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
162 // comp = (d0, d1, d2) -> (d2, d0, d1)
163 auto comp = compressUnusedDims(map);
164 // Get positions of remaining result dims.
165 SmallVector<int64_t> indices;
166 llvm::transform(comp.getResults(), std::back_inserter(indices),
167 [](AffineExpr expr) {
168 return expr.dyn_cast<AffineDimExpr>().getPosition();
169 });
170
171 // Transpose mask operand.
172 Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
173 op.getLoc(), op.getMask(), indices)
174 : Value();
175
176 // Transpose in_bounds attribute.
177 ArrayAttr newInBoundsAttr =
178 op.getInBounds() ? transposeInBoundsAttr(
179 rewriter, op.getInBounds().value(), permutation)
180 : ArrayAttr();
181
182 // Generate new transfer_write operation.
183 Value newVec = rewriter.create<vector::TransposeOp>(
184 op.getLoc(), op.getVector(), indices);
185 auto newMap = AffineMap::getMinorIdentityMap(
186 map.getNumDims(), map.getNumResults(), rewriter.getContext());
187 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
188 op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
189 newMask, newInBoundsAttr);
190
191 return success();
192 }
193 };
194
195 /// Lower transfer_read op with broadcast in the leading dimensions into
196 /// transfer_read of lower rank + vector.broadcast.
197 /// Ex: vector.transfer_read ...
198 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
199 /// into:
200 /// %v = vector.transfer_read ...
201 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
202 /// vector.broadcast %v
203 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
204 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
205
matchAndRewriteTransferOpReduceRank206 LogicalResult matchAndRewrite(vector::TransferReadOp op,
207 PatternRewriter &rewriter) const override {
208 // TODO: support 0-d corner case.
209 if (op.getTransferRank() == 0)
210 return failure();
211
212 AffineMap map = op.getPermutationMap();
213 unsigned numLeadingBroadcast = 0;
214 for (auto expr : map.getResults()) {
215 auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
216 if (!dimExpr || dimExpr.getValue() != 0)
217 break;
218 numLeadingBroadcast++;
219 }
220 // If there are no leading zeros in the map there is nothing to do.
221 if (numLeadingBroadcast == 0)
222 return failure();
223 VectorType originalVecType = op.getVectorType();
224 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
225 // Calculate new map, vector type and masks without the leading zeros.
226 AffineMap newMap = AffineMap::get(
227 map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
228 op.getContext());
229 // Only remove the leading zeros if the rest of the map is a minor identity
230 // with broadasting. Otherwise we first want to permute the map.
231 if (!newMap.isMinorIdentityWithBroadcasting())
232 return failure();
233
234 // TODO: support zero-dimension vectors natively. See:
235 // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
236 // In the meantime, lower these to a scalar load when they pop up.
237 if (reducedShapeRank == 0) {
238 Value newRead;
239 if (op.getShapedType().isa<TensorType>()) {
240 newRead = rewriter.create<tensor::ExtractOp>(
241 op.getLoc(), op.getSource(), op.getIndices());
242 } else {
243 newRead = rewriter.create<memref::LoadOp>(
244 op.getLoc(), originalVecType.getElementType(), op.getSource(),
245 op.getIndices());
246 }
247 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
248 newRead);
249 return success();
250 }
251 SmallVector<int64_t> newShape = llvm::to_vector<4>(
252 originalVecType.getShape().take_back(reducedShapeRank));
253 // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
254 if (newShape.empty())
255 return failure();
256 VectorType newReadType =
257 VectorType::get(newShape, originalVecType.getElementType());
258 ArrayAttr newInBoundsAttr =
259 op.getInBounds()
260 ? rewriter.getArrayAttr(
261 op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
262 : ArrayAttr();
263 Value newRead = rewriter.create<vector::TransferReadOp>(
264 op.getLoc(), newReadType, op.getSource(), op.getIndices(),
265 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
266 newInBoundsAttr);
267 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
268 newRead);
269 return success();
270 }
271 };
272
populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet & patterns)273 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
274 RewritePatternSet &patterns) {
275 patterns.add<TransferReadPermutationLowering,
276 TransferWritePermutationLowering, TransferOpReduceRank>(
277 patterns.getContext());
278 }
279