1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
11 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/ImplicitLocOpBuilder.h"
14 #include "mlir/IR/TypeUtilities.h"
15 
16 #define DEBUG_TYPE "vector-drop-unit-dim"
17 
18 using namespace mlir;
19 using namespace mlir::vector;
20 
21 // Trims leading one dimensions from `oldType` and returns the result type.
22 // Returns `vector<1xT>` if `oldType` only has one element.
23 static VectorType trimLeadingOneDims(VectorType oldType) {
24   ArrayRef<int64_t> oldShape = oldType.getShape();
25   ArrayRef<int64_t> newShape =
26       oldShape.drop_while([](int64_t dim) { return dim == 1; });
27   // Make sure we have at least 1 dimension per vector type requirements.
28   if (newShape.empty())
29     newShape = oldShape.take_back();
30   return VectorType::get(newShape, oldType.getElementType());
31 }
32 
33 /// Return a smallVector of size `rank` containing all zeros.
34 static SmallVector<int64_t> splatZero(int64_t rank) {
35   return SmallVector<int64_t>(rank, 0);
36 }
37 namespace {
38 
39 // Casts away leading one dimensions in vector.extract_strided_slice's vector
40 // input by inserting vector.shape_cast.
41 struct CastAwayExtractStridedSliceLeadingOneDim
42     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
43   using OpRewritePattern::OpRewritePattern;
44 
45   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
46                                 PatternRewriter &rewriter) const override {
47     // vector.extract_strided_slice requires the input and output vector to have
48     // the same rank. Here we drop leading one dimensions from the input vector
49     // type to make sure we don't cause mismatch.
50     VectorType oldSrcType = extractOp.getVectorType();
51     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
52 
53     if (newSrcType.getRank() == oldSrcType.getRank())
54       return failure();
55 
56     int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
57 
58     VectorType oldDstType = extractOp.getType();
59     VectorType newDstType =
60         VectorType::get(oldDstType.getShape().drop_front(dropCount),
61                         oldDstType.getElementType());
62 
63     Location loc = extractOp.getLoc();
64 
65     Value newSrcVector = rewriter.create<vector::ExtractOp>(
66         loc, extractOp.vector(), splatZero(dropCount));
67 
68     // The offsets/sizes/strides attribute can have a less number of elements
69     // than the input vector's rank: it is meant for the leading dimensions.
70     auto newOffsets = rewriter.getArrayAttr(
71         extractOp.offsets().getValue().drop_front(dropCount));
72     auto newSizes = rewriter.getArrayAttr(
73         extractOp.sizes().getValue().drop_front(dropCount));
74     auto newStrides = rewriter.getArrayAttr(
75         extractOp.strides().getValue().drop_front(dropCount));
76 
77     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
78         loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
79 
80     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
81                                                      newExtractOp);
82 
83     return success();
84   }
85 };
86 
87 // Casts away leading one dimensions in vector.extract_strided_slice's vector
88 // inputs by inserting vector.shape_cast.
89 struct CastAwayInsertStridedSliceLeadingOneDim
90     : public OpRewritePattern<vector::InsertStridedSliceOp> {
91   using OpRewritePattern::OpRewritePattern;
92 
93   LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
94                                 PatternRewriter &rewriter) const override {
95     VectorType oldSrcType = insertOp.getSourceVectorType();
96     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
97     VectorType oldDstType = insertOp.getDestVectorType();
98     VectorType newDstType = trimLeadingOneDims(oldDstType);
99 
100     int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
101     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
102     if (srcDropCount == 0 && dstDropCount == 0)
103       return failure();
104 
105     // Trim leading one dimensions from both operands.
106     Location loc = insertOp.getLoc();
107 
108     Value newSrcVector = rewriter.create<vector::ExtractOp>(
109         loc, insertOp.source(), splatZero(srcDropCount));
110     Value newDstVector = rewriter.create<vector::ExtractOp>(
111         loc, insertOp.dest(), splatZero(dstDropCount));
112 
113     auto newOffsets = rewriter.getArrayAttr(
114         insertOp.offsets().getValue().take_back(newDstType.getRank()));
115     auto newStrides = rewriter.getArrayAttr(
116         insertOp.strides().getValue().take_back(newSrcType.getRank()));
117 
118     auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
119         loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
120 
121     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
122                                                      newInsertOp);
123 
124     return success();
125   }
126 };
127 
128 // Turns vector.transfer_read on vector with leading 1 dimensions into
129 // vector.shape_cast followed by vector.transfer_read on vector without leading
130 // 1 dimensions.
131 struct CastAwayTransferReadLeadingOneDim
132     : public OpRewritePattern<vector::TransferReadOp> {
133   using OpRewritePattern::OpRewritePattern;
134 
135   LogicalResult matchAndRewrite(vector::TransferReadOp read,
136                                 PatternRewriter &rewriter) const override {
137     // TODO: support 0-d corner case.
138     if (read.getTransferRank() == 0)
139       return failure();
140 
141     if (read.mask())
142       return failure();
143 
144     auto shapedType = read.source().getType().cast<ShapedType>();
145     if (shapedType.getElementType() != read.getVectorType().getElementType())
146       return failure();
147 
148     VectorType oldType = read.getVectorType();
149     VectorType newType = trimLeadingOneDims(oldType);
150 
151     if (newType == oldType)
152       return failure();
153 
154     AffineMap oldMap = read.permutation_map();
155     ArrayRef<AffineExpr> newResults =
156         oldMap.getResults().take_back(newType.getRank());
157     AffineMap newMap =
158         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
159                        rewriter.getContext());
160 
161     ArrayAttr inBoundsAttr;
162     if (read.in_bounds())
163       inBoundsAttr = rewriter.getArrayAttr(
164           read.in_boundsAttr().getValue().take_back(newType.getRank()));
165 
166     auto newRead = rewriter.create<vector::TransferReadOp>(
167         read.getLoc(), newType, read.source(), read.indices(),
168         AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(),
169         inBoundsAttr);
170     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
171 
172     return success();
173   }
174 };
175 
176 // Turns vector.transfer_write on vector with leading 1 dimensions into
177 // vector.shape_cast followed by vector.transfer_write on vector without leading
178 // 1 dimensions.
179 struct CastAwayTransferWriteLeadingOneDim
180     : public OpRewritePattern<vector::TransferWriteOp> {
181   using OpRewritePattern::OpRewritePattern;
182 
183   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
184                                 PatternRewriter &rewriter) const override {
185     // TODO: support 0-d corner case.
186     if (write.getTransferRank() == 0)
187       return failure();
188 
189     if (write.mask())
190       return failure();
191 
192     auto shapedType = write.source().getType().dyn_cast<ShapedType>();
193     if (shapedType.getElementType() != write.getVectorType().getElementType())
194       return failure();
195 
196     VectorType oldType = write.getVectorType();
197     VectorType newType = trimLeadingOneDims(oldType);
198     if (newType == oldType)
199       return failure();
200     int64_t dropDim = oldType.getRank() - newType.getRank();
201 
202     AffineMap oldMap = write.permutation_map();
203     ArrayRef<AffineExpr> newResults =
204         oldMap.getResults().take_back(newType.getRank());
205     AffineMap newMap =
206         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
207                        rewriter.getContext());
208 
209     ArrayAttr inBoundsAttr;
210     if (write.in_bounds())
211       inBoundsAttr = rewriter.getArrayAttr(
212           write.in_boundsAttr().getValue().take_back(newType.getRank()));
213 
214     auto newVector = rewriter.create<vector::ExtractOp>(
215         write.getLoc(), write.vector(), splatZero(dropDim));
216     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
217         write, newVector, write.source(), write.indices(),
218         AffineMapAttr::get(newMap), inBoundsAttr);
219 
220     return success();
221   }
222 };
223 
224 /// Turns vector.contract on vector with leading 1 dimensions into
225 /// vector.extract followed by vector.contract on vector without leading
226 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
227 /// prior to extract.
228 struct CastAwayContractionLeadingOneDim
229     : public OpRewritePattern<vector::ContractionOp> {
230   using OpRewritePattern::OpRewritePattern;
231 
232   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
233                                 PatternRewriter &rewriter) const override {
234     VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
235     if (oldAccType == nullptr)
236       return failure();
237     if (oldAccType.getRank() < 2)
238       return failure();
239     // TODO: implement masks.
240     if (llvm::size(contractOp.masks()) != 0)
241       return failure();
242     if (oldAccType.getShape()[0] != 1)
243       return failure();
244     // currently we support only dropping one dim but the pattern can be applied
245     // greedily to drop more.
246     int64_t dropDim = 1;
247 
248     auto oldIndexingMaps = contractOp.getIndexingMaps();
249     SmallVector<AffineMap> newIndexingMaps;
250 
251     auto oldIteratorTypes = contractOp.iterator_types();
252     SmallVector<Attribute> newIteratorTypes;
253 
254     int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
255 
256     if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
257       // only parallel type iterators can be dropped.
258       return failure();
259 
260     for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
261       int64_t currDim = it.index();
262       if (currDim == dimToDrop)
263         continue;
264       newIteratorTypes.push_back(it.value());
265     }
266 
267     SmallVector<Value> operands = {contractOp.lhs(), contractOp.rhs(),
268                                    contractOp.acc()};
269     SmallVector<Value> newOperands;
270 
271     for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
272       // Check if the dim to be dropped exists as a leading dim in the operand
273       // if it does then we use vector.extract to drop it.
274       bool validExtract = false;
275       SmallVector<AffineExpr> results;
276       auto map = it.value();
277       int64_t orginalZeroDim = it.value().getDimPosition(0);
278       if (orginalZeroDim != dimToDrop) {
279         // There are two reasons to be in this path, 1. We need to
280         // tranpose the operand to make the dim to be dropped
281         // leading. 2. The dim to be dropped does not exist and in
282         // that case we dont want to add a unit tranpose but we must
283         // check all the indices to make sure this is the case.
284         bool tranposeNeeded = false;
285         SmallVector<int64_t> perm;
286         SmallVector<AffineExpr> transposeResults;
287 
288         for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
289           int64_t currDim = map.getDimPosition(i);
290           if (currDim == dimToDrop) {
291             tranposeNeeded = true;
292             perm.insert(perm.begin(), i);
293             auto targetExpr = rewriter.getAffineDimExpr(currDim);
294             transposeResults.insert(transposeResults.begin(), targetExpr);
295           } else {
296             perm.push_back(i);
297             auto targetExpr = rewriter.getAffineDimExpr(currDim);
298             transposeResults.push_back(targetExpr);
299           }
300         }
301         // Do the tranpose now if needed so that we can drop the
302         // correct dim using extract later.
303         if (tranposeNeeded) {
304           map = AffineMap::get(map.getNumDims(), 0, transposeResults,
305                                contractOp.getContext());
306           operands[it.index()] = rewriter.create<vector::TransposeOp>(
307               contractOp.getLoc(), operands[it.index()], perm);
308         }
309       }
310       // We have taken care to have the dim to be dropped be
311       // the leading dim. If its still not leading that means it
312       // does not exist in this operand and hence we do not need
313       // an extract.
314       if (map.getDimPosition(0) == dimToDrop)
315         validExtract = true;
316 
317       for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
318         int64_t currDim = map.getDimPosition(i);
319         if (currDim == dimToDrop)
320           // This is the dim we are dropping.
321           continue;
322         auto targetExpr = rewriter.getAffineDimExpr(
323             currDim < dimToDrop ? currDim : currDim - 1);
324         results.push_back(targetExpr);
325       }
326       newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
327                                                contractOp.getContext()));
328       // Extract if its a valid extraction, otherwise use the operand
329       // without extraction.
330       newOperands.push_back(validExtract
331                                 ? rewriter.create<vector::ExtractOp>(
332                                       contractOp.getLoc(), operands[it.index()],
333                                       splatZero(dropDim))
334                                 : operands[it.index()]);
335     }
336     auto newContractOp = rewriter.create<vector::ContractionOp>(
337         contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
338         rewriter.getAffineMapArrayAttr(newIndexingMaps),
339         rewriter.getArrayAttr(newIteratorTypes), contractOp.kind());
340     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
341         contractOp, contractOp->getResultTypes()[0], newContractOp);
342     return success();
343   }
344 };
345 
346 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
347 public:
348   CastAwayElementwiseLeadingOneDim(MLIRContext *context)
349       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
350 
351   LogicalResult matchAndRewrite(Operation *op,
352                                 PatternRewriter &rewriter) const override {
353     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
354       return failure();
355     auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
356     if (!vecType)
357       return failure();
358     VectorType newVecType = trimLeadingOneDims(vecType);
359     if (newVecType == vecType)
360       return failure();
361     int64_t dropDim = vecType.getRank() - newVecType.getRank();
362     SmallVector<Value, 4> newOperands;
363     for (Value operand : op->getOperands()) {
364       if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
365         newOperands.push_back(rewriter.create<vector::ExtractOp>(
366             op->getLoc(), operand, splatZero(dropDim)));
367       } else {
368         newOperands.push_back(operand);
369       }
370     }
371     OperationState state(op->getLoc(), op->getName());
372     state.addAttributes(op->getAttrs());
373     state.addOperands(newOperands);
374     state.addTypes(newVecType);
375     Operation *newOp = rewriter.createOperation(state);
376     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
377                                                      newOp->getResult(0));
378     return success();
379   }
380 };
381 
382 } // namespace
383 
384 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
385     RewritePatternSet &patterns) {
386   patterns
387       .add<CastAwayExtractStridedSliceLeadingOneDim,
388            CastAwayInsertStridedSliceLeadingOneDim,
389            CastAwayTransferReadLeadingOneDim,
390            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
391            CastAwayContractionLeadingOneDim>(patterns.getContext());
392   populateShapeCastFoldingPatterns(patterns);
393 }
394