1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/Dominance.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "vector-transfer-opt" 26 27 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 28 29 using namespace mlir; 30 31 /// Return the ancestor op in the region or nullptr if the region is not 32 /// an ancestor of the op. 33 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 34 for (; op != nullptr && op->getParentRegion() != region; 35 op = op->getParentOp()) 36 ; 37 return op; 38 } 39 40 namespace { 41 42 class TransferOptimization { 43 public: 44 TransferOptimization(Operation *op) : dominators(op), postDominators(op) {} 45 void deadStoreOp(vector::TransferWriteOp); 46 void storeToLoadForwarding(vector::TransferReadOp); 47 void removeDeadOp() { 48 for (Operation *op : opToErase) 49 op->erase(); 50 opToErase.clear(); 51 } 52 53 private: 54 bool isReachable(Operation *start, Operation *dest); 55 DominanceInfo dominators; 56 PostDominanceInfo postDominators; 57 std::vector<Operation *> opToErase; 58 }; 59 60 /// Return true if there is a path from start operation to dest operation, 61 /// otherwise return false. The operations have to be in the same region. 62 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 63 assert(start->getParentRegion() == dest->getParentRegion() && 64 "This function only works for ops i the same region"); 65 // Simple case where the start op dominate the destination. 66 if (dominators.dominates(start, dest)) 67 return true; 68 Block *startBlock = start->getBlock(); 69 Block *destBlock = dest->getBlock(); 70 SmallVector<Block *, 32> worklist(startBlock->succ_begin(), 71 startBlock->succ_end()); 72 SmallPtrSet<Block *, 32> visited; 73 while (!worklist.empty()) { 74 Block *bb = worklist.pop_back_val(); 75 if (!visited.insert(bb).second) 76 continue; 77 if (dominators.dominates(bb, destBlock)) 78 return true; 79 worklist.append(bb->succ_begin(), bb->succ_end()); 80 } 81 return false; 82 } 83 84 /// For transfer_write to overwrite fully another transfer_write must: 85 /// 1. Access the same memref with the same indices and vector type. 86 /// 2. Post-dominate the other transfer_write operation. 87 /// If several candidates are available, one must be post-dominated by all the 88 /// others since they are all post-dominating the same transfer_write. We only 89 /// consider the transfer_write post-dominated by all the other candidates as 90 /// this will be the first transfer_write executed after the potentially dead 91 /// transfer_write. 92 /// If we found such an overwriting transfer_write we know that the original 93 /// transfer_write is dead if all reads that can be reached from the potentially 94 /// dead transfer_write are dominated by the overwriting transfer_write. 95 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 96 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 97 << "\n"); 98 llvm::SmallVector<Operation *, 8> reads; 99 Operation *firstOverwriteCandidate = nullptr; 100 for (auto *user : write.getSource().getUsers()) { 101 if (user == write.getOperation()) 102 continue; 103 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 104 // Check candidate that can override the store. 105 if (checkSameValueWAW(nextWrite, write) && 106 postDominators.postDominates(nextWrite, write)) { 107 if (firstOverwriteCandidate == nullptr || 108 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 109 firstOverwriteCandidate = nextWrite; 110 else 111 assert( 112 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 113 } 114 } else { 115 if (auto read = dyn_cast<vector::TransferReadOp>(user)) { 116 // Don't need to consider disjoint reads. 117 if (vector::isDisjointTransferSet( 118 cast<VectorTransferOpInterface>(write.getOperation()), 119 cast<VectorTransferOpInterface>(read.getOperation()))) 120 continue; 121 } 122 reads.push_back(user); 123 } 124 } 125 if (firstOverwriteCandidate == nullptr) 126 return; 127 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 128 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 129 assert(writeAncestor && 130 "write op should be recursively part of the top region"); 131 132 for (Operation *read : reads) { 133 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 134 // TODO: if the read and write have the same ancestor we could recurse in 135 // the region to know if the read is reachable with more precision. 136 if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 137 continue; 138 if (!dominators.dominates(firstOverwriteCandidate, read)) { 139 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read 140 << "\n"); 141 return; 142 } 143 } 144 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 145 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 146 opToErase.push_back(write.getOperation()); 147 } 148 149 /// A transfer_write candidate to storeToLoad forwarding must: 150 /// 1. Access the same memref with the same indices and vector type as the 151 /// transfer_read. 152 /// 2. Dominate the transfer_read operation. 153 /// If several candidates are available, one must be dominated by all the others 154 /// since they are all dominating the same transfer_read. We only consider the 155 /// transfer_write dominated by all the other candidates as this will be the 156 /// last transfer_write executed before the transfer_read. 157 /// If we found such a candidate we can do the forwarding if all the other 158 /// potentially aliasing ops that may reach the transfer_read are post-dominated 159 /// by the transfer_write. 160 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 161 if (read.hasOutOfBoundsDim()) 162 return; 163 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 164 << "\n"); 165 SmallVector<Operation *, 8> blockingWrites; 166 vector::TransferWriteOp lastwrite = nullptr; 167 for (Operation *user : read.getSource().getUsers()) { 168 if (isa<vector::TransferReadOp>(user)) 169 continue; 170 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 171 // If there is a write, but we can prove that it is disjoint we can ignore 172 // the write. 173 if (vector::isDisjointTransferSet( 174 cast<VectorTransferOpInterface>(write.getOperation()), 175 cast<VectorTransferOpInterface>(read.getOperation()))) 176 continue; 177 if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 178 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 179 lastwrite = write; 180 else 181 assert(dominators.dominates(write, lastwrite)); 182 continue; 183 } 184 } 185 blockingWrites.push_back(user); 186 } 187 188 if (lastwrite == nullptr) 189 return; 190 191 Region *topRegion = lastwrite->getParentRegion(); 192 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 193 assert(readAncestor && 194 "read op should be recursively part of the top region"); 195 196 for (Operation *write : blockingWrites) { 197 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 198 // TODO: if the store and read have the same ancestor we could recurse in 199 // the region to know if the read is reachable with more precision. 200 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 201 continue; 202 if (!postDominators.postDominates(lastwrite, write)) { 203 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 204 << *write << "\n"); 205 return; 206 } 207 } 208 209 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 210 << " to: " << *read.getOperation() << "\n"); 211 read.replaceAllUsesWith(lastwrite.getVector()); 212 opToErase.push_back(read.getOperation()); 213 } 214 215 /// Drops unit dimensions from the input MemRefType. 216 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets, 217 ArrayRef<int64_t> sizes, 218 ArrayRef<int64_t> strides) { 219 SmallVector<int64_t> targetShape = llvm::to_vector( 220 llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; })); 221 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 222 targetShape, inputType, offsets, sizes, strides); 223 return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>()); 224 } 225 226 /// Creates a rank-reducing memref.subview op that drops unit dims from its 227 /// input. Or just returns the input if it was already without unit dims. 228 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 229 mlir::Location loc, 230 Value input) { 231 MemRefType inputType = input.getType().cast<MemRefType>(); 232 assert(inputType.hasStaticShape()); 233 SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0); 234 SmallVector<int64_t> subViewStrides(inputType.getRank(), 1); 235 ArrayRef<int64_t> subViewSizes = inputType.getShape(); 236 MemRefType resultType = 237 dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides); 238 if (canonicalizeStridedLayout(resultType) == 239 canonicalizeStridedLayout(inputType)) 240 return input; 241 return rewriter.create<memref::SubViewOp>( 242 loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides); 243 } 244 245 /// Returns the number of dims that aren't unit dims. 246 static int getReducedRank(ArrayRef<int64_t> shape) { 247 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 248 } 249 250 /// Returns true if all values are `arith.constant 0 : index` 251 static bool isZero(Value v) { 252 auto cst = v.getDefiningOp<arith::ConstantIndexOp>(); 253 return cst && cst.value() == 0; 254 } 255 256 /// Rewrites vector.transfer_read ops where the source has unit dims, by 257 /// inserting a memref.subview dropping those unit dims. 258 class TransferReadDropUnitDimsPattern 259 : public OpRewritePattern<vector::TransferReadOp> { 260 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 261 262 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 263 PatternRewriter &rewriter) const override { 264 auto loc = transferReadOp.getLoc(); 265 Value vector = transferReadOp.getVector(); 266 VectorType vectorType = vector.getType().cast<VectorType>(); 267 Value source = transferReadOp.getSource(); 268 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 269 // TODO: support tensor types. 270 if (!sourceType || !sourceType.hasStaticShape()) 271 return failure(); 272 if (sourceType.getNumElements() != vectorType.getNumElements()) 273 return failure(); 274 // TODO: generalize this pattern, relax the requirements here. 275 if (transferReadOp.hasOutOfBoundsDim()) 276 return failure(); 277 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 278 return failure(); 279 int reducedRank = getReducedRank(sourceType.getShape()); 280 if (reducedRank == sourceType.getRank()) 281 return failure(); // The source shape can't be further reduced. 282 if (reducedRank != vectorType.getRank()) 283 return failure(); // This pattern requires the vector shape to match the 284 // reduced source shape. 285 if (llvm::any_of(transferReadOp.getIndices(), 286 [](Value v) { return !isZero(v); })) 287 return failure(); 288 Value reducedShapeSource = 289 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 290 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 291 SmallVector<Value> zeros(reducedRank, c0); 292 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 293 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 294 transferReadOp, vectorType, reducedShapeSource, zeros, identityMap); 295 return success(); 296 } 297 }; 298 299 /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has 300 /// unit dims, by inserting a memref.subview dropping those unit dims. 301 class TransferWriteDropUnitDimsPattern 302 : public OpRewritePattern<vector::TransferWriteOp> { 303 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 304 305 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 306 PatternRewriter &rewriter) const override { 307 auto loc = transferWriteOp.getLoc(); 308 Value vector = transferWriteOp.getVector(); 309 VectorType vectorType = vector.getType().cast<VectorType>(); 310 Value source = transferWriteOp.getSource(); 311 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 312 // TODO: support tensor type. 313 if (!sourceType || !sourceType.hasStaticShape()) 314 return failure(); 315 if (sourceType.getNumElements() != vectorType.getNumElements()) 316 return failure(); 317 // TODO: generalize this pattern, relax the requirements here. 318 if (transferWriteOp.hasOutOfBoundsDim()) 319 return failure(); 320 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 321 return failure(); 322 int reducedRank = getReducedRank(sourceType.getShape()); 323 if (reducedRank == sourceType.getRank()) 324 return failure(); // The source shape can't be further reduced. 325 if (reducedRank != vectorType.getRank()) 326 return failure(); // This pattern requires the vector shape to match the 327 // reduced source shape. 328 if (llvm::any_of(transferWriteOp.getIndices(), 329 [](Value v) { return !isZero(v); })) 330 return failure(); 331 Value reducedShapeSource = 332 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 333 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 334 SmallVector<Value> zeros(reducedRank, c0); 335 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 336 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 337 transferWriteOp, vector, reducedShapeSource, zeros, identityMap); 338 return success(); 339 } 340 }; 341 342 /// Returns the position of the first inner dimension that has contiguous layout 343 /// with at least `requiredContiguousSize` contiguous elements. 344 /// When such a dimension is found, the return value satisfies: 345 /// 0 <= return_value <= memrefType.getRank() - 1. 346 /// When no such dimension is found, the return value is memrefType.getRank(). 347 static int64_t getContiguousInnerDim(MemRefType memrefType, 348 int64_t requiredContiguousSize) { 349 auto shape = memrefType.getShape(); 350 SmallVector<int64_t> strides; 351 int64_t offset; 352 int64_t innerDim = shape.size(); 353 if (succeeded(getStridesAndOffset(memrefType, strides, offset))) { 354 int64_t innerSize = 1; 355 while (true) { 356 if (innerDim == 0) 357 break; 358 const int64_t nextDim = innerDim - 1; 359 if (shape[nextDim] == ShapedType::kDynamicSize) 360 break; 361 if (strides[nextDim] != innerSize) 362 break; 363 innerSize *= shape[nextDim]; 364 innerDim = nextDim; 365 if (innerSize >= requiredContiguousSize) 366 break; 367 } 368 } 369 return innerDim; 370 } 371 372 /// Creates a memref.collapse_shape collapsing all inner dimensions of the 373 /// input starting at `firstDimToCollapse`. 374 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, 375 Value input, int64_t firstDimToCollapse) { 376 ShapedType inputType = input.getType().cast<ShapedType>(); 377 if (inputType.getRank() == 1) 378 return input; 379 SmallVector<ReassociationIndices> reassociation; 380 for (int64_t i = 0; i < firstDimToCollapse; ++i) 381 reassociation.push_back(ReassociationIndices{i}); 382 ReassociationIndices collapsedIndices; 383 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) 384 collapsedIndices.push_back(i); 385 reassociation.push_back(collapsedIndices); 386 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); 387 } 388 389 /// Checks that the indices corresponding to dimensions starting at 390 /// `firstDimToCollapse` are constant 0, and writes to `outIndices` 391 /// the truncated indices where `firstDimToCollapse` is now the innermost dim. 392 static LogicalResult 393 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, 394 SmallVector<Value> &outIndices) { 395 int64_t rank = indices.size(); 396 if (firstDimToCollapse >= rank) 397 return failure(); 398 for (int64_t i = firstDimToCollapse; i < rank; ++i) { 399 arith::ConstantIndexOp cst = 400 indices[i].getDefiningOp<arith::ConstantIndexOp>(); 401 if (!cst || cst.value() != 0) 402 return failure(); 403 } 404 outIndices = indices; 405 outIndices.resize(firstDimToCollapse + 1); 406 return success(); 407 } 408 409 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 410 /// memref.collapse_shape on the source so that the resulting 411 /// vector.transfer_read has a 1D source. Requires the source shape to be 412 /// already reduced i.e. without unit dims. 413 class FlattenContiguousRowMajorTransferReadPattern 414 : public OpRewritePattern<vector::TransferReadOp> { 415 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 416 417 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 418 PatternRewriter &rewriter) const override { 419 auto loc = transferReadOp.getLoc(); 420 Value vector = transferReadOp.getVector(); 421 VectorType vectorType = vector.getType().cast<VectorType>(); 422 Value source = transferReadOp.getSource(); 423 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 424 // Contiguity check is valid on tensors only. 425 if (!sourceType) 426 return failure(); 427 if (vectorType.getRank() <= 1) 428 // Already 0D/1D, nothing to do. 429 return failure(); 430 int64_t firstContiguousInnerDim = 431 getContiguousInnerDim(sourceType, vectorType.getNumElements()); 432 if (firstContiguousInnerDim >= sourceType.getRank() - 1) 433 return failure(); 434 // TODO: generalize this pattern, relax the requirements here. 435 if (transferReadOp.hasOutOfBoundsDim()) 436 return failure(); 437 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 438 return failure(); 439 if (transferReadOp.getMask()) 440 return failure(); 441 SmallVector<Value> collapsedIndices; 442 if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), 443 firstContiguousInnerDim, 444 collapsedIndices))) 445 return failure(); 446 Value collapsedSource = 447 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 448 MemRefType collapsedSourceType = 449 collapsedSource.getType().dyn_cast<MemRefType>(); 450 int64_t collapsedRank = collapsedSourceType.getRank(); 451 assert(collapsedRank == firstContiguousInnerDim + 1); 452 SmallVector<AffineExpr, 1> dimExprs{ 453 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 454 auto collapsedMap = 455 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 456 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 457 vectorType.getElementType()); 458 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( 459 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); 460 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 461 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 462 transferReadOp, vector.getType().cast<VectorType>(), flatRead); 463 return success(); 464 } 465 }; 466 467 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 468 /// memref.collapse_shape on the source so that the resulting 469 /// vector.transfer_write has a 1D source. Requires the source shape to be 470 /// already reduced i.e. without unit dims. 471 class FlattenContiguousRowMajorTransferWritePattern 472 : public OpRewritePattern<vector::TransferWriteOp> { 473 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 474 475 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 476 PatternRewriter &rewriter) const override { 477 auto loc = transferWriteOp.getLoc(); 478 Value vector = transferWriteOp.getVector(); 479 VectorType vectorType = vector.getType().cast<VectorType>(); 480 Value source = transferWriteOp.getSource(); 481 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 482 // Contiguity check is valid on tensors only. 483 if (!sourceType) 484 return failure(); 485 if (vectorType.getRank() <= 1) 486 // Already 0D/1D, nothing to do. 487 return failure(); 488 int64_t firstContiguousInnerDim = 489 getContiguousInnerDim(sourceType, vectorType.getNumElements()); 490 if (firstContiguousInnerDim >= sourceType.getRank() - 1) 491 return failure(); 492 // TODO: generalize this pattern, relax the requirements here. 493 if (transferWriteOp.hasOutOfBoundsDim()) 494 return failure(); 495 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 496 return failure(); 497 if (transferWriteOp.getMask()) 498 return failure(); 499 SmallVector<Value> collapsedIndices; 500 if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(), 501 firstContiguousInnerDim, 502 collapsedIndices))) 503 return failure(); 504 Value collapsedSource = 505 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 506 MemRefType collapsedSourceType = 507 collapsedSource.getType().cast<MemRefType>(); 508 int64_t collapsedRank = collapsedSourceType.getRank(); 509 assert(collapsedRank == firstContiguousInnerDim + 1); 510 SmallVector<AffineExpr, 1> dimExprs{ 511 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 512 auto collapsedMap = 513 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 514 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 515 vectorType.getElementType()); 516 Value flatVector = 517 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); 518 vector::TransferWriteOp flatWrite = 519 rewriter.create<vector::TransferWriteOp>( 520 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); 521 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 522 rewriter.eraseOp(transferWriteOp); 523 return success(); 524 } 525 }; 526 527 } // namespace 528 529 void mlir::vector::transferOpflowOpt(Operation *rootOp) { 530 TransferOptimization opt(rootOp); 531 // Run store to load forwarding first since it can expose more dead store 532 // opportunity. 533 rootOp->walk([&](vector::TransferReadOp read) { 534 if (read.getShapedType().isa<MemRefType>()) 535 opt.storeToLoadForwarding(read); 536 }); 537 opt.removeDeadOp(); 538 rootOp->walk([&](vector::TransferWriteOp write) { 539 if (write.getShapedType().isa<MemRefType>()) 540 opt.deadStoreOp(write); 541 }); 542 opt.removeDeadOp(); 543 } 544 545 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 546 RewritePatternSet &patterns) { 547 patterns 548 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 549 patterns.getContext()); 550 populateShapeCastFoldingPatterns(patterns); 551 } 552 553 void mlir::vector::populateFlattenVectorTransferPatterns( 554 RewritePatternSet &patterns) { 555 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 556 FlattenContiguousRowMajorTransferWritePattern>( 557 patterns.getContext()); 558 populateShapeCastFoldingPatterns(patterns); 559 } 560