1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Dominance.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/Debug.h"
24
25 #define DEBUG_TYPE "vector-transfer-opt"
26
27 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
28
29 using namespace mlir;
30
31 /// Return the ancestor op in the region or nullptr if the region is not
32 /// an ancestor of the op.
findAncestorOpInRegion(Region * region,Operation * op)33 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
34 for (; op != nullptr && op->getParentRegion() != region;
35 op = op->getParentOp())
36 ;
37 return op;
38 }
39
40 namespace {
41
42 class TransferOptimization {
43 public:
TransferOptimization(Operation * op)44 TransferOptimization(Operation *op) : dominators(op), postDominators(op) {}
45 void deadStoreOp(vector::TransferWriteOp);
46 void storeToLoadForwarding(vector::TransferReadOp);
removeDeadOp()47 void removeDeadOp() {
48 for (Operation *op : opToErase)
49 op->erase();
50 opToErase.clear();
51 }
52
53 private:
54 bool isReachable(Operation *start, Operation *dest);
55 DominanceInfo dominators;
56 PostDominanceInfo postDominators;
57 std::vector<Operation *> opToErase;
58 };
59
60 /// Return true if there is a path from start operation to dest operation,
61 /// otherwise return false. The operations have to be in the same region.
isReachable(Operation * start,Operation * dest)62 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
63 assert(start->getParentRegion() == dest->getParentRegion() &&
64 "This function only works for ops i the same region");
65 // Simple case where the start op dominate the destination.
66 if (dominators.dominates(start, dest))
67 return true;
68 Block *startBlock = start->getBlock();
69 Block *destBlock = dest->getBlock();
70 SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
71 startBlock->succ_end());
72 SmallPtrSet<Block *, 32> visited;
73 while (!worklist.empty()) {
74 Block *bb = worklist.pop_back_val();
75 if (!visited.insert(bb).second)
76 continue;
77 if (dominators.dominates(bb, destBlock))
78 return true;
79 worklist.append(bb->succ_begin(), bb->succ_end());
80 }
81 return false;
82 }
83
84 /// For transfer_write to overwrite fully another transfer_write must:
85 /// 1. Access the same memref with the same indices and vector type.
86 /// 2. Post-dominate the other transfer_write operation.
87 /// If several candidates are available, one must be post-dominated by all the
88 /// others since they are all post-dominating the same transfer_write. We only
89 /// consider the transfer_write post-dominated by all the other candidates as
90 /// this will be the first transfer_write executed after the potentially dead
91 /// transfer_write.
92 /// If we found such an overwriting transfer_write we know that the original
93 /// transfer_write is dead if all reads that can be reached from the potentially
94 /// dead transfer_write are dominated by the overwriting transfer_write.
deadStoreOp(vector::TransferWriteOp write)95 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
96 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
97 << "\n");
98 llvm::SmallVector<Operation *, 8> reads;
99 Operation *firstOverwriteCandidate = nullptr;
100 for (auto *user : write.getSource().getUsers()) {
101 if (user == write.getOperation())
102 continue;
103 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
104 // Check candidate that can override the store.
105 if (checkSameValueWAW(nextWrite, write) &&
106 postDominators.postDominates(nextWrite, write)) {
107 if (firstOverwriteCandidate == nullptr ||
108 postDominators.postDominates(firstOverwriteCandidate, nextWrite))
109 firstOverwriteCandidate = nextWrite;
110 else
111 assert(
112 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
113 }
114 } else {
115 if (auto read = dyn_cast<vector::TransferReadOp>(user)) {
116 // Don't need to consider disjoint reads.
117 if (vector::isDisjointTransferSet(
118 cast<VectorTransferOpInterface>(write.getOperation()),
119 cast<VectorTransferOpInterface>(read.getOperation())))
120 continue;
121 }
122 reads.push_back(user);
123 }
124 }
125 if (firstOverwriteCandidate == nullptr)
126 return;
127 Region *topRegion = firstOverwriteCandidate->getParentRegion();
128 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
129 assert(writeAncestor &&
130 "write op should be recursively part of the top region");
131
132 for (Operation *read : reads) {
133 Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
134 // TODO: if the read and write have the same ancestor we could recurse in
135 // the region to know if the read is reachable with more precision.
136 if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
137 continue;
138 if (!dominators.dominates(firstOverwriteCandidate, read)) {
139 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read
140 << "\n");
141 return;
142 }
143 }
144 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
145 << " overwritten by: " << *firstOverwriteCandidate << "\n");
146 opToErase.push_back(write.getOperation());
147 }
148
149 /// A transfer_write candidate to storeToLoad forwarding must:
150 /// 1. Access the same memref with the same indices and vector type as the
151 /// transfer_read.
152 /// 2. Dominate the transfer_read operation.
153 /// If several candidates are available, one must be dominated by all the others
154 /// since they are all dominating the same transfer_read. We only consider the
155 /// transfer_write dominated by all the other candidates as this will be the
156 /// last transfer_write executed before the transfer_read.
157 /// If we found such a candidate we can do the forwarding if all the other
158 /// potentially aliasing ops that may reach the transfer_read are post-dominated
159 /// by the transfer_write.
storeToLoadForwarding(vector::TransferReadOp read)160 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
161 if (read.hasOutOfBoundsDim())
162 return;
163 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
164 << "\n");
165 SmallVector<Operation *, 8> blockingWrites;
166 vector::TransferWriteOp lastwrite = nullptr;
167 for (Operation *user : read.getSource().getUsers()) {
168 if (isa<vector::TransferReadOp>(user))
169 continue;
170 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
171 // If there is a write, but we can prove that it is disjoint we can ignore
172 // the write.
173 if (vector::isDisjointTransferSet(
174 cast<VectorTransferOpInterface>(write.getOperation()),
175 cast<VectorTransferOpInterface>(read.getOperation())))
176 continue;
177 if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
178 if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
179 lastwrite = write;
180 else
181 assert(dominators.dominates(write, lastwrite));
182 continue;
183 }
184 }
185 blockingWrites.push_back(user);
186 }
187
188 if (lastwrite == nullptr)
189 return;
190
191 Region *topRegion = lastwrite->getParentRegion();
192 Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
193 assert(readAncestor &&
194 "read op should be recursively part of the top region");
195
196 for (Operation *write : blockingWrites) {
197 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
198 // TODO: if the store and read have the same ancestor we could recurse in
199 // the region to know if the read is reachable with more precision.
200 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
201 continue;
202 if (!postDominators.postDominates(lastwrite, write)) {
203 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
204 << *write << "\n");
205 return;
206 }
207 }
208
209 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
210 << " to: " << *read.getOperation() << "\n");
211 read.replaceAllUsesWith(lastwrite.getVector());
212 opToErase.push_back(read.getOperation());
213 }
214
215 /// Drops unit dimensions from the input MemRefType.
dropUnitDims(MemRefType inputType,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)216 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
217 ArrayRef<int64_t> sizes,
218 ArrayRef<int64_t> strides) {
219 SmallVector<int64_t> targetShape = llvm::to_vector(
220 llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
221 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
222 targetShape, inputType, offsets, sizes, strides);
223 return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
224 }
225
226 /// Creates a rank-reducing memref.subview op that drops unit dims from its
227 /// input. Or just returns the input if it was already without unit dims.
rankReducingSubviewDroppingUnitDims(PatternRewriter & rewriter,mlir::Location loc,Value input)228 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
229 mlir::Location loc,
230 Value input) {
231 MemRefType inputType = input.getType().cast<MemRefType>();
232 assert(inputType.hasStaticShape());
233 SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
234 SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
235 ArrayRef<int64_t> subViewSizes = inputType.getShape();
236 MemRefType resultType =
237 dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
238 if (canonicalizeStridedLayout(resultType) ==
239 canonicalizeStridedLayout(inputType))
240 return input;
241 return rewriter.create<memref::SubViewOp>(
242 loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
243 }
244
245 /// Returns the number of dims that aren't unit dims.
getReducedRank(ArrayRef<int64_t> shape)246 static int getReducedRank(ArrayRef<int64_t> shape) {
247 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
248 }
249
250 /// Returns true if all values are `arith.constant 0 : index`
isZero(Value v)251 static bool isZero(Value v) {
252 auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
253 return cst && cst.value() == 0;
254 }
255
256 /// Rewrites vector.transfer_read ops where the source has unit dims, by
257 /// inserting a memref.subview dropping those unit dims.
258 class TransferReadDropUnitDimsPattern
259 : public OpRewritePattern<vector::TransferReadOp> {
260 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
261
matchAndRewrite(vector::TransferReadOp transferReadOp,PatternRewriter & rewriter) const262 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
263 PatternRewriter &rewriter) const override {
264 auto loc = transferReadOp.getLoc();
265 Value vector = transferReadOp.getVector();
266 VectorType vectorType = vector.getType().cast<VectorType>();
267 Value source = transferReadOp.getSource();
268 MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
269 // TODO: support tensor types.
270 if (!sourceType || !sourceType.hasStaticShape())
271 return failure();
272 if (sourceType.getNumElements() != vectorType.getNumElements())
273 return failure();
274 // TODO: generalize this pattern, relax the requirements here.
275 if (transferReadOp.hasOutOfBoundsDim())
276 return failure();
277 if (!transferReadOp.getPermutationMap().isMinorIdentity())
278 return failure();
279 int reducedRank = getReducedRank(sourceType.getShape());
280 if (reducedRank == sourceType.getRank())
281 return failure(); // The source shape can't be further reduced.
282 if (reducedRank != vectorType.getRank())
283 return failure(); // This pattern requires the vector shape to match the
284 // reduced source shape.
285 if (llvm::any_of(transferReadOp.getIndices(),
286 [](Value v) { return !isZero(v); }))
287 return failure();
288 Value reducedShapeSource =
289 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
290 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
291 SmallVector<Value> zeros(reducedRank, c0);
292 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
293 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
294 transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
295 return success();
296 }
297 };
298
299 /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
300 /// unit dims, by inserting a memref.subview dropping those unit dims.
301 class TransferWriteDropUnitDimsPattern
302 : public OpRewritePattern<vector::TransferWriteOp> {
303 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
304
matchAndRewrite(vector::TransferWriteOp transferWriteOp,PatternRewriter & rewriter) const305 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
306 PatternRewriter &rewriter) const override {
307 auto loc = transferWriteOp.getLoc();
308 Value vector = transferWriteOp.getVector();
309 VectorType vectorType = vector.getType().cast<VectorType>();
310 Value source = transferWriteOp.getSource();
311 MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
312 // TODO: support tensor type.
313 if (!sourceType || !sourceType.hasStaticShape())
314 return failure();
315 if (sourceType.getNumElements() != vectorType.getNumElements())
316 return failure();
317 // TODO: generalize this pattern, relax the requirements here.
318 if (transferWriteOp.hasOutOfBoundsDim())
319 return failure();
320 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
321 return failure();
322 int reducedRank = getReducedRank(sourceType.getShape());
323 if (reducedRank == sourceType.getRank())
324 return failure(); // The source shape can't be further reduced.
325 if (reducedRank != vectorType.getRank())
326 return failure(); // This pattern requires the vector shape to match the
327 // reduced source shape.
328 if (llvm::any_of(transferWriteOp.getIndices(),
329 [](Value v) { return !isZero(v); }))
330 return failure();
331 Value reducedShapeSource =
332 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
333 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
334 SmallVector<Value> zeros(reducedRank, c0);
335 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
336 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
337 transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
338 return success();
339 }
340 };
341
342 /// Returns the position of the first inner dimension that has contiguous layout
343 /// with at least `requiredContiguousSize` contiguous elements.
344 /// When such a dimension is found, the return value satisfies:
345 /// 0 <= return_value <= memrefType.getRank() - 1.
346 /// When no such dimension is found, the return value is memrefType.getRank().
getContiguousInnerDim(MemRefType memrefType,int64_t requiredContiguousSize)347 static int64_t getContiguousInnerDim(MemRefType memrefType,
348 int64_t requiredContiguousSize) {
349 auto shape = memrefType.getShape();
350 SmallVector<int64_t> strides;
351 int64_t offset;
352 int64_t innerDim = shape.size();
353 if (succeeded(getStridesAndOffset(memrefType, strides, offset))) {
354 int64_t innerSize = 1;
355 while (true) {
356 if (innerDim == 0)
357 break;
358 const int64_t nextDim = innerDim - 1;
359 if (shape[nextDim] == ShapedType::kDynamicSize)
360 break;
361 if (strides[nextDim] != innerSize)
362 break;
363 innerSize *= shape[nextDim];
364 innerDim = nextDim;
365 if (innerSize >= requiredContiguousSize)
366 break;
367 }
368 }
369 return innerDim;
370 }
371
372 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
373 /// input starting at `firstDimToCollapse`.
collapseInnerDims(PatternRewriter & rewriter,mlir::Location loc,Value input,int64_t firstDimToCollapse)374 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
375 Value input, int64_t firstDimToCollapse) {
376 ShapedType inputType = input.getType().cast<ShapedType>();
377 if (inputType.getRank() == 1)
378 return input;
379 SmallVector<ReassociationIndices> reassociation;
380 for (int64_t i = 0; i < firstDimToCollapse; ++i)
381 reassociation.push_back(ReassociationIndices{i});
382 ReassociationIndices collapsedIndices;
383 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
384 collapsedIndices.push_back(i);
385 reassociation.push_back(collapsedIndices);
386 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
387 }
388
389 /// Checks that the indices corresponding to dimensions starting at
390 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
391 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
392 static LogicalResult
checkAndCollapseInnerZeroIndices(ValueRange indices,int64_t firstDimToCollapse,SmallVector<Value> & outIndices)393 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
394 SmallVector<Value> &outIndices) {
395 int64_t rank = indices.size();
396 if (firstDimToCollapse >= rank)
397 return failure();
398 for (int64_t i = firstDimToCollapse; i < rank; ++i) {
399 arith::ConstantIndexOp cst =
400 indices[i].getDefiningOp<arith::ConstantIndexOp>();
401 if (!cst || cst.value() != 0)
402 return failure();
403 }
404 outIndices = indices;
405 outIndices.resize(firstDimToCollapse + 1);
406 return success();
407 }
408
409 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
410 /// memref.collapse_shape on the source so that the resulting
411 /// vector.transfer_read has a 1D source. Requires the source shape to be
412 /// already reduced i.e. without unit dims.
413 class FlattenContiguousRowMajorTransferReadPattern
414 : public OpRewritePattern<vector::TransferReadOp> {
415 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
416
matchAndRewrite(vector::TransferReadOp transferReadOp,PatternRewriter & rewriter) const417 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
418 PatternRewriter &rewriter) const override {
419 auto loc = transferReadOp.getLoc();
420 Value vector = transferReadOp.getVector();
421 VectorType vectorType = vector.getType().cast<VectorType>();
422 Value source = transferReadOp.getSource();
423 MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
424 // Contiguity check is valid on tensors only.
425 if (!sourceType)
426 return failure();
427 if (vectorType.getRank() <= 1)
428 // Already 0D/1D, nothing to do.
429 return failure();
430 int64_t firstContiguousInnerDim =
431 getContiguousInnerDim(sourceType, vectorType.getNumElements());
432 if (firstContiguousInnerDim >= sourceType.getRank() - 1)
433 return failure();
434 // TODO: generalize this pattern, relax the requirements here.
435 if (transferReadOp.hasOutOfBoundsDim())
436 return failure();
437 if (!transferReadOp.getPermutationMap().isMinorIdentity())
438 return failure();
439 if (transferReadOp.getMask())
440 return failure();
441 SmallVector<Value> collapsedIndices;
442 if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
443 firstContiguousInnerDim,
444 collapsedIndices)))
445 return failure();
446 Value collapsedSource =
447 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
448 MemRefType collapsedSourceType =
449 collapsedSource.getType().dyn_cast<MemRefType>();
450 int64_t collapsedRank = collapsedSourceType.getRank();
451 assert(collapsedRank == firstContiguousInnerDim + 1);
452 SmallVector<AffineExpr, 1> dimExprs{
453 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
454 auto collapsedMap =
455 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
456 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
457 vectorType.getElementType());
458 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
459 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
460 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
461 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
462 transferReadOp, vector.getType().cast<VectorType>(), flatRead);
463 return success();
464 }
465 };
466
467 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
468 /// memref.collapse_shape on the source so that the resulting
469 /// vector.transfer_write has a 1D source. Requires the source shape to be
470 /// already reduced i.e. without unit dims.
471 class FlattenContiguousRowMajorTransferWritePattern
472 : public OpRewritePattern<vector::TransferWriteOp> {
473 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
474
matchAndRewrite(vector::TransferWriteOp transferWriteOp,PatternRewriter & rewriter) const475 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
476 PatternRewriter &rewriter) const override {
477 auto loc = transferWriteOp.getLoc();
478 Value vector = transferWriteOp.getVector();
479 VectorType vectorType = vector.getType().cast<VectorType>();
480 Value source = transferWriteOp.getSource();
481 MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
482 // Contiguity check is valid on tensors only.
483 if (!sourceType)
484 return failure();
485 if (vectorType.getRank() <= 1)
486 // Already 0D/1D, nothing to do.
487 return failure();
488 int64_t firstContiguousInnerDim =
489 getContiguousInnerDim(sourceType, vectorType.getNumElements());
490 if (firstContiguousInnerDim >= sourceType.getRank() - 1)
491 return failure();
492 // TODO: generalize this pattern, relax the requirements here.
493 if (transferWriteOp.hasOutOfBoundsDim())
494 return failure();
495 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
496 return failure();
497 if (transferWriteOp.getMask())
498 return failure();
499 SmallVector<Value> collapsedIndices;
500 if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
501 firstContiguousInnerDim,
502 collapsedIndices)))
503 return failure();
504 Value collapsedSource =
505 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
506 MemRefType collapsedSourceType =
507 collapsedSource.getType().cast<MemRefType>();
508 int64_t collapsedRank = collapsedSourceType.getRank();
509 assert(collapsedRank == firstContiguousInnerDim + 1);
510 SmallVector<AffineExpr, 1> dimExprs{
511 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
512 auto collapsedMap =
513 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
514 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
515 vectorType.getElementType());
516 Value flatVector =
517 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
518 vector::TransferWriteOp flatWrite =
519 rewriter.create<vector::TransferWriteOp>(
520 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
521 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
522 rewriter.eraseOp(transferWriteOp);
523 return success();
524 }
525 };
526
527 } // namespace
528
transferOpflowOpt(Operation * rootOp)529 void mlir::vector::transferOpflowOpt(Operation *rootOp) {
530 TransferOptimization opt(rootOp);
531 // Run store to load forwarding first since it can expose more dead store
532 // opportunity.
533 rootOp->walk([&](vector::TransferReadOp read) {
534 if (read.getShapedType().isa<MemRefType>())
535 opt.storeToLoadForwarding(read);
536 });
537 opt.removeDeadOp();
538 rootOp->walk([&](vector::TransferWriteOp write) {
539 if (write.getShapedType().isa<MemRefType>())
540 opt.deadStoreOp(write);
541 });
542 opt.removeDeadOp();
543 }
544
populateVectorTransferDropUnitDimsPatterns(RewritePatternSet & patterns)545 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
546 RewritePatternSet &patterns) {
547 patterns
548 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
549 patterns.getContext());
550 populateShapeCastFoldingPatterns(patterns);
551 }
552
populateFlattenVectorTransferPatterns(RewritePatternSet & patterns)553 void mlir::vector::populateFlattenVectorTransferPatterns(
554 RewritePatternSet &patterns) {
555 patterns.add<FlattenContiguousRowMajorTransferReadPattern,
556 FlattenContiguousRowMajorTransferWritePattern>(
557 patterns.getContext());
558 populateShapeCastFoldingPatterns(patterns);
559 }
560