1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
10 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/ImplicitLocOpBuilder.h"
13 #include "mlir/IR/TypeUtilities.h"
14 
15 #define DEBUG_TYPE "vector-drop-unit-dim"
16 
17 using namespace mlir;
18 using namespace mlir::vector;
19 
20 // Trims leading one dimensions from `oldType` and returns the result type.
21 // Returns `vector<1xT>` if `oldType` only has one element.
22 static VectorType trimLeadingOneDims(VectorType oldType) {
23   ArrayRef<int64_t> oldShape = oldType.getShape();
24   ArrayRef<int64_t> newShape =
25       oldShape.drop_while([](int64_t dim) { return dim == 1; });
26   // Make sure we have at least 1 dimension per vector type requirements.
27   if (newShape.empty())
28     newShape = oldShape.take_back();
29   return VectorType::get(newShape, oldType.getElementType());
30 }
31 
32 /// Return a smallVector of size `rank` containing all zeros.
33 static SmallVector<int64_t> splatZero(int64_t rank) {
34   return SmallVector<int64_t>(rank, 0);
35 }
36 namespace {
37 
38 // Casts away leading one dimensions in vector.extract_strided_slice's vector
39 // input by inserting vector.shape_cast.
40 struct CastAwayExtractStridedSliceLeadingOneDim
41     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
42   using OpRewritePattern::OpRewritePattern;
43 
44   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
45                                 PatternRewriter &rewriter) const override {
46     // vector.extract_strided_slice requires the input and output vector to have
47     // the same rank. Here we drop leading one dimensions from the input vector
48     // type to make sure we don't cause mismatch.
49     VectorType oldSrcType = extractOp.getVectorType();
50     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
51 
52     if (newSrcType.getRank() == oldSrcType.getRank())
53       return failure();
54 
55     int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
56 
57     VectorType oldDstType = extractOp.getType();
58     VectorType newDstType =
59         VectorType::get(oldDstType.getShape().drop_front(dropCount),
60                         oldDstType.getElementType());
61 
62     Location loc = extractOp.getLoc();
63 
64     Value newSrcVector = rewriter.create<vector::ExtractOp>(
65         loc, extractOp.vector(), splatZero(dropCount));
66 
67     // The offsets/sizes/strides attribute can have a less number of elements
68     // than the input vector's rank: it is meant for the leading dimensions.
69     auto newOffsets = rewriter.getArrayAttr(
70         extractOp.offsets().getValue().drop_front(dropCount));
71     auto newSizes = rewriter.getArrayAttr(
72         extractOp.sizes().getValue().drop_front(dropCount));
73     auto newStrides = rewriter.getArrayAttr(
74         extractOp.strides().getValue().drop_front(dropCount));
75 
76     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
77         loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
78 
79     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
80                                                      newExtractOp);
81 
82     return success();
83   }
84 };
85 
86 // Casts away leading one dimensions in vector.extract_strided_slice's vector
87 // inputs by inserting vector.shape_cast.
88 struct CastAwayInsertStridedSliceLeadingOneDim
89     : public OpRewritePattern<vector::InsertStridedSliceOp> {
90   using OpRewritePattern::OpRewritePattern;
91 
92   LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
93                                 PatternRewriter &rewriter) const override {
94     VectorType oldSrcType = insertOp.getSourceVectorType();
95     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
96     VectorType oldDstType = insertOp.getDestVectorType();
97     VectorType newDstType = trimLeadingOneDims(oldDstType);
98 
99     int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
100     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
101     if (srcDropCount == 0 && dstDropCount == 0)
102       return failure();
103 
104     // Trim leading one dimensions from both operands.
105     Location loc = insertOp.getLoc();
106 
107     Value newSrcVector = rewriter.create<vector::ExtractOp>(
108         loc, insertOp.source(), splatZero(srcDropCount));
109     Value newDstVector = rewriter.create<vector::ExtractOp>(
110         loc, insertOp.dest(), splatZero(dstDropCount));
111 
112     auto newOffsets = rewriter.getArrayAttr(
113         insertOp.offsets().getValue().take_back(newDstType.getRank()));
114     auto newStrides = rewriter.getArrayAttr(
115         insertOp.strides().getValue().take_back(newSrcType.getRank()));
116 
117     auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
118         loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
119 
120     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
121                                                      newInsertOp);
122 
123     return success();
124   }
125 };
126 
127 // Turns vector.transfer_read on vector with leading 1 dimensions into
128 // vector.shape_cast followed by vector.transfer_read on vector without leading
129 // 1 dimensions.
130 struct CastAwayTransferReadLeadingOneDim
131     : public OpRewritePattern<vector::TransferReadOp> {
132   using OpRewritePattern::OpRewritePattern;
133 
134   LogicalResult matchAndRewrite(vector::TransferReadOp read,
135                                 PatternRewriter &rewriter) const override {
136     // TODO: support 0-d corner case.
137     if (read.getTransferRank() == 0)
138       return failure();
139 
140     if (read.mask())
141       return failure();
142 
143     auto shapedType = read.source().getType().cast<ShapedType>();
144     if (shapedType.getElementType() != read.getVectorType().getElementType())
145       return failure();
146 
147     VectorType oldType = read.getVectorType();
148     VectorType newType = trimLeadingOneDims(oldType);
149 
150     if (newType == oldType)
151       return failure();
152 
153     AffineMap oldMap = read.permutation_map();
154     ArrayRef<AffineExpr> newResults =
155         oldMap.getResults().take_back(newType.getRank());
156     AffineMap newMap =
157         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
158                        rewriter.getContext());
159 
160     ArrayAttr inBoundsAttr;
161     if (read.in_bounds())
162       inBoundsAttr = rewriter.getArrayAttr(
163           read.in_boundsAttr().getValue().take_back(newType.getRank()));
164 
165     auto newRead = rewriter.create<vector::TransferReadOp>(
166         read.getLoc(), newType, read.source(), read.indices(),
167         AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(),
168         inBoundsAttr);
169     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
170 
171     return success();
172   }
173 };
174 
175 // Turns vector.transfer_write on vector with leading 1 dimensions into
176 // vector.shape_cast followed by vector.transfer_write on vector without leading
177 // 1 dimensions.
178 struct CastAwayTransferWriteLeadingOneDim
179     : public OpRewritePattern<vector::TransferWriteOp> {
180   using OpRewritePattern::OpRewritePattern;
181 
182   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
183                                 PatternRewriter &rewriter) const override {
184     // TODO: support 0-d corner case.
185     if (write.getTransferRank() == 0)
186       return failure();
187 
188     if (write.mask())
189       return failure();
190 
191     auto shapedType = write.source().getType().dyn_cast<ShapedType>();
192     if (shapedType.getElementType() != write.getVectorType().getElementType())
193       return failure();
194 
195     VectorType oldType = write.getVectorType();
196     VectorType newType = trimLeadingOneDims(oldType);
197     if (newType == oldType)
198       return failure();
199     int64_t dropDim = oldType.getRank() - newType.getRank();
200 
201     AffineMap oldMap = write.permutation_map();
202     ArrayRef<AffineExpr> newResults =
203         oldMap.getResults().take_back(newType.getRank());
204     AffineMap newMap =
205         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
206                        rewriter.getContext());
207 
208     ArrayAttr inBoundsAttr;
209     if (write.in_bounds())
210       inBoundsAttr = rewriter.getArrayAttr(
211           write.in_boundsAttr().getValue().take_back(newType.getRank()));
212 
213     auto newVector = rewriter.create<vector::ExtractOp>(
214         write.getLoc(), write.vector(), splatZero(dropDim));
215     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
216         write, newVector, write.source(), write.indices(),
217         AffineMapAttr::get(newMap), inBoundsAttr);
218 
219     return success();
220   }
221 };
222 
223 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
224 public:
225   CastAwayElementwiseLeadingOneDim(MLIRContext *context)
226       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
227 
228   LogicalResult matchAndRewrite(Operation *op,
229                                 PatternRewriter &rewriter) const override {
230     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
231       return failure();
232     auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
233     if (!vecType)
234       return failure();
235     VectorType newVecType = trimLeadingOneDims(vecType);
236     if (newVecType == vecType)
237       return failure();
238     int64_t dropDim = vecType.getRank() - newVecType.getRank();
239     SmallVector<Value, 4> newOperands;
240     for (Value operand : op->getOperands()) {
241       if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
242         newOperands.push_back(rewriter.create<vector::ExtractOp>(
243             op->getLoc(), operand, splatZero(dropDim)));
244       } else {
245         newOperands.push_back(operand);
246       }
247     }
248     OperationState state(op->getLoc(), op->getName());
249     state.addAttributes(op->getAttrs());
250     state.addOperands(newOperands);
251     state.addTypes(newVecType);
252     Operation *newOp = rewriter.createOperation(state);
253     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
254                                                      newOp->getResult(0));
255     return success();
256   }
257 };
258 
259 } // namespace
260 
261 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
262     RewritePatternSet &patterns) {
263   patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
264                CastAwayInsertStridedSliceLeadingOneDim,
265                CastAwayTransferReadLeadingOneDim,
266                CastAwayTransferWriteLeadingOneDim,
267                CastAwayElementwiseLeadingOneDim>(patterns.getContext());
268   populateShapeCastFoldingPatterns(patterns);
269 }
270