199ef9eebSMatthias Springer //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements rewrite patterns for the permutation_map attribute of
1099ef9eebSMatthias Springer // vector.transfer operations.
1199ef9eebSMatthias Springer //
1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1399ef9eebSMatthias Springer 
1499ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
1599ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
1799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1899ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h"
1999ef9eebSMatthias Springer 
2099ef9eebSMatthias Springer using namespace mlir;
2199ef9eebSMatthias Springer using namespace mlir::vector;
2299ef9eebSMatthias Springer 
2399ef9eebSMatthias Springer /// Transpose a vector transfer op's `in_bounds` attribute according to given
2499ef9eebSMatthias Springer /// indices.
2599ef9eebSMatthias Springer static ArrayAttr
transposeInBoundsAttr(OpBuilder & builder,ArrayAttr attr,const SmallVector<unsigned> & permutation)2699ef9eebSMatthias Springer transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
2799ef9eebSMatthias Springer                       const SmallVector<unsigned> &permutation) {
2899ef9eebSMatthias Springer   SmallVector<bool> newInBoundsValues;
2999ef9eebSMatthias Springer   for (unsigned pos : permutation)
3099ef9eebSMatthias Springer     newInBoundsValues.push_back(
3199ef9eebSMatthias Springer         attr.getValue()[pos].cast<BoolAttr>().getValue());
3299ef9eebSMatthias Springer   return builder.getBoolArrayAttr(newInBoundsValues);
3399ef9eebSMatthias Springer }
3499ef9eebSMatthias Springer 
3599ef9eebSMatthias Springer /// Lower transfer_read op with permutation into a transfer_read with a
3699ef9eebSMatthias Springer /// permutation map composed of leading zeros followed by a minor identiy +
3799ef9eebSMatthias Springer /// vector.transpose op.
3899ef9eebSMatthias Springer /// Ex:
3999ef9eebSMatthias Springer ///     vector.transfer_read ...
4099ef9eebSMatthias Springer ///         permutation_map: (d0, d1, d2) -> (0, d1)
4199ef9eebSMatthias Springer /// into:
4299ef9eebSMatthias Springer ///     %v = vector.transfer_read ...
4399ef9eebSMatthias Springer ///         permutation_map: (d0, d1, d2) -> (d1, 0)
4499ef9eebSMatthias Springer ///     vector.transpose %v, [1, 0]
4599ef9eebSMatthias Springer ///
4699ef9eebSMatthias Springer ///     vector.transfer_read ...
4799ef9eebSMatthias Springer ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
4899ef9eebSMatthias Springer /// into:
4999ef9eebSMatthias Springer ///     %v = vector.transfer_read ...
5099ef9eebSMatthias Springer ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
5199ef9eebSMatthias Springer ///     vector.transpose %v, [0, 1, 3, 2, 4]
5299ef9eebSMatthias Springer /// Note that an alternative is to transform it to linalg.transpose +
5399ef9eebSMatthias Springer /// vector.transfer_read to do the transpose in memory instead.
5499ef9eebSMatthias Springer struct TransferReadPermutationLowering
5599ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferReadOp> {
5699ef9eebSMatthias Springer   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
5799ef9eebSMatthias Springer 
matchAndRewriteTransferReadPermutationLowering5899ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferReadOp op,
5999ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
6099ef9eebSMatthias Springer     // TODO: support 0-d corner case.
6199ef9eebSMatthias Springer     if (op.getTransferRank() == 0)
6299ef9eebSMatthias Springer       return failure();
6399ef9eebSMatthias Springer 
6499ef9eebSMatthias Springer     SmallVector<unsigned> permutation;
657c38fd60SJacques Pienaar     AffineMap map = op.getPermutationMap();
6699ef9eebSMatthias Springer     if (map.getNumResults() == 0)
6799ef9eebSMatthias Springer       return failure();
6899ef9eebSMatthias Springer     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
6999ef9eebSMatthias Springer       return failure();
7099ef9eebSMatthias Springer     AffineMap permutationMap =
7199ef9eebSMatthias Springer         map.getPermutationMap(permutation, op.getContext());
7299ef9eebSMatthias Springer     if (permutationMap.isIdentity())
7399ef9eebSMatthias Springer       return failure();
7499ef9eebSMatthias Springer 
7599ef9eebSMatthias Springer     permutationMap = map.getPermutationMap(permutation, op.getContext());
7699ef9eebSMatthias Springer     // Caluclate the map of the new read by applying the inverse permutation.
7799ef9eebSMatthias Springer     permutationMap = inversePermutation(permutationMap);
7899ef9eebSMatthias Springer     AffineMap newMap = permutationMap.compose(map);
7999ef9eebSMatthias Springer     // Apply the reverse transpose to deduce the type of the transfer_read.
8099ef9eebSMatthias Springer     ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
8199ef9eebSMatthias Springer     SmallVector<int64_t> newVectorShape(originalShape.size());
8299ef9eebSMatthias Springer     for (const auto &pos : llvm::enumerate(permutation)) {
8399ef9eebSMatthias Springer       newVectorShape[pos.value()] = originalShape[pos.index()];
8499ef9eebSMatthias Springer     }
8599ef9eebSMatthias Springer 
8699ef9eebSMatthias Springer     // Transpose mask operand.
8799ef9eebSMatthias Springer     Value newMask;
887c38fd60SJacques Pienaar     if (op.getMask()) {
8999ef9eebSMatthias Springer       // Remove unused dims from the permutation map. E.g.:
9099ef9eebSMatthias Springer       // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
9199ef9eebSMatthias Springer       // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
9299ef9eebSMatthias Springer       auto comp = compressUnusedDims(map);
9399ef9eebSMatthias Springer       // Get positions of remaining result dims.
9499ef9eebSMatthias Springer       // E.g.:  (d0, d1, d2) -> (d2, 0, d1, 0 d0)
9599ef9eebSMatthias Springer       // maskTransposeIndices = [ 2,     1,    0]
9699ef9eebSMatthias Springer       SmallVector<int64_t> maskTransposeIndices;
9799ef9eebSMatthias Springer       for (unsigned i = 0; i < comp.getNumResults(); ++i) {
9899ef9eebSMatthias Springer         if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
9999ef9eebSMatthias Springer           maskTransposeIndices.push_back(expr.getPosition());
10099ef9eebSMatthias Springer       }
10199ef9eebSMatthias Springer 
1027c38fd60SJacques Pienaar       newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
10399ef9eebSMatthias Springer                                                      maskTransposeIndices);
10499ef9eebSMatthias Springer     }
10599ef9eebSMatthias Springer 
10699ef9eebSMatthias Springer     // Transpose in_bounds attribute.
10799ef9eebSMatthias Springer     ArrayAttr newInBoundsAttr =
108*c27d8152SKazu Hirata         op.getInBounds() ? transposeInBoundsAttr(
109*c27d8152SKazu Hirata                                rewriter, op.getInBounds().value(), permutation)
11099ef9eebSMatthias Springer                          : ArrayAttr();
11199ef9eebSMatthias Springer 
11299ef9eebSMatthias Springer     // Generate new transfer_read operation.
11399ef9eebSMatthias Springer     VectorType newReadType =
11499ef9eebSMatthias Springer         VectorType::get(newVectorShape, op.getVectorType().getElementType());
11599ef9eebSMatthias Springer     Value newRead = rewriter.create<vector::TransferReadOp>(
1167c38fd60SJacques Pienaar         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
1177c38fd60SJacques Pienaar         AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);
11899ef9eebSMatthias Springer 
11999ef9eebSMatthias Springer     // Transpose result of transfer_read.
12099ef9eebSMatthias Springer     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
12199ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
12299ef9eebSMatthias Springer                                                      transposePerm);
12399ef9eebSMatthias Springer     return success();
12499ef9eebSMatthias Springer   }
12599ef9eebSMatthias Springer };
12699ef9eebSMatthias Springer 
12799ef9eebSMatthias Springer /// Lower transfer_write op with permutation into a transfer_write with a
12899ef9eebSMatthias Springer /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
12999ef9eebSMatthias Springer /// Ex:
13099ef9eebSMatthias Springer ///     vector.transfer_write %v ...
13199ef9eebSMatthias Springer ///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
13299ef9eebSMatthias Springer /// into:
13399ef9eebSMatthias Springer ///     %tmp = vector.transpose %v, [2, 0, 1]
13499ef9eebSMatthias Springer ///     vector.transfer_write %tmp ...
13599ef9eebSMatthias Springer ///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
13699ef9eebSMatthias Springer ///
13799ef9eebSMatthias Springer ///     vector.transfer_write %v ...
13899ef9eebSMatthias Springer ///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
13999ef9eebSMatthias Springer /// into:
14099ef9eebSMatthias Springer ///     %tmp = vector.transpose %v, [1, 0]
14199ef9eebSMatthias Springer ///     %v = vector.transfer_write %tmp ...
14299ef9eebSMatthias Springer ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
14399ef9eebSMatthias Springer struct TransferWritePermutationLowering
14499ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferWriteOp> {
14599ef9eebSMatthias Springer   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
14699ef9eebSMatthias Springer 
matchAndRewriteTransferWritePermutationLowering14799ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferWriteOp op,
14899ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
14999ef9eebSMatthias Springer     // TODO: support 0-d corner case.
15099ef9eebSMatthias Springer     if (op.getTransferRank() == 0)
15199ef9eebSMatthias Springer       return failure();
15299ef9eebSMatthias Springer 
15399ef9eebSMatthias Springer     SmallVector<unsigned> permutation;
1547c38fd60SJacques Pienaar     AffineMap map = op.getPermutationMap();
15599ef9eebSMatthias Springer     if (map.isMinorIdentity())
15699ef9eebSMatthias Springer       return failure();
15799ef9eebSMatthias Springer     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
15899ef9eebSMatthias Springer       return failure();
15999ef9eebSMatthias Springer 
16099ef9eebSMatthias Springer     // Remove unused dims from the permutation map. E.g.:
16199ef9eebSMatthias Springer     // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
16299ef9eebSMatthias Springer     // comp = (d0, d1, d2) -> (d2, d0, d1)
16399ef9eebSMatthias Springer     auto comp = compressUnusedDims(map);
16499ef9eebSMatthias Springer     // Get positions of remaining result dims.
16599ef9eebSMatthias Springer     SmallVector<int64_t> indices;
16699ef9eebSMatthias Springer     llvm::transform(comp.getResults(), std::back_inserter(indices),
16799ef9eebSMatthias Springer                     [](AffineExpr expr) {
16899ef9eebSMatthias Springer                       return expr.dyn_cast<AffineDimExpr>().getPosition();
16999ef9eebSMatthias Springer                     });
17099ef9eebSMatthias Springer 
17199ef9eebSMatthias Springer     // Transpose mask operand.
1727c38fd60SJacques Pienaar     Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
1737c38fd60SJacques Pienaar                                        op.getLoc(), op.getMask(), indices)
17499ef9eebSMatthias Springer                                  : Value();
17599ef9eebSMatthias Springer 
17699ef9eebSMatthias Springer     // Transpose in_bounds attribute.
17799ef9eebSMatthias Springer     ArrayAttr newInBoundsAttr =
178*c27d8152SKazu Hirata         op.getInBounds() ? transposeInBoundsAttr(
179*c27d8152SKazu Hirata                                rewriter, op.getInBounds().value(), permutation)
18099ef9eebSMatthias Springer                          : ArrayAttr();
18199ef9eebSMatthias Springer 
18299ef9eebSMatthias Springer     // Generate new transfer_write operation.
1837c38fd60SJacques Pienaar     Value newVec = rewriter.create<vector::TransposeOp>(
1847c38fd60SJacques Pienaar         op.getLoc(), op.getVector(), indices);
18599ef9eebSMatthias Springer     auto newMap = AffineMap::getMinorIdentityMap(
18699ef9eebSMatthias Springer         map.getNumDims(), map.getNumResults(), rewriter.getContext());
18799ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1887c38fd60SJacques Pienaar         op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
189dc82547bSgysit         newMask, newInBoundsAttr);
19099ef9eebSMatthias Springer 
19199ef9eebSMatthias Springer     return success();
19299ef9eebSMatthias Springer   }
19399ef9eebSMatthias Springer };
19499ef9eebSMatthias Springer 
19599ef9eebSMatthias Springer /// Lower transfer_read op with broadcast in the leading dimensions into
19699ef9eebSMatthias Springer /// transfer_read of lower rank + vector.broadcast.
19799ef9eebSMatthias Springer /// Ex: vector.transfer_read ...
19899ef9eebSMatthias Springer ///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
19999ef9eebSMatthias Springer /// into:
20099ef9eebSMatthias Springer ///     %v = vector.transfer_read ...
20199ef9eebSMatthias Springer ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
20299ef9eebSMatthias Springer ///     vector.broadcast %v
20399ef9eebSMatthias Springer struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
20499ef9eebSMatthias Springer   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
20599ef9eebSMatthias Springer 
matchAndRewriteTransferOpReduceRank20699ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferReadOp op,
20799ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
20899ef9eebSMatthias Springer     // TODO: support 0-d corner case.
20999ef9eebSMatthias Springer     if (op.getTransferRank() == 0)
21099ef9eebSMatthias Springer       return failure();
21199ef9eebSMatthias Springer 
2127c38fd60SJacques Pienaar     AffineMap map = op.getPermutationMap();
21399ef9eebSMatthias Springer     unsigned numLeadingBroadcast = 0;
21499ef9eebSMatthias Springer     for (auto expr : map.getResults()) {
21599ef9eebSMatthias Springer       auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
21699ef9eebSMatthias Springer       if (!dimExpr || dimExpr.getValue() != 0)
21799ef9eebSMatthias Springer         break;
21899ef9eebSMatthias Springer       numLeadingBroadcast++;
21999ef9eebSMatthias Springer     }
22099ef9eebSMatthias Springer     // If there are no leading zeros in the map there is nothing to do.
22199ef9eebSMatthias Springer     if (numLeadingBroadcast == 0)
22299ef9eebSMatthias Springer       return failure();
22399ef9eebSMatthias Springer     VectorType originalVecType = op.getVectorType();
22499ef9eebSMatthias Springer     unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
22599ef9eebSMatthias Springer     // Calculate new map, vector type and masks without the leading zeros.
22699ef9eebSMatthias Springer     AffineMap newMap = AffineMap::get(
22799ef9eebSMatthias Springer         map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
22899ef9eebSMatthias Springer         op.getContext());
22999ef9eebSMatthias Springer     // Only remove the leading zeros if the rest of the map is a minor identity
23099ef9eebSMatthias Springer     // with broadasting. Otherwise we first want to permute the map.
23199ef9eebSMatthias Springer     if (!newMap.isMinorIdentityWithBroadcasting())
23299ef9eebSMatthias Springer       return failure();
23399ef9eebSMatthias Springer 
23499ef9eebSMatthias Springer     // TODO: support zero-dimension vectors natively.  See:
23599ef9eebSMatthias Springer     // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
23699ef9eebSMatthias Springer     // In the meantime, lower these to a scalar load when they pop up.
23799ef9eebSMatthias Springer     if (reducedShapeRank == 0) {
23899ef9eebSMatthias Springer       Value newRead;
23999ef9eebSMatthias Springer       if (op.getShapedType().isa<TensorType>()) {
2407c38fd60SJacques Pienaar         newRead = rewriter.create<tensor::ExtractOp>(
2417c38fd60SJacques Pienaar             op.getLoc(), op.getSource(), op.getIndices());
24299ef9eebSMatthias Springer       } else {
24399ef9eebSMatthias Springer         newRead = rewriter.create<memref::LoadOp>(
2447c38fd60SJacques Pienaar             op.getLoc(), originalVecType.getElementType(), op.getSource(),
2457c38fd60SJacques Pienaar             op.getIndices());
24699ef9eebSMatthias Springer       }
24799ef9eebSMatthias Springer       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
24899ef9eebSMatthias Springer                                                        newRead);
24999ef9eebSMatthias Springer       return success();
25099ef9eebSMatthias Springer     }
25199ef9eebSMatthias Springer     SmallVector<int64_t> newShape = llvm::to_vector<4>(
25299ef9eebSMatthias Springer         originalVecType.getShape().take_back(reducedShapeRank));
25399ef9eebSMatthias Springer     // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
25499ef9eebSMatthias Springer     if (newShape.empty())
25599ef9eebSMatthias Springer       return failure();
25699ef9eebSMatthias Springer     VectorType newReadType =
25799ef9eebSMatthias Springer         VectorType::get(newShape, originalVecType.getElementType());
25899ef9eebSMatthias Springer     ArrayAttr newInBoundsAttr =
2597c38fd60SJacques Pienaar         op.getInBounds()
26099ef9eebSMatthias Springer             ? rewriter.getArrayAttr(
2617c38fd60SJacques Pienaar                   op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
26299ef9eebSMatthias Springer             : ArrayAttr();
26399ef9eebSMatthias Springer     Value newRead = rewriter.create<vector::TransferReadOp>(
2647c38fd60SJacques Pienaar         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
2657c38fd60SJacques Pienaar         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
2667c38fd60SJacques Pienaar         newInBoundsAttr);
26799ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
26899ef9eebSMatthias Springer                                                      newRead);
26999ef9eebSMatthias Springer     return success();
27099ef9eebSMatthias Springer   }
27199ef9eebSMatthias Springer };
27299ef9eebSMatthias Springer 
populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet & patterns)27399ef9eebSMatthias Springer void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
27499ef9eebSMatthias Springer     RewritePatternSet &patterns) {
27599ef9eebSMatthias Springer   patterns.add<TransferReadPermutationLowering,
27699ef9eebSMatthias Springer                TransferWritePermutationLowering, TransferOpReduceRank>(
27799ef9eebSMatthias Springer       patterns.getContext());
27899ef9eebSMatthias Springer }
279