1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/Dialect/Vector/IR/VectorOps.h"
11 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
12 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/IR/TypeUtilities.h"
16
17 #define DEBUG_TYPE "vector-drop-unit-dim"
18
19 using namespace mlir;
20 using namespace mlir::vector;
21
22 // Trims leading one dimensions from `oldType` and returns the result type.
23 // Returns `vector<1xT>` if `oldType` only has one element.
trimLeadingOneDims(VectorType oldType)24 static VectorType trimLeadingOneDims(VectorType oldType) {
25 ArrayRef<int64_t> oldShape = oldType.getShape();
26 ArrayRef<int64_t> newShape =
27 oldShape.drop_while([](int64_t dim) { return dim == 1; });
28 // Make sure we have at least 1 dimension per vector type requirements.
29 if (newShape.empty())
30 newShape = oldShape.take_back();
31 return VectorType::get(newShape, oldType.getElementType());
32 }
33
34 /// Return a smallVector of size `rank` containing all zeros.
splatZero(int64_t rank)35 static SmallVector<int64_t> splatZero(int64_t rank) {
36 return SmallVector<int64_t>(rank, 0);
37 }
38 namespace {
39
40 // Casts away leading one dimensions in vector.extract_strided_slice's vector
41 // input by inserting vector.broadcast.
42 struct CastAwayExtractStridedSliceLeadingOneDim
43 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
44 using OpRewritePattern::OpRewritePattern;
45
matchAndRewrite__anon384a66160211::CastAwayExtractStridedSliceLeadingOneDim46 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
47 PatternRewriter &rewriter) const override {
48 // vector.extract_strided_slice requires the input and output vector to have
49 // the same rank. Here we drop leading one dimensions from the input vector
50 // type to make sure we don't cause mismatch.
51 VectorType oldSrcType = extractOp.getVectorType();
52 VectorType newSrcType = trimLeadingOneDims(oldSrcType);
53
54 if (newSrcType.getRank() == oldSrcType.getRank())
55 return failure();
56
57 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
58
59 VectorType oldDstType = extractOp.getType();
60 VectorType newDstType =
61 VectorType::get(oldDstType.getShape().drop_front(dropCount),
62 oldDstType.getElementType());
63
64 Location loc = extractOp.getLoc();
65
66 Value newSrcVector = rewriter.create<vector::ExtractOp>(
67 loc, extractOp.getVector(), splatZero(dropCount));
68
69 // The offsets/sizes/strides attribute can have a less number of elements
70 // than the input vector's rank: it is meant for the leading dimensions.
71 auto newOffsets = rewriter.getArrayAttr(
72 extractOp.getOffsets().getValue().drop_front(dropCount));
73 auto newSizes = rewriter.getArrayAttr(
74 extractOp.getSizes().getValue().drop_front(dropCount));
75 auto newStrides = rewriter.getArrayAttr(
76 extractOp.getStrides().getValue().drop_front(dropCount));
77
78 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
79 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
80
81 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
82 newExtractOp);
83
84 return success();
85 }
86 };
87
88 // Casts away leading one dimensions in vector.insert_strided_slice's vector
89 // inputs by inserting vector.broadcast.
90 struct CastAwayInsertStridedSliceLeadingOneDim
91 : public OpRewritePattern<vector::InsertStridedSliceOp> {
92 using OpRewritePattern::OpRewritePattern;
93
matchAndRewrite__anon384a66160211::CastAwayInsertStridedSliceLeadingOneDim94 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
95 PatternRewriter &rewriter) const override {
96 VectorType oldSrcType = insertOp.getSourceVectorType();
97 VectorType newSrcType = trimLeadingOneDims(oldSrcType);
98 VectorType oldDstType = insertOp.getDestVectorType();
99 VectorType newDstType = trimLeadingOneDims(oldDstType);
100
101 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
102 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
103 if (srcDropCount == 0 && dstDropCount == 0)
104 return failure();
105
106 // Trim leading one dimensions from both operands.
107 Location loc = insertOp.getLoc();
108
109 Value newSrcVector = rewriter.create<vector::ExtractOp>(
110 loc, insertOp.getSource(), splatZero(srcDropCount));
111 Value newDstVector = rewriter.create<vector::ExtractOp>(
112 loc, insertOp.getDest(), splatZero(dstDropCount));
113
114 auto newOffsets = rewriter.getArrayAttr(
115 insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
116 auto newStrides = rewriter.getArrayAttr(
117 insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
118
119 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
120 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
121
122 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
123 newInsertOp);
124
125 return success();
126 }
127 };
128
129 // Casts away leading one dimensions in vector.insert's vector inputs by
130 // inserting vector.broadcast.
131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
132 using OpRewritePattern::OpRewritePattern;
133
matchAndRewrite__anon384a66160211::CastAwayInsertLeadingOneDim134 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
135 PatternRewriter &rewriter) const override {
136 Type oldSrcType = insertOp.getSourceType();
137 Type newSrcType = oldSrcType;
138 int64_t oldSrcRank = 0, newSrcRank = 0;
139 if (auto type = oldSrcType.dyn_cast<VectorType>()) {
140 newSrcType = trimLeadingOneDims(type);
141 oldSrcRank = type.getRank();
142 newSrcRank = newSrcType.cast<VectorType>().getRank();
143 }
144
145 VectorType oldDstType = insertOp.getDestVectorType();
146 VectorType newDstType = trimLeadingOneDims(oldDstType);
147
148 int64_t srcDropCount = oldSrcRank - newSrcRank;
149 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
150 if (srcDropCount == 0 && dstDropCount == 0)
151 return failure();
152
153 // Trim leading one dimensions from both operands.
154 Location loc = insertOp.getLoc();
155
156 Value newSrcVector = insertOp.getSource();
157 if (oldSrcRank != 0) {
158 newSrcVector = rewriter.create<vector::ExtractOp>(
159 loc, insertOp.getSource(), splatZero(srcDropCount));
160 }
161 Value newDstVector = rewriter.create<vector::ExtractOp>(
162 loc, insertOp.getDest(), splatZero(dstDropCount));
163
164 unsigned oldPosRank = insertOp.getPosition().getValue().size();
165 unsigned newPosRank = newDstType.getRank() - newSrcRank;
166 SmallVector<Attribute> newPositions = llvm::to_vector(
167 insertOp.getPosition().getValue().take_back(newPosRank));
168 if (newPosRank > oldPosRank) {
169 auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
170 newPositions.resize(newPosRank, zeroAttr);
171 }
172
173 auto newInsertOp = rewriter.create<vector::InsertOp>(
174 loc, newDstType, newSrcVector, newDstVector,
175 rewriter.getArrayAttr(newPositions));
176
177 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
178 newInsertOp);
179
180 return success();
181 }
182 };
183
184 // Turns vector.transfer_read on vector with leading 1 dimensions into
185 // vector.shape_cast followed by vector.transfer_read on vector without leading
186 // 1 dimensions.
187 struct CastAwayTransferReadLeadingOneDim
188 : public OpRewritePattern<vector::TransferReadOp> {
189 using OpRewritePattern::OpRewritePattern;
190
matchAndRewrite__anon384a66160211::CastAwayTransferReadLeadingOneDim191 LogicalResult matchAndRewrite(vector::TransferReadOp read,
192 PatternRewriter &rewriter) const override {
193 // TODO: support 0-d corner case.
194 if (read.getTransferRank() == 0)
195 return failure();
196
197 if (read.getMask())
198 return failure();
199
200 auto shapedType = read.getSource().getType().cast<ShapedType>();
201 if (shapedType.getElementType() != read.getVectorType().getElementType())
202 return failure();
203
204 VectorType oldType = read.getVectorType();
205 VectorType newType = trimLeadingOneDims(oldType);
206
207 if (newType == oldType)
208 return failure();
209
210 AffineMap oldMap = read.getPermutationMap();
211 ArrayRef<AffineExpr> newResults =
212 oldMap.getResults().take_back(newType.getRank());
213 AffineMap newMap =
214 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
215 rewriter.getContext());
216
217 ArrayAttr inBoundsAttr;
218 if (read.getInBounds())
219 inBoundsAttr = rewriter.getArrayAttr(
220 read.getInBoundsAttr().getValue().take_back(newType.getRank()));
221
222 auto newRead = rewriter.create<vector::TransferReadOp>(
223 read.getLoc(), newType, read.getSource(), read.getIndices(),
224 AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
225 inBoundsAttr);
226 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
227
228 return success();
229 }
230 };
231
232 // Turns vector.transfer_write on vector with leading 1 dimensions into
233 // vector.shape_cast followed by vector.transfer_write on vector without leading
234 // 1 dimensions.
235 struct CastAwayTransferWriteLeadingOneDim
236 : public OpRewritePattern<vector::TransferWriteOp> {
237 using OpRewritePattern::OpRewritePattern;
238
matchAndRewrite__anon384a66160211::CastAwayTransferWriteLeadingOneDim239 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
240 PatternRewriter &rewriter) const override {
241 // TODO: support 0-d corner case.
242 if (write.getTransferRank() == 0)
243 return failure();
244
245 if (write.getMask())
246 return failure();
247
248 auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
249 if (shapedType.getElementType() != write.getVectorType().getElementType())
250 return failure();
251
252 VectorType oldType = write.getVectorType();
253 VectorType newType = trimLeadingOneDims(oldType);
254 if (newType == oldType)
255 return failure();
256 int64_t dropDim = oldType.getRank() - newType.getRank();
257
258 AffineMap oldMap = write.getPermutationMap();
259 ArrayRef<AffineExpr> newResults =
260 oldMap.getResults().take_back(newType.getRank());
261 AffineMap newMap =
262 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
263 rewriter.getContext());
264
265 ArrayAttr inBoundsAttr;
266 if (write.getInBounds())
267 inBoundsAttr = rewriter.getArrayAttr(
268 write.getInBoundsAttr().getValue().take_back(newType.getRank()));
269
270 auto newVector = rewriter.create<vector::ExtractOp>(
271 write.getLoc(), write.getVector(), splatZero(dropDim));
272 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
273 write, newVector, write.getSource(), write.getIndices(),
274 AffineMapAttr::get(newMap), inBoundsAttr);
275
276 return success();
277 }
278 };
279
280 /// Turns vector.contract on vector with leading 1 dimensions into
281 /// vector.extract followed by vector.contract on vector without leading
282 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
283 /// prior to extract.
284 struct CastAwayContractionLeadingOneDim
285 : public OpRewritePattern<vector::ContractionOp> {
286 using OpRewritePattern::OpRewritePattern;
287
matchAndRewrite__anon384a66160211::CastAwayContractionLeadingOneDim288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289 PatternRewriter &rewriter) const override {
290 VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
291 if (oldAccType == nullptr)
292 return failure();
293 if (oldAccType.getRank() < 2)
294 return failure();
295 // TODO: implement masks.
296 if (llvm::size(contractOp.getMasks()) != 0)
297 return failure();
298 if (oldAccType.getShape()[0] != 1)
299 return failure();
300 // currently we support only dropping one dim but the pattern can be applied
301 // greedily to drop more.
302 int64_t dropDim = 1;
303
304 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
305 SmallVector<AffineMap> newIndexingMaps;
306
307 auto oldIteratorTypes = contractOp.getIteratorTypes();
308 SmallVector<Attribute> newIteratorTypes;
309
310 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
311
312 if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
313 // only parallel type iterators can be dropped.
314 return failure();
315
316 for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
317 int64_t currDim = it.index();
318 if (currDim == dimToDrop)
319 continue;
320 newIteratorTypes.push_back(it.value());
321 }
322
323 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
324 contractOp.getAcc()};
325 SmallVector<Value> newOperands;
326
327 for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
328 // Check if the dim to be dropped exists as a leading dim in the operand
329 // if it does then we use vector.extract to drop it.
330 bool validExtract = false;
331 SmallVector<AffineExpr> results;
332 auto map = it.value();
333 int64_t orginalZeroDim = it.value().getDimPosition(0);
334 if (orginalZeroDim != dimToDrop) {
335 // There are two reasons to be in this path, 1. We need to
336 // tranpose the operand to make the dim to be dropped
337 // leading. 2. The dim to be dropped does not exist and in
338 // that case we dont want to add a unit tranpose but we must
339 // check all the indices to make sure this is the case.
340 bool tranposeNeeded = false;
341 SmallVector<int64_t> perm;
342 SmallVector<AffineExpr> transposeResults;
343
344 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
345 int64_t currDim = map.getDimPosition(i);
346 if (currDim == dimToDrop) {
347 tranposeNeeded = true;
348 perm.insert(perm.begin(), i);
349 auto targetExpr = rewriter.getAffineDimExpr(currDim);
350 transposeResults.insert(transposeResults.begin(), targetExpr);
351 } else {
352 perm.push_back(i);
353 auto targetExpr = rewriter.getAffineDimExpr(currDim);
354 transposeResults.push_back(targetExpr);
355 }
356 }
357 // Do the tranpose now if needed so that we can drop the
358 // correct dim using extract later.
359 if (tranposeNeeded) {
360 map = AffineMap::get(map.getNumDims(), 0, transposeResults,
361 contractOp.getContext());
362 operands[it.index()] = rewriter.create<vector::TransposeOp>(
363 contractOp.getLoc(), operands[it.index()], perm);
364 }
365 }
366 // We have taken care to have the dim to be dropped be
367 // the leading dim. If its still not leading that means it
368 // does not exist in this operand and hence we do not need
369 // an extract.
370 if (map.getDimPosition(0) == dimToDrop)
371 validExtract = true;
372
373 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
374 int64_t currDim = map.getDimPosition(i);
375 if (currDim == dimToDrop)
376 // This is the dim we are dropping.
377 continue;
378 auto targetExpr = rewriter.getAffineDimExpr(
379 currDim < dimToDrop ? currDim : currDim - 1);
380 results.push_back(targetExpr);
381 }
382 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
383 contractOp.getContext()));
384 // Extract if its a valid extraction, otherwise use the operand
385 // without extraction.
386 newOperands.push_back(validExtract
387 ? rewriter.create<vector::ExtractOp>(
388 contractOp.getLoc(), operands[it.index()],
389 splatZero(dropDim))
390 : operands[it.index()]);
391 }
392 auto newContractOp = rewriter.create<vector::ContractionOp>(
393 contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
394 rewriter.getAffineMapArrayAttr(newIndexingMaps),
395 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
396 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
397 contractOp, contractOp->getResultTypes()[0], newContractOp);
398 return success();
399 }
400 };
401
402 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
403 public:
CastAwayElementwiseLeadingOneDim(MLIRContext * context)404 CastAwayElementwiseLeadingOneDim(MLIRContext *context)
405 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
406
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const407 LogicalResult matchAndRewrite(Operation *op,
408 PatternRewriter &rewriter) const override {
409 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
410 return failure();
411 auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
412 if (!vecType)
413 return failure();
414 VectorType newVecType = trimLeadingOneDims(vecType);
415 if (newVecType == vecType)
416 return failure();
417 int64_t dropDim = vecType.getRank() - newVecType.getRank();
418 SmallVector<Value, 4> newOperands;
419 for (Value operand : op->getOperands()) {
420 if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
421 newOperands.push_back(rewriter.create<vector::ExtractOp>(
422 op->getLoc(), operand, splatZero(dropDim)));
423 } else {
424 newOperands.push_back(operand);
425 }
426 }
427 Operation *newOp =
428 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
429 newOperands, newVecType, op->getAttrs());
430 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
431 newOp->getResult(0));
432 return success();
433 }
434 };
435
436 } // namespace
437
populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet & patterns)438 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
439 RewritePatternSet &patterns) {
440 patterns
441 .add<CastAwayExtractStridedSliceLeadingOneDim,
442 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
443 CastAwayTransferReadLeadingOneDim,
444 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
445 CastAwayContractionLeadingOneDim>(patterns.getContext());
446 populateShapeCastFoldingPatterns(patterns);
447 }
448