1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with optimizing transfer_read and
10 // transfer_write ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Dominance.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "vector-transfer-opt"
26 
27 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
28 
29 using namespace mlir;
30 
31 /// Return the ancestor op in the region or nullptr if the region is not
32 /// an ancestor of the op.
33 static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
34   for (; op != nullptr && op->getParentRegion() != region;
35        op = op->getParentOp())
36     ;
37   return op;
38 }
39 
40 namespace {
41 
42 class TransferOptimization {
43 public:
44   TransferOptimization(Operation *op) : dominators(op), postDominators(op) {}
45   void deadStoreOp(vector::TransferWriteOp);
46   void storeToLoadForwarding(vector::TransferReadOp);
47   void removeDeadOp() {
48     for (Operation *op : opToErase)
49       op->erase();
50     opToErase.clear();
51   }
52 
53 private:
54   bool isReachable(Operation *start, Operation *dest);
55   DominanceInfo dominators;
56   PostDominanceInfo postDominators;
57   std::vector<Operation *> opToErase;
58 };
59 
60 /// Return true if there is a path from start operation to dest operation,
61 /// otherwise return false. The operations have to be in the same region.
62 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
63   assert(start->getParentRegion() == dest->getParentRegion() &&
64          "This function only works for ops i the same region");
65   // Simple case where the start op dominate the destination.
66   if (dominators.dominates(start, dest))
67     return true;
68   Block *startBlock = start->getBlock();
69   Block *destBlock = dest->getBlock();
70   SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
71                                     startBlock->succ_end());
72   SmallPtrSet<Block *, 32> visited;
73   while (!worklist.empty()) {
74     Block *bb = worklist.pop_back_val();
75     if (!visited.insert(bb).second)
76       continue;
77     if (dominators.dominates(bb, destBlock))
78       return true;
79     worklist.append(bb->succ_begin(), bb->succ_end());
80   }
81   return false;
82 }
83 
84 /// For transfer_write to overwrite fully another transfer_write must:
85 /// 1. Access the same memref with the same indices and vector type.
86 /// 2. Post-dominate the other transfer_write operation.
87 /// If several candidates are available, one must be post-dominated by all the
88 /// others since they are all post-dominating the same transfer_write. We only
89 /// consider the transfer_write post-dominated by all the other candidates as
90 /// this will be the first transfer_write executed after the potentially dead
91 /// transfer_write.
92 /// If we found such an overwriting transfer_write we know that the original
93 /// transfer_write is dead if all reads that can be reached from the potentially
94 /// dead transfer_write are dominated by the overwriting transfer_write.
95 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
96   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
97                     << "\n");
98   llvm::SmallVector<Operation *, 8> reads;
99   Operation *firstOverwriteCandidate = nullptr;
100   for (auto *user : write.getSource().getUsers()) {
101     if (user == write.getOperation())
102       continue;
103     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
104       // Check candidate that can override the store.
105       if (checkSameValueWAW(nextWrite, write) &&
106           postDominators.postDominates(nextWrite, write)) {
107         if (firstOverwriteCandidate == nullptr ||
108             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
109           firstOverwriteCandidate = nextWrite;
110         else
111           assert(
112               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
113       }
114     } else {
115       if (auto read = dyn_cast<vector::TransferReadOp>(user)) {
116         // Don't need to consider disjoint reads.
117         if (vector::isDisjointTransferSet(
118                 cast<VectorTransferOpInterface>(write.getOperation()),
119                 cast<VectorTransferOpInterface>(read.getOperation())))
120           continue;
121       }
122       reads.push_back(user);
123     }
124   }
125   if (firstOverwriteCandidate == nullptr)
126     return;
127   Region *topRegion = firstOverwriteCandidate->getParentRegion();
128   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
129   assert(writeAncestor &&
130          "write op should be recursively part of the top region");
131 
132   for (Operation *read : reads) {
133     Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
134     // TODO: if the read and write have the same ancestor we could recurse in
135     // the region to know if the read is reachable with more precision.
136     if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
137       continue;
138     if (!dominators.dominates(firstOverwriteCandidate, read)) {
139       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read
140                         << "\n");
141       return;
142     }
143   }
144   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
145                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
146   opToErase.push_back(write.getOperation());
147 }
148 
149 /// A transfer_write candidate to storeToLoad forwarding must:
150 /// 1. Access the same memref with the same indices and vector type as the
151 /// transfer_read.
152 /// 2. Dominate the transfer_read operation.
153 /// If several candidates are available, one must be dominated by all the others
154 /// since they are all dominating the same transfer_read. We only consider the
155 /// transfer_write dominated by all the other candidates as this will be the
156 /// last transfer_write executed before the transfer_read.
157 /// If we found such a candidate we can do the forwarding if all the other
158 /// potentially aliasing ops that may reach the transfer_read are post-dominated
159 /// by the transfer_write.
160 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
161   if (read.hasOutOfBoundsDim())
162     return;
163   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
164                     << "\n");
165   SmallVector<Operation *, 8> blockingWrites;
166   vector::TransferWriteOp lastwrite = nullptr;
167   for (Operation *user : read.getSource().getUsers()) {
168     if (isa<vector::TransferReadOp>(user))
169       continue;
170     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
171       // If there is a write, but we can prove that it is disjoint we can ignore
172       // the write.
173       if (vector::isDisjointTransferSet(
174               cast<VectorTransferOpInterface>(write.getOperation()),
175               cast<VectorTransferOpInterface>(read.getOperation())))
176         continue;
177       if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
178         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
179           lastwrite = write;
180         else
181           assert(dominators.dominates(write, lastwrite));
182         continue;
183       }
184     }
185     blockingWrites.push_back(user);
186   }
187 
188   if (lastwrite == nullptr)
189     return;
190 
191   Region *topRegion = lastwrite->getParentRegion();
192   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
193   assert(readAncestor &&
194          "read op should be recursively part of the top region");
195 
196   for (Operation *write : blockingWrites) {
197     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
198     // TODO: if the store and read have the same ancestor we could recurse in
199     // the region to know if the read is reachable with more precision.
200     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
201       continue;
202     if (!postDominators.postDominates(lastwrite, write)) {
203       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
204                         << *write << "\n");
205       return;
206     }
207   }
208 
209   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
210                     << " to: " << *read.getOperation() << "\n");
211   read.replaceAllUsesWith(lastwrite.getVector());
212   opToErase.push_back(read.getOperation());
213 }
214 
215 /// Drops unit dimensions from the input MemRefType.
216 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
217                                ArrayRef<int64_t> sizes,
218                                ArrayRef<int64_t> strides) {
219   SmallVector<int64_t> targetShape = llvm::to_vector(
220       llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
221   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
222       targetShape, inputType, offsets, sizes, strides);
223   return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
224 }
225 
226 /// Creates a rank-reducing memref.subview op that drops unit dims from its
227 /// input. Or just returns the input if it was already without unit dims.
228 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
229                                                  mlir::Location loc,
230                                                  Value input) {
231   MemRefType inputType = input.getType().cast<MemRefType>();
232   assert(inputType.hasStaticShape());
233   SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
234   SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
235   ArrayRef<int64_t> subViewSizes = inputType.getShape();
236   MemRefType resultType =
237       dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
238   if (canonicalizeStridedLayout(resultType) ==
239       canonicalizeStridedLayout(inputType))
240     return input;
241   return rewriter.create<memref::SubViewOp>(
242       loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
243 }
244 
245 /// Returns the number of dims that aren't unit dims.
246 static int getReducedRank(ArrayRef<int64_t> shape) {
247   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
248 }
249 
250 /// Returns true if all values are `arith.constant 0 : index`
251 static bool isZero(Value v) {
252   auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
253   return cst && cst.value() == 0;
254 }
255 
256 /// Rewrites vector.transfer_read ops where the source has unit dims, by
257 /// inserting a memref.subview dropping those unit dims.
258 class TransferReadDropUnitDimsPattern
259     : public OpRewritePattern<vector::TransferReadOp> {
260   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
261 
262   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
263                                 PatternRewriter &rewriter) const override {
264     auto loc = transferReadOp.getLoc();
265     Value vector = transferReadOp.getVector();
266     VectorType vectorType = vector.getType().cast<VectorType>();
267     Value source = transferReadOp.getSource();
268     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
269     // TODO: support tensor types.
270     if (!sourceType || !sourceType.hasStaticShape())
271       return failure();
272     if (sourceType.getNumElements() != vectorType.getNumElements())
273       return failure();
274     // TODO: generalize this pattern, relax the requirements here.
275     if (transferReadOp.hasOutOfBoundsDim())
276       return failure();
277     if (!transferReadOp.getPermutationMap().isMinorIdentity())
278       return failure();
279     int reducedRank = getReducedRank(sourceType.getShape());
280     if (reducedRank == sourceType.getRank())
281       return failure(); // The source shape can't be further reduced.
282     if (reducedRank != vectorType.getRank())
283       return failure(); // This pattern requires the vector shape to match the
284                         // reduced source shape.
285     if (llvm::any_of(transferReadOp.getIndices(),
286                      [](Value v) { return !isZero(v); }))
287       return failure();
288     Value reducedShapeSource =
289         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
290     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
291     SmallVector<Value> zeros(reducedRank, c0);
292     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
293     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
294         transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
295     return success();
296   }
297 };
298 
299 /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
300 /// unit dims, by inserting a memref.subview dropping those unit dims.
301 class TransferWriteDropUnitDimsPattern
302     : public OpRewritePattern<vector::TransferWriteOp> {
303   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
304 
305   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
306                                 PatternRewriter &rewriter) const override {
307     auto loc = transferWriteOp.getLoc();
308     Value vector = transferWriteOp.getVector();
309     VectorType vectorType = vector.getType().cast<VectorType>();
310     Value source = transferWriteOp.getSource();
311     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
312     // TODO: support tensor type.
313     if (!sourceType || !sourceType.hasStaticShape())
314       return failure();
315     if (sourceType.getNumElements() != vectorType.getNumElements())
316       return failure();
317     // TODO: generalize this pattern, relax the requirements here.
318     if (transferWriteOp.hasOutOfBoundsDim())
319       return failure();
320     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
321       return failure();
322     int reducedRank = getReducedRank(sourceType.getShape());
323     if (reducedRank == sourceType.getRank())
324       return failure(); // The source shape can't be further reduced.
325     if (reducedRank != vectorType.getRank())
326       return failure(); // This pattern requires the vector shape to match the
327                         // reduced source shape.
328     if (llvm::any_of(transferWriteOp.getIndices(),
329                      [](Value v) { return !isZero(v); }))
330       return failure();
331     Value reducedShapeSource =
332         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
333     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
334     SmallVector<Value> zeros(reducedRank, c0);
335     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
336     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
337         transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
338     return success();
339   }
340 };
341 
342 /// Returns the position of the first inner dimension that has contiguous layout
343 /// with at least `requiredContiguousSize` contiguous elements.
344 /// When such a dimension is found, the return value satisfies:
345 ///   0 <= return_value <= memrefType.getRank() - 1.
346 /// When no such dimension is found, the return value is memrefType.getRank().
347 static int64_t getContiguousInnerDim(MemRefType memrefType,
348                                      int64_t requiredContiguousSize) {
349   auto shape = memrefType.getShape();
350   SmallVector<int64_t> strides;
351   int64_t offset;
352   int64_t innerDim = shape.size();
353   if (succeeded(getStridesAndOffset(memrefType, strides, offset))) {
354     int64_t innerSize = 1;
355     while (true) {
356       if (innerDim == 0)
357         break;
358       const int64_t nextDim = innerDim - 1;
359       if (shape[nextDim] == ShapedType::kDynamicSize)
360         break;
361       if (strides[nextDim] != innerSize)
362         break;
363       innerSize *= shape[nextDim];
364       innerDim = nextDim;
365       if (innerSize >= requiredContiguousSize)
366         break;
367     }
368   }
369   return innerDim;
370 }
371 
372 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
373 /// input starting at `firstDimToCollapse`.
374 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
375                                Value input, int64_t firstDimToCollapse) {
376   ShapedType inputType = input.getType().cast<ShapedType>();
377   if (inputType.getRank() == 1)
378     return input;
379   SmallVector<ReassociationIndices> reassociation;
380   for (int64_t i = 0; i < firstDimToCollapse; ++i)
381     reassociation.push_back(ReassociationIndices{i});
382   ReassociationIndices collapsedIndices;
383   for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
384     collapsedIndices.push_back(i);
385   reassociation.push_back(collapsedIndices);
386   return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
387 }
388 
389 /// Checks that the indices corresponding to dimensions starting at
390 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
391 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
392 static LogicalResult
393 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
394                                  SmallVector<Value> &outIndices) {
395   int64_t rank = indices.size();
396   if (firstDimToCollapse >= rank)
397     return failure();
398   for (int64_t i = firstDimToCollapse; i < rank; ++i) {
399     arith::ConstantIndexOp cst =
400         indices[i].getDefiningOp<arith::ConstantIndexOp>();
401     if (!cst || cst.value() != 0)
402       return failure();
403   }
404   outIndices = indices;
405   outIndices.resize(firstDimToCollapse + 1);
406   return success();
407 }
408 
409 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
410 /// memref.collapse_shape on the source so that the resulting
411 /// vector.transfer_read has a 1D source. Requires the source shape to be
412 /// already reduced i.e. without unit dims.
413 class FlattenContiguousRowMajorTransferReadPattern
414     : public OpRewritePattern<vector::TransferReadOp> {
415   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
416 
417   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
418                                 PatternRewriter &rewriter) const override {
419     auto loc = transferReadOp.getLoc();
420     Value vector = transferReadOp.getVector();
421     VectorType vectorType = vector.getType().cast<VectorType>();
422     Value source = transferReadOp.getSource();
423     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
424     // Contiguity check is valid on tensors only.
425     if (!sourceType)
426       return failure();
427     if (vectorType.getRank() <= 1)
428       // Already 0D/1D, nothing to do.
429       return failure();
430     int64_t firstContiguousInnerDim =
431         getContiguousInnerDim(sourceType, vectorType.getNumElements());
432     if (firstContiguousInnerDim >= sourceType.getRank() - 1)
433       return failure();
434     // TODO: generalize this pattern, relax the requirements here.
435     if (transferReadOp.hasOutOfBoundsDim())
436       return failure();
437     if (!transferReadOp.getPermutationMap().isMinorIdentity())
438       return failure();
439     if (transferReadOp.getMask())
440       return failure();
441     SmallVector<Value> collapsedIndices;
442     if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
443                                                 firstContiguousInnerDim,
444                                                 collapsedIndices)))
445       return failure();
446     Value collapsedSource =
447         collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
448     MemRefType collapsedSourceType =
449         collapsedSource.getType().dyn_cast<MemRefType>();
450     int64_t collapsedRank = collapsedSourceType.getRank();
451     assert(collapsedRank == firstContiguousInnerDim + 1);
452     SmallVector<AffineExpr, 1> dimExprs{
453         getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
454     auto collapsedMap =
455         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
456     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
457                                                 vectorType.getElementType());
458     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
459         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
460     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
461     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
462         transferReadOp, vector.getType().cast<VectorType>(), flatRead);
463     return success();
464   }
465 };
466 
467 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
468 /// memref.collapse_shape on the source so that the resulting
469 /// vector.transfer_write has a 1D source. Requires the source shape to be
470 /// already reduced i.e. without unit dims.
471 class FlattenContiguousRowMajorTransferWritePattern
472     : public OpRewritePattern<vector::TransferWriteOp> {
473   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
474 
475   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
476                                 PatternRewriter &rewriter) const override {
477     auto loc = transferWriteOp.getLoc();
478     Value vector = transferWriteOp.getVector();
479     VectorType vectorType = vector.getType().cast<VectorType>();
480     Value source = transferWriteOp.getSource();
481     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
482     // Contiguity check is valid on tensors only.
483     if (!sourceType)
484       return failure();
485     if (vectorType.getRank() <= 1)
486       // Already 0D/1D, nothing to do.
487       return failure();
488     int64_t firstContiguousInnerDim =
489         getContiguousInnerDim(sourceType, vectorType.getNumElements());
490     if (firstContiguousInnerDim >= sourceType.getRank() - 1)
491       return failure();
492     // TODO: generalize this pattern, relax the requirements here.
493     if (transferWriteOp.hasOutOfBoundsDim())
494       return failure();
495     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
496       return failure();
497     if (transferWriteOp.getMask())
498       return failure();
499     SmallVector<Value> collapsedIndices;
500     if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
501                                                 firstContiguousInnerDim,
502                                                 collapsedIndices)))
503       return failure();
504     Value collapsedSource =
505         collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
506     MemRefType collapsedSourceType =
507         collapsedSource.getType().cast<MemRefType>();
508     int64_t collapsedRank = collapsedSourceType.getRank();
509     assert(collapsedRank == firstContiguousInnerDim + 1);
510     SmallVector<AffineExpr, 1> dimExprs{
511         getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
512     auto collapsedMap =
513         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
514     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
515                                                 vectorType.getElementType());
516     Value flatVector =
517         rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
518     vector::TransferWriteOp flatWrite =
519         rewriter.create<vector::TransferWriteOp>(
520             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
521     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
522     rewriter.eraseOp(transferWriteOp);
523     return success();
524   }
525 };
526 
527 } // namespace
528 
529 void mlir::vector::transferOpflowOpt(Operation *rootOp) {
530   TransferOptimization opt(rootOp);
531   // Run store to load forwarding first since it can expose more dead store
532   // opportunity.
533   rootOp->walk([&](vector::TransferReadOp read) {
534     if (read.getShapedType().isa<MemRefType>())
535       opt.storeToLoadForwarding(read);
536   });
537   opt.removeDeadOp();
538   rootOp->walk([&](vector::TransferWriteOp write) {
539     if (write.getShapedType().isa<MemRefType>())
540       opt.deadStoreOp(write);
541   });
542   opt.removeDeadOp();
543 }
544 
545 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
546     RewritePatternSet &patterns) {
547   patterns
548       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
549           patterns.getContext());
550   populateShapeCastFoldingPatterns(patterns);
551 }
552 
553 void mlir::vector::populateFlattenVectorTransferPatterns(
554     RewritePatternSet &patterns) {
555   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
556                FlattenContiguousRowMajorTransferWritePattern>(
557       patterns.getContext());
558   populateShapeCastFoldingPatterns(patterns);
559 }
560