1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/MemRef/IR/MemRef.h"
11 #include "mlir/Dialect/Utils/IndexingUtils.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/BuiltinTypes.h"
16
17 using namespace mlir;
18 using namespace mlir::vector;
19
20 // Helper that picks the proper sequence for inserting.
insertOne(PatternRewriter & rewriter,Location loc,Value from,Value into,int64_t offset)21 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
22 Value into, int64_t offset) {
23 auto vectorType = into.getType().cast<VectorType>();
24 if (vectorType.getRank() > 1)
25 return rewriter.create<InsertOp>(loc, from, into, offset);
26 return rewriter.create<vector::InsertElementOp>(
27 loc, vectorType, from, into,
28 rewriter.create<arith::ConstantIndexOp>(loc, offset));
29 }
30
31 // Helper that picks the proper sequence for extracting.
extractOne(PatternRewriter & rewriter,Location loc,Value vector,int64_t offset)32 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
33 int64_t offset) {
34 auto vectorType = vector.getType().cast<VectorType>();
35 if (vectorType.getRank() > 1)
36 return rewriter.create<ExtractOp>(loc, vector, offset);
37 return rewriter.create<vector::ExtractElementOp>(
38 loc, vectorType.getElementType(), vector,
39 rewriter.create<arith::ConstantIndexOp>(loc, offset));
40 }
41
42 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
43 /// have different ranks.
44 ///
45 /// When ranks are different, InsertStridedSlice needs to extract a properly
46 /// ranked vector from the destination vector into which to insert. This pattern
47 /// only takes care of this extraction part and forwards the rest to
48 /// [ConvertSameRankInsertStridedSliceIntoShuffle].
49 ///
50 /// For a k-D source and n-D destination vector (k < n), we emit:
51 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
52 /// insert the k-D source.
53 /// 2. k-D -> (n-1)-D InsertStridedSlice op
54 /// 3. InsertOp that is the reverse of 1.
55 class DecomposeDifferentRankInsertStridedSlice
56 : public OpRewritePattern<InsertStridedSliceOp> {
57 public:
58 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
59
matchAndRewrite(InsertStridedSliceOp op,PatternRewriter & rewriter) const60 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
61 PatternRewriter &rewriter) const override {
62 auto srcType = op.getSourceVectorType();
63 auto dstType = op.getDestVectorType();
64
65 if (op.getOffsets().getValue().empty())
66 return failure();
67
68 auto loc = op.getLoc();
69 int64_t rankDiff = dstType.getRank() - srcType.getRank();
70 assert(rankDiff >= 0);
71 if (rankDiff == 0)
72 return failure();
73
74 int64_t rankRest = dstType.getRank() - rankDiff;
75 // Extract / insert the subvector of matching rank and InsertStridedSlice
76 // on it.
77 Value extracted = rewriter.create<ExtractOp>(
78 loc, op.getDest(),
79 getI64SubArray(op.getOffsets(), /*dropFront=*/0,
80 /*dropBack=*/rankRest));
81
82 // A different pattern will kick in for InsertStridedSlice with matching
83 // ranks.
84 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
85 loc, op.getSource(), extracted,
86 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
87 getI64SubArray(op.getStrides(), /*dropFront=*/0));
88
89 rewriter.replaceOpWithNewOp<InsertOp>(
90 op, stridedSliceInnerOp.getResult(), op.getDest(),
91 getI64SubArray(op.getOffsets(), /*dropFront=*/0,
92 /*dropBack=*/rankRest));
93 return success();
94 }
95 };
96
97 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
98 /// have the same rank. For each outermost index in the slice:
99 /// begin end stride
100 /// [offset : offset+size*stride : stride]
101 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
102 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
103 /// 3. the destination subvector is inserted back in the proper place
104 /// 3. InsertOp that is the reverse of 1.
105 class ConvertSameRankInsertStridedSliceIntoShuffle
106 : public OpRewritePattern<InsertStridedSliceOp> {
107 public:
108 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
109
initialize()110 void initialize() {
111 // This pattern creates recursive InsertStridedSliceOp, but the recursion is
112 // bounded as the rank is strictly decreasing.
113 setHasBoundedRewriteRecursion();
114 }
115
matchAndRewrite(InsertStridedSliceOp op,PatternRewriter & rewriter) const116 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
117 PatternRewriter &rewriter) const override {
118 auto srcType = op.getSourceVectorType();
119 auto dstType = op.getDestVectorType();
120
121 if (op.getOffsets().getValue().empty())
122 return failure();
123
124 int64_t srcRank = srcType.getRank();
125 int64_t dstRank = dstType.getRank();
126 assert(dstRank >= srcRank);
127 if (dstRank != srcRank)
128 return failure();
129
130 if (srcType == dstType) {
131 rewriter.replaceOp(op, op.getSource());
132 return success();
133 }
134
135 int64_t offset =
136 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
137 int64_t size = srcType.getShape().front();
138 int64_t stride =
139 op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
140
141 auto loc = op.getLoc();
142 Value res = op.getDest();
143
144 if (srcRank == 1) {
145 int nSrc = srcType.getShape().front();
146 int nDest = dstType.getShape().front();
147 // 1. Scale source to destType so we can shufflevector them together.
148 SmallVector<int64_t> offsets(nDest, 0);
149 for (int64_t i = 0; i < nSrc; ++i)
150 offsets[i] = i;
151 Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
152 op.getSource(), offsets);
153
154 // 2. Create a mask where we take the value from scaledSource of dest
155 // depending on the offset.
156 offsets.clear();
157 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
158 if (i < offset || i >= e || (i - offset) % stride != 0)
159 offsets.push_back(nDest + i);
160 else
161 offsets.push_back((i - offset) / stride);
162 }
163
164 // 3. Replace with a ShuffleOp.
165 rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
166 offsets);
167
168 return success();
169 }
170
171 // For each slice of the source vector along the most major dimension.
172 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
173 off += stride, ++idx) {
174 // 1. extract the proper subvector (or element) from source
175 Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
176 if (extractedSource.getType().isa<VectorType>()) {
177 // 2. If we have a vector, extract the proper subvector from destination
178 // Otherwise we are at the element level and no need to recurse.
179 Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
180 // 3. Reduce the problem to lowering a new InsertStridedSlice op with
181 // smaller rank.
182 extractedSource = rewriter.create<InsertStridedSliceOp>(
183 loc, extractedSource, extractedDest,
184 getI64SubArray(op.getOffsets(), /* dropFront=*/1),
185 getI64SubArray(op.getStrides(), /* dropFront=*/1));
186 }
187 // 4. Insert the extractedSource into the res vector.
188 res = insertOne(rewriter, loc, extractedSource, res, off);
189 }
190
191 rewriter.replaceOp(op, res);
192 return success();
193 }
194 };
195
196 /// RewritePattern for ExtractStridedSliceOp where source and destination
197 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
198 class Convert1DExtractStridedSliceIntoShuffle
199 : public OpRewritePattern<ExtractStridedSliceOp> {
200 public:
201 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
202
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const203 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
204 PatternRewriter &rewriter) const override {
205 auto dstType = op.getType();
206
207 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
208
209 int64_t offset =
210 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
211 int64_t size =
212 op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
213 int64_t stride =
214 op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
215
216 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
217
218 // Single offset can be more efficiently shuffled.
219 if (op.getOffsets().getValue().size() != 1)
220 return failure();
221
222 SmallVector<int64_t, 4> offsets;
223 offsets.reserve(size);
224 for (int64_t off = offset, e = offset + size * stride; off < e;
225 off += stride)
226 offsets.push_back(off);
227 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
228 op.getVector(),
229 rewriter.getI64ArrayAttr(offsets));
230 return success();
231 }
232 };
233
234 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
235 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
236 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
237 class DecomposeNDExtractStridedSlice
238 : public OpRewritePattern<ExtractStridedSliceOp> {
239 public:
240 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
241
initialize()242 void initialize() {
243 // This pattern creates recursive ExtractStridedSliceOp, but the recursion
244 // is bounded as the rank is strictly decreasing.
245 setHasBoundedRewriteRecursion();
246 }
247
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const248 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
249 PatternRewriter &rewriter) const override {
250 auto dstType = op.getType();
251
252 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
253
254 int64_t offset =
255 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
256 int64_t size =
257 op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
258 int64_t stride =
259 op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
260
261 auto loc = op.getLoc();
262 auto elemType = dstType.getElementType();
263 assert(elemType.isSignlessIntOrIndexOrFloat());
264
265 // Single offset can be more efficiently shuffled. It's handled in
266 // Convert1DExtractStridedSliceIntoShuffle.
267 if (op.getOffsets().getValue().size() == 1)
268 return failure();
269
270 // Extract/insert on a lower ranked extract strided slice op.
271 Value zero = rewriter.create<arith::ConstantOp>(
272 loc, elemType, rewriter.getZeroAttr(elemType));
273 Value res = rewriter.create<SplatOp>(loc, dstType, zero);
274 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
275 off += stride, ++idx) {
276 Value one = extractOne(rewriter, loc, op.getVector(), off);
277 Value extracted = rewriter.create<ExtractStridedSliceOp>(
278 loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
279 getI64SubArray(op.getSizes(), /* dropFront=*/1),
280 getI64SubArray(op.getStrides(), /* dropFront=*/1));
281 res = insertOne(rewriter, loc, extracted, res, idx);
282 }
283 rewriter.replaceOp(op, res);
284 return success();
285 }
286 };
287
populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet & patterns)288 void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
289 RewritePatternSet &patterns) {
290 patterns.add<DecomposeDifferentRankInsertStridedSlice,
291 DecomposeNDExtractStridedSlice>(patterns.getContext());
292 }
293
294 /// Populate the given list with patterns that convert from Vector to LLVM.
populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet & patterns)295 void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
296 RewritePatternSet &patterns) {
297 populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns);
298 patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
299 Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext());
300 }
301