199ef9eebSMatthias Springer //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements rewrite patterns for the permutation_map attribute of
1099ef9eebSMatthias Springer // vector.transfer operations.
1199ef9eebSMatthias Springer //
1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1399ef9eebSMatthias Springer
1499ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
1599ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1899ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h"
1999ef9eebSMatthias Springer
2099ef9eebSMatthias Springer using namespace mlir;
2199ef9eebSMatthias Springer using namespace mlir::vector;
2299ef9eebSMatthias Springer
2399ef9eebSMatthias Springer /// Transpose a vector transfer op's `in_bounds` attribute according to given
2499ef9eebSMatthias Springer /// indices.
2599ef9eebSMatthias Springer static ArrayAttr
transposeInBoundsAttr(OpBuilder & builder,ArrayAttr attr,const SmallVector<unsigned> & permutation)2699ef9eebSMatthias Springer transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
2799ef9eebSMatthias Springer const SmallVector<unsigned> &permutation) {
2899ef9eebSMatthias Springer SmallVector<bool> newInBoundsValues;
2999ef9eebSMatthias Springer for (unsigned pos : permutation)
3099ef9eebSMatthias Springer newInBoundsValues.push_back(
3199ef9eebSMatthias Springer attr.getValue()[pos].cast<BoolAttr>().getValue());
3299ef9eebSMatthias Springer return builder.getBoolArrayAttr(newInBoundsValues);
3399ef9eebSMatthias Springer }
3499ef9eebSMatthias Springer
3599ef9eebSMatthias Springer /// Lower transfer_read op with permutation into a transfer_read with a
3699ef9eebSMatthias Springer /// permutation map composed of leading zeros followed by a minor identiy +
3799ef9eebSMatthias Springer /// vector.transpose op.
3899ef9eebSMatthias Springer /// Ex:
3999ef9eebSMatthias Springer /// vector.transfer_read ...
4099ef9eebSMatthias Springer /// permutation_map: (d0, d1, d2) -> (0, d1)
4199ef9eebSMatthias Springer /// into:
4299ef9eebSMatthias Springer /// %v = vector.transfer_read ...
4399ef9eebSMatthias Springer /// permutation_map: (d0, d1, d2) -> (d1, 0)
4499ef9eebSMatthias Springer /// vector.transpose %v, [1, 0]
4599ef9eebSMatthias Springer ///
4699ef9eebSMatthias Springer /// vector.transfer_read ...
4799ef9eebSMatthias Springer /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
4899ef9eebSMatthias Springer /// into:
4999ef9eebSMatthias Springer /// %v = vector.transfer_read ...
5099ef9eebSMatthias Springer /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
5199ef9eebSMatthias Springer /// vector.transpose %v, [0, 1, 3, 2, 4]
5299ef9eebSMatthias Springer /// Note that an alternative is to transform it to linalg.transpose +
5399ef9eebSMatthias Springer /// vector.transfer_read to do the transpose in memory instead.
5499ef9eebSMatthias Springer struct TransferReadPermutationLowering
5599ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferReadOp> {
5699ef9eebSMatthias Springer using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
5799ef9eebSMatthias Springer
matchAndRewriteTransferReadPermutationLowering5899ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp op,
5999ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
6099ef9eebSMatthias Springer // TODO: support 0-d corner case.
6199ef9eebSMatthias Springer if (op.getTransferRank() == 0)
6299ef9eebSMatthias Springer return failure();
6399ef9eebSMatthias Springer
6499ef9eebSMatthias Springer SmallVector<unsigned> permutation;
657c38fd60SJacques Pienaar AffineMap map = op.getPermutationMap();
6699ef9eebSMatthias Springer if (map.getNumResults() == 0)
6799ef9eebSMatthias Springer return failure();
6899ef9eebSMatthias Springer if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
6999ef9eebSMatthias Springer return failure();
7099ef9eebSMatthias Springer AffineMap permutationMap =
7199ef9eebSMatthias Springer map.getPermutationMap(permutation, op.getContext());
7299ef9eebSMatthias Springer if (permutationMap.isIdentity())
7399ef9eebSMatthias Springer return failure();
7499ef9eebSMatthias Springer
7599ef9eebSMatthias Springer permutationMap = map.getPermutationMap(permutation, op.getContext());
7699ef9eebSMatthias Springer // Caluclate the map of the new read by applying the inverse permutation.
7799ef9eebSMatthias Springer permutationMap = inversePermutation(permutationMap);
7899ef9eebSMatthias Springer AffineMap newMap = permutationMap.compose(map);
7999ef9eebSMatthias Springer // Apply the reverse transpose to deduce the type of the transfer_read.
8099ef9eebSMatthias Springer ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
8199ef9eebSMatthias Springer SmallVector<int64_t> newVectorShape(originalShape.size());
8299ef9eebSMatthias Springer for (const auto &pos : llvm::enumerate(permutation)) {
8399ef9eebSMatthias Springer newVectorShape[pos.value()] = originalShape[pos.index()];
8499ef9eebSMatthias Springer }
8599ef9eebSMatthias Springer
8699ef9eebSMatthias Springer // Transpose mask operand.
8799ef9eebSMatthias Springer Value newMask;
887c38fd60SJacques Pienaar if (op.getMask()) {
8999ef9eebSMatthias Springer // Remove unused dims from the permutation map. E.g.:
9099ef9eebSMatthias Springer // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
9199ef9eebSMatthias Springer // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
9299ef9eebSMatthias Springer auto comp = compressUnusedDims(map);
9399ef9eebSMatthias Springer // Get positions of remaining result dims.
9499ef9eebSMatthias Springer // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
9599ef9eebSMatthias Springer // maskTransposeIndices = [ 2, 1, 0]
9699ef9eebSMatthias Springer SmallVector<int64_t> maskTransposeIndices;
9799ef9eebSMatthias Springer for (unsigned i = 0; i < comp.getNumResults(); ++i) {
9899ef9eebSMatthias Springer if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
9999ef9eebSMatthias Springer maskTransposeIndices.push_back(expr.getPosition());
10099ef9eebSMatthias Springer }
10199ef9eebSMatthias Springer
1027c38fd60SJacques Pienaar newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
10399ef9eebSMatthias Springer maskTransposeIndices);
10499ef9eebSMatthias Springer }
10599ef9eebSMatthias Springer
10699ef9eebSMatthias Springer // Transpose in_bounds attribute.
10799ef9eebSMatthias Springer ArrayAttr newInBoundsAttr =
108*c27d8152SKazu Hirata op.getInBounds() ? transposeInBoundsAttr(
109*c27d8152SKazu Hirata rewriter, op.getInBounds().value(), permutation)
11099ef9eebSMatthias Springer : ArrayAttr();
11199ef9eebSMatthias Springer
11299ef9eebSMatthias Springer // Generate new transfer_read operation.
11399ef9eebSMatthias Springer VectorType newReadType =
11499ef9eebSMatthias Springer VectorType::get(newVectorShape, op.getVectorType().getElementType());
11599ef9eebSMatthias Springer Value newRead = rewriter.create<vector::TransferReadOp>(
1167c38fd60SJacques Pienaar op.getLoc(), newReadType, op.getSource(), op.getIndices(),
1177c38fd60SJacques Pienaar AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);
11899ef9eebSMatthias Springer
11999ef9eebSMatthias Springer // Transpose result of transfer_read.
12099ef9eebSMatthias Springer SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
12199ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
12299ef9eebSMatthias Springer transposePerm);
12399ef9eebSMatthias Springer return success();
12499ef9eebSMatthias Springer }
12599ef9eebSMatthias Springer };
12699ef9eebSMatthias Springer
12799ef9eebSMatthias Springer /// Lower transfer_write op with permutation into a transfer_write with a
12899ef9eebSMatthias Springer /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
12999ef9eebSMatthias Springer /// Ex:
13099ef9eebSMatthias Springer /// vector.transfer_write %v ...
13199ef9eebSMatthias Springer /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
13299ef9eebSMatthias Springer /// into:
13399ef9eebSMatthias Springer /// %tmp = vector.transpose %v, [2, 0, 1]
13499ef9eebSMatthias Springer /// vector.transfer_write %tmp ...
13599ef9eebSMatthias Springer /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
13699ef9eebSMatthias Springer ///
13799ef9eebSMatthias Springer /// vector.transfer_write %v ...
13899ef9eebSMatthias Springer /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
13999ef9eebSMatthias Springer /// into:
14099ef9eebSMatthias Springer /// %tmp = vector.transpose %v, [1, 0]
14199ef9eebSMatthias Springer /// %v = vector.transfer_write %tmp ...
14299ef9eebSMatthias Springer /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
14399ef9eebSMatthias Springer struct TransferWritePermutationLowering
14499ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferWriteOp> {
14599ef9eebSMatthias Springer using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
14699ef9eebSMatthias Springer
matchAndRewriteTransferWritePermutationLowering14799ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferWriteOp op,
14899ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
14999ef9eebSMatthias Springer // TODO: support 0-d corner case.
15099ef9eebSMatthias Springer if (op.getTransferRank() == 0)
15199ef9eebSMatthias Springer return failure();
15299ef9eebSMatthias Springer
15399ef9eebSMatthias Springer SmallVector<unsigned> permutation;
1547c38fd60SJacques Pienaar AffineMap map = op.getPermutationMap();
15599ef9eebSMatthias Springer if (map.isMinorIdentity())
15699ef9eebSMatthias Springer return failure();
15799ef9eebSMatthias Springer if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
15899ef9eebSMatthias Springer return failure();
15999ef9eebSMatthias Springer
16099ef9eebSMatthias Springer // Remove unused dims from the permutation map. E.g.:
16199ef9eebSMatthias Springer // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
16299ef9eebSMatthias Springer // comp = (d0, d1, d2) -> (d2, d0, d1)
16399ef9eebSMatthias Springer auto comp = compressUnusedDims(map);
16499ef9eebSMatthias Springer // Get positions of remaining result dims.
16599ef9eebSMatthias Springer SmallVector<int64_t> indices;
16699ef9eebSMatthias Springer llvm::transform(comp.getResults(), std::back_inserter(indices),
16799ef9eebSMatthias Springer [](AffineExpr expr) {
16899ef9eebSMatthias Springer return expr.dyn_cast<AffineDimExpr>().getPosition();
16999ef9eebSMatthias Springer });
17099ef9eebSMatthias Springer
17199ef9eebSMatthias Springer // Transpose mask operand.
1727c38fd60SJacques Pienaar Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
1737c38fd60SJacques Pienaar op.getLoc(), op.getMask(), indices)
17499ef9eebSMatthias Springer : Value();
17599ef9eebSMatthias Springer
17699ef9eebSMatthias Springer // Transpose in_bounds attribute.
17799ef9eebSMatthias Springer ArrayAttr newInBoundsAttr =
178*c27d8152SKazu Hirata op.getInBounds() ? transposeInBoundsAttr(
179*c27d8152SKazu Hirata rewriter, op.getInBounds().value(), permutation)
18099ef9eebSMatthias Springer : ArrayAttr();
18199ef9eebSMatthias Springer
18299ef9eebSMatthias Springer // Generate new transfer_write operation.
1837c38fd60SJacques Pienaar Value newVec = rewriter.create<vector::TransposeOp>(
1847c38fd60SJacques Pienaar op.getLoc(), op.getVector(), indices);
18599ef9eebSMatthias Springer auto newMap = AffineMap::getMinorIdentityMap(
18699ef9eebSMatthias Springer map.getNumDims(), map.getNumResults(), rewriter.getContext());
18799ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1887c38fd60SJacques Pienaar op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
189dc82547bSgysit newMask, newInBoundsAttr);
19099ef9eebSMatthias Springer
19199ef9eebSMatthias Springer return success();
19299ef9eebSMatthias Springer }
19399ef9eebSMatthias Springer };
19499ef9eebSMatthias Springer
19599ef9eebSMatthias Springer /// Lower transfer_read op with broadcast in the leading dimensions into
19699ef9eebSMatthias Springer /// transfer_read of lower rank + vector.broadcast.
19799ef9eebSMatthias Springer /// Ex: vector.transfer_read ...
19899ef9eebSMatthias Springer /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
19999ef9eebSMatthias Springer /// into:
20099ef9eebSMatthias Springer /// %v = vector.transfer_read ...
20199ef9eebSMatthias Springer /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
20299ef9eebSMatthias Springer /// vector.broadcast %v
20399ef9eebSMatthias Springer struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
20499ef9eebSMatthias Springer using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
20599ef9eebSMatthias Springer
matchAndRewriteTransferOpReduceRank20699ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp op,
20799ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
20899ef9eebSMatthias Springer // TODO: support 0-d corner case.
20999ef9eebSMatthias Springer if (op.getTransferRank() == 0)
21099ef9eebSMatthias Springer return failure();
21199ef9eebSMatthias Springer
2127c38fd60SJacques Pienaar AffineMap map = op.getPermutationMap();
21399ef9eebSMatthias Springer unsigned numLeadingBroadcast = 0;
21499ef9eebSMatthias Springer for (auto expr : map.getResults()) {
21599ef9eebSMatthias Springer auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
21699ef9eebSMatthias Springer if (!dimExpr || dimExpr.getValue() != 0)
21799ef9eebSMatthias Springer break;
21899ef9eebSMatthias Springer numLeadingBroadcast++;
21999ef9eebSMatthias Springer }
22099ef9eebSMatthias Springer // If there are no leading zeros in the map there is nothing to do.
22199ef9eebSMatthias Springer if (numLeadingBroadcast == 0)
22299ef9eebSMatthias Springer return failure();
22399ef9eebSMatthias Springer VectorType originalVecType = op.getVectorType();
22499ef9eebSMatthias Springer unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
22599ef9eebSMatthias Springer // Calculate new map, vector type and masks without the leading zeros.
22699ef9eebSMatthias Springer AffineMap newMap = AffineMap::get(
22799ef9eebSMatthias Springer map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
22899ef9eebSMatthias Springer op.getContext());
22999ef9eebSMatthias Springer // Only remove the leading zeros if the rest of the map is a minor identity
23099ef9eebSMatthias Springer // with broadasting. Otherwise we first want to permute the map.
23199ef9eebSMatthias Springer if (!newMap.isMinorIdentityWithBroadcasting())
23299ef9eebSMatthias Springer return failure();
23399ef9eebSMatthias Springer
23499ef9eebSMatthias Springer // TODO: support zero-dimension vectors natively. See:
23599ef9eebSMatthias Springer // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
23699ef9eebSMatthias Springer // In the meantime, lower these to a scalar load when they pop up.
23799ef9eebSMatthias Springer if (reducedShapeRank == 0) {
23899ef9eebSMatthias Springer Value newRead;
23999ef9eebSMatthias Springer if (op.getShapedType().isa<TensorType>()) {
2407c38fd60SJacques Pienaar newRead = rewriter.create<tensor::ExtractOp>(
2417c38fd60SJacques Pienaar op.getLoc(), op.getSource(), op.getIndices());
24299ef9eebSMatthias Springer } else {
24399ef9eebSMatthias Springer newRead = rewriter.create<memref::LoadOp>(
2447c38fd60SJacques Pienaar op.getLoc(), originalVecType.getElementType(), op.getSource(),
2457c38fd60SJacques Pienaar op.getIndices());
24699ef9eebSMatthias Springer }
24799ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
24899ef9eebSMatthias Springer newRead);
24999ef9eebSMatthias Springer return success();
25099ef9eebSMatthias Springer }
25199ef9eebSMatthias Springer SmallVector<int64_t> newShape = llvm::to_vector<4>(
25299ef9eebSMatthias Springer originalVecType.getShape().take_back(reducedShapeRank));
25399ef9eebSMatthias Springer // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
25499ef9eebSMatthias Springer if (newShape.empty())
25599ef9eebSMatthias Springer return failure();
25699ef9eebSMatthias Springer VectorType newReadType =
25799ef9eebSMatthias Springer VectorType::get(newShape, originalVecType.getElementType());
25899ef9eebSMatthias Springer ArrayAttr newInBoundsAttr =
2597c38fd60SJacques Pienaar op.getInBounds()
26099ef9eebSMatthias Springer ? rewriter.getArrayAttr(
2617c38fd60SJacques Pienaar op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
26299ef9eebSMatthias Springer : ArrayAttr();
26399ef9eebSMatthias Springer Value newRead = rewriter.create<vector::TransferReadOp>(
2647c38fd60SJacques Pienaar op.getLoc(), newReadType, op.getSource(), op.getIndices(),
2657c38fd60SJacques Pienaar AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
2667c38fd60SJacques Pienaar newInBoundsAttr);
26799ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
26899ef9eebSMatthias Springer newRead);
26999ef9eebSMatthias Springer return success();
27099ef9eebSMatthias Springer }
27199ef9eebSMatthias Springer };
27299ef9eebSMatthias Springer
populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet & patterns)27399ef9eebSMatthias Springer void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
27499ef9eebSMatthias Springer RewritePatternSet &patterns) {
27599ef9eebSMatthias Springer patterns.add<TransferReadPermutationLowering,
27699ef9eebSMatthias Springer TransferWritePermutationLowering, TransferOpReduceRank>(
27799ef9eebSMatthias Springer patterns.getContext());
27899ef9eebSMatthias Springer }
279