199ef9eebSMatthias Springer //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer
9ad9b5a4bSNirvedh #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10e54236dfSLei Zhang #include "mlir/Dialect/Vector/IR/VectorOps.h"
1199ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1399ef9eebSMatthias Springer #include "mlir/IR/Builders.h"
1499ef9eebSMatthias Springer #include "mlir/IR/ImplicitLocOpBuilder.h"
1599ef9eebSMatthias Springer #include "mlir/IR/TypeUtilities.h"
1699ef9eebSMatthias Springer
1799ef9eebSMatthias Springer #define DEBUG_TYPE "vector-drop-unit-dim"
1899ef9eebSMatthias Springer
1999ef9eebSMatthias Springer using namespace mlir;
2099ef9eebSMatthias Springer using namespace mlir::vector;
2199ef9eebSMatthias Springer
2299ef9eebSMatthias Springer // Trims leading one dimensions from `oldType` and returns the result type.
2399ef9eebSMatthias Springer // Returns `vector<1xT>` if `oldType` only has one element.
trimLeadingOneDims(VectorType oldType)2499ef9eebSMatthias Springer static VectorType trimLeadingOneDims(VectorType oldType) {
2599ef9eebSMatthias Springer ArrayRef<int64_t> oldShape = oldType.getShape();
2699ef9eebSMatthias Springer ArrayRef<int64_t> newShape =
2799ef9eebSMatthias Springer oldShape.drop_while([](int64_t dim) { return dim == 1; });
2899ef9eebSMatthias Springer // Make sure we have at least 1 dimension per vector type requirements.
2999ef9eebSMatthias Springer if (newShape.empty())
3099ef9eebSMatthias Springer newShape = oldShape.take_back();
3199ef9eebSMatthias Springer return VectorType::get(newShape, oldType.getElementType());
3299ef9eebSMatthias Springer }
3399ef9eebSMatthias Springer
3499ef9eebSMatthias Springer /// Return a smallVector of size `rank` containing all zeros.
splatZero(int64_t rank)3599ef9eebSMatthias Springer static SmallVector<int64_t> splatZero(int64_t rank) {
3699ef9eebSMatthias Springer return SmallVector<int64_t>(rank, 0);
3799ef9eebSMatthias Springer }
3899ef9eebSMatthias Springer namespace {
3999ef9eebSMatthias Springer
4099ef9eebSMatthias Springer // Casts away leading one dimensions in vector.extract_strided_slice's vector
41e54236dfSLei Zhang // input by inserting vector.broadcast.
4299ef9eebSMatthias Springer struct CastAwayExtractStridedSliceLeadingOneDim
4399ef9eebSMatthias Springer : public OpRewritePattern<vector::ExtractStridedSliceOp> {
4499ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern;
4599ef9eebSMatthias Springer
matchAndRewrite__anon384a66160211::CastAwayExtractStridedSliceLeadingOneDim4699ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
4799ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
4899ef9eebSMatthias Springer // vector.extract_strided_slice requires the input and output vector to have
4999ef9eebSMatthias Springer // the same rank. Here we drop leading one dimensions from the input vector
5099ef9eebSMatthias Springer // type to make sure we don't cause mismatch.
5199ef9eebSMatthias Springer VectorType oldSrcType = extractOp.getVectorType();
5299ef9eebSMatthias Springer VectorType newSrcType = trimLeadingOneDims(oldSrcType);
5399ef9eebSMatthias Springer
5499ef9eebSMatthias Springer if (newSrcType.getRank() == oldSrcType.getRank())
5599ef9eebSMatthias Springer return failure();
5699ef9eebSMatthias Springer
5799ef9eebSMatthias Springer int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
5899ef9eebSMatthias Springer
5999ef9eebSMatthias Springer VectorType oldDstType = extractOp.getType();
6099ef9eebSMatthias Springer VectorType newDstType =
6199ef9eebSMatthias Springer VectorType::get(oldDstType.getShape().drop_front(dropCount),
6299ef9eebSMatthias Springer oldDstType.getElementType());
6399ef9eebSMatthias Springer
6499ef9eebSMatthias Springer Location loc = extractOp.getLoc();
6599ef9eebSMatthias Springer
6699ef9eebSMatthias Springer Value newSrcVector = rewriter.create<vector::ExtractOp>(
677c38fd60SJacques Pienaar loc, extractOp.getVector(), splatZero(dropCount));
6899ef9eebSMatthias Springer
6999ef9eebSMatthias Springer // The offsets/sizes/strides attribute can have a less number of elements
7099ef9eebSMatthias Springer // than the input vector's rank: it is meant for the leading dimensions.
7199ef9eebSMatthias Springer auto newOffsets = rewriter.getArrayAttr(
727c38fd60SJacques Pienaar extractOp.getOffsets().getValue().drop_front(dropCount));
7399ef9eebSMatthias Springer auto newSizes = rewriter.getArrayAttr(
747c38fd60SJacques Pienaar extractOp.getSizes().getValue().drop_front(dropCount));
7599ef9eebSMatthias Springer auto newStrides = rewriter.getArrayAttr(
767c38fd60SJacques Pienaar extractOp.getStrides().getValue().drop_front(dropCount));
7799ef9eebSMatthias Springer
7899ef9eebSMatthias Springer auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
7999ef9eebSMatthias Springer loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
8099ef9eebSMatthias Springer
8199ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
8299ef9eebSMatthias Springer newExtractOp);
8399ef9eebSMatthias Springer
8499ef9eebSMatthias Springer return success();
8599ef9eebSMatthias Springer }
8699ef9eebSMatthias Springer };
8799ef9eebSMatthias Springer
88e54236dfSLei Zhang // Casts away leading one dimensions in vector.insert_strided_slice's vector
89e54236dfSLei Zhang // inputs by inserting vector.broadcast.
9099ef9eebSMatthias Springer struct CastAwayInsertStridedSliceLeadingOneDim
9199ef9eebSMatthias Springer : public OpRewritePattern<vector::InsertStridedSliceOp> {
9299ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern;
9399ef9eebSMatthias Springer
matchAndRewrite__anon384a66160211::CastAwayInsertStridedSliceLeadingOneDim9499ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
9599ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
9699ef9eebSMatthias Springer VectorType oldSrcType = insertOp.getSourceVectorType();
9799ef9eebSMatthias Springer VectorType newSrcType = trimLeadingOneDims(oldSrcType);
9899ef9eebSMatthias Springer VectorType oldDstType = insertOp.getDestVectorType();
9999ef9eebSMatthias Springer VectorType newDstType = trimLeadingOneDims(oldDstType);
10099ef9eebSMatthias Springer
10199ef9eebSMatthias Springer int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
10299ef9eebSMatthias Springer int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
10399ef9eebSMatthias Springer if (srcDropCount == 0 && dstDropCount == 0)
10499ef9eebSMatthias Springer return failure();
10599ef9eebSMatthias Springer
10699ef9eebSMatthias Springer // Trim leading one dimensions from both operands.
10799ef9eebSMatthias Springer Location loc = insertOp.getLoc();
10899ef9eebSMatthias Springer
10999ef9eebSMatthias Springer Value newSrcVector = rewriter.create<vector::ExtractOp>(
1107c38fd60SJacques Pienaar loc, insertOp.getSource(), splatZero(srcDropCount));
11199ef9eebSMatthias Springer Value newDstVector = rewriter.create<vector::ExtractOp>(
1127c38fd60SJacques Pienaar loc, insertOp.getDest(), splatZero(dstDropCount));
11399ef9eebSMatthias Springer
11499ef9eebSMatthias Springer auto newOffsets = rewriter.getArrayAttr(
1157c38fd60SJacques Pienaar insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
11699ef9eebSMatthias Springer auto newStrides = rewriter.getArrayAttr(
1177c38fd60SJacques Pienaar insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
11899ef9eebSMatthias Springer
11999ef9eebSMatthias Springer auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
12099ef9eebSMatthias Springer loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
12199ef9eebSMatthias Springer
12299ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
12399ef9eebSMatthias Springer newInsertOp);
12499ef9eebSMatthias Springer
12599ef9eebSMatthias Springer return success();
12699ef9eebSMatthias Springer }
12799ef9eebSMatthias Springer };
12899ef9eebSMatthias Springer
129e54236dfSLei Zhang // Casts away leading one dimensions in vector.insert's vector inputs by
130e54236dfSLei Zhang // inserting vector.broadcast.
131e54236dfSLei Zhang struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
132e54236dfSLei Zhang using OpRewritePattern::OpRewritePattern;
133e54236dfSLei Zhang
matchAndRewrite__anon384a66160211::CastAwayInsertLeadingOneDim134e54236dfSLei Zhang LogicalResult matchAndRewrite(vector::InsertOp insertOp,
135e54236dfSLei Zhang PatternRewriter &rewriter) const override {
136e54236dfSLei Zhang Type oldSrcType = insertOp.getSourceType();
137e54236dfSLei Zhang Type newSrcType = oldSrcType;
138e54236dfSLei Zhang int64_t oldSrcRank = 0, newSrcRank = 0;
139e54236dfSLei Zhang if (auto type = oldSrcType.dyn_cast<VectorType>()) {
140e54236dfSLei Zhang newSrcType = trimLeadingOneDims(type);
141e54236dfSLei Zhang oldSrcRank = type.getRank();
142e54236dfSLei Zhang newSrcRank = newSrcType.cast<VectorType>().getRank();
143e54236dfSLei Zhang }
144e54236dfSLei Zhang
145e54236dfSLei Zhang VectorType oldDstType = insertOp.getDestVectorType();
146e54236dfSLei Zhang VectorType newDstType = trimLeadingOneDims(oldDstType);
147e54236dfSLei Zhang
148e54236dfSLei Zhang int64_t srcDropCount = oldSrcRank - newSrcRank;
149e54236dfSLei Zhang int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
150e54236dfSLei Zhang if (srcDropCount == 0 && dstDropCount == 0)
151e54236dfSLei Zhang return failure();
152e54236dfSLei Zhang
153e54236dfSLei Zhang // Trim leading one dimensions from both operands.
154e54236dfSLei Zhang Location loc = insertOp.getLoc();
155e54236dfSLei Zhang
156e54236dfSLei Zhang Value newSrcVector = insertOp.getSource();
157e54236dfSLei Zhang if (oldSrcRank != 0) {
158e54236dfSLei Zhang newSrcVector = rewriter.create<vector::ExtractOp>(
159e54236dfSLei Zhang loc, insertOp.getSource(), splatZero(srcDropCount));
160e54236dfSLei Zhang }
161e54236dfSLei Zhang Value newDstVector = rewriter.create<vector::ExtractOp>(
162e54236dfSLei Zhang loc, insertOp.getDest(), splatZero(dstDropCount));
163e54236dfSLei Zhang
164e54236dfSLei Zhang unsigned oldPosRank = insertOp.getPosition().getValue().size();
165e54236dfSLei Zhang unsigned newPosRank = newDstType.getRank() - newSrcRank;
166e54236dfSLei Zhang SmallVector<Attribute> newPositions = llvm::to_vector(
167e54236dfSLei Zhang insertOp.getPosition().getValue().take_back(newPosRank));
168e54236dfSLei Zhang if (newPosRank > oldPosRank) {
169e54236dfSLei Zhang auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
170e54236dfSLei Zhang newPositions.resize(newPosRank, zeroAttr);
171e54236dfSLei Zhang }
172e54236dfSLei Zhang
173e54236dfSLei Zhang auto newInsertOp = rewriter.create<vector::InsertOp>(
174e54236dfSLei Zhang loc, newDstType, newSrcVector, newDstVector,
175e54236dfSLei Zhang rewriter.getArrayAttr(newPositions));
176e54236dfSLei Zhang
177e54236dfSLei Zhang rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
178e54236dfSLei Zhang newInsertOp);
179e54236dfSLei Zhang
180e54236dfSLei Zhang return success();
181e54236dfSLei Zhang }
182e54236dfSLei Zhang };
183e54236dfSLei Zhang
18499ef9eebSMatthias Springer // Turns vector.transfer_read on vector with leading 1 dimensions into
18599ef9eebSMatthias Springer // vector.shape_cast followed by vector.transfer_read on vector without leading
18699ef9eebSMatthias Springer // 1 dimensions.
18799ef9eebSMatthias Springer struct CastAwayTransferReadLeadingOneDim
18899ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferReadOp> {
18999ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern;
19099ef9eebSMatthias Springer
matchAndRewrite__anon384a66160211::CastAwayTransferReadLeadingOneDim19199ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp read,
19299ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
19399ef9eebSMatthias Springer // TODO: support 0-d corner case.
19499ef9eebSMatthias Springer if (read.getTransferRank() == 0)
19599ef9eebSMatthias Springer return failure();
19699ef9eebSMatthias Springer
1977c38fd60SJacques Pienaar if (read.getMask())
19899ef9eebSMatthias Springer return failure();
19999ef9eebSMatthias Springer
2007c38fd60SJacques Pienaar auto shapedType = read.getSource().getType().cast<ShapedType>();
20199ef9eebSMatthias Springer if (shapedType.getElementType() != read.getVectorType().getElementType())
20299ef9eebSMatthias Springer return failure();
20399ef9eebSMatthias Springer
20499ef9eebSMatthias Springer VectorType oldType = read.getVectorType();
20599ef9eebSMatthias Springer VectorType newType = trimLeadingOneDims(oldType);
20699ef9eebSMatthias Springer
20799ef9eebSMatthias Springer if (newType == oldType)
20899ef9eebSMatthias Springer return failure();
20999ef9eebSMatthias Springer
2107c38fd60SJacques Pienaar AffineMap oldMap = read.getPermutationMap();
21199ef9eebSMatthias Springer ArrayRef<AffineExpr> newResults =
21299ef9eebSMatthias Springer oldMap.getResults().take_back(newType.getRank());
21399ef9eebSMatthias Springer AffineMap newMap =
21499ef9eebSMatthias Springer AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
21599ef9eebSMatthias Springer rewriter.getContext());
21699ef9eebSMatthias Springer
21799ef9eebSMatthias Springer ArrayAttr inBoundsAttr;
2187c38fd60SJacques Pienaar if (read.getInBounds())
21999ef9eebSMatthias Springer inBoundsAttr = rewriter.getArrayAttr(
2207c38fd60SJacques Pienaar read.getInBoundsAttr().getValue().take_back(newType.getRank()));
22199ef9eebSMatthias Springer
22299ef9eebSMatthias Springer auto newRead = rewriter.create<vector::TransferReadOp>(
2237c38fd60SJacques Pienaar read.getLoc(), newType, read.getSource(), read.getIndices(),
2247c38fd60SJacques Pienaar AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
22599ef9eebSMatthias Springer inBoundsAttr);
22699ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
22799ef9eebSMatthias Springer
22899ef9eebSMatthias Springer return success();
22999ef9eebSMatthias Springer }
23099ef9eebSMatthias Springer };
23199ef9eebSMatthias Springer
23299ef9eebSMatthias Springer // Turns vector.transfer_write on vector with leading 1 dimensions into
23399ef9eebSMatthias Springer // vector.shape_cast followed by vector.transfer_write on vector without leading
23499ef9eebSMatthias Springer // 1 dimensions.
23599ef9eebSMatthias Springer struct CastAwayTransferWriteLeadingOneDim
23699ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferWriteOp> {
23799ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern;
23899ef9eebSMatthias Springer
matchAndRewrite__anon384a66160211::CastAwayTransferWriteLeadingOneDim23999ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferWriteOp write,
24099ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
24199ef9eebSMatthias Springer // TODO: support 0-d corner case.
24299ef9eebSMatthias Springer if (write.getTransferRank() == 0)
24399ef9eebSMatthias Springer return failure();
24499ef9eebSMatthias Springer
2457c38fd60SJacques Pienaar if (write.getMask())
24699ef9eebSMatthias Springer return failure();
24799ef9eebSMatthias Springer
2487c38fd60SJacques Pienaar auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
24999ef9eebSMatthias Springer if (shapedType.getElementType() != write.getVectorType().getElementType())
25099ef9eebSMatthias Springer return failure();
25199ef9eebSMatthias Springer
25299ef9eebSMatthias Springer VectorType oldType = write.getVectorType();
25399ef9eebSMatthias Springer VectorType newType = trimLeadingOneDims(oldType);
25499ef9eebSMatthias Springer if (newType == oldType)
25599ef9eebSMatthias Springer return failure();
25699ef9eebSMatthias Springer int64_t dropDim = oldType.getRank() - newType.getRank();
25799ef9eebSMatthias Springer
2587c38fd60SJacques Pienaar AffineMap oldMap = write.getPermutationMap();
25999ef9eebSMatthias Springer ArrayRef<AffineExpr> newResults =
26099ef9eebSMatthias Springer oldMap.getResults().take_back(newType.getRank());
26199ef9eebSMatthias Springer AffineMap newMap =
26299ef9eebSMatthias Springer AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
26399ef9eebSMatthias Springer rewriter.getContext());
26499ef9eebSMatthias Springer
26599ef9eebSMatthias Springer ArrayAttr inBoundsAttr;
2667c38fd60SJacques Pienaar if (write.getInBounds())
26799ef9eebSMatthias Springer inBoundsAttr = rewriter.getArrayAttr(
2687c38fd60SJacques Pienaar write.getInBoundsAttr().getValue().take_back(newType.getRank()));
26999ef9eebSMatthias Springer
27099ef9eebSMatthias Springer auto newVector = rewriter.create<vector::ExtractOp>(
2717c38fd60SJacques Pienaar write.getLoc(), write.getVector(), splatZero(dropDim));
27299ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2737c38fd60SJacques Pienaar write, newVector, write.getSource(), write.getIndices(),
27499ef9eebSMatthias Springer AffineMapAttr::get(newMap), inBoundsAttr);
27599ef9eebSMatthias Springer
27699ef9eebSMatthias Springer return success();
27799ef9eebSMatthias Springer }
27899ef9eebSMatthias Springer };
27999ef9eebSMatthias Springer
280ad9b5a4bSNirvedh /// Turns vector.contract on vector with leading 1 dimensions into
281ad9b5a4bSNirvedh /// vector.extract followed by vector.contract on vector without leading
282ad9b5a4bSNirvedh /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
283ad9b5a4bSNirvedh /// prior to extract.
284ad9b5a4bSNirvedh struct CastAwayContractionLeadingOneDim
285ad9b5a4bSNirvedh : public OpRewritePattern<vector::ContractionOp> {
286ad9b5a4bSNirvedh using OpRewritePattern::OpRewritePattern;
287ad9b5a4bSNirvedh
matchAndRewrite__anon384a66160211::CastAwayContractionLeadingOneDim288ad9b5a4bSNirvedh LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289ad9b5a4bSNirvedh PatternRewriter &rewriter) const override {
290ad9b5a4bSNirvedh VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
291ad9b5a4bSNirvedh if (oldAccType == nullptr)
292ad9b5a4bSNirvedh return failure();
293ad9b5a4bSNirvedh if (oldAccType.getRank() < 2)
294ad9b5a4bSNirvedh return failure();
295ad9b5a4bSNirvedh // TODO: implement masks.
2967c38fd60SJacques Pienaar if (llvm::size(contractOp.getMasks()) != 0)
297ad9b5a4bSNirvedh return failure();
298ad9b5a4bSNirvedh if (oldAccType.getShape()[0] != 1)
299ad9b5a4bSNirvedh return failure();
300ad9b5a4bSNirvedh // currently we support only dropping one dim but the pattern can be applied
301ad9b5a4bSNirvedh // greedily to drop more.
302ad9b5a4bSNirvedh int64_t dropDim = 1;
303ad9b5a4bSNirvedh
304*d2c0572bSJacques Pienaar auto oldIndexingMaps = contractOp.getIndexingMapsArray();
305ad9b5a4bSNirvedh SmallVector<AffineMap> newIndexingMaps;
306ad9b5a4bSNirvedh
3077c38fd60SJacques Pienaar auto oldIteratorTypes = contractOp.getIteratorTypes();
308ad9b5a4bSNirvedh SmallVector<Attribute> newIteratorTypes;
309ad9b5a4bSNirvedh
310ad9b5a4bSNirvedh int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
311ad9b5a4bSNirvedh
312ad9b5a4bSNirvedh if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
313ad9b5a4bSNirvedh // only parallel type iterators can be dropped.
314ad9b5a4bSNirvedh return failure();
315ad9b5a4bSNirvedh
316ad9b5a4bSNirvedh for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
317ad9b5a4bSNirvedh int64_t currDim = it.index();
318ad9b5a4bSNirvedh if (currDim == dimToDrop)
319ad9b5a4bSNirvedh continue;
320ad9b5a4bSNirvedh newIteratorTypes.push_back(it.value());
321ad9b5a4bSNirvedh }
322ad9b5a4bSNirvedh
3237c38fd60SJacques Pienaar SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
3247c38fd60SJacques Pienaar contractOp.getAcc()};
325ad9b5a4bSNirvedh SmallVector<Value> newOperands;
326ad9b5a4bSNirvedh
327ad9b5a4bSNirvedh for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
328ad9b5a4bSNirvedh // Check if the dim to be dropped exists as a leading dim in the operand
329ad9b5a4bSNirvedh // if it does then we use vector.extract to drop it.
330ad9b5a4bSNirvedh bool validExtract = false;
331ad9b5a4bSNirvedh SmallVector<AffineExpr> results;
332ad9b5a4bSNirvedh auto map = it.value();
333ad9b5a4bSNirvedh int64_t orginalZeroDim = it.value().getDimPosition(0);
334ad9b5a4bSNirvedh if (orginalZeroDim != dimToDrop) {
335ad9b5a4bSNirvedh // There are two reasons to be in this path, 1. We need to
336ad9b5a4bSNirvedh // tranpose the operand to make the dim to be dropped
337ad9b5a4bSNirvedh // leading. 2. The dim to be dropped does not exist and in
338ad9b5a4bSNirvedh // that case we dont want to add a unit tranpose but we must
339ad9b5a4bSNirvedh // check all the indices to make sure this is the case.
340ad9b5a4bSNirvedh bool tranposeNeeded = false;
341ad9b5a4bSNirvedh SmallVector<int64_t> perm;
342ad9b5a4bSNirvedh SmallVector<AffineExpr> transposeResults;
343ad9b5a4bSNirvedh
344ad9b5a4bSNirvedh for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
345ad9b5a4bSNirvedh int64_t currDim = map.getDimPosition(i);
346ad9b5a4bSNirvedh if (currDim == dimToDrop) {
347ad9b5a4bSNirvedh tranposeNeeded = true;
348ad9b5a4bSNirvedh perm.insert(perm.begin(), i);
349ad9b5a4bSNirvedh auto targetExpr = rewriter.getAffineDimExpr(currDim);
350ad9b5a4bSNirvedh transposeResults.insert(transposeResults.begin(), targetExpr);
351ad9b5a4bSNirvedh } else {
352ad9b5a4bSNirvedh perm.push_back(i);
353ad9b5a4bSNirvedh auto targetExpr = rewriter.getAffineDimExpr(currDim);
354ad9b5a4bSNirvedh transposeResults.push_back(targetExpr);
355ad9b5a4bSNirvedh }
356ad9b5a4bSNirvedh }
357ad9b5a4bSNirvedh // Do the tranpose now if needed so that we can drop the
358ad9b5a4bSNirvedh // correct dim using extract later.
359ad9b5a4bSNirvedh if (tranposeNeeded) {
360ad9b5a4bSNirvedh map = AffineMap::get(map.getNumDims(), 0, transposeResults,
361ad9b5a4bSNirvedh contractOp.getContext());
362ad9b5a4bSNirvedh operands[it.index()] = rewriter.create<vector::TransposeOp>(
363ad9b5a4bSNirvedh contractOp.getLoc(), operands[it.index()], perm);
364ad9b5a4bSNirvedh }
365ad9b5a4bSNirvedh }
366ad9b5a4bSNirvedh // We have taken care to have the dim to be dropped be
367ad9b5a4bSNirvedh // the leading dim. If its still not leading that means it
368ad9b5a4bSNirvedh // does not exist in this operand and hence we do not need
369ad9b5a4bSNirvedh // an extract.
370ad9b5a4bSNirvedh if (map.getDimPosition(0) == dimToDrop)
371ad9b5a4bSNirvedh validExtract = true;
372ad9b5a4bSNirvedh
373ad9b5a4bSNirvedh for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
374ad9b5a4bSNirvedh int64_t currDim = map.getDimPosition(i);
375ad9b5a4bSNirvedh if (currDim == dimToDrop)
376ad9b5a4bSNirvedh // This is the dim we are dropping.
377ad9b5a4bSNirvedh continue;
378ad9b5a4bSNirvedh auto targetExpr = rewriter.getAffineDimExpr(
379ad9b5a4bSNirvedh currDim < dimToDrop ? currDim : currDim - 1);
380ad9b5a4bSNirvedh results.push_back(targetExpr);
381ad9b5a4bSNirvedh }
382ad9b5a4bSNirvedh newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
383ad9b5a4bSNirvedh contractOp.getContext()));
384ad9b5a4bSNirvedh // Extract if its a valid extraction, otherwise use the operand
385ad9b5a4bSNirvedh // without extraction.
386ad9b5a4bSNirvedh newOperands.push_back(validExtract
387ad9b5a4bSNirvedh ? rewriter.create<vector::ExtractOp>(
388ad9b5a4bSNirvedh contractOp.getLoc(), operands[it.index()],
389ad9b5a4bSNirvedh splatZero(dropDim))
390ad9b5a4bSNirvedh : operands[it.index()]);
391ad9b5a4bSNirvedh }
392ad9b5a4bSNirvedh auto newContractOp = rewriter.create<vector::ContractionOp>(
393ad9b5a4bSNirvedh contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
394ad9b5a4bSNirvedh rewriter.getAffineMapArrayAttr(newIndexingMaps),
3957c38fd60SJacques Pienaar rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
396ad9b5a4bSNirvedh rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
397ad9b5a4bSNirvedh contractOp, contractOp->getResultTypes()[0], newContractOp);
398ad9b5a4bSNirvedh return success();
399ad9b5a4bSNirvedh }
400ad9b5a4bSNirvedh };
401ad9b5a4bSNirvedh
40299ef9eebSMatthias Springer class CastAwayElementwiseLeadingOneDim : public RewritePattern {
40399ef9eebSMatthias Springer public:
CastAwayElementwiseLeadingOneDim(MLIRContext * context)40499ef9eebSMatthias Springer CastAwayElementwiseLeadingOneDim(MLIRContext *context)
40599ef9eebSMatthias Springer : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
40699ef9eebSMatthias Springer
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const40799ef9eebSMatthias Springer LogicalResult matchAndRewrite(Operation *op,
40899ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
40999ef9eebSMatthias Springer if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
41099ef9eebSMatthias Springer return failure();
41199ef9eebSMatthias Springer auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
41299ef9eebSMatthias Springer if (!vecType)
41399ef9eebSMatthias Springer return failure();
41499ef9eebSMatthias Springer VectorType newVecType = trimLeadingOneDims(vecType);
41599ef9eebSMatthias Springer if (newVecType == vecType)
41699ef9eebSMatthias Springer return failure();
41799ef9eebSMatthias Springer int64_t dropDim = vecType.getRank() - newVecType.getRank();
41899ef9eebSMatthias Springer SmallVector<Value, 4> newOperands;
41999ef9eebSMatthias Springer for (Value operand : op->getOperands()) {
42099ef9eebSMatthias Springer if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
42199ef9eebSMatthias Springer newOperands.push_back(rewriter.create<vector::ExtractOp>(
42299ef9eebSMatthias Springer op->getLoc(), operand, splatZero(dropDim)));
42399ef9eebSMatthias Springer } else {
42499ef9eebSMatthias Springer newOperands.push_back(operand);
42599ef9eebSMatthias Springer }
42699ef9eebSMatthias Springer }
42714ecafd0SChia-hung Duan Operation *newOp =
42814ecafd0SChia-hung Duan rewriter.create(op->getLoc(), op->getName().getIdentifier(),
42914ecafd0SChia-hung Duan newOperands, newVecType, op->getAttrs());
43099ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
43199ef9eebSMatthias Springer newOp->getResult(0));
43299ef9eebSMatthias Springer return success();
43399ef9eebSMatthias Springer }
43499ef9eebSMatthias Springer };
43599ef9eebSMatthias Springer
43699ef9eebSMatthias Springer } // namespace
43799ef9eebSMatthias Springer
populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet & patterns)43899ef9eebSMatthias Springer void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
43999ef9eebSMatthias Springer RewritePatternSet &patterns) {
440ad9b5a4bSNirvedh patterns
441ad9b5a4bSNirvedh .add<CastAwayExtractStridedSliceLeadingOneDim,
442e54236dfSLei Zhang CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
44399ef9eebSMatthias Springer CastAwayTransferReadLeadingOneDim,
444ad9b5a4bSNirvedh CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
445ad9b5a4bSNirvedh CastAwayContractionLeadingOneDim>(patterns.getContext());
44699ef9eebSMatthias Springer populateShapeCastFoldingPatterns(patterns);
44799ef9eebSMatthias Springer }
448