1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/Dominance.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "vector-transfer-opt" 26 27 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 28 29 using namespace mlir; 30 31 /// Return the ancestor op in the region or nullptr if the region is not 32 /// an ancestor of the op. 33 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 34 for (; op != nullptr && op->getParentRegion() != region; 35 op = op->getParentOp()) 36 ; 37 return op; 38 } 39 40 namespace { 41 42 class TransferOptimization { 43 public: 44 TransferOptimization(Operation *op) : dominators(op), postDominators(op) {} 45 void deadStoreOp(vector::TransferWriteOp); 46 void storeToLoadForwarding(vector::TransferReadOp); 47 void removeDeadOp() { 48 for (Operation *op : opToErase) 49 op->erase(); 50 opToErase.clear(); 51 } 52 53 private: 54 bool isReachable(Operation *start, Operation *dest); 55 DominanceInfo dominators; 56 PostDominanceInfo postDominators; 57 std::vector<Operation *> opToErase; 58 }; 59 60 /// Return true if there is a path from start operation to dest operation, 61 /// otherwise return false. The operations have to be in the same region. 62 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 63 assert(start->getParentRegion() == dest->getParentRegion() && 64 "This function only works for ops i the same region"); 65 // Simple case where the start op dominate the destination. 66 if (dominators.dominates(start, dest)) 67 return true; 68 Block *startBlock = start->getBlock(); 69 Block *destBlock = dest->getBlock(); 70 SmallVector<Block *, 32> worklist(startBlock->succ_begin(), 71 startBlock->succ_end()); 72 SmallPtrSet<Block *, 32> visited; 73 while (!worklist.empty()) { 74 Block *bb = worklist.pop_back_val(); 75 if (!visited.insert(bb).second) 76 continue; 77 if (dominators.dominates(bb, destBlock)) 78 return true; 79 worklist.append(bb->succ_begin(), bb->succ_end()); 80 } 81 return false; 82 } 83 84 /// For transfer_write to overwrite fully another transfer_write must: 85 /// 1. Access the same memref with the same indices and vector type. 86 /// 2. Post-dominate the other transfer_write operation. 87 /// If several candidates are available, one must be post-dominated by all the 88 /// others since they are all post-dominating the same transfer_write. We only 89 /// consider the transfer_write post-dominated by all the other candidates as 90 /// this will be the first transfer_write executed after the potentially dead 91 /// transfer_write. 92 /// If we found such an overwriting transfer_write we know that the original 93 /// transfer_write is dead if all reads that can be reached from the potentially 94 /// dead transfer_write are dominated by the overwriting transfer_write. 95 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 96 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 97 << "\n"); 98 llvm::SmallVector<Operation *, 8> reads; 99 Operation *firstOverwriteCandidate = nullptr; 100 for (auto *user : write.getSource().getUsers()) { 101 if (user == write.getOperation()) 102 continue; 103 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 104 // Check candidate that can override the store. 105 if (checkSameValueWAW(nextWrite, write) && 106 postDominators.postDominates(nextWrite, write)) { 107 if (firstOverwriteCandidate == nullptr || 108 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 109 firstOverwriteCandidate = nextWrite; 110 else 111 assert( 112 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 113 } 114 } else { 115 if (auto read = dyn_cast<vector::TransferReadOp>(user)) { 116 // Don't need to consider disjoint reads. 117 if (vector::isDisjointTransferSet( 118 cast<VectorTransferOpInterface>(write.getOperation()), 119 cast<VectorTransferOpInterface>(read.getOperation()))) 120 continue; 121 } 122 reads.push_back(user); 123 } 124 } 125 if (firstOverwriteCandidate == nullptr) 126 return; 127 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 128 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 129 assert(writeAncestor && 130 "write op should be recursively part of the top region"); 131 132 for (Operation *read : reads) { 133 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 134 // TODO: if the read and write have the same ancestor we could recurse in 135 // the region to know if the read is reachable with more precision. 136 if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 137 continue; 138 if (!dominators.dominates(firstOverwriteCandidate, read)) { 139 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read 140 << "\n"); 141 return; 142 } 143 } 144 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 145 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 146 opToErase.push_back(write.getOperation()); 147 } 148 149 /// A transfer_write candidate to storeToLoad forwarding must: 150 /// 1. Access the same memref with the same indices and vector type as the 151 /// transfer_read. 152 /// 2. Dominate the transfer_read operation. 153 /// If several candidates are available, one must be dominated by all the others 154 /// since they are all dominating the same transfer_read. We only consider the 155 /// transfer_write dominated by all the other candidates as this will be the 156 /// last transfer_write executed before the transfer_read. 157 /// If we found such a candidate we can do the forwarding if all the other 158 /// potentially aliasing ops that may reach the transfer_read are post-dominated 159 /// by the transfer_write. 160 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 161 if (read.hasOutOfBoundsDim()) 162 return; 163 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 164 << "\n"); 165 SmallVector<Operation *, 8> blockingWrites; 166 vector::TransferWriteOp lastwrite = nullptr; 167 for (Operation *user : read.getSource().getUsers()) { 168 if (isa<vector::TransferReadOp>(user)) 169 continue; 170 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 171 // If there is a write, but we can prove that it is disjoint we can ignore 172 // the write. 173 if (vector::isDisjointTransferSet( 174 cast<VectorTransferOpInterface>(write.getOperation()), 175 cast<VectorTransferOpInterface>(read.getOperation()))) 176 continue; 177 if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 178 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 179 lastwrite = write; 180 else 181 assert(dominators.dominates(write, lastwrite)); 182 continue; 183 } 184 } 185 blockingWrites.push_back(user); 186 } 187 188 if (lastwrite == nullptr) 189 return; 190 191 Region *topRegion = lastwrite->getParentRegion(); 192 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 193 assert(readAncestor && 194 "read op should be recursively part of the top region"); 195 196 for (Operation *write : blockingWrites) { 197 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 198 // TODO: if the store and read have the same ancestor we could recurse in 199 // the region to know if the read is reachable with more precision. 200 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 201 continue; 202 if (!postDominators.postDominates(lastwrite, write)) { 203 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 204 << *write << "\n"); 205 return; 206 } 207 } 208 209 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 210 << " to: " << *read.getOperation() << "\n"); 211 read.replaceAllUsesWith(lastwrite.getVector()); 212 opToErase.push_back(read.getOperation()); 213 } 214 215 /// Drops unit dimensions from the input MemRefType. 216 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets, 217 ArrayRef<int64_t> sizes, 218 ArrayRef<int64_t> strides) { 219 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 220 0, inputType, offsets, sizes, strides); 221 return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>()); 222 } 223 224 /// Creates a rank-reducing memref.subview op that drops unit dims from its 225 /// input. Or just returns the input if it was already without unit dims. 226 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 227 mlir::Location loc, 228 Value input) { 229 MemRefType inputType = input.getType().cast<MemRefType>(); 230 assert(inputType.hasStaticShape()); 231 SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0); 232 SmallVector<int64_t> subViewStrides(inputType.getRank(), 1); 233 ArrayRef<int64_t> subViewSizes = inputType.getShape(); 234 MemRefType resultType = 235 dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides); 236 if (canonicalizeStridedLayout(resultType) == 237 canonicalizeStridedLayout(inputType)) 238 return input; 239 return rewriter.create<memref::SubViewOp>( 240 loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides); 241 } 242 243 /// Returns the number of dims that aren't unit dims. 244 static int getReducedRank(ArrayRef<int64_t> shape) { 245 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 246 } 247 248 /// Returns true if all values are `arith.constant 0 : index` 249 static bool isZero(Value v) { 250 auto cst = v.getDefiningOp<arith::ConstantIndexOp>(); 251 return cst && cst.value() == 0; 252 } 253 254 /// Rewrites vector.transfer_read ops where the source has unit dims, by 255 /// inserting a memref.subview dropping those unit dims. 256 class TransferReadDropUnitDimsPattern 257 : public OpRewritePattern<vector::TransferReadOp> { 258 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 259 260 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 261 PatternRewriter &rewriter) const override { 262 auto loc = transferReadOp.getLoc(); 263 Value vector = transferReadOp.getVector(); 264 VectorType vectorType = vector.getType().cast<VectorType>(); 265 Value source = transferReadOp.getSource(); 266 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 267 // TODO: support tensor types. 268 if (!sourceType || !sourceType.hasStaticShape()) 269 return failure(); 270 if (sourceType.getNumElements() != vectorType.getNumElements()) 271 return failure(); 272 // TODO: generalize this pattern, relax the requirements here. 273 if (transferReadOp.hasOutOfBoundsDim()) 274 return failure(); 275 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 276 return failure(); 277 int reducedRank = getReducedRank(sourceType.getShape()); 278 if (reducedRank == sourceType.getRank()) 279 return failure(); // The source shape can't be further reduced. 280 if (reducedRank != vectorType.getRank()) 281 return failure(); // This pattern requires the vector shape to match the 282 // reduced source shape. 283 if (llvm::any_of(transferReadOp.getIndices(), 284 [](Value v) { return !isZero(v); })) 285 return failure(); 286 Value reducedShapeSource = 287 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 288 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 289 SmallVector<Value> zeros(reducedRank, c0); 290 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 291 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 292 transferReadOp, vectorType, reducedShapeSource, zeros, identityMap); 293 return success(); 294 } 295 }; 296 297 /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has 298 /// unit dims, by inserting a memref.subview dropping those unit dims. 299 class TransferWriteDropUnitDimsPattern 300 : public OpRewritePattern<vector::TransferWriteOp> { 301 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 302 303 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 304 PatternRewriter &rewriter) const override { 305 auto loc = transferWriteOp.getLoc(); 306 Value vector = transferWriteOp.getVector(); 307 VectorType vectorType = vector.getType().cast<VectorType>(); 308 Value source = transferWriteOp.getSource(); 309 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 310 // TODO: support tensor type. 311 if (!sourceType || !sourceType.hasStaticShape()) 312 return failure(); 313 if (sourceType.getNumElements() != vectorType.getNumElements()) 314 return failure(); 315 // TODO: generalize this pattern, relax the requirements here. 316 if (transferWriteOp.hasOutOfBoundsDim()) 317 return failure(); 318 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 319 return failure(); 320 int reducedRank = getReducedRank(sourceType.getShape()); 321 if (reducedRank == sourceType.getRank()) 322 return failure(); // The source shape can't be further reduced. 323 if (reducedRank != vectorType.getRank()) 324 return failure(); // This pattern requires the vector shape to match the 325 // reduced source shape. 326 if (llvm::any_of(transferWriteOp.getIndices(), 327 [](Value v) { return !isZero(v); })) 328 return failure(); 329 Value reducedShapeSource = 330 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 331 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 332 SmallVector<Value> zeros(reducedRank, c0); 333 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 334 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 335 transferWriteOp, vector, reducedShapeSource, zeros, identityMap); 336 return success(); 337 } 338 }; 339 340 /// Creates a memref.collapse_shape collapsing all of the dimensions of the 341 /// input into a 1D shape. 342 // TODO: move helper function 343 static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter, 344 mlir::Location loc, 345 Value input) { 346 Value rankReducedInput = 347 rankReducingSubviewDroppingUnitDims(rewriter, loc, input); 348 ShapedType rankReducedInputType = 349 rankReducedInput.getType().cast<ShapedType>(); 350 if (rankReducedInputType.getRank() == 1) 351 return rankReducedInput; 352 ReassociationIndices indices; 353 for (int i = 0; i < rankReducedInputType.getRank(); ++i) 354 indices.push_back(i); 355 return rewriter.create<memref::CollapseShapeOp>( 356 loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices}); 357 } 358 359 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 360 /// memref.collapse_shape on the source so that the resulting 361 /// vector.transfer_read has a 1D source. Requires the source shape to be 362 /// already reduced i.e. without unit dims. 363 class FlattenContiguousRowMajorTransferReadPattern 364 : public OpRewritePattern<vector::TransferReadOp> { 365 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 366 367 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 368 PatternRewriter &rewriter) const override { 369 auto loc = transferReadOp.getLoc(); 370 Value vector = transferReadOp.getVector(); 371 VectorType vectorType = vector.getType().cast<VectorType>(); 372 Value source = transferReadOp.getSource(); 373 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 374 // Contiguity check is valid on tensors only. 375 if (!sourceType) 376 return failure(); 377 if (vectorType.getRank() <= 1) 378 // Already 0D/1D, nothing to do. 379 return failure(); 380 if (!isStaticShapeAndContiguousRowMajor(sourceType)) 381 return failure(); 382 if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) 383 // This pattern requires the source to already be rank-reduced. 384 return failure(); 385 if (sourceType.getNumElements() != vectorType.getNumElements()) 386 return failure(); 387 // TODO: generalize this pattern, relax the requirements here. 388 if (transferReadOp.hasOutOfBoundsDim()) 389 return failure(); 390 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 391 return failure(); 392 if (transferReadOp.getMask()) 393 return failure(); 394 if (llvm::any_of(transferReadOp.getIndices(), 395 [](Value v) { return !isZero(v); })) 396 return failure(); 397 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 398 auto identityMap1D = rewriter.getMultiDimIdentityMap(1); 399 VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, 400 sourceType.getElementType()); 401 Value source1d = 402 collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); 403 Value read1d = rewriter.create<vector::TransferReadOp>( 404 loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); 405 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 406 transferReadOp, vector.getType().cast<VectorType>(), read1d); 407 return success(); 408 } 409 }; 410 411 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 412 /// memref.collapse_shape on the source so that the resulting 413 /// vector.transfer_write has a 1D source. Requires the source shape to be 414 /// already reduced i.e. without unit dims. 415 class FlattenContiguousRowMajorTransferWritePattern 416 : public OpRewritePattern<vector::TransferWriteOp> { 417 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 418 419 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 420 PatternRewriter &rewriter) const override { 421 auto loc = transferWriteOp.getLoc(); 422 Value vector = transferWriteOp.getVector(); 423 VectorType vectorType = vector.getType().cast<VectorType>(); 424 Value source = transferWriteOp.getSource(); 425 MemRefType sourceType = source.getType().dyn_cast<MemRefType>(); 426 // Contiguity check is valid on tensors only. 427 if (!sourceType) 428 return failure(); 429 if (vectorType.getRank() <= 1) 430 // Already 0D/1D, nothing to do. 431 return failure(); 432 if (!isStaticShapeAndContiguousRowMajor(sourceType)) 433 return failure(); 434 if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) 435 // This pattern requires the source to already be rank-reduced. 436 return failure(); 437 if (sourceType.getNumElements() != vectorType.getNumElements()) 438 return failure(); 439 // TODO: generalize this pattern, relax the requirements here. 440 if (transferWriteOp.hasOutOfBoundsDim()) 441 return failure(); 442 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 443 return failure(); 444 if (transferWriteOp.getMask()) 445 return failure(); 446 if (llvm::any_of(transferWriteOp.getIndices(), 447 [](Value v) { return !isZero(v); })) 448 return failure(); 449 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 450 auto identityMap1D = rewriter.getMultiDimIdentityMap(1); 451 VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, 452 sourceType.getElementType()); 453 Value source1d = 454 collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); 455 Value vector1d = 456 rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector); 457 rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d, 458 ValueRange{c0}, identityMap1D); 459 rewriter.eraseOp(transferWriteOp); 460 return success(); 461 } 462 }; 463 464 } // namespace 465 466 void mlir::vector::transferOpflowOpt(Operation *rootOp) { 467 TransferOptimization opt(rootOp); 468 // Run store to load forwarding first since it can expose more dead store 469 // opportunity. 470 rootOp->walk([&](vector::TransferReadOp read) { 471 if (read.getShapedType().isa<MemRefType>()) 472 opt.storeToLoadForwarding(read); 473 }); 474 opt.removeDeadOp(); 475 rootOp->walk([&](vector::TransferWriteOp write) { 476 if (write.getShapedType().isa<MemRefType>()) 477 opt.deadStoreOp(write); 478 }); 479 opt.removeDeadOp(); 480 } 481 482 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 483 RewritePatternSet &patterns) { 484 patterns 485 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 486 patterns.getContext()); 487 populateShapeCastFoldingPatterns(patterns); 488 } 489 490 void mlir::vector::populateFlattenVectorTransferPatterns( 491 RewritePatternSet &patterns) { 492 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 493 FlattenContiguousRowMajorTransferWritePattern>( 494 patterns.getContext()); 495 populateShapeCastFoldingPatterns(patterns); 496 } 497