1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/Dominance.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "vector-transfer-opt" 26 27 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 28 29 using namespace mlir; 30 31 /// Return the ancestor op in the region or nullptr if the region is not 32 /// an ancestor of the op. 33 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 34 for (; op != nullptr && op->getParentRegion() != region; 35 op = op->getParentOp()) 36 ; 37 return op; 38 } 39 40 namespace { 41 42 class TransferOptimization { 43 public: 44 TransferOptimization(Operation *op) : dominators(op), postDominators(op) {} 45 void deadStoreOp(vector::TransferWriteOp); 46 void storeToLoadForwarding(vector::TransferReadOp); 47 void removeDeadOp() { 48 for (Operation *op : opToErase) 49 op->erase(); 50 opToErase.clear(); 51 } 52 53 private: 54 bool isReachable(Operation *start, Operation *dest); 55 DominanceInfo dominators; 56 PostDominanceInfo postDominators; 57 std::vector<Operation *> opToErase; 58 }; 59 60 /// Return true if there is a path from start operation to dest operation, 61 /// otherwise return false. The operations have to be in the same region. 62 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 63 assert(start->getParentRegion() == dest->getParentRegion() && 64 "This function only works for ops i the same region"); 65 // Simple case where the start op dominate the destination. 66 if (dominators.dominates(start, dest)) 67 return true; 68 Block *startBlock = start->getBlock(); 69 Block *destBlock = dest->getBlock(); 70 SmallVector<Block *, 32> worklist(startBlock->succ_begin(), 71 startBlock->succ_end()); 72 SmallPtrSet<Block *, 32> visited; 73 while (!worklist.empty()) { 74 Block *bb = worklist.pop_back_val(); 75 if (!visited.insert(bb).second) 76 continue; 77 if (dominators.dominates(bb, destBlock)) 78 return true; 79 worklist.append(bb->succ_begin(), bb->succ_end()); 80 } 81 return false; 82 } 83 84 /// For transfer_write to overwrite fully another transfer_write must: 85 /// 1. Access the same memref with the same indices and vector type. 86 /// 2. Post-dominate the other transfer_write operation. 87 /// If several candidates are available, one must be post-dominated by all the 88 /// others since they are all post-dominating the same transfer_write. We only 89 /// consider the transfer_write post-dominated by all the other candidates as 90 /// this will be the first transfer_write executed after the potentially dead 91 /// transfer_write. 92 /// If we found such an overwriting transfer_write we know that the original 93 /// transfer_write is dead if all reads that can be reached from the potentially 94 /// dead transfer_write are dominated by the overwriting transfer_write. 95 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 96 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 97 << "\n"); 98 llvm::SmallVector<Operation *, 8> reads; 99 Operation *firstOverwriteCandidate = nullptr; 100 for (auto *user : write.getSource().getUsers()) { 101 if (user == write.getOperation()) 102 continue; 103 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 104 // Check candidate that can override the store. 105 if (checkSameValueWAW(nextWrite, write) && 106 postDominators.postDominates(nextWrite, write)) { 107 if (firstOverwriteCandidate == nullptr || 108 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 109 firstOverwriteCandidate = nextWrite; 110 else 111 assert( 112 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 113 } 114 } else { 115 if (auto read = dyn_cast<vector::TransferReadOp>(user)) { 116 // Don't need to consider disjoint reads. 117 if (vector::isDisjointTransferSet( 118 cast<VectorTransferOpInterface>(write.getOperation()), 119 cast<VectorTransferOpInterface>(read.getOperation()))) 120 continue; 121 } 122 reads.push_back(user); 123 } 124 } 125 if (firstOverwriteCandidate == nullptr) 126 return; 127 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 128 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 129 assert(writeAncestor && 130 "write op should be recursively part of the top region"); 131 132 for (Operation *read : reads) { 133 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 134 // TODO: if the read and write have the same ancestor we could recurse in 135 // the region to know if the read is reachable with more precision. 136 if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 137 continue; 138 if (!dominators.dominates(firstOverwriteCandidate, read)) { 139 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read 140 << "\n"); 141 return; 142 } 143 } 144 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 145 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 146 opToErase.push_back(write.getOperation()); 147 } 148 149 /// A transfer_write candidate to storeToLoad forwarding must: 150 /// 1. Access the same memref with the same indices and vector type as the 151 /// transfer_read. 152 /// 2. Dominate the transfer_read operation. 153 /// If several candidates are available, one must be dominated by all the others 154 /// since they are all dominating the same transfer_read. We only consider the 155 /// transfer_write dominated by all the other candidates as this will be the 156 /// last transfer_write executed before the transfer_read. 157 /// If we found such a candidate we can do the forwarding if all the other 158 /// potentially aliasing ops that may reach the transfer_read are post-dominated 159 /// by the transfer_write. 160 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 161 if (read.hasOutOfBoundsDim()) 162 return; 163 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 164 << "\n"); 165 SmallVector<Operation *, 8> blockingWrites; 166 vector::TransferWriteOp lastwrite = nullptr; 167 for (Operation *user : read.getSource().getUsers()) { 168 if (isa<vector::TransferReadOp>(user)) 169 continue; 170 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 171 // If there is a write, but we can prove that it is disjoint we can ignore 172 // the write. 173 if (vector::isDisjointTransferSet( 174 cast<VectorTransferOpInterface>(write.getOperation()), 175 cast<VectorTransferOpInterface>(read.getOperation()))) 176 continue; 177 if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 178 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 179 lastwrite = write; 180 else 181 assert(dominators.dominates(write, lastwrite)); 182 continue; 183 } 184 } 185 blockingWrites.push_back(user); 186 } 187 188 if (lastwrite == nullptr) 189 return; 190 191 Region *topRegion = lastwrite->getParentRegion(); 192 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 193 assert(readAncestor && 194 "read op should be recursively part of the top region"); 195 196 for (Operation *write : blockingWrites) { 197 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 198 // TODO: if the store and read have the same ancestor we could recurse in 199 // the region to know if the read is reachable with more precision. 200 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 201 continue; 202 if (!postDominators.postDominates(lastwrite, write)) { 203 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 204 << *write << "\n"); 205 return; 206 } 207 } 208 209 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 210 << " to: " << *read.getOperation() << "\n"); 211 read.replaceAllUsesWith(lastwrite.getVector()); 212 opToErase.push_back(read.getOperation()); 213 } 214 215 /// Drops unit dimensions from the input MemRefType. 216 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets, 217 ArrayRef<int64_t> sizes, 218 ArrayRef<int64_t> strides) { 219 SmallVector<int64_t> targetShape = llvm::to_vector( 220 llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; })); 221 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 222 targetShape, inputType, offsets, sizes, strides); 223 return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>()); 224 } 225 226 /// Creates a rank-reducing memref.subview op that drops unit dims from its 227 /// input. Or just returns the input if it was already without unit dims. 228 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 229 mlir::Location loc, 230 Value input) { 231 MemRefType inputType = input.getType().cast<MemRefType>(); 232 assert(inputType.hasStaticShape()); 233 SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0); 234 SmallVector<int64_t> subViewStrides(inputType.getRank(), 1); 235 ArrayRef<int64_t> subViewSizes = inputType.getShape(); 236 MemRefType resultType = 237 dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides); 238 if (canonicalizeStridedLayout(resultType) == 239 canonicalizeStridedLayout(inputType)) 240 return input; 241 return rewriter.create<memref::SubViewOp>( 242 loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides); 243 } 244 245 /// Returns the number of dims that aren't unit dims. 246 static int getReducedRank(ArrayRef<int64_t> shape) { 247 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 248 } 249 250 /// Returns true if all values are `arith.constant 0 : index` 251 static bool isZero(Value v) { 252 auto cst = v.getDefiningOp<arith::ConstantIndexOp>(); 253 return cst && cst.value() == 0; 254 } 255 256 /// Rewrites vector.transfer_read ops where the source has unit dims, by 257 /// inserting a memref.subview dropping those unit dims. 258 class TransferReadDropUnitDimsPattern 259 : public OpRewritePattern<vector::TransferReadOp> { 260 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 261 262 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 263 PatternRewriter &rewriter) const override { 264 auto loc = transferReadOp.getLoc(); 265 Value vector = transferReadOp.getVector(); 266 VectorType vectorType = vector.getType().cast<VectorType>(); 267 Value source = transferReadOp.getSource(); 268 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 269 // TODO: support tensor types. 270 if (!sourceType || !sourceType.hasStaticShape()) 271 return failure(); 272 if (sourceType.getNumElements() != vectorType.getNumElements()) 273 return failure(); 274 // TODO: generalize this pattern, relax the requirements here. 275 if (transferReadOp.hasOutOfBoundsDim()) 276 return failure(); 277 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 278 return failure(); 279 int reducedRank = getReducedRank(sourceType.getShape()); 280 if (reducedRank == sourceType.getRank()) 281 return failure(); // The source shape can't be further reduced. 282 if (reducedRank != vectorType.getRank()) 283 return failure(); // This pattern requires the vector shape to match the 284 // reduced source shape. 285 if (llvm::any_of(transferReadOp.getIndices(), 286 [](Value v) { return !isZero(v); })) 287 return failure(); 288 Value reducedShapeSource = 289 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 290 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 291 SmallVector<Value> zeros(reducedRank, c0); 292 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 293 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 294 transferReadOp, vectorType, reducedShapeSource, zeros, identityMap); 295 return success(); 296 } 297 }; 298 299 /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has 300 /// unit dims, by inserting a memref.subview dropping those unit dims. 301 class TransferWriteDropUnitDimsPattern 302 : public OpRewritePattern<vector::TransferWriteOp> { 303 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 304 305 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 306 PatternRewriter &rewriter) const override { 307 auto loc = transferWriteOp.getLoc(); 308 Value vector = transferWriteOp.getVector(); 309 VectorType vectorType = vector.getType().cast<VectorType>(); 310 Value source = transferWriteOp.getSource(); 311 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 312 // TODO: support tensor type. 313 if (!sourceType || !sourceType.hasStaticShape()) 314 return failure(); 315 if (sourceType.getNumElements() != vectorType.getNumElements()) 316 return failure(); 317 // TODO: generalize this pattern, relax the requirements here. 318 if (transferWriteOp.hasOutOfBoundsDim()) 319 return failure(); 320 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 321 return failure(); 322 int reducedRank = getReducedRank(sourceType.getShape()); 323 if (reducedRank == sourceType.getRank()) 324 return failure(); // The source shape can't be further reduced. 325 if (reducedRank != vectorType.getRank()) 326 return failure(); // This pattern requires the vector shape to match the 327 // reduced source shape. 328 if (llvm::any_of(transferWriteOp.getIndices(), 329 [](Value v) { return !isZero(v); })) 330 return failure(); 331 Value reducedShapeSource = 332 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 333 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 334 SmallVector<Value> zeros(reducedRank, c0); 335 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 336 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 337 transferWriteOp, vector, reducedShapeSource, zeros, identityMap); 338 return success(); 339 } 340 }; 341 342 /// Creates a memref.collapse_shape collapsing all of the dimensions of the 343 /// input into a 1D shape. 344 // TODO: move helper function 345 static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter, 346 mlir::Location loc, 347 Value input) { 348 Value rankReducedInput = 349 rankReducingSubviewDroppingUnitDims(rewriter, loc, input); 350 ShapedType rankReducedInputType = 351 rankReducedInput.getType().cast<ShapedType>(); 352 if (rankReducedInputType.getRank() == 1) 353 return rankReducedInput; 354 ReassociationIndices indices; 355 for (int i = 0; i < rankReducedInputType.getRank(); ++i) 356 indices.push_back(i); 357 return rewriter.create<memref::CollapseShapeOp>( 358 loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices}); 359 } 360 361 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 362 /// memref.collapse_shape on the source so that the resulting 363 /// vector.transfer_read has a 1D source. Requires the source shape to be 364 /// already reduced i.e. without unit dims. 365 class FlattenContiguousRowMajorTransferReadPattern 366 : public OpRewritePattern<vector::TransferReadOp> { 367 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 368 369 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 370 PatternRewriter &rewriter) const override { 371 auto loc = transferReadOp.getLoc(); 372 Value vector = transferReadOp.getVector(); 373 VectorType vectorType = vector.getType().cast<VectorType>(); 374 Value source = transferReadOp.getSource(); 375 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 376 // Contiguity check is valid on tensors only. 377 if (!sourceType) 378 return failure(); 379 if (vectorType.getRank() <= 1) 380 // Already 0D/1D, nothing to do. 381 return failure(); 382 if (!isStaticShapeAndContiguousRowMajor(sourceType)) 383 return failure(); 384 if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) 385 // This pattern requires the source to already be rank-reduced. 386 return failure(); 387 if (sourceType.getNumElements() != vectorType.getNumElements()) 388 return failure(); 389 // TODO: generalize this pattern, relax the requirements here. 390 if (transferReadOp.hasOutOfBoundsDim()) 391 return failure(); 392 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 393 return failure(); 394 if (transferReadOp.getMask()) 395 return failure(); 396 if (llvm::any_of(transferReadOp.getIndices(), 397 [](Value v) { return !isZero(v); })) 398 return failure(); 399 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 400 auto identityMap1D = rewriter.getMultiDimIdentityMap(1); 401 VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, 402 sourceType.getElementType()); 403 Value source1d = 404 collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); 405 Value read1d = rewriter.create<vector::TransferReadOp>( 406 loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); 407 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 408 transferReadOp, vector.getType().cast<VectorType>(), read1d); 409 return success(); 410 } 411 }; 412 413 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 414 /// memref.collapse_shape on the source so that the resulting 415 /// vector.transfer_write has a 1D source. Requires the source shape to be 416 /// already reduced i.e. without unit dims. 417 class FlattenContiguousRowMajorTransferWritePattern 418 : public OpRewritePattern<vector::TransferWriteOp> { 419 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 420 421 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 422 PatternRewriter &rewriter) const override { 423 auto loc = transferWriteOp.getLoc(); 424 Value vector = transferWriteOp.getVector(); 425 VectorType vectorType = vector.getType().cast<VectorType>(); 426 Value source = transferWriteOp.getSource(); 427 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 428 // Contiguity check is valid on tensors only. 429 if (!sourceType) 430 return failure(); 431 if (vectorType.getRank() <= 1) 432 // Already 0D/1D, nothing to do. 433 return failure(); 434 if (!isStaticShapeAndContiguousRowMajor(sourceType)) 435 return failure(); 436 if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) 437 // This pattern requires the source to already be rank-reduced. 438 return failure(); 439 if (sourceType.getNumElements() != vectorType.getNumElements()) 440 return failure(); 441 // TODO: generalize this pattern, relax the requirements here. 442 if (transferWriteOp.hasOutOfBoundsDim()) 443 return failure(); 444 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 445 return failure(); 446 if (transferWriteOp.getMask()) 447 return failure(); 448 if (llvm::any_of(transferWriteOp.getIndices(), 449 [](Value v) { return !isZero(v); })) 450 return failure(); 451 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 452 auto identityMap1D = rewriter.getMultiDimIdentityMap(1); 453 VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, 454 sourceType.getElementType()); 455 Value source1d = 456 collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); 457 Value vector1d = 458 rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector); 459 rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d, 460 ValueRange{c0}, identityMap1D); 461 rewriter.eraseOp(transferWriteOp); 462 return success(); 463 } 464 }; 465 466 } // namespace 467 468 void mlir::vector::transferOpflowOpt(Operation *rootOp) { 469 TransferOptimization opt(rootOp); 470 // Run store to load forwarding first since it can expose more dead store 471 // opportunity. 472 rootOp->walk([&](vector::TransferReadOp read) { 473 if (read.getShapedType().isa<MemRefType>()) 474 opt.storeToLoadForwarding(read); 475 }); 476 opt.removeDeadOp(); 477 rootOp->walk([&](vector::TransferWriteOp write) { 478 if (write.getShapedType().isa<MemRefType>()) 479 opt.deadStoreOp(write); 480 }); 481 opt.removeDeadOp(); 482 } 483 484 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 485 RewritePatternSet &patterns) { 486 patterns 487 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 488 patterns.getContext()); 489 populateShapeCastFoldingPatterns(patterns); 490 } 491 492 void mlir::vector::populateFlattenVectorTransferPatterns( 493 RewritePatternSet &patterns) { 494 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 495 FlattenContiguousRowMajorTransferWritePattern>( 496 patterns.getContext()); 497 populateShapeCastFoldingPatterns(patterns); 498 } 499