1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 17 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/Dominance.h" 20 #include "llvm/ADT/STLExtras.h" 21 #include "llvm/ADT/StringRef.h" 22 #include "llvm/Support/Debug.h" 23 24 #define DEBUG_TYPE "vector-transfer-opt" 25 26 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 27 28 using namespace mlir; 29 30 /// Return the ancestor op in the region or nullptr if the region is not 31 /// an ancestor of the op. 32 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 33 for (; op != nullptr && op->getParentRegion() != region; 34 op = op->getParentOp()) 35 ; 36 return op; 37 } 38 39 namespace { 40 41 class TransferOptimization { 42 public: 43 TransferOptimization(Operation *op) : dominators(op), postDominators(op) {} 44 void deadStoreOp(vector::TransferWriteOp); 45 void storeToLoadForwarding(vector::TransferReadOp); 46 void removeDeadOp() { 47 for (Operation *op : opToErase) 48 op->erase(); 49 opToErase.clear(); 50 } 51 52 private: 53 bool isReachable(Operation *start, Operation *dest); 54 DominanceInfo dominators; 55 PostDominanceInfo postDominators; 56 std::vector<Operation *> opToErase; 57 }; 58 59 /// Return true if there is a path from start operation to dest operation, 60 /// otherwise return false. The operations have to be in the same region. 61 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 62 assert(start->getParentRegion() == dest->getParentRegion() && 63 "This function only works for ops i the same region"); 64 // Simple case where the start op dominate the destination. 65 if (dominators.dominates(start, dest)) 66 return true; 67 Block *startBlock = start->getBlock(); 68 Block *destBlock = dest->getBlock(); 69 SmallVector<Block *, 32> worklist(startBlock->succ_begin(), 70 startBlock->succ_end()); 71 SmallPtrSet<Block *, 32> visited; 72 while (!worklist.empty()) { 73 Block *bb = worklist.pop_back_val(); 74 if (!visited.insert(bb).second) 75 continue; 76 if (dominators.dominates(bb, destBlock)) 77 return true; 78 worklist.append(bb->succ_begin(), bb->succ_end()); 79 } 80 return false; 81 } 82 83 /// For transfer_write to overwrite fully another transfer_write must: 84 /// 1. Access the same memref with the same indices and vector type. 85 /// 2. Post-dominate the other transfer_write operation. 86 /// If several candidates are available, one must be post-dominated by all the 87 /// others since they are all post-dominating the same transfer_write. We only 88 /// consider the transfer_write post-dominated by all the other candidates as 89 /// this will be the first transfer_write executed after the potentially dead 90 /// transfer_write. 91 /// If we found such an overwriting transfer_write we know that the original 92 /// transfer_write is dead if all reads that can be reached from the potentially 93 /// dead transfer_write are dominated by the overwriting transfer_write. 94 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 95 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 96 << "\n"); 97 llvm::SmallVector<Operation *, 8> reads; 98 Operation *firstOverwriteCandidate = nullptr; 99 for (auto *user : write.source().getUsers()) { 100 if (user == write.getOperation()) 101 continue; 102 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 103 // Check candidate that can override the store. 104 if (checkSameValueWAW(nextWrite, write) && 105 postDominators.postDominates(nextWrite, write)) { 106 if (firstOverwriteCandidate == nullptr || 107 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 108 firstOverwriteCandidate = nextWrite; 109 else 110 assert( 111 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 112 } 113 } else { 114 if (auto read = dyn_cast<vector::TransferReadOp>(user)) { 115 // Don't need to consider disjoint reads. 116 if (vector::isDisjointTransferSet( 117 cast<VectorTransferOpInterface>(write.getOperation()), 118 cast<VectorTransferOpInterface>(read.getOperation()))) 119 continue; 120 } 121 reads.push_back(user); 122 } 123 } 124 if (firstOverwriteCandidate == nullptr) 125 return; 126 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 127 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 128 assert(writeAncestor && 129 "write op should be recursively part of the top region"); 130 131 for (Operation *read : reads) { 132 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 133 // TODO: if the read and write have the same ancestor we could recurse in 134 // the region to know if the read is reachable with more precision. 135 if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 136 continue; 137 if (!dominators.dominates(firstOverwriteCandidate, read)) { 138 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read 139 << "\n"); 140 return; 141 } 142 } 143 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 144 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 145 opToErase.push_back(write.getOperation()); 146 } 147 148 /// A transfer_write candidate to storeToLoad forwarding must: 149 /// 1. Access the same memref with the same indices and vector type as the 150 /// transfer_read. 151 /// 2. Dominate the transfer_read operation. 152 /// If several candidates are available, one must be dominated by all the others 153 /// since they are all dominating the same transfer_read. We only consider the 154 /// transfer_write dominated by all the other candidates as this will be the 155 /// last transfer_write executed before the transfer_read. 156 /// If we found such a candidate we can do the forwarding if all the other 157 /// potentially aliasing ops that may reach the transfer_read are post-dominated 158 /// by the transfer_write. 159 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 160 if (read.hasOutOfBoundsDim()) 161 return; 162 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 163 << "\n"); 164 SmallVector<Operation *, 8> blockingWrites; 165 vector::TransferWriteOp lastwrite = nullptr; 166 for (Operation *user : read.source().getUsers()) { 167 if (isa<vector::TransferReadOp>(user)) 168 continue; 169 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 170 // If there is a write, but we can prove that it is disjoint we can ignore 171 // the write. 172 if (vector::isDisjointTransferSet( 173 cast<VectorTransferOpInterface>(write.getOperation()), 174 cast<VectorTransferOpInterface>(read.getOperation()))) 175 continue; 176 if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 177 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 178 lastwrite = write; 179 else 180 assert(dominators.dominates(write, lastwrite)); 181 continue; 182 } 183 } 184 blockingWrites.push_back(user); 185 } 186 187 if (lastwrite == nullptr) 188 return; 189 190 Region *topRegion = lastwrite->getParentRegion(); 191 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 192 assert(readAncestor && 193 "read op should be recursively part of the top region"); 194 195 for (Operation *write : blockingWrites) { 196 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 197 // TODO: if the store and read have the same ancestor we could recurse in 198 // the region to know if the read is reachable with more precision. 199 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 200 continue; 201 if (!postDominators.postDominates(lastwrite, write)) { 202 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 203 << *write << "\n"); 204 return; 205 } 206 } 207 208 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 209 << " to: " << *read.getOperation() << "\n"); 210 read.replaceAllUsesWith(lastwrite.vector()); 211 opToErase.push_back(read.getOperation()); 212 } 213 214 /// Drops unit dimensions from the input MemRefType. 215 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets, 216 ArrayRef<int64_t> sizes, 217 ArrayRef<int64_t> strides) { 218 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 219 0, inputType, offsets, sizes, strides); 220 return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>()); 221 } 222 223 /// Creates a rank-reducing memref.subview op that drops unit dims from its 224 /// input. Or just returns the input if it was already without unit dims. 225 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 226 mlir::Location loc, 227 Value input) { 228 MemRefType inputType = input.getType().cast<MemRefType>(); 229 assert(inputType.hasStaticShape()); 230 SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0); 231 SmallVector<int64_t> subViewStrides(inputType.getRank(), 1); 232 ArrayRef<int64_t> subViewSizes = inputType.getShape(); 233 MemRefType resultType = 234 dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides); 235 if (canonicalizeStridedLayout(resultType) == 236 canonicalizeStridedLayout(inputType)) 237 return input; 238 return rewriter.create<memref::SubViewOp>( 239 loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides); 240 } 241 242 /// Returns the number of dims that aren't unit dims. 243 static int getReducedRank(ArrayRef<int64_t> shape) { 244 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 245 } 246 247 /// Returns true if all values are `arith.constant 0 : index` 248 static bool isZero(Value v) { 249 auto cst = v.getDefiningOp<arith::ConstantIndexOp>(); 250 return cst && cst.value() == 0; 251 } 252 253 /// Rewrites vector.transfer_read ops where the source has unit dims, by 254 /// inserting a memref.subview dropping those unit dims. 255 class TransferReadDropUnitDimsPattern 256 : public OpRewritePattern<vector::TransferReadOp> { 257 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 258 259 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 260 PatternRewriter &rewriter) const override { 261 auto loc = transferReadOp.getLoc(); 262 Value vector = transferReadOp.vector(); 263 VectorType vectorType = vector.getType().cast<VectorType>(); 264 Value source = transferReadOp.source(); 265 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 266 // TODO: support tensor types. 267 if (!sourceType || !sourceType.hasStaticShape()) 268 return failure(); 269 if (sourceType.getNumElements() != vectorType.getNumElements()) 270 return failure(); 271 // TODO: generalize this pattern, relax the requirements here. 272 if (transferReadOp.hasOutOfBoundsDim()) 273 return failure(); 274 if (!transferReadOp.permutation_map().isMinorIdentity()) 275 return failure(); 276 int reducedRank = getReducedRank(sourceType.getShape()); 277 if (reducedRank == sourceType.getRank()) 278 return failure(); // The source shape can't be further reduced. 279 if (reducedRank != vectorType.getRank()) 280 return failure(); // This pattern requires the vector shape to match the 281 // reduced source shape. 282 if (llvm::any_of(transferReadOp.indices(), 283 [](Value v) { return !isZero(v); })) 284 return failure(); 285 Value reducedShapeSource = 286 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 287 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 288 SmallVector<Value> zeros(reducedRank, c0); 289 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 290 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 291 transferReadOp, vectorType, reducedShapeSource, zeros, identityMap); 292 return success(); 293 } 294 }; 295 296 /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has 297 /// unit dims, by inserting a memref.subview dropping those unit dims. 298 class TransferWriteDropUnitDimsPattern 299 : public OpRewritePattern<vector::TransferWriteOp> { 300 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 301 302 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 303 PatternRewriter &rewriter) const override { 304 auto loc = transferWriteOp.getLoc(); 305 Value vector = transferWriteOp.vector(); 306 VectorType vectorType = vector.getType().cast<VectorType>(); 307 Value source = transferWriteOp.source(); 308 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 309 // TODO: support tensor type. 310 if (!sourceType || !sourceType.hasStaticShape()) 311 return failure(); 312 if (sourceType.getNumElements() != vectorType.getNumElements()) 313 return failure(); 314 // TODO: generalize this pattern, relax the requirements here. 315 if (transferWriteOp.hasOutOfBoundsDim()) 316 return failure(); 317 if (!transferWriteOp.permutation_map().isMinorIdentity()) 318 return failure(); 319 int reducedRank = getReducedRank(sourceType.getShape()); 320 if (reducedRank == sourceType.getRank()) 321 return failure(); // The source shape can't be further reduced. 322 if (reducedRank != vectorType.getRank()) 323 return failure(); // This pattern requires the vector shape to match the 324 // reduced source shape. 325 if (llvm::any_of(transferWriteOp.indices(), 326 [](Value v) { return !isZero(v); })) 327 return failure(); 328 Value reducedShapeSource = 329 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 330 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 331 SmallVector<Value> zeros(reducedRank, c0); 332 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 333 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 334 transferWriteOp, vector, reducedShapeSource, zeros, identityMap); 335 return success(); 336 } 337 }; 338 339 /// Creates a memref.collapse_shape collapsing all of the dimensions of the 340 /// input into a 1D shape. 341 // TODO: move helper function 342 static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter, 343 mlir::Location loc, 344 Value input) { 345 Value rankReducedInput = 346 rankReducingSubviewDroppingUnitDims(rewriter, loc, input); 347 ShapedType rankReducedInputType = 348 rankReducedInput.getType().cast<ShapedType>(); 349 if (rankReducedInputType.getRank() == 1) 350 return rankReducedInput; 351 ReassociationIndices indices; 352 for (int i = 0; i < rankReducedInputType.getRank(); ++i) 353 indices.push_back(i); 354 return rewriter.create<memref::CollapseShapeOp>( 355 loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices}); 356 } 357 358 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 359 /// memref.collapse_shape on the source so that the resulting 360 /// vector.transfer_read has a 1D source. Requires the source shape to be 361 /// already reduced i.e. without unit dims. 362 class FlattenContiguousRowMajorTransferReadPattern 363 : public OpRewritePattern<vector::TransferReadOp> { 364 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 365 366 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 367 PatternRewriter &rewriter) const override { 368 auto loc = transferReadOp.getLoc(); 369 Value vector = transferReadOp.vector(); 370 VectorType vectorType = vector.getType().cast<VectorType>(); 371 Value source = transferReadOp.source(); 372 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 373 // Contiguity check is valid on tensors only. 374 if (!sourceType) 375 return failure(); 376 if (vectorType.getRank() <= 1) 377 // Already 0D/1D, nothing to do. 378 return failure(); 379 if (!isStaticShapeAndContiguousRowMajor(sourceType)) 380 return failure(); 381 if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) 382 // This pattern requires the source to already be rank-reduced. 383 return failure(); 384 if (sourceType.getNumElements() != vectorType.getNumElements()) 385 return failure(); 386 // TODO: generalize this pattern, relax the requirements here. 387 if (transferReadOp.hasOutOfBoundsDim()) 388 return failure(); 389 if (!transferReadOp.permutation_map().isMinorIdentity()) 390 return failure(); 391 if (transferReadOp.mask()) 392 return failure(); 393 if (llvm::any_of(transferReadOp.indices(), 394 [](Value v) { return !isZero(v); })) 395 return failure(); 396 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 397 auto identityMap1D = rewriter.getMultiDimIdentityMap(1); 398 VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, 399 sourceType.getElementType()); 400 Value source1d = 401 collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); 402 Value read1d = rewriter.create<vector::TransferReadOp>( 403 loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); 404 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 405 transferReadOp, vector.getType().cast<VectorType>(), read1d); 406 return success(); 407 } 408 }; 409 410 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 411 /// memref.collapse_shape on the source so that the resulting 412 /// vector.transfer_write has a 1D source. Requires the source shape to be 413 /// already reduced i.e. without unit dims. 414 class FlattenContiguousRowMajorTransferWritePattern 415 : public OpRewritePattern<vector::TransferWriteOp> { 416 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 417 418 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 419 PatternRewriter &rewriter) const override { 420 auto loc = transferWriteOp.getLoc(); 421 Value vector = transferWriteOp.vector(); 422 VectorType vectorType = vector.getType().cast<VectorType>(); 423 Value source = transferWriteOp.source(); 424 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 425 // Contiguity check is valid on tensors only. 426 if (!sourceType) 427 return failure(); 428 if (vectorType.getRank() <= 1) 429 // Already 0D/1D, nothing to do. 430 return failure(); 431 if (!isStaticShapeAndContiguousRowMajor(sourceType)) 432 return failure(); 433 if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) 434 // This pattern requires the source to already be rank-reduced. 435 return failure(); 436 if (sourceType.getNumElements() != vectorType.getNumElements()) 437 return failure(); 438 // TODO: generalize this pattern, relax the requirements here. 439 if (transferWriteOp.hasOutOfBoundsDim()) 440 return failure(); 441 if (!transferWriteOp.permutation_map().isMinorIdentity()) 442 return failure(); 443 if (transferWriteOp.mask()) 444 return failure(); 445 if (llvm::any_of(transferWriteOp.indices(), 446 [](Value v) { return !isZero(v); })) 447 return failure(); 448 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 449 auto identityMap1D = rewriter.getMultiDimIdentityMap(1); 450 VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, 451 sourceType.getElementType()); 452 Value source1d = 453 collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); 454 Value vector1d = 455 rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector); 456 rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d, 457 ValueRange{c0}, identityMap1D); 458 rewriter.eraseOp(transferWriteOp); 459 return success(); 460 } 461 }; 462 463 } // namespace 464 465 void mlir::vector::transferOpflowOpt(Operation *rootOp) { 466 TransferOptimization opt(rootOp); 467 // Run store to load forwarding first since it can expose more dead store 468 // opportunity. 469 rootOp->walk([&](vector::TransferReadOp read) { 470 if (read.getShapedType().isa<MemRefType>()) 471 opt.storeToLoadForwarding(read); 472 }); 473 opt.removeDeadOp(); 474 rootOp->walk([&](vector::TransferWriteOp write) { 475 if (write.getShapedType().isa<MemRefType>()) 476 opt.deadStoreOp(write); 477 }); 478 opt.removeDeadOp(); 479 } 480 481 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 482 RewritePatternSet &patterns) { 483 patterns 484 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 485 patterns.getContext()); 486 populateShapeCastFoldingPatterns(patterns); 487 } 488 489 void mlir::vector::populateFlattenVectorTransferPatterns( 490 RewritePatternSet &patterns) { 491 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 492 FlattenContiguousRowMajorTransferWritePattern>( 493 patterns.getContext()); 494 populateShapeCastFoldingPatterns(patterns); 495 } 496