199ef9eebSMatthias Springer //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer 
9ad9b5a4bSNirvedh #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10e54236dfSLei Zhang #include "mlir/Dialect/Vector/IR/VectorOps.h"
1199ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1399ef9eebSMatthias Springer #include "mlir/IR/Builders.h"
1499ef9eebSMatthias Springer #include "mlir/IR/ImplicitLocOpBuilder.h"
1599ef9eebSMatthias Springer #include "mlir/IR/TypeUtilities.h"
1699ef9eebSMatthias Springer 
1799ef9eebSMatthias Springer #define DEBUG_TYPE "vector-drop-unit-dim"
1899ef9eebSMatthias Springer 
1999ef9eebSMatthias Springer using namespace mlir;
2099ef9eebSMatthias Springer using namespace mlir::vector;
2199ef9eebSMatthias Springer 
2299ef9eebSMatthias Springer // Trims leading one dimensions from `oldType` and returns the result type.
2399ef9eebSMatthias Springer // Returns `vector<1xT>` if `oldType` only has one element.
trimLeadingOneDims(VectorType oldType)2499ef9eebSMatthias Springer static VectorType trimLeadingOneDims(VectorType oldType) {
2599ef9eebSMatthias Springer   ArrayRef<int64_t> oldShape = oldType.getShape();
2699ef9eebSMatthias Springer   ArrayRef<int64_t> newShape =
2799ef9eebSMatthias Springer       oldShape.drop_while([](int64_t dim) { return dim == 1; });
2899ef9eebSMatthias Springer   // Make sure we have at least 1 dimension per vector type requirements.
2999ef9eebSMatthias Springer   if (newShape.empty())
3099ef9eebSMatthias Springer     newShape = oldShape.take_back();
3199ef9eebSMatthias Springer   return VectorType::get(newShape, oldType.getElementType());
3299ef9eebSMatthias Springer }
3399ef9eebSMatthias Springer 
3499ef9eebSMatthias Springer /// Return a smallVector of size `rank` containing all zeros.
splatZero(int64_t rank)3599ef9eebSMatthias Springer static SmallVector<int64_t> splatZero(int64_t rank) {
3699ef9eebSMatthias Springer   return SmallVector<int64_t>(rank, 0);
3799ef9eebSMatthias Springer }
3899ef9eebSMatthias Springer namespace {
3999ef9eebSMatthias Springer 
4099ef9eebSMatthias Springer // Casts away leading one dimensions in vector.extract_strided_slice's vector
41e54236dfSLei Zhang // input by inserting vector.broadcast.
4299ef9eebSMatthias Springer struct CastAwayExtractStridedSliceLeadingOneDim
4399ef9eebSMatthias Springer     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
4499ef9eebSMatthias Springer   using OpRewritePattern::OpRewritePattern;
4599ef9eebSMatthias Springer 
matchAndRewrite__anon384a66160211::CastAwayExtractStridedSliceLeadingOneDim4699ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
4799ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
4899ef9eebSMatthias Springer     // vector.extract_strided_slice requires the input and output vector to have
4999ef9eebSMatthias Springer     // the same rank. Here we drop leading one dimensions from the input vector
5099ef9eebSMatthias Springer     // type to make sure we don't cause mismatch.
5199ef9eebSMatthias Springer     VectorType oldSrcType = extractOp.getVectorType();
5299ef9eebSMatthias Springer     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
5399ef9eebSMatthias Springer 
5499ef9eebSMatthias Springer     if (newSrcType.getRank() == oldSrcType.getRank())
5599ef9eebSMatthias Springer       return failure();
5699ef9eebSMatthias Springer 
5799ef9eebSMatthias Springer     int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
5899ef9eebSMatthias Springer 
5999ef9eebSMatthias Springer     VectorType oldDstType = extractOp.getType();
6099ef9eebSMatthias Springer     VectorType newDstType =
6199ef9eebSMatthias Springer         VectorType::get(oldDstType.getShape().drop_front(dropCount),
6299ef9eebSMatthias Springer                         oldDstType.getElementType());
6399ef9eebSMatthias Springer 
6499ef9eebSMatthias Springer     Location loc = extractOp.getLoc();
6599ef9eebSMatthias Springer 
6699ef9eebSMatthias Springer     Value newSrcVector = rewriter.create<vector::ExtractOp>(
677c38fd60SJacques Pienaar         loc, extractOp.getVector(), splatZero(dropCount));
6899ef9eebSMatthias Springer 
6999ef9eebSMatthias Springer     // The offsets/sizes/strides attribute can have a less number of elements
7099ef9eebSMatthias Springer     // than the input vector's rank: it is meant for the leading dimensions.
7199ef9eebSMatthias Springer     auto newOffsets = rewriter.getArrayAttr(
727c38fd60SJacques Pienaar         extractOp.getOffsets().getValue().drop_front(dropCount));
7399ef9eebSMatthias Springer     auto newSizes = rewriter.getArrayAttr(
747c38fd60SJacques Pienaar         extractOp.getSizes().getValue().drop_front(dropCount));
7599ef9eebSMatthias Springer     auto newStrides = rewriter.getArrayAttr(
767c38fd60SJacques Pienaar         extractOp.getStrides().getValue().drop_front(dropCount));
7799ef9eebSMatthias Springer 
7899ef9eebSMatthias Springer     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
7999ef9eebSMatthias Springer         loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
8099ef9eebSMatthias Springer 
8199ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
8299ef9eebSMatthias Springer                                                      newExtractOp);
8399ef9eebSMatthias Springer 
8499ef9eebSMatthias Springer     return success();
8599ef9eebSMatthias Springer   }
8699ef9eebSMatthias Springer };
8799ef9eebSMatthias Springer 
88e54236dfSLei Zhang // Casts away leading one dimensions in vector.insert_strided_slice's vector
89e54236dfSLei Zhang // inputs by inserting vector.broadcast.
9099ef9eebSMatthias Springer struct CastAwayInsertStridedSliceLeadingOneDim
9199ef9eebSMatthias Springer     : public OpRewritePattern<vector::InsertStridedSliceOp> {
9299ef9eebSMatthias Springer   using OpRewritePattern::OpRewritePattern;
9399ef9eebSMatthias Springer 
matchAndRewrite__anon384a66160211::CastAwayInsertStridedSliceLeadingOneDim9499ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
9599ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
9699ef9eebSMatthias Springer     VectorType oldSrcType = insertOp.getSourceVectorType();
9799ef9eebSMatthias Springer     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
9899ef9eebSMatthias Springer     VectorType oldDstType = insertOp.getDestVectorType();
9999ef9eebSMatthias Springer     VectorType newDstType = trimLeadingOneDims(oldDstType);
10099ef9eebSMatthias Springer 
10199ef9eebSMatthias Springer     int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
10299ef9eebSMatthias Springer     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
10399ef9eebSMatthias Springer     if (srcDropCount == 0 && dstDropCount == 0)
10499ef9eebSMatthias Springer       return failure();
10599ef9eebSMatthias Springer 
10699ef9eebSMatthias Springer     // Trim leading one dimensions from both operands.
10799ef9eebSMatthias Springer     Location loc = insertOp.getLoc();
10899ef9eebSMatthias Springer 
10999ef9eebSMatthias Springer     Value newSrcVector = rewriter.create<vector::ExtractOp>(
1107c38fd60SJacques Pienaar         loc, insertOp.getSource(), splatZero(srcDropCount));
11199ef9eebSMatthias Springer     Value newDstVector = rewriter.create<vector::ExtractOp>(
1127c38fd60SJacques Pienaar         loc, insertOp.getDest(), splatZero(dstDropCount));
11399ef9eebSMatthias Springer 
11499ef9eebSMatthias Springer     auto newOffsets = rewriter.getArrayAttr(
1157c38fd60SJacques Pienaar         insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
11699ef9eebSMatthias Springer     auto newStrides = rewriter.getArrayAttr(
1177c38fd60SJacques Pienaar         insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
11899ef9eebSMatthias Springer 
11999ef9eebSMatthias Springer     auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
12099ef9eebSMatthias Springer         loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
12199ef9eebSMatthias Springer 
12299ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
12399ef9eebSMatthias Springer                                                      newInsertOp);
12499ef9eebSMatthias Springer 
12599ef9eebSMatthias Springer     return success();
12699ef9eebSMatthias Springer   }
12799ef9eebSMatthias Springer };
12899ef9eebSMatthias Springer 
129e54236dfSLei Zhang // Casts away leading one dimensions in vector.insert's vector inputs by
130e54236dfSLei Zhang // inserting vector.broadcast.
131e54236dfSLei Zhang struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
132e54236dfSLei Zhang   using OpRewritePattern::OpRewritePattern;
133e54236dfSLei Zhang 
matchAndRewrite__anon384a66160211::CastAwayInsertLeadingOneDim134e54236dfSLei Zhang   LogicalResult matchAndRewrite(vector::InsertOp insertOp,
135e54236dfSLei Zhang                                 PatternRewriter &rewriter) const override {
136e54236dfSLei Zhang     Type oldSrcType = insertOp.getSourceType();
137e54236dfSLei Zhang     Type newSrcType = oldSrcType;
138e54236dfSLei Zhang     int64_t oldSrcRank = 0, newSrcRank = 0;
139e54236dfSLei Zhang     if (auto type = oldSrcType.dyn_cast<VectorType>()) {
140e54236dfSLei Zhang       newSrcType = trimLeadingOneDims(type);
141e54236dfSLei Zhang       oldSrcRank = type.getRank();
142e54236dfSLei Zhang       newSrcRank = newSrcType.cast<VectorType>().getRank();
143e54236dfSLei Zhang     }
144e54236dfSLei Zhang 
145e54236dfSLei Zhang     VectorType oldDstType = insertOp.getDestVectorType();
146e54236dfSLei Zhang     VectorType newDstType = trimLeadingOneDims(oldDstType);
147e54236dfSLei Zhang 
148e54236dfSLei Zhang     int64_t srcDropCount = oldSrcRank - newSrcRank;
149e54236dfSLei Zhang     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
150e54236dfSLei Zhang     if (srcDropCount == 0 && dstDropCount == 0)
151e54236dfSLei Zhang       return failure();
152e54236dfSLei Zhang 
153e54236dfSLei Zhang     // Trim leading one dimensions from both operands.
154e54236dfSLei Zhang     Location loc = insertOp.getLoc();
155e54236dfSLei Zhang 
156e54236dfSLei Zhang     Value newSrcVector = insertOp.getSource();
157e54236dfSLei Zhang     if (oldSrcRank != 0) {
158e54236dfSLei Zhang       newSrcVector = rewriter.create<vector::ExtractOp>(
159e54236dfSLei Zhang           loc, insertOp.getSource(), splatZero(srcDropCount));
160e54236dfSLei Zhang     }
161e54236dfSLei Zhang     Value newDstVector = rewriter.create<vector::ExtractOp>(
162e54236dfSLei Zhang         loc, insertOp.getDest(), splatZero(dstDropCount));
163e54236dfSLei Zhang 
164e54236dfSLei Zhang     unsigned oldPosRank = insertOp.getPosition().getValue().size();
165e54236dfSLei Zhang     unsigned newPosRank = newDstType.getRank() - newSrcRank;
166e54236dfSLei Zhang     SmallVector<Attribute> newPositions = llvm::to_vector(
167e54236dfSLei Zhang         insertOp.getPosition().getValue().take_back(newPosRank));
168e54236dfSLei Zhang     if (newPosRank > oldPosRank) {
169e54236dfSLei Zhang       auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
170e54236dfSLei Zhang       newPositions.resize(newPosRank, zeroAttr);
171e54236dfSLei Zhang     }
172e54236dfSLei Zhang 
173e54236dfSLei Zhang     auto newInsertOp = rewriter.create<vector::InsertOp>(
174e54236dfSLei Zhang         loc, newDstType, newSrcVector, newDstVector,
175e54236dfSLei Zhang         rewriter.getArrayAttr(newPositions));
176e54236dfSLei Zhang 
177e54236dfSLei Zhang     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
178e54236dfSLei Zhang                                                      newInsertOp);
179e54236dfSLei Zhang 
180e54236dfSLei Zhang     return success();
181e54236dfSLei Zhang   }
182e54236dfSLei Zhang };
183e54236dfSLei Zhang 
18499ef9eebSMatthias Springer // Turns vector.transfer_read on vector with leading 1 dimensions into
18599ef9eebSMatthias Springer // vector.shape_cast followed by vector.transfer_read on vector without leading
18699ef9eebSMatthias Springer // 1 dimensions.
18799ef9eebSMatthias Springer struct CastAwayTransferReadLeadingOneDim
18899ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferReadOp> {
18999ef9eebSMatthias Springer   using OpRewritePattern::OpRewritePattern;
19099ef9eebSMatthias Springer 
matchAndRewrite__anon384a66160211::CastAwayTransferReadLeadingOneDim19199ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferReadOp read,
19299ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
19399ef9eebSMatthias Springer     // TODO: support 0-d corner case.
19499ef9eebSMatthias Springer     if (read.getTransferRank() == 0)
19599ef9eebSMatthias Springer       return failure();
19699ef9eebSMatthias Springer 
1977c38fd60SJacques Pienaar     if (read.getMask())
19899ef9eebSMatthias Springer       return failure();
19999ef9eebSMatthias Springer 
2007c38fd60SJacques Pienaar     auto shapedType = read.getSource().getType().cast<ShapedType>();
20199ef9eebSMatthias Springer     if (shapedType.getElementType() != read.getVectorType().getElementType())
20299ef9eebSMatthias Springer       return failure();
20399ef9eebSMatthias Springer 
20499ef9eebSMatthias Springer     VectorType oldType = read.getVectorType();
20599ef9eebSMatthias Springer     VectorType newType = trimLeadingOneDims(oldType);
20699ef9eebSMatthias Springer 
20799ef9eebSMatthias Springer     if (newType == oldType)
20899ef9eebSMatthias Springer       return failure();
20999ef9eebSMatthias Springer 
2107c38fd60SJacques Pienaar     AffineMap oldMap = read.getPermutationMap();
21199ef9eebSMatthias Springer     ArrayRef<AffineExpr> newResults =
21299ef9eebSMatthias Springer         oldMap.getResults().take_back(newType.getRank());
21399ef9eebSMatthias Springer     AffineMap newMap =
21499ef9eebSMatthias Springer         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
21599ef9eebSMatthias Springer                        rewriter.getContext());
21699ef9eebSMatthias Springer 
21799ef9eebSMatthias Springer     ArrayAttr inBoundsAttr;
2187c38fd60SJacques Pienaar     if (read.getInBounds())
21999ef9eebSMatthias Springer       inBoundsAttr = rewriter.getArrayAttr(
2207c38fd60SJacques Pienaar           read.getInBoundsAttr().getValue().take_back(newType.getRank()));
22199ef9eebSMatthias Springer 
22299ef9eebSMatthias Springer     auto newRead = rewriter.create<vector::TransferReadOp>(
2237c38fd60SJacques Pienaar         read.getLoc(), newType, read.getSource(), read.getIndices(),
2247c38fd60SJacques Pienaar         AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
22599ef9eebSMatthias Springer         inBoundsAttr);
22699ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
22799ef9eebSMatthias Springer 
22899ef9eebSMatthias Springer     return success();
22999ef9eebSMatthias Springer   }
23099ef9eebSMatthias Springer };
23199ef9eebSMatthias Springer 
23299ef9eebSMatthias Springer // Turns vector.transfer_write on vector with leading 1 dimensions into
23399ef9eebSMatthias Springer // vector.shape_cast followed by vector.transfer_write on vector without leading
23499ef9eebSMatthias Springer // 1 dimensions.
23599ef9eebSMatthias Springer struct CastAwayTransferWriteLeadingOneDim
23699ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferWriteOp> {
23799ef9eebSMatthias Springer   using OpRewritePattern::OpRewritePattern;
23899ef9eebSMatthias Springer 
matchAndRewrite__anon384a66160211::CastAwayTransferWriteLeadingOneDim23999ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
24099ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
24199ef9eebSMatthias Springer     // TODO: support 0-d corner case.
24299ef9eebSMatthias Springer     if (write.getTransferRank() == 0)
24399ef9eebSMatthias Springer       return failure();
24499ef9eebSMatthias Springer 
2457c38fd60SJacques Pienaar     if (write.getMask())
24699ef9eebSMatthias Springer       return failure();
24799ef9eebSMatthias Springer 
2487c38fd60SJacques Pienaar     auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
24999ef9eebSMatthias Springer     if (shapedType.getElementType() != write.getVectorType().getElementType())
25099ef9eebSMatthias Springer       return failure();
25199ef9eebSMatthias Springer 
25299ef9eebSMatthias Springer     VectorType oldType = write.getVectorType();
25399ef9eebSMatthias Springer     VectorType newType = trimLeadingOneDims(oldType);
25499ef9eebSMatthias Springer     if (newType == oldType)
25599ef9eebSMatthias Springer       return failure();
25699ef9eebSMatthias Springer     int64_t dropDim = oldType.getRank() - newType.getRank();
25799ef9eebSMatthias Springer 
2587c38fd60SJacques Pienaar     AffineMap oldMap = write.getPermutationMap();
25999ef9eebSMatthias Springer     ArrayRef<AffineExpr> newResults =
26099ef9eebSMatthias Springer         oldMap.getResults().take_back(newType.getRank());
26199ef9eebSMatthias Springer     AffineMap newMap =
26299ef9eebSMatthias Springer         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
26399ef9eebSMatthias Springer                        rewriter.getContext());
26499ef9eebSMatthias Springer 
26599ef9eebSMatthias Springer     ArrayAttr inBoundsAttr;
2667c38fd60SJacques Pienaar     if (write.getInBounds())
26799ef9eebSMatthias Springer       inBoundsAttr = rewriter.getArrayAttr(
2687c38fd60SJacques Pienaar           write.getInBoundsAttr().getValue().take_back(newType.getRank()));
26999ef9eebSMatthias Springer 
27099ef9eebSMatthias Springer     auto newVector = rewriter.create<vector::ExtractOp>(
2717c38fd60SJacques Pienaar         write.getLoc(), write.getVector(), splatZero(dropDim));
27299ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2737c38fd60SJacques Pienaar         write, newVector, write.getSource(), write.getIndices(),
27499ef9eebSMatthias Springer         AffineMapAttr::get(newMap), inBoundsAttr);
27599ef9eebSMatthias Springer 
27699ef9eebSMatthias Springer     return success();
27799ef9eebSMatthias Springer   }
27899ef9eebSMatthias Springer };
27999ef9eebSMatthias Springer 
280ad9b5a4bSNirvedh /// Turns vector.contract on vector with leading 1 dimensions into
281ad9b5a4bSNirvedh /// vector.extract followed by vector.contract on vector without leading
282ad9b5a4bSNirvedh /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
283ad9b5a4bSNirvedh /// prior to extract.
284ad9b5a4bSNirvedh struct CastAwayContractionLeadingOneDim
285ad9b5a4bSNirvedh     : public OpRewritePattern<vector::ContractionOp> {
286ad9b5a4bSNirvedh   using OpRewritePattern::OpRewritePattern;
287ad9b5a4bSNirvedh 
matchAndRewrite__anon384a66160211::CastAwayContractionLeadingOneDim288ad9b5a4bSNirvedh   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289ad9b5a4bSNirvedh                                 PatternRewriter &rewriter) const override {
290ad9b5a4bSNirvedh     VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
291ad9b5a4bSNirvedh     if (oldAccType == nullptr)
292ad9b5a4bSNirvedh       return failure();
293ad9b5a4bSNirvedh     if (oldAccType.getRank() < 2)
294ad9b5a4bSNirvedh       return failure();
295ad9b5a4bSNirvedh     // TODO: implement masks.
2967c38fd60SJacques Pienaar     if (llvm::size(contractOp.getMasks()) != 0)
297ad9b5a4bSNirvedh       return failure();
298ad9b5a4bSNirvedh     if (oldAccType.getShape()[0] != 1)
299ad9b5a4bSNirvedh       return failure();
300ad9b5a4bSNirvedh     // currently we support only dropping one dim but the pattern can be applied
301ad9b5a4bSNirvedh     // greedily to drop more.
302ad9b5a4bSNirvedh     int64_t dropDim = 1;
303ad9b5a4bSNirvedh 
304*d2c0572bSJacques Pienaar     auto oldIndexingMaps = contractOp.getIndexingMapsArray();
305ad9b5a4bSNirvedh     SmallVector<AffineMap> newIndexingMaps;
306ad9b5a4bSNirvedh 
3077c38fd60SJacques Pienaar     auto oldIteratorTypes = contractOp.getIteratorTypes();
308ad9b5a4bSNirvedh     SmallVector<Attribute> newIteratorTypes;
309ad9b5a4bSNirvedh 
310ad9b5a4bSNirvedh     int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
311ad9b5a4bSNirvedh 
312ad9b5a4bSNirvedh     if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
313ad9b5a4bSNirvedh       // only parallel type iterators can be dropped.
314ad9b5a4bSNirvedh       return failure();
315ad9b5a4bSNirvedh 
316ad9b5a4bSNirvedh     for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
317ad9b5a4bSNirvedh       int64_t currDim = it.index();
318ad9b5a4bSNirvedh       if (currDim == dimToDrop)
319ad9b5a4bSNirvedh         continue;
320ad9b5a4bSNirvedh       newIteratorTypes.push_back(it.value());
321ad9b5a4bSNirvedh     }
322ad9b5a4bSNirvedh 
3237c38fd60SJacques Pienaar     SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
3247c38fd60SJacques Pienaar                                    contractOp.getAcc()};
325ad9b5a4bSNirvedh     SmallVector<Value> newOperands;
326ad9b5a4bSNirvedh 
327ad9b5a4bSNirvedh     for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
328ad9b5a4bSNirvedh       // Check if the dim to be dropped exists as a leading dim in the operand
329ad9b5a4bSNirvedh       // if it does then we use vector.extract to drop it.
330ad9b5a4bSNirvedh       bool validExtract = false;
331ad9b5a4bSNirvedh       SmallVector<AffineExpr> results;
332ad9b5a4bSNirvedh       auto map = it.value();
333ad9b5a4bSNirvedh       int64_t orginalZeroDim = it.value().getDimPosition(0);
334ad9b5a4bSNirvedh       if (orginalZeroDim != dimToDrop) {
335ad9b5a4bSNirvedh         // There are two reasons to be in this path, 1. We need to
336ad9b5a4bSNirvedh         // tranpose the operand to make the dim to be dropped
337ad9b5a4bSNirvedh         // leading. 2. The dim to be dropped does not exist and in
338ad9b5a4bSNirvedh         // that case we dont want to add a unit tranpose but we must
339ad9b5a4bSNirvedh         // check all the indices to make sure this is the case.
340ad9b5a4bSNirvedh         bool tranposeNeeded = false;
341ad9b5a4bSNirvedh         SmallVector<int64_t> perm;
342ad9b5a4bSNirvedh         SmallVector<AffineExpr> transposeResults;
343ad9b5a4bSNirvedh 
344ad9b5a4bSNirvedh         for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
345ad9b5a4bSNirvedh           int64_t currDim = map.getDimPosition(i);
346ad9b5a4bSNirvedh           if (currDim == dimToDrop) {
347ad9b5a4bSNirvedh             tranposeNeeded = true;
348ad9b5a4bSNirvedh             perm.insert(perm.begin(), i);
349ad9b5a4bSNirvedh             auto targetExpr = rewriter.getAffineDimExpr(currDim);
350ad9b5a4bSNirvedh             transposeResults.insert(transposeResults.begin(), targetExpr);
351ad9b5a4bSNirvedh           } else {
352ad9b5a4bSNirvedh             perm.push_back(i);
353ad9b5a4bSNirvedh             auto targetExpr = rewriter.getAffineDimExpr(currDim);
354ad9b5a4bSNirvedh             transposeResults.push_back(targetExpr);
355ad9b5a4bSNirvedh           }
356ad9b5a4bSNirvedh         }
357ad9b5a4bSNirvedh         // Do the tranpose now if needed so that we can drop the
358ad9b5a4bSNirvedh         // correct dim using extract later.
359ad9b5a4bSNirvedh         if (tranposeNeeded) {
360ad9b5a4bSNirvedh           map = AffineMap::get(map.getNumDims(), 0, transposeResults,
361ad9b5a4bSNirvedh                                contractOp.getContext());
362ad9b5a4bSNirvedh           operands[it.index()] = rewriter.create<vector::TransposeOp>(
363ad9b5a4bSNirvedh               contractOp.getLoc(), operands[it.index()], perm);
364ad9b5a4bSNirvedh         }
365ad9b5a4bSNirvedh       }
366ad9b5a4bSNirvedh       // We have taken care to have the dim to be dropped be
367ad9b5a4bSNirvedh       // the leading dim. If its still not leading that means it
368ad9b5a4bSNirvedh       // does not exist in this operand and hence we do not need
369ad9b5a4bSNirvedh       // an extract.
370ad9b5a4bSNirvedh       if (map.getDimPosition(0) == dimToDrop)
371ad9b5a4bSNirvedh         validExtract = true;
372ad9b5a4bSNirvedh 
373ad9b5a4bSNirvedh       for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
374ad9b5a4bSNirvedh         int64_t currDim = map.getDimPosition(i);
375ad9b5a4bSNirvedh         if (currDim == dimToDrop)
376ad9b5a4bSNirvedh           // This is the dim we are dropping.
377ad9b5a4bSNirvedh           continue;
378ad9b5a4bSNirvedh         auto targetExpr = rewriter.getAffineDimExpr(
379ad9b5a4bSNirvedh             currDim < dimToDrop ? currDim : currDim - 1);
380ad9b5a4bSNirvedh         results.push_back(targetExpr);
381ad9b5a4bSNirvedh       }
382ad9b5a4bSNirvedh       newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
383ad9b5a4bSNirvedh                                                contractOp.getContext()));
384ad9b5a4bSNirvedh       // Extract if its a valid extraction, otherwise use the operand
385ad9b5a4bSNirvedh       // without extraction.
386ad9b5a4bSNirvedh       newOperands.push_back(validExtract
387ad9b5a4bSNirvedh                                 ? rewriter.create<vector::ExtractOp>(
388ad9b5a4bSNirvedh                                       contractOp.getLoc(), operands[it.index()],
389ad9b5a4bSNirvedh                                       splatZero(dropDim))
390ad9b5a4bSNirvedh                                 : operands[it.index()]);
391ad9b5a4bSNirvedh     }
392ad9b5a4bSNirvedh     auto newContractOp = rewriter.create<vector::ContractionOp>(
393ad9b5a4bSNirvedh         contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
394ad9b5a4bSNirvedh         rewriter.getAffineMapArrayAttr(newIndexingMaps),
3957c38fd60SJacques Pienaar         rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
396ad9b5a4bSNirvedh     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
397ad9b5a4bSNirvedh         contractOp, contractOp->getResultTypes()[0], newContractOp);
398ad9b5a4bSNirvedh     return success();
399ad9b5a4bSNirvedh   }
400ad9b5a4bSNirvedh };
401ad9b5a4bSNirvedh 
40299ef9eebSMatthias Springer class CastAwayElementwiseLeadingOneDim : public RewritePattern {
40399ef9eebSMatthias Springer public:
CastAwayElementwiseLeadingOneDim(MLIRContext * context)40499ef9eebSMatthias Springer   CastAwayElementwiseLeadingOneDim(MLIRContext *context)
40599ef9eebSMatthias Springer       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
40699ef9eebSMatthias Springer 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const40799ef9eebSMatthias Springer   LogicalResult matchAndRewrite(Operation *op,
40899ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
40999ef9eebSMatthias Springer     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
41099ef9eebSMatthias Springer       return failure();
41199ef9eebSMatthias Springer     auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
41299ef9eebSMatthias Springer     if (!vecType)
41399ef9eebSMatthias Springer       return failure();
41499ef9eebSMatthias Springer     VectorType newVecType = trimLeadingOneDims(vecType);
41599ef9eebSMatthias Springer     if (newVecType == vecType)
41699ef9eebSMatthias Springer       return failure();
41799ef9eebSMatthias Springer     int64_t dropDim = vecType.getRank() - newVecType.getRank();
41899ef9eebSMatthias Springer     SmallVector<Value, 4> newOperands;
41999ef9eebSMatthias Springer     for (Value operand : op->getOperands()) {
42099ef9eebSMatthias Springer       if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
42199ef9eebSMatthias Springer         newOperands.push_back(rewriter.create<vector::ExtractOp>(
42299ef9eebSMatthias Springer             op->getLoc(), operand, splatZero(dropDim)));
42399ef9eebSMatthias Springer       } else {
42499ef9eebSMatthias Springer         newOperands.push_back(operand);
42599ef9eebSMatthias Springer       }
42699ef9eebSMatthias Springer     }
42714ecafd0SChia-hung Duan     Operation *newOp =
42814ecafd0SChia-hung Duan         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
42914ecafd0SChia-hung Duan                         newOperands, newVecType, op->getAttrs());
43099ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
43199ef9eebSMatthias Springer                                                      newOp->getResult(0));
43299ef9eebSMatthias Springer     return success();
43399ef9eebSMatthias Springer   }
43499ef9eebSMatthias Springer };
43599ef9eebSMatthias Springer 
43699ef9eebSMatthias Springer } // namespace
43799ef9eebSMatthias Springer 
populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet & patterns)43899ef9eebSMatthias Springer void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
43999ef9eebSMatthias Springer     RewritePatternSet &patterns) {
440ad9b5a4bSNirvedh   patterns
441ad9b5a4bSNirvedh       .add<CastAwayExtractStridedSliceLeadingOneDim,
442e54236dfSLei Zhang            CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
44399ef9eebSMatthias Springer            CastAwayTransferReadLeadingOneDim,
444ad9b5a4bSNirvedh            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
445ad9b5a4bSNirvedh            CastAwayContractionLeadingOneDim>(patterns.getContext());
44699ef9eebSMatthias Springer   populateShapeCastFoldingPatterns(patterns);
44799ef9eebSMatthias Springer }
448