199ef9eebSMatthias Springer //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer 
999ef9eebSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1099ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1199ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h"
1299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1599ef9eebSMatthias Springer #include "mlir/IR/BuiltinTypes.h"
1699ef9eebSMatthias Springer 
1799ef9eebSMatthias Springer using namespace mlir;
1899ef9eebSMatthias Springer using namespace mlir::vector;
1999ef9eebSMatthias Springer 
2099ef9eebSMatthias Springer // Helper that picks the proper sequence for inserting.
insertOne(PatternRewriter & rewriter,Location loc,Value from,Value into,int64_t offset)2199ef9eebSMatthias Springer static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
2299ef9eebSMatthias Springer                        Value into, int64_t offset) {
2399ef9eebSMatthias Springer   auto vectorType = into.getType().cast<VectorType>();
2499ef9eebSMatthias Springer   if (vectorType.getRank() > 1)
2599ef9eebSMatthias Springer     return rewriter.create<InsertOp>(loc, from, into, offset);
2699ef9eebSMatthias Springer   return rewriter.create<vector::InsertElementOp>(
2799ef9eebSMatthias Springer       loc, vectorType, from, into,
2899ef9eebSMatthias Springer       rewriter.create<arith::ConstantIndexOp>(loc, offset));
2999ef9eebSMatthias Springer }
3099ef9eebSMatthias Springer 
3199ef9eebSMatthias Springer // Helper that picks the proper sequence for extracting.
extractOne(PatternRewriter & rewriter,Location loc,Value vector,int64_t offset)3299ef9eebSMatthias Springer static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
3399ef9eebSMatthias Springer                         int64_t offset) {
3499ef9eebSMatthias Springer   auto vectorType = vector.getType().cast<VectorType>();
3599ef9eebSMatthias Springer   if (vectorType.getRank() > 1)
3699ef9eebSMatthias Springer     return rewriter.create<ExtractOp>(loc, vector, offset);
3799ef9eebSMatthias Springer   return rewriter.create<vector::ExtractElementOp>(
3899ef9eebSMatthias Springer       loc, vectorType.getElementType(), vector,
3999ef9eebSMatthias Springer       rewriter.create<arith::ConstantIndexOp>(loc, offset));
4099ef9eebSMatthias Springer }
4199ef9eebSMatthias Springer 
4299ef9eebSMatthias Springer /// RewritePattern for InsertStridedSliceOp where source and destination vectors
4399ef9eebSMatthias Springer /// have different ranks.
4499ef9eebSMatthias Springer ///
4599ef9eebSMatthias Springer /// When ranks are different, InsertStridedSlice needs to extract a properly
4699ef9eebSMatthias Springer /// ranked vector from the destination vector into which to insert. This pattern
4799ef9eebSMatthias Springer /// only takes care of this extraction part and forwards the rest to
4859d3a9e0SLei Zhang /// [ConvertSameRankInsertStridedSliceIntoShuffle].
4999ef9eebSMatthias Springer ///
5099ef9eebSMatthias Springer /// For a k-D source and n-D destination vector (k < n), we emit:
5199ef9eebSMatthias Springer ///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
5299ef9eebSMatthias Springer ///      insert the k-D source.
5399ef9eebSMatthias Springer ///   2. k-D -> (n-1)-D InsertStridedSlice op
5499ef9eebSMatthias Springer ///   3. InsertOp that is the reverse of 1.
5559d3a9e0SLei Zhang class DecomposeDifferentRankInsertStridedSlice
5699ef9eebSMatthias Springer     : public OpRewritePattern<InsertStridedSliceOp> {
5799ef9eebSMatthias Springer public:
5899ef9eebSMatthias Springer   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
5999ef9eebSMatthias Springer 
matchAndRewrite(InsertStridedSliceOp op,PatternRewriter & rewriter) const6099ef9eebSMatthias Springer   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
6199ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
6299ef9eebSMatthias Springer     auto srcType = op.getSourceVectorType();
6399ef9eebSMatthias Springer     auto dstType = op.getDestVectorType();
6499ef9eebSMatthias Springer 
657c38fd60SJacques Pienaar     if (op.getOffsets().getValue().empty())
6699ef9eebSMatthias Springer       return failure();
6799ef9eebSMatthias Springer 
6899ef9eebSMatthias Springer     auto loc = op.getLoc();
6999ef9eebSMatthias Springer     int64_t rankDiff = dstType.getRank() - srcType.getRank();
7099ef9eebSMatthias Springer     assert(rankDiff >= 0);
7199ef9eebSMatthias Springer     if (rankDiff == 0)
7299ef9eebSMatthias Springer       return failure();
7399ef9eebSMatthias Springer 
7499ef9eebSMatthias Springer     int64_t rankRest = dstType.getRank() - rankDiff;
7599ef9eebSMatthias Springer     // Extract / insert the subvector of matching rank and InsertStridedSlice
7699ef9eebSMatthias Springer     // on it.
777c38fd60SJacques Pienaar     Value extracted = rewriter.create<ExtractOp>(
787c38fd60SJacques Pienaar         loc, op.getDest(),
797c38fd60SJacques Pienaar         getI64SubArray(op.getOffsets(), /*dropFront=*/0,
8099ef9eebSMatthias Springer                        /*dropBack=*/rankRest));
8199ef9eebSMatthias Springer 
8299ef9eebSMatthias Springer     // A different pattern will kick in for InsertStridedSlice with matching
8399ef9eebSMatthias Springer     // ranks.
8499ef9eebSMatthias Springer     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
857c38fd60SJacques Pienaar         loc, op.getSource(), extracted,
867c38fd60SJacques Pienaar         getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
877c38fd60SJacques Pienaar         getI64SubArray(op.getStrides(), /*dropFront=*/0));
8899ef9eebSMatthias Springer 
8999ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<InsertOp>(
907c38fd60SJacques Pienaar         op, stridedSliceInnerOp.getResult(), op.getDest(),
917c38fd60SJacques Pienaar         getI64SubArray(op.getOffsets(), /*dropFront=*/0,
9299ef9eebSMatthias Springer                        /*dropBack=*/rankRest));
9399ef9eebSMatthias Springer     return success();
9499ef9eebSMatthias Springer   }
9599ef9eebSMatthias Springer };
9699ef9eebSMatthias Springer 
9799ef9eebSMatthias Springer /// RewritePattern for InsertStridedSliceOp where source and destination vectors
9899ef9eebSMatthias Springer /// have the same rank. For each outermost index in the slice:
9999ef9eebSMatthias Springer ///   begin    end             stride
10099ef9eebSMatthias Springer /// [offset : offset+size*stride : stride]
10199ef9eebSMatthias Springer ///   1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
10299ef9eebSMatthias Springer ///   2. InsertStridedSlice (k-1)-D into (n-1)-D
10399ef9eebSMatthias Springer ///   3. the destination subvector is inserted back in the proper place
10499ef9eebSMatthias Springer ///   3. InsertOp that is the reverse of 1.
10559d3a9e0SLei Zhang class ConvertSameRankInsertStridedSliceIntoShuffle
10699ef9eebSMatthias Springer     : public OpRewritePattern<InsertStridedSliceOp> {
10799ef9eebSMatthias Springer public:
10899ef9eebSMatthias Springer   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
10999ef9eebSMatthias Springer 
initialize()11099ef9eebSMatthias Springer   void initialize() {
11199ef9eebSMatthias Springer     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
11299ef9eebSMatthias Springer     // bounded as the rank is strictly decreasing.
11399ef9eebSMatthias Springer     setHasBoundedRewriteRecursion();
11499ef9eebSMatthias Springer   }
11599ef9eebSMatthias Springer 
matchAndRewrite(InsertStridedSliceOp op,PatternRewriter & rewriter) const11699ef9eebSMatthias Springer   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
11799ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
11899ef9eebSMatthias Springer     auto srcType = op.getSourceVectorType();
11999ef9eebSMatthias Springer     auto dstType = op.getDestVectorType();
12099ef9eebSMatthias Springer 
1217c38fd60SJacques Pienaar     if (op.getOffsets().getValue().empty())
12299ef9eebSMatthias Springer       return failure();
12399ef9eebSMatthias Springer 
12499ef9eebSMatthias Springer     int64_t srcRank = srcType.getRank();
12599ef9eebSMatthias Springer     int64_t dstRank = dstType.getRank();
12699ef9eebSMatthias Springer     assert(dstRank >= srcRank);
12799ef9eebSMatthias Springer     if (dstRank != srcRank)
12899ef9eebSMatthias Springer       return failure();
12999ef9eebSMatthias Springer 
13099ef9eebSMatthias Springer     if (srcType == dstType) {
1317c38fd60SJacques Pienaar       rewriter.replaceOp(op, op.getSource());
13299ef9eebSMatthias Springer       return success();
13399ef9eebSMatthias Springer     }
13499ef9eebSMatthias Springer 
13599ef9eebSMatthias Springer     int64_t offset =
1367c38fd60SJacques Pienaar         op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
13799ef9eebSMatthias Springer     int64_t size = srcType.getShape().front();
13899ef9eebSMatthias Springer     int64_t stride =
1397c38fd60SJacques Pienaar         op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
14099ef9eebSMatthias Springer 
14199ef9eebSMatthias Springer     auto loc = op.getLoc();
1427c38fd60SJacques Pienaar     Value res = op.getDest();
14399ef9eebSMatthias Springer 
14499ef9eebSMatthias Springer     if (srcRank == 1) {
14599ef9eebSMatthias Springer       int nSrc = srcType.getShape().front();
14699ef9eebSMatthias Springer       int nDest = dstType.getShape().front();
14799ef9eebSMatthias Springer       // 1. Scale source to destType so we can shufflevector them together.
14899ef9eebSMatthias Springer       SmallVector<int64_t> offsets(nDest, 0);
14999ef9eebSMatthias Springer       for (int64_t i = 0; i < nSrc; ++i)
15099ef9eebSMatthias Springer         offsets[i] = i;
1517c38fd60SJacques Pienaar       Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
1527c38fd60SJacques Pienaar                                                       op.getSource(), offsets);
15399ef9eebSMatthias Springer 
15499ef9eebSMatthias Springer       // 2. Create a mask where we take the value from scaledSource of dest
15599ef9eebSMatthias Springer       // depending on the offset.
15699ef9eebSMatthias Springer       offsets.clear();
15799ef9eebSMatthias Springer       for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
15899ef9eebSMatthias Springer         if (i < offset || i >= e || (i - offset) % stride != 0)
15999ef9eebSMatthias Springer           offsets.push_back(nDest + i);
16099ef9eebSMatthias Springer         else
16199ef9eebSMatthias Springer           offsets.push_back((i - offset) / stride);
16299ef9eebSMatthias Springer       }
16399ef9eebSMatthias Springer 
16499ef9eebSMatthias Springer       // 3. Replace with a ShuffleOp.
1657c38fd60SJacques Pienaar       rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
16699ef9eebSMatthias Springer                                              offsets);
16799ef9eebSMatthias Springer 
16899ef9eebSMatthias Springer       return success();
16999ef9eebSMatthias Springer     }
17099ef9eebSMatthias Springer 
17199ef9eebSMatthias Springer     // For each slice of the source vector along the most major dimension.
17299ef9eebSMatthias Springer     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
17399ef9eebSMatthias Springer          off += stride, ++idx) {
17499ef9eebSMatthias Springer       // 1. extract the proper subvector (or element) from source
1757c38fd60SJacques Pienaar       Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
17699ef9eebSMatthias Springer       if (extractedSource.getType().isa<VectorType>()) {
17799ef9eebSMatthias Springer         // 2. If we have a vector, extract the proper subvector from destination
17899ef9eebSMatthias Springer         // Otherwise we are at the element level and no need to recurse.
1797c38fd60SJacques Pienaar         Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
18099ef9eebSMatthias Springer         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
18199ef9eebSMatthias Springer         // smaller rank.
18299ef9eebSMatthias Springer         extractedSource = rewriter.create<InsertStridedSliceOp>(
18399ef9eebSMatthias Springer             loc, extractedSource, extractedDest,
1847c38fd60SJacques Pienaar             getI64SubArray(op.getOffsets(), /* dropFront=*/1),
1857c38fd60SJacques Pienaar             getI64SubArray(op.getStrides(), /* dropFront=*/1));
18699ef9eebSMatthias Springer       }
18799ef9eebSMatthias Springer       // 4. Insert the extractedSource into the res vector.
18899ef9eebSMatthias Springer       res = insertOne(rewriter, loc, extractedSource, res, off);
18999ef9eebSMatthias Springer     }
19099ef9eebSMatthias Springer 
19199ef9eebSMatthias Springer     rewriter.replaceOp(op, res);
19299ef9eebSMatthias Springer     return success();
19399ef9eebSMatthias Springer   }
19499ef9eebSMatthias Springer };
19599ef9eebSMatthias Springer 
19659d3a9e0SLei Zhang /// RewritePattern for ExtractStridedSliceOp where source and destination
19759d3a9e0SLei Zhang /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
19859d3a9e0SLei Zhang class Convert1DExtractStridedSliceIntoShuffle
19959d3a9e0SLei Zhang     : public OpRewritePattern<ExtractStridedSliceOp> {
20059d3a9e0SLei Zhang public:
20159d3a9e0SLei Zhang   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
20259d3a9e0SLei Zhang 
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const20359d3a9e0SLei Zhang   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
20459d3a9e0SLei Zhang                                 PatternRewriter &rewriter) const override {
20559d3a9e0SLei Zhang     auto dstType = op.getType();
20659d3a9e0SLei Zhang 
20759d3a9e0SLei Zhang     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
20859d3a9e0SLei Zhang 
20959d3a9e0SLei Zhang     int64_t offset =
21059d3a9e0SLei Zhang         op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
21159d3a9e0SLei Zhang     int64_t size =
21259d3a9e0SLei Zhang         op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
21359d3a9e0SLei Zhang     int64_t stride =
21459d3a9e0SLei Zhang         op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
21559d3a9e0SLei Zhang 
216e7f05526SBenjamin Kramer     assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
21759d3a9e0SLei Zhang 
21859d3a9e0SLei Zhang     // Single offset can be more efficiently shuffled.
21959d3a9e0SLei Zhang     if (op.getOffsets().getValue().size() != 1)
22059d3a9e0SLei Zhang       return failure();
22159d3a9e0SLei Zhang 
22259d3a9e0SLei Zhang     SmallVector<int64_t, 4> offsets;
22359d3a9e0SLei Zhang     offsets.reserve(size);
22459d3a9e0SLei Zhang     for (int64_t off = offset, e = offset + size * stride; off < e;
22559d3a9e0SLei Zhang          off += stride)
22659d3a9e0SLei Zhang       offsets.push_back(off);
22759d3a9e0SLei Zhang     rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
22859d3a9e0SLei Zhang                                            op.getVector(),
22959d3a9e0SLei Zhang                                            rewriter.getI64ArrayAttr(offsets));
23059d3a9e0SLei Zhang     return success();
23159d3a9e0SLei Zhang   }
23259d3a9e0SLei Zhang };
23359d3a9e0SLei Zhang 
23459d3a9e0SLei Zhang /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
23559d3a9e0SLei Zhang /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
23659d3a9e0SLei Zhang /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
23759d3a9e0SLei Zhang class DecomposeNDExtractStridedSlice
23899ef9eebSMatthias Springer     : public OpRewritePattern<ExtractStridedSliceOp> {
23999ef9eebSMatthias Springer public:
24099ef9eebSMatthias Springer   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
24199ef9eebSMatthias Springer 
initialize()24299ef9eebSMatthias Springer   void initialize() {
24399ef9eebSMatthias Springer     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
24499ef9eebSMatthias Springer     // is bounded as the rank is strictly decreasing.
24599ef9eebSMatthias Springer     setHasBoundedRewriteRecursion();
24699ef9eebSMatthias Springer   }
24799ef9eebSMatthias Springer 
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const24899ef9eebSMatthias Springer   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
24999ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
25099ef9eebSMatthias Springer     auto dstType = op.getType();
25199ef9eebSMatthias Springer 
2527c38fd60SJacques Pienaar     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
25399ef9eebSMatthias Springer 
25499ef9eebSMatthias Springer     int64_t offset =
2557c38fd60SJacques Pienaar         op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
2567c38fd60SJacques Pienaar     int64_t size =
2577c38fd60SJacques Pienaar         op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
25899ef9eebSMatthias Springer     int64_t stride =
2597c38fd60SJacques Pienaar         op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
26099ef9eebSMatthias Springer 
261*1acba8a4SBill Wendling     auto loc = op.getLoc();
26299ef9eebSMatthias Springer     auto elemType = dstType.getElementType();
26399ef9eebSMatthias Springer     assert(elemType.isSignlessIntOrIndexOrFloat());
26499ef9eebSMatthias Springer 
26559d3a9e0SLei Zhang     // Single offset can be more efficiently shuffled. It's handled in
26659d3a9e0SLei Zhang     // Convert1DExtractStridedSliceIntoShuffle.
26759d3a9e0SLei Zhang     if (op.getOffsets().getValue().size() == 1)
26859d3a9e0SLei Zhang       return failure();
26999ef9eebSMatthias Springer 
27099ef9eebSMatthias Springer     // Extract/insert on a lower ranked extract strided slice op.
27199ef9eebSMatthias Springer     Value zero = rewriter.create<arith::ConstantOp>(
27299ef9eebSMatthias Springer         loc, elemType, rewriter.getZeroAttr(elemType));
27399ef9eebSMatthias Springer     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
27499ef9eebSMatthias Springer     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
27599ef9eebSMatthias Springer          off += stride, ++idx) {
2767c38fd60SJacques Pienaar       Value one = extractOne(rewriter, loc, op.getVector(), off);
27799ef9eebSMatthias Springer       Value extracted = rewriter.create<ExtractStridedSliceOp>(
2787c38fd60SJacques Pienaar           loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
2797c38fd60SJacques Pienaar           getI64SubArray(op.getSizes(), /* dropFront=*/1),
2807c38fd60SJacques Pienaar           getI64SubArray(op.getStrides(), /* dropFront=*/1));
28199ef9eebSMatthias Springer       res = insertOne(rewriter, loc, extracted, res, idx);
28299ef9eebSMatthias Springer     }
28399ef9eebSMatthias Springer     rewriter.replaceOp(op, res);
28499ef9eebSMatthias Springer     return success();
28599ef9eebSMatthias Springer   }
28699ef9eebSMatthias Springer };
28799ef9eebSMatthias Springer 
populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet & patterns)28859d3a9e0SLei Zhang void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
28959d3a9e0SLei Zhang     RewritePatternSet &patterns) {
29059d3a9e0SLei Zhang   patterns.add<DecomposeDifferentRankInsertStridedSlice,
29159d3a9e0SLei Zhang                DecomposeNDExtractStridedSlice>(patterns.getContext());
29259d3a9e0SLei Zhang }
29359d3a9e0SLei Zhang 
29499ef9eebSMatthias Springer /// Populate the given list with patterns that convert from Vector to LLVM.
populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet & patterns)29599ef9eebSMatthias Springer void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
29699ef9eebSMatthias Springer     RewritePatternSet &patterns) {
29759d3a9e0SLei Zhang   populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns);
29859d3a9e0SLei Zhang   patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
29959d3a9e0SLei Zhang                Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext());
30099ef9eebSMatthias Springer }
301