1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/MemRef/IR/MemRef.h"
11 #include "mlir/Dialect/Utils/IndexingUtils.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 
17 using namespace mlir;
18 using namespace mlir::vector;
19 
20 // Helper that picks the proper sequence for inserting.
21 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
22                        Value into, int64_t offset) {
23   auto vectorType = into.getType().cast<VectorType>();
24   if (vectorType.getRank() > 1)
25     return rewriter.create<InsertOp>(loc, from, into, offset);
26   return rewriter.create<vector::InsertElementOp>(
27       loc, vectorType, from, into,
28       rewriter.create<arith::ConstantIndexOp>(loc, offset));
29 }
30 
31 // Helper that picks the proper sequence for extracting.
32 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
33                         int64_t offset) {
34   auto vectorType = vector.getType().cast<VectorType>();
35   if (vectorType.getRank() > 1)
36     return rewriter.create<ExtractOp>(loc, vector, offset);
37   return rewriter.create<vector::ExtractElementOp>(
38       loc, vectorType.getElementType(), vector,
39       rewriter.create<arith::ConstantIndexOp>(loc, offset));
40 }
41 
42 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
43 /// have different ranks.
44 ///
45 /// When ranks are different, InsertStridedSlice needs to extract a properly
46 /// ranked vector from the destination vector into which to insert. This pattern
47 /// only takes care of this extraction part and forwards the rest to
48 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
49 ///
50 /// For a k-D source and n-D destination vector (k < n), we emit:
51 ///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
52 ///      insert the k-D source.
53 ///   2. k-D -> (n-1)-D InsertStridedSlice op
54 ///   3. InsertOp that is the reverse of 1.
55 class VectorInsertStridedSliceOpDifferentRankRewritePattern
56     : public OpRewritePattern<InsertStridedSliceOp> {
57 public:
58   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
59 
60   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
61                                 PatternRewriter &rewriter) const override {
62     auto srcType = op.getSourceVectorType();
63     auto dstType = op.getDestVectorType();
64 
65     if (op.offsets().getValue().empty())
66       return failure();
67 
68     auto loc = op.getLoc();
69     int64_t rankDiff = dstType.getRank() - srcType.getRank();
70     assert(rankDiff >= 0);
71     if (rankDiff == 0)
72       return failure();
73 
74     int64_t rankRest = dstType.getRank() - rankDiff;
75     // Extract / insert the subvector of matching rank and InsertStridedSlice
76     // on it.
77     Value extracted =
78         rewriter.create<ExtractOp>(loc, op.dest(),
79                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
80                                                   /*dropBack=*/rankRest));
81 
82     // A different pattern will kick in for InsertStridedSlice with matching
83     // ranks.
84     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
85         loc, op.source(), extracted,
86         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
87         getI64SubArray(op.strides(), /*dropFront=*/0));
88 
89     rewriter.replaceOpWithNewOp<InsertOp>(
90         op, stridedSliceInnerOp.getResult(), op.dest(),
91         getI64SubArray(op.offsets(), /*dropFront=*/0,
92                        /*dropBack=*/rankRest));
93     return success();
94   }
95 };
96 
97 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
98 /// have the same rank. For each outermost index in the slice:
99 ///   begin    end             stride
100 /// [offset : offset+size*stride : stride]
101 ///   1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
102 ///   2. InsertStridedSlice (k-1)-D into (n-1)-D
103 ///   3. the destination subvector is inserted back in the proper place
104 ///   3. InsertOp that is the reverse of 1.
105 class VectorInsertStridedSliceOpSameRankRewritePattern
106     : public OpRewritePattern<InsertStridedSliceOp> {
107 public:
108   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
109 
110   void initialize() {
111     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
112     // bounded as the rank is strictly decreasing.
113     setHasBoundedRewriteRecursion();
114   }
115 
116   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
117                                 PatternRewriter &rewriter) const override {
118     auto srcType = op.getSourceVectorType();
119     auto dstType = op.getDestVectorType();
120 
121     if (op.offsets().getValue().empty())
122       return failure();
123 
124     int64_t srcRank = srcType.getRank();
125     int64_t dstRank = dstType.getRank();
126     assert(dstRank >= srcRank);
127     if (dstRank != srcRank)
128       return failure();
129 
130     if (srcType == dstType) {
131       rewriter.replaceOp(op, op.source());
132       return success();
133     }
134 
135     int64_t offset =
136         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
137     int64_t size = srcType.getShape().front();
138     int64_t stride =
139         op.strides().getValue().front().cast<IntegerAttr>().getInt();
140 
141     auto loc = op.getLoc();
142     Value res = op.dest();
143 
144     if (srcRank == 1) {
145       int nSrc = srcType.getShape().front();
146       int nDest = dstType.getShape().front();
147       // 1. Scale source to destType so we can shufflevector them together.
148       SmallVector<int64_t> offsets(nDest, 0);
149       for (int64_t i = 0; i < nSrc; ++i)
150         offsets[i] = i;
151       Value scaledSource =
152           rewriter.create<ShuffleOp>(loc, op.source(), op.source(), offsets);
153 
154       // 2. Create a mask where we take the value from scaledSource of dest
155       // depending on the offset.
156       offsets.clear();
157       for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
158         if (i < offset || i >= e || (i - offset) % stride != 0)
159           offsets.push_back(nDest + i);
160         else
161           offsets.push_back((i - offset) / stride);
162       }
163 
164       // 3. Replace with a ShuffleOp.
165       rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.dest(),
166                                              offsets);
167 
168       return success();
169     }
170 
171     // For each slice of the source vector along the most major dimension.
172     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
173          off += stride, ++idx) {
174       // 1. extract the proper subvector (or element) from source
175       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
176       if (extractedSource.getType().isa<VectorType>()) {
177         // 2. If we have a vector, extract the proper subvector from destination
178         // Otherwise we are at the element level and no need to recurse.
179         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
180         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
181         // smaller rank.
182         extractedSource = rewriter.create<InsertStridedSliceOp>(
183             loc, extractedSource, extractedDest,
184             getI64SubArray(op.offsets(), /* dropFront=*/1),
185             getI64SubArray(op.strides(), /* dropFront=*/1));
186       }
187       // 4. Insert the extractedSource into the res vector.
188       res = insertOne(rewriter, loc, extractedSource, res, off);
189     }
190 
191     rewriter.replaceOp(op, res);
192     return success();
193   }
194 };
195 
196 /// Progressive lowering of ExtractStridedSliceOp to either:
197 ///   1. single offset extract as a direct vector::ShuffleOp.
198 ///   2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
199 ///      InsertOp/InsertElementOp for the n-D case.
200 class VectorExtractStridedSliceOpRewritePattern
201     : public OpRewritePattern<ExtractStridedSliceOp> {
202 public:
203   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
204 
205   void initialize() {
206     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
207     // is bounded as the rank is strictly decreasing.
208     setHasBoundedRewriteRecursion();
209   }
210 
211   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
212                                 PatternRewriter &rewriter) const override {
213     auto dstType = op.getType();
214 
215     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
216 
217     int64_t offset =
218         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
219     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
220     int64_t stride =
221         op.strides().getValue().front().cast<IntegerAttr>().getInt();
222 
223     auto loc = op.getLoc();
224     auto elemType = dstType.getElementType();
225     assert(elemType.isSignlessIntOrIndexOrFloat());
226 
227     // Single offset can be more efficiently shuffled.
228     if (op.offsets().getValue().size() == 1) {
229       SmallVector<int64_t, 4> offsets;
230       offsets.reserve(size);
231       for (int64_t off = offset, e = offset + size * stride; off < e;
232            off += stride)
233         offsets.push_back(off);
234       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
235                                              op.vector(),
236                                              rewriter.getI64ArrayAttr(offsets));
237       return success();
238     }
239 
240     // Extract/insert on a lower ranked extract strided slice op.
241     Value zero = rewriter.create<arith::ConstantOp>(
242         loc, elemType, rewriter.getZeroAttr(elemType));
243     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
244     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
245          off += stride, ++idx) {
246       Value one = extractOne(rewriter, loc, op.vector(), off);
247       Value extracted = rewriter.create<ExtractStridedSliceOp>(
248           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
249           getI64SubArray(op.sizes(), /* dropFront=*/1),
250           getI64SubArray(op.strides(), /* dropFront=*/1));
251       res = insertOne(rewriter, loc, extracted, res, idx);
252     }
253     rewriter.replaceOp(op, res);
254     return success();
255   }
256 };
257 
258 /// Populate the given list with patterns that convert from Vector to LLVM.
259 void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
260     RewritePatternSet &patterns) {
261   patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern,
262                VectorInsertStridedSliceOpSameRankRewritePattern,
263                VectorExtractStridedSliceOpRewritePattern>(
264       patterns.getContext());
265 }
266