199ef9eebSMatthias Springer //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements functions concerned with optimizing transfer_read and
1099ef9eebSMatthias Springer // transfer_write ops.
1199ef9eebSMatthias Springer //
1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
136a8ba318SRiver Riddle 
14eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1599ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1899ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1999ef9eebSMatthias Springer #include "mlir/IR/BuiltinOps.h"
2099ef9eebSMatthias Springer #include "mlir/IR/Dominance.h"
2199ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h"
2299ef9eebSMatthias Springer #include "llvm/ADT/StringRef.h"
2399ef9eebSMatthias Springer #include "llvm/Support/Debug.h"
2499ef9eebSMatthias Springer 
2599ef9eebSMatthias Springer #define DEBUG_TYPE "vector-transfer-opt"
2699ef9eebSMatthias Springer 
2799ef9eebSMatthias Springer #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
2899ef9eebSMatthias Springer 
2999ef9eebSMatthias Springer using namespace mlir;
3099ef9eebSMatthias Springer 
3199ef9eebSMatthias Springer /// Return the ancestor op in the region or nullptr if the region is not
3299ef9eebSMatthias Springer /// an ancestor of the op.
findAncestorOpInRegion(Region * region,Operation * op)3399ef9eebSMatthias Springer static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
3499ef9eebSMatthias Springer   for (; op != nullptr && op->getParentRegion() != region;
3599ef9eebSMatthias Springer        op = op->getParentOp())
3699ef9eebSMatthias Springer     ;
3799ef9eebSMatthias Springer   return op;
3899ef9eebSMatthias Springer }
3999ef9eebSMatthias Springer 
4099ef9eebSMatthias Springer namespace {
4199ef9eebSMatthias Springer 
4299ef9eebSMatthias Springer class TransferOptimization {
4399ef9eebSMatthias Springer public:
TransferOptimization(Operation * op)44171850c5SRiver Riddle   TransferOptimization(Operation *op) : dominators(op), postDominators(op) {}
4599ef9eebSMatthias Springer   void deadStoreOp(vector::TransferWriteOp);
4699ef9eebSMatthias Springer   void storeToLoadForwarding(vector::TransferReadOp);
removeDeadOp()4799ef9eebSMatthias Springer   void removeDeadOp() {
4899ef9eebSMatthias Springer     for (Operation *op : opToErase)
4999ef9eebSMatthias Springer       op->erase();
5099ef9eebSMatthias Springer     opToErase.clear();
5199ef9eebSMatthias Springer   }
5299ef9eebSMatthias Springer 
5399ef9eebSMatthias Springer private:
5499ef9eebSMatthias Springer   bool isReachable(Operation *start, Operation *dest);
5599ef9eebSMatthias Springer   DominanceInfo dominators;
5699ef9eebSMatthias Springer   PostDominanceInfo postDominators;
5799ef9eebSMatthias Springer   std::vector<Operation *> opToErase;
5899ef9eebSMatthias Springer };
5999ef9eebSMatthias Springer 
6099ef9eebSMatthias Springer /// Return true if there is a path from start operation to dest operation,
6199ef9eebSMatthias Springer /// otherwise return false. The operations have to be in the same region.
isReachable(Operation * start,Operation * dest)6299ef9eebSMatthias Springer bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
6399ef9eebSMatthias Springer   assert(start->getParentRegion() == dest->getParentRegion() &&
6499ef9eebSMatthias Springer          "This function only works for ops i the same region");
6599ef9eebSMatthias Springer   // Simple case where the start op dominate the destination.
6699ef9eebSMatthias Springer   if (dominators.dominates(start, dest))
6799ef9eebSMatthias Springer     return true;
6899ef9eebSMatthias Springer   Block *startBlock = start->getBlock();
6999ef9eebSMatthias Springer   Block *destBlock = dest->getBlock();
7099ef9eebSMatthias Springer   SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
7199ef9eebSMatthias Springer                                     startBlock->succ_end());
7299ef9eebSMatthias Springer   SmallPtrSet<Block *, 32> visited;
7399ef9eebSMatthias Springer   while (!worklist.empty()) {
7499ef9eebSMatthias Springer     Block *bb = worklist.pop_back_val();
7599ef9eebSMatthias Springer     if (!visited.insert(bb).second)
7699ef9eebSMatthias Springer       continue;
7799ef9eebSMatthias Springer     if (dominators.dominates(bb, destBlock))
7899ef9eebSMatthias Springer       return true;
7999ef9eebSMatthias Springer     worklist.append(bb->succ_begin(), bb->succ_end());
8099ef9eebSMatthias Springer   }
8199ef9eebSMatthias Springer   return false;
8299ef9eebSMatthias Springer }
8399ef9eebSMatthias Springer 
8499ef9eebSMatthias Springer /// For transfer_write to overwrite fully another transfer_write must:
8599ef9eebSMatthias Springer /// 1. Access the same memref with the same indices and vector type.
8699ef9eebSMatthias Springer /// 2. Post-dominate the other transfer_write operation.
8799ef9eebSMatthias Springer /// If several candidates are available, one must be post-dominated by all the
8899ef9eebSMatthias Springer /// others since they are all post-dominating the same transfer_write. We only
8999ef9eebSMatthias Springer /// consider the transfer_write post-dominated by all the other candidates as
9099ef9eebSMatthias Springer /// this will be the first transfer_write executed after the potentially dead
9199ef9eebSMatthias Springer /// transfer_write.
9299ef9eebSMatthias Springer /// If we found such an overwriting transfer_write we know that the original
9399ef9eebSMatthias Springer /// transfer_write is dead if all reads that can be reached from the potentially
9499ef9eebSMatthias Springer /// dead transfer_write are dominated by the overwriting transfer_write.
deadStoreOp(vector::TransferWriteOp write)9599ef9eebSMatthias Springer void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
9699ef9eebSMatthias Springer   LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
9799ef9eebSMatthias Springer                     << "\n");
9899ef9eebSMatthias Springer   llvm::SmallVector<Operation *, 8> reads;
9999ef9eebSMatthias Springer   Operation *firstOverwriteCandidate = nullptr;
1007c38fd60SJacques Pienaar   for (auto *user : write.getSource().getUsers()) {
10199ef9eebSMatthias Springer     if (user == write.getOperation())
10299ef9eebSMatthias Springer       continue;
10399ef9eebSMatthias Springer     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
10499ef9eebSMatthias Springer       // Check candidate that can override the store.
10599ef9eebSMatthias Springer       if (checkSameValueWAW(nextWrite, write) &&
10699ef9eebSMatthias Springer           postDominators.postDominates(nextWrite, write)) {
10799ef9eebSMatthias Springer         if (firstOverwriteCandidate == nullptr ||
10899ef9eebSMatthias Springer             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
10999ef9eebSMatthias Springer           firstOverwriteCandidate = nextWrite;
11099ef9eebSMatthias Springer         else
11199ef9eebSMatthias Springer           assert(
11299ef9eebSMatthias Springer               postDominators.postDominates(nextWrite, firstOverwriteCandidate));
11399ef9eebSMatthias Springer       }
11499ef9eebSMatthias Springer     } else {
11599ef9eebSMatthias Springer       if (auto read = dyn_cast<vector::TransferReadOp>(user)) {
11699ef9eebSMatthias Springer         // Don't need to consider disjoint reads.
11799ef9eebSMatthias Springer         if (vector::isDisjointTransferSet(
11899ef9eebSMatthias Springer                 cast<VectorTransferOpInterface>(write.getOperation()),
11999ef9eebSMatthias Springer                 cast<VectorTransferOpInterface>(read.getOperation())))
12099ef9eebSMatthias Springer           continue;
12199ef9eebSMatthias Springer       }
12299ef9eebSMatthias Springer       reads.push_back(user);
12399ef9eebSMatthias Springer     }
12499ef9eebSMatthias Springer   }
12599ef9eebSMatthias Springer   if (firstOverwriteCandidate == nullptr)
12699ef9eebSMatthias Springer     return;
12799ef9eebSMatthias Springer   Region *topRegion = firstOverwriteCandidate->getParentRegion();
12899ef9eebSMatthias Springer   Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
12999ef9eebSMatthias Springer   assert(writeAncestor &&
13099ef9eebSMatthias Springer          "write op should be recursively part of the top region");
13199ef9eebSMatthias Springer 
13299ef9eebSMatthias Springer   for (Operation *read : reads) {
13399ef9eebSMatthias Springer     Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
13499ef9eebSMatthias Springer     // TODO: if the read and write have the same ancestor we could recurse in
13599ef9eebSMatthias Springer     // the region to know if the read is reachable with more precision.
13699ef9eebSMatthias Springer     if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
13799ef9eebSMatthias Springer       continue;
13899ef9eebSMatthias Springer     if (!dominators.dominates(firstOverwriteCandidate, read)) {
13999ef9eebSMatthias Springer       LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read
14099ef9eebSMatthias Springer                         << "\n");
14199ef9eebSMatthias Springer       return;
14299ef9eebSMatthias Springer     }
14399ef9eebSMatthias Springer   }
14499ef9eebSMatthias Springer   LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
14599ef9eebSMatthias Springer                     << " overwritten by: " << *firstOverwriteCandidate << "\n");
14699ef9eebSMatthias Springer   opToErase.push_back(write.getOperation());
14799ef9eebSMatthias Springer }
14899ef9eebSMatthias Springer 
14999ef9eebSMatthias Springer /// A transfer_write candidate to storeToLoad forwarding must:
15099ef9eebSMatthias Springer /// 1. Access the same memref with the same indices and vector type as the
15199ef9eebSMatthias Springer /// transfer_read.
15299ef9eebSMatthias Springer /// 2. Dominate the transfer_read operation.
15399ef9eebSMatthias Springer /// If several candidates are available, one must be dominated by all the others
15499ef9eebSMatthias Springer /// since they are all dominating the same transfer_read. We only consider the
15599ef9eebSMatthias Springer /// transfer_write dominated by all the other candidates as this will be the
15699ef9eebSMatthias Springer /// last transfer_write executed before the transfer_read.
15799ef9eebSMatthias Springer /// If we found such a candidate we can do the forwarding if all the other
15899ef9eebSMatthias Springer /// potentially aliasing ops that may reach the transfer_read are post-dominated
15999ef9eebSMatthias Springer /// by the transfer_write.
storeToLoadForwarding(vector::TransferReadOp read)16099ef9eebSMatthias Springer void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
16199ef9eebSMatthias Springer   if (read.hasOutOfBoundsDim())
16299ef9eebSMatthias Springer     return;
16399ef9eebSMatthias Springer   LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
16499ef9eebSMatthias Springer                     << "\n");
16599ef9eebSMatthias Springer   SmallVector<Operation *, 8> blockingWrites;
16699ef9eebSMatthias Springer   vector::TransferWriteOp lastwrite = nullptr;
1677c38fd60SJacques Pienaar   for (Operation *user : read.getSource().getUsers()) {
16899ef9eebSMatthias Springer     if (isa<vector::TransferReadOp>(user))
16999ef9eebSMatthias Springer       continue;
17099ef9eebSMatthias Springer     if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
17199ef9eebSMatthias Springer       // If there is a write, but we can prove that it is disjoint we can ignore
17299ef9eebSMatthias Springer       // the write.
17399ef9eebSMatthias Springer       if (vector::isDisjointTransferSet(
17499ef9eebSMatthias Springer               cast<VectorTransferOpInterface>(write.getOperation()),
17599ef9eebSMatthias Springer               cast<VectorTransferOpInterface>(read.getOperation())))
17699ef9eebSMatthias Springer         continue;
17799ef9eebSMatthias Springer       if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
17899ef9eebSMatthias Springer         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
17999ef9eebSMatthias Springer           lastwrite = write;
18099ef9eebSMatthias Springer         else
18199ef9eebSMatthias Springer           assert(dominators.dominates(write, lastwrite));
18299ef9eebSMatthias Springer         continue;
18399ef9eebSMatthias Springer       }
18499ef9eebSMatthias Springer     }
18599ef9eebSMatthias Springer     blockingWrites.push_back(user);
18699ef9eebSMatthias Springer   }
18799ef9eebSMatthias Springer 
18899ef9eebSMatthias Springer   if (lastwrite == nullptr)
18999ef9eebSMatthias Springer     return;
19099ef9eebSMatthias Springer 
19199ef9eebSMatthias Springer   Region *topRegion = lastwrite->getParentRegion();
19299ef9eebSMatthias Springer   Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
19399ef9eebSMatthias Springer   assert(readAncestor &&
19499ef9eebSMatthias Springer          "read op should be recursively part of the top region");
19599ef9eebSMatthias Springer 
19699ef9eebSMatthias Springer   for (Operation *write : blockingWrites) {
19799ef9eebSMatthias Springer     Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
19899ef9eebSMatthias Springer     // TODO: if the store and read have the same ancestor we could recurse in
19999ef9eebSMatthias Springer     // the region to know if the read is reachable with more precision.
20099ef9eebSMatthias Springer     if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
20199ef9eebSMatthias Springer       continue;
20299ef9eebSMatthias Springer     if (!postDominators.postDominates(lastwrite, write)) {
20399ef9eebSMatthias Springer       LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
20499ef9eebSMatthias Springer                         << *write << "\n");
20599ef9eebSMatthias Springer       return;
20699ef9eebSMatthias Springer     }
20799ef9eebSMatthias Springer   }
20899ef9eebSMatthias Springer 
20999ef9eebSMatthias Springer   LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
21099ef9eebSMatthias Springer                     << " to: " << *read.getOperation() << "\n");
2117c38fd60SJacques Pienaar   read.replaceAllUsesWith(lastwrite.getVector());
21299ef9eebSMatthias Springer   opToErase.push_back(read.getOperation());
21399ef9eebSMatthias Springer }
21499ef9eebSMatthias Springer 
21599ef9eebSMatthias Springer /// Drops unit dimensions from the input MemRefType.
dropUnitDims(MemRefType inputType,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)21699ef9eebSMatthias Springer static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
21799ef9eebSMatthias Springer                                ArrayRef<int64_t> sizes,
21899ef9eebSMatthias Springer                                ArrayRef<int64_t> strides) {
2196c3c5f80SMatthias Springer   SmallVector<int64_t> targetShape = llvm::to_vector(
2206c3c5f80SMatthias Springer       llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
22199ef9eebSMatthias Springer   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
2226c3c5f80SMatthias Springer       targetShape, inputType, offsets, sizes, strides);
22399ef9eebSMatthias Springer   return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
22499ef9eebSMatthias Springer }
22599ef9eebSMatthias Springer 
22699ef9eebSMatthias Springer /// Creates a rank-reducing memref.subview op that drops unit dims from its
22799ef9eebSMatthias Springer /// input. Or just returns the input if it was already without unit dims.
rankReducingSubviewDroppingUnitDims(PatternRewriter & rewriter,mlir::Location loc,Value input)22899ef9eebSMatthias Springer static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
22999ef9eebSMatthias Springer                                                  mlir::Location loc,
23099ef9eebSMatthias Springer                                                  Value input) {
23199ef9eebSMatthias Springer   MemRefType inputType = input.getType().cast<MemRefType>();
23299ef9eebSMatthias Springer   assert(inputType.hasStaticShape());
23399ef9eebSMatthias Springer   SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
23499ef9eebSMatthias Springer   SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
23599ef9eebSMatthias Springer   ArrayRef<int64_t> subViewSizes = inputType.getShape();
23699ef9eebSMatthias Springer   MemRefType resultType =
23799ef9eebSMatthias Springer       dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
23899ef9eebSMatthias Springer   if (canonicalizeStridedLayout(resultType) ==
23999ef9eebSMatthias Springer       canonicalizeStridedLayout(inputType))
24099ef9eebSMatthias Springer     return input;
24199ef9eebSMatthias Springer   return rewriter.create<memref::SubViewOp>(
24299ef9eebSMatthias Springer       loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
24399ef9eebSMatthias Springer }
24499ef9eebSMatthias Springer 
24599ef9eebSMatthias Springer /// Returns the number of dims that aren't unit dims.
getReducedRank(ArrayRef<int64_t> shape)24699ef9eebSMatthias Springer static int getReducedRank(ArrayRef<int64_t> shape) {
24799ef9eebSMatthias Springer   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
24899ef9eebSMatthias Springer }
24999ef9eebSMatthias Springer 
25099ef9eebSMatthias Springer /// Returns true if all values are `arith.constant 0 : index`
isZero(Value v)25199ef9eebSMatthias Springer static bool isZero(Value v) {
25299ef9eebSMatthias Springer   auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
25399ef9eebSMatthias Springer   return cst && cst.value() == 0;
25499ef9eebSMatthias Springer }
25599ef9eebSMatthias Springer 
25699ef9eebSMatthias Springer /// Rewrites vector.transfer_read ops where the source has unit dims, by
25799ef9eebSMatthias Springer /// inserting a memref.subview dropping those unit dims.
25899ef9eebSMatthias Springer class TransferReadDropUnitDimsPattern
25999ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferReadOp> {
26099ef9eebSMatthias Springer   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
26199ef9eebSMatthias Springer 
matchAndRewrite(vector::TransferReadOp transferReadOp,PatternRewriter & rewriter) const26299ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
26399ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
26499ef9eebSMatthias Springer     auto loc = transferReadOp.getLoc();
2657c38fd60SJacques Pienaar     Value vector = transferReadOp.getVector();
26699ef9eebSMatthias Springer     VectorType vectorType = vector.getType().cast<VectorType>();
2677c38fd60SJacques Pienaar     Value source = transferReadOp.getSource();
26899ef9eebSMatthias Springer     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
26999ef9eebSMatthias Springer     // TODO: support tensor types.
27099ef9eebSMatthias Springer     if (!sourceType || !sourceType.hasStaticShape())
27199ef9eebSMatthias Springer       return failure();
27299ef9eebSMatthias Springer     if (sourceType.getNumElements() != vectorType.getNumElements())
27399ef9eebSMatthias Springer       return failure();
27499ef9eebSMatthias Springer     // TODO: generalize this pattern, relax the requirements here.
27599ef9eebSMatthias Springer     if (transferReadOp.hasOutOfBoundsDim())
27699ef9eebSMatthias Springer       return failure();
2777c38fd60SJacques Pienaar     if (!transferReadOp.getPermutationMap().isMinorIdentity())
27899ef9eebSMatthias Springer       return failure();
27999ef9eebSMatthias Springer     int reducedRank = getReducedRank(sourceType.getShape());
28099ef9eebSMatthias Springer     if (reducedRank == sourceType.getRank())
28199ef9eebSMatthias Springer       return failure(); // The source shape can't be further reduced.
28299ef9eebSMatthias Springer     if (reducedRank != vectorType.getRank())
28399ef9eebSMatthias Springer       return failure(); // This pattern requires the vector shape to match the
28499ef9eebSMatthias Springer                         // reduced source shape.
2857c38fd60SJacques Pienaar     if (llvm::any_of(transferReadOp.getIndices(),
28699ef9eebSMatthias Springer                      [](Value v) { return !isZero(v); }))
28799ef9eebSMatthias Springer       return failure();
28899ef9eebSMatthias Springer     Value reducedShapeSource =
28999ef9eebSMatthias Springer         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
29099ef9eebSMatthias Springer     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
29199ef9eebSMatthias Springer     SmallVector<Value> zeros(reducedRank, c0);
29299ef9eebSMatthias Springer     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
29399ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
29499ef9eebSMatthias Springer         transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
29599ef9eebSMatthias Springer     return success();
29699ef9eebSMatthias Springer   }
29799ef9eebSMatthias Springer };
29899ef9eebSMatthias Springer 
29999ef9eebSMatthias Springer /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
30099ef9eebSMatthias Springer /// unit dims, by inserting a memref.subview dropping those unit dims.
30199ef9eebSMatthias Springer class TransferWriteDropUnitDimsPattern
30299ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferWriteOp> {
30399ef9eebSMatthias Springer   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
30499ef9eebSMatthias Springer 
matchAndRewrite(vector::TransferWriteOp transferWriteOp,PatternRewriter & rewriter) const30599ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
30699ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
30799ef9eebSMatthias Springer     auto loc = transferWriteOp.getLoc();
3087c38fd60SJacques Pienaar     Value vector = transferWriteOp.getVector();
30999ef9eebSMatthias Springer     VectorType vectorType = vector.getType().cast<VectorType>();
3107c38fd60SJacques Pienaar     Value source = transferWriteOp.getSource();
31199ef9eebSMatthias Springer     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
31299ef9eebSMatthias Springer     // TODO: support tensor type.
31399ef9eebSMatthias Springer     if (!sourceType || !sourceType.hasStaticShape())
31499ef9eebSMatthias Springer       return failure();
31599ef9eebSMatthias Springer     if (sourceType.getNumElements() != vectorType.getNumElements())
31699ef9eebSMatthias Springer       return failure();
31799ef9eebSMatthias Springer     // TODO: generalize this pattern, relax the requirements here.
31899ef9eebSMatthias Springer     if (transferWriteOp.hasOutOfBoundsDim())
31999ef9eebSMatthias Springer       return failure();
3207c38fd60SJacques Pienaar     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
32199ef9eebSMatthias Springer       return failure();
32299ef9eebSMatthias Springer     int reducedRank = getReducedRank(sourceType.getShape());
32399ef9eebSMatthias Springer     if (reducedRank == sourceType.getRank())
32499ef9eebSMatthias Springer       return failure(); // The source shape can't be further reduced.
32599ef9eebSMatthias Springer     if (reducedRank != vectorType.getRank())
32699ef9eebSMatthias Springer       return failure(); // This pattern requires the vector shape to match the
32799ef9eebSMatthias Springer                         // reduced source shape.
3287c38fd60SJacques Pienaar     if (llvm::any_of(transferWriteOp.getIndices(),
32999ef9eebSMatthias Springer                      [](Value v) { return !isZero(v); }))
33099ef9eebSMatthias Springer       return failure();
33199ef9eebSMatthias Springer     Value reducedShapeSource =
33299ef9eebSMatthias Springer         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
33399ef9eebSMatthias Springer     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
33499ef9eebSMatthias Springer     SmallVector<Value> zeros(reducedRank, c0);
33599ef9eebSMatthias Springer     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
33699ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
33799ef9eebSMatthias Springer         transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
33899ef9eebSMatthias Springer     return success();
33999ef9eebSMatthias Springer   }
34099ef9eebSMatthias Springer };
34199ef9eebSMatthias Springer 
342*f4ac9509SBenoit Jacob /// Returns the position of the first inner dimension that has contiguous layout
343*f4ac9509SBenoit Jacob /// with at least `requiredContiguousSize` contiguous elements.
344*f4ac9509SBenoit Jacob /// When such a dimension is found, the return value satisfies:
345*f4ac9509SBenoit Jacob ///   0 <= return_value <= memrefType.getRank() - 1.
346*f4ac9509SBenoit Jacob /// When no such dimension is found, the return value is memrefType.getRank().
getContiguousInnerDim(MemRefType memrefType,int64_t requiredContiguousSize)347*f4ac9509SBenoit Jacob static int64_t getContiguousInnerDim(MemRefType memrefType,
348*f4ac9509SBenoit Jacob                                      int64_t requiredContiguousSize) {
349*f4ac9509SBenoit Jacob   auto shape = memrefType.getShape();
350*f4ac9509SBenoit Jacob   SmallVector<int64_t> strides;
351*f4ac9509SBenoit Jacob   int64_t offset;
352*f4ac9509SBenoit Jacob   int64_t innerDim = shape.size();
353*f4ac9509SBenoit Jacob   if (succeeded(getStridesAndOffset(memrefType, strides, offset))) {
354*f4ac9509SBenoit Jacob     int64_t innerSize = 1;
355*f4ac9509SBenoit Jacob     while (true) {
356*f4ac9509SBenoit Jacob       if (innerDim == 0)
357*f4ac9509SBenoit Jacob         break;
358*f4ac9509SBenoit Jacob       const int64_t nextDim = innerDim - 1;
359*f4ac9509SBenoit Jacob       if (shape[nextDim] == ShapedType::kDynamicSize)
360*f4ac9509SBenoit Jacob         break;
361*f4ac9509SBenoit Jacob       if (strides[nextDim] != innerSize)
362*f4ac9509SBenoit Jacob         break;
363*f4ac9509SBenoit Jacob       innerSize *= shape[nextDim];
364*f4ac9509SBenoit Jacob       innerDim = nextDim;
365*f4ac9509SBenoit Jacob       if (innerSize >= requiredContiguousSize)
366*f4ac9509SBenoit Jacob         break;
367*f4ac9509SBenoit Jacob     }
368*f4ac9509SBenoit Jacob   }
369*f4ac9509SBenoit Jacob   return innerDim;
370*f4ac9509SBenoit Jacob }
371*f4ac9509SBenoit Jacob 
372*f4ac9509SBenoit Jacob /// Creates a memref.collapse_shape collapsing all inner dimensions of the
373*f4ac9509SBenoit Jacob /// input starting at `firstDimToCollapse`.
collapseInnerDims(PatternRewriter & rewriter,mlir::Location loc,Value input,int64_t firstDimToCollapse)374*f4ac9509SBenoit Jacob static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
375*f4ac9509SBenoit Jacob                                Value input, int64_t firstDimToCollapse) {
376*f4ac9509SBenoit Jacob   ShapedType inputType = input.getType().cast<ShapedType>();
377*f4ac9509SBenoit Jacob   if (inputType.getRank() == 1)
378*f4ac9509SBenoit Jacob     return input;
379*f4ac9509SBenoit Jacob   SmallVector<ReassociationIndices> reassociation;
380*f4ac9509SBenoit Jacob   for (int64_t i = 0; i < firstDimToCollapse; ++i)
381*f4ac9509SBenoit Jacob     reassociation.push_back(ReassociationIndices{i});
382*f4ac9509SBenoit Jacob   ReassociationIndices collapsedIndices;
383*f4ac9509SBenoit Jacob   for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
384*f4ac9509SBenoit Jacob     collapsedIndices.push_back(i);
385*f4ac9509SBenoit Jacob   reassociation.push_back(collapsedIndices);
386*f4ac9509SBenoit Jacob   return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
387*f4ac9509SBenoit Jacob }
388*f4ac9509SBenoit Jacob 
389*f4ac9509SBenoit Jacob /// Checks that the indices corresponding to dimensions starting at
390*f4ac9509SBenoit Jacob /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
391*f4ac9509SBenoit Jacob /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
392*f4ac9509SBenoit Jacob static LogicalResult
checkAndCollapseInnerZeroIndices(ValueRange indices,int64_t firstDimToCollapse,SmallVector<Value> & outIndices)393*f4ac9509SBenoit Jacob checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
394*f4ac9509SBenoit Jacob                                  SmallVector<Value> &outIndices) {
395*f4ac9509SBenoit Jacob   int64_t rank = indices.size();
396*f4ac9509SBenoit Jacob   if (firstDimToCollapse >= rank)
397*f4ac9509SBenoit Jacob     return failure();
398*f4ac9509SBenoit Jacob   for (int64_t i = firstDimToCollapse; i < rank; ++i) {
399*f4ac9509SBenoit Jacob     arith::ConstantIndexOp cst =
400*f4ac9509SBenoit Jacob         indices[i].getDefiningOp<arith::ConstantIndexOp>();
401*f4ac9509SBenoit Jacob     if (!cst || cst.value() != 0)
402*f4ac9509SBenoit Jacob       return failure();
403*f4ac9509SBenoit Jacob   }
404*f4ac9509SBenoit Jacob   outIndices = indices;
405*f4ac9509SBenoit Jacob   outIndices.resize(firstDimToCollapse + 1);
406*f4ac9509SBenoit Jacob   return success();
40799ef9eebSMatthias Springer }
40899ef9eebSMatthias Springer 
40999ef9eebSMatthias Springer /// Rewrites contiguous row-major vector.transfer_read ops by inserting
41099ef9eebSMatthias Springer /// memref.collapse_shape on the source so that the resulting
41199ef9eebSMatthias Springer /// vector.transfer_read has a 1D source. Requires the source shape to be
41299ef9eebSMatthias Springer /// already reduced i.e. without unit dims.
41399ef9eebSMatthias Springer class FlattenContiguousRowMajorTransferReadPattern
41499ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferReadOp> {
41599ef9eebSMatthias Springer   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
41699ef9eebSMatthias Springer 
matchAndRewrite(vector::TransferReadOp transferReadOp,PatternRewriter & rewriter) const41799ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
41899ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
41999ef9eebSMatthias Springer     auto loc = transferReadOp.getLoc();
4207c38fd60SJacques Pienaar     Value vector = transferReadOp.getVector();
42199ef9eebSMatthias Springer     VectorType vectorType = vector.getType().cast<VectorType>();
4227c38fd60SJacques Pienaar     Value source = transferReadOp.getSource();
42399ef9eebSMatthias Springer     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
42499ef9eebSMatthias Springer     // Contiguity check is valid on tensors only.
42599ef9eebSMatthias Springer     if (!sourceType)
42699ef9eebSMatthias Springer       return failure();
4274a876b13Sharsh     if (vectorType.getRank() <= 1)
4284a876b13Sharsh       // Already 0D/1D, nothing to do.
42999ef9eebSMatthias Springer       return failure();
430*f4ac9509SBenoit Jacob     int64_t firstContiguousInnerDim =
431*f4ac9509SBenoit Jacob         getContiguousInnerDim(sourceType, vectorType.getNumElements());
432*f4ac9509SBenoit Jacob     if (firstContiguousInnerDim >= sourceType.getRank() - 1)
43399ef9eebSMatthias Springer       return failure();
43499ef9eebSMatthias Springer     // TODO: generalize this pattern, relax the requirements here.
43599ef9eebSMatthias Springer     if (transferReadOp.hasOutOfBoundsDim())
43699ef9eebSMatthias Springer       return failure();
4377c38fd60SJacques Pienaar     if (!transferReadOp.getPermutationMap().isMinorIdentity())
43899ef9eebSMatthias Springer       return failure();
4397c38fd60SJacques Pienaar     if (transferReadOp.getMask())
44099ef9eebSMatthias Springer       return failure();
441*f4ac9509SBenoit Jacob     SmallVector<Value> collapsedIndices;
442*f4ac9509SBenoit Jacob     if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
443*f4ac9509SBenoit Jacob                                                 firstContiguousInnerDim,
444*f4ac9509SBenoit Jacob                                                 collapsedIndices)))
44599ef9eebSMatthias Springer       return failure();
446*f4ac9509SBenoit Jacob     Value collapsedSource =
447*f4ac9509SBenoit Jacob         collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
448*f4ac9509SBenoit Jacob     MemRefType collapsedSourceType =
449*f4ac9509SBenoit Jacob         collapsedSource.getType().dyn_cast<MemRefType>();
450*f4ac9509SBenoit Jacob     int64_t collapsedRank = collapsedSourceType.getRank();
451*f4ac9509SBenoit Jacob     assert(collapsedRank == firstContiguousInnerDim + 1);
452*f4ac9509SBenoit Jacob     SmallVector<AffineExpr, 1> dimExprs{
453*f4ac9509SBenoit Jacob         getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
454*f4ac9509SBenoit Jacob     auto collapsedMap =
455*f4ac9509SBenoit Jacob         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
456*f4ac9509SBenoit Jacob     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
457*f4ac9509SBenoit Jacob                                                 vectorType.getElementType());
458*f4ac9509SBenoit Jacob     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
459*f4ac9509SBenoit Jacob         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
460*f4ac9509SBenoit Jacob     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
46199ef9eebSMatthias Springer     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
462*f4ac9509SBenoit Jacob         transferReadOp, vector.getType().cast<VectorType>(), flatRead);
46399ef9eebSMatthias Springer     return success();
46499ef9eebSMatthias Springer   }
46599ef9eebSMatthias Springer };
46699ef9eebSMatthias Springer 
46799ef9eebSMatthias Springer /// Rewrites contiguous row-major vector.transfer_write ops by inserting
46899ef9eebSMatthias Springer /// memref.collapse_shape on the source so that the resulting
46999ef9eebSMatthias Springer /// vector.transfer_write has a 1D source. Requires the source shape to be
47099ef9eebSMatthias Springer /// already reduced i.e. without unit dims.
47199ef9eebSMatthias Springer class FlattenContiguousRowMajorTransferWritePattern
47299ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferWriteOp> {
47399ef9eebSMatthias Springer   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
47499ef9eebSMatthias Springer 
matchAndRewrite(vector::TransferWriteOp transferWriteOp,PatternRewriter & rewriter) const47599ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
47699ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
47799ef9eebSMatthias Springer     auto loc = transferWriteOp.getLoc();
4787c38fd60SJacques Pienaar     Value vector = transferWriteOp.getVector();
47999ef9eebSMatthias Springer     VectorType vectorType = vector.getType().cast<VectorType>();
4807c38fd60SJacques Pienaar     Value source = transferWriteOp.getSource();
48199ef9eebSMatthias Springer     MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
48299ef9eebSMatthias Springer     // Contiguity check is valid on tensors only.
48399ef9eebSMatthias Springer     if (!sourceType)
48499ef9eebSMatthias Springer       return failure();
4854a876b13Sharsh     if (vectorType.getRank() <= 1)
4864a876b13Sharsh       // Already 0D/1D, nothing to do.
48799ef9eebSMatthias Springer       return failure();
488*f4ac9509SBenoit Jacob     int64_t firstContiguousInnerDim =
489*f4ac9509SBenoit Jacob         getContiguousInnerDim(sourceType, vectorType.getNumElements());
490*f4ac9509SBenoit Jacob     if (firstContiguousInnerDim >= sourceType.getRank() - 1)
49199ef9eebSMatthias Springer       return failure();
49299ef9eebSMatthias Springer     // TODO: generalize this pattern, relax the requirements here.
49399ef9eebSMatthias Springer     if (transferWriteOp.hasOutOfBoundsDim())
49499ef9eebSMatthias Springer       return failure();
4957c38fd60SJacques Pienaar     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
49699ef9eebSMatthias Springer       return failure();
4977c38fd60SJacques Pienaar     if (transferWriteOp.getMask())
49899ef9eebSMatthias Springer       return failure();
499*f4ac9509SBenoit Jacob     SmallVector<Value> collapsedIndices;
500*f4ac9509SBenoit Jacob     if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
501*f4ac9509SBenoit Jacob                                                 firstContiguousInnerDim,
502*f4ac9509SBenoit Jacob                                                 collapsedIndices)))
50399ef9eebSMatthias Springer       return failure();
504*f4ac9509SBenoit Jacob     Value collapsedSource =
505*f4ac9509SBenoit Jacob         collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
506*f4ac9509SBenoit Jacob     MemRefType collapsedSourceType =
507*f4ac9509SBenoit Jacob         collapsedSource.getType().cast<MemRefType>();
508*f4ac9509SBenoit Jacob     int64_t collapsedRank = collapsedSourceType.getRank();
509*f4ac9509SBenoit Jacob     assert(collapsedRank == firstContiguousInnerDim + 1);
510*f4ac9509SBenoit Jacob     SmallVector<AffineExpr, 1> dimExprs{
511*f4ac9509SBenoit Jacob         getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
512*f4ac9509SBenoit Jacob     auto collapsedMap =
513*f4ac9509SBenoit Jacob         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
514*f4ac9509SBenoit Jacob     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
515*f4ac9509SBenoit Jacob                                                 vectorType.getElementType());
516*f4ac9509SBenoit Jacob     Value flatVector =
517*f4ac9509SBenoit Jacob         rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
518*f4ac9509SBenoit Jacob     vector::TransferWriteOp flatWrite =
519*f4ac9509SBenoit Jacob         rewriter.create<vector::TransferWriteOp>(
520*f4ac9509SBenoit Jacob             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
521*f4ac9509SBenoit Jacob     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
52299ef9eebSMatthias Springer     rewriter.eraseOp(transferWriteOp);
52399ef9eebSMatthias Springer     return success();
52499ef9eebSMatthias Springer   }
52599ef9eebSMatthias Springer };
52699ef9eebSMatthias Springer 
52799ef9eebSMatthias Springer } // namespace
52899ef9eebSMatthias Springer 
transferOpflowOpt(Operation * rootOp)529171850c5SRiver Riddle void mlir::vector::transferOpflowOpt(Operation *rootOp) {
530171850c5SRiver Riddle   TransferOptimization opt(rootOp);
53199ef9eebSMatthias Springer   // Run store to load forwarding first since it can expose more dead store
53299ef9eebSMatthias Springer   // opportunity.
533171850c5SRiver Riddle   rootOp->walk([&](vector::TransferReadOp read) {
53499ef9eebSMatthias Springer     if (read.getShapedType().isa<MemRefType>())
53599ef9eebSMatthias Springer       opt.storeToLoadForwarding(read);
53699ef9eebSMatthias Springer   });
53799ef9eebSMatthias Springer   opt.removeDeadOp();
538171850c5SRiver Riddle   rootOp->walk([&](vector::TransferWriteOp write) {
53999ef9eebSMatthias Springer     if (write.getShapedType().isa<MemRefType>())
54099ef9eebSMatthias Springer       opt.deadStoreOp(write);
54199ef9eebSMatthias Springer   });
54299ef9eebSMatthias Springer   opt.removeDeadOp();
54399ef9eebSMatthias Springer }
54499ef9eebSMatthias Springer 
populateVectorTransferDropUnitDimsPatterns(RewritePatternSet & patterns)54599ef9eebSMatthias Springer void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
54699ef9eebSMatthias Springer     RewritePatternSet &patterns) {
54799ef9eebSMatthias Springer   patterns
54899ef9eebSMatthias Springer       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
54999ef9eebSMatthias Springer           patterns.getContext());
55099ef9eebSMatthias Springer   populateShapeCastFoldingPatterns(patterns);
55199ef9eebSMatthias Springer }
55299ef9eebSMatthias Springer 
populateFlattenVectorTransferPatterns(RewritePatternSet & patterns)55399ef9eebSMatthias Springer void mlir::vector::populateFlattenVectorTransferPatterns(
55499ef9eebSMatthias Springer     RewritePatternSet &patterns) {
55599ef9eebSMatthias Springer   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
55699ef9eebSMatthias Springer                FlattenContiguousRowMajorTransferWritePattern>(
55799ef9eebSMatthias Springer       patterns.getContext());
55899ef9eebSMatthias Springer   populateShapeCastFoldingPatterns(patterns);
55999ef9eebSMatthias Springer }
560