199ef9eebSMatthias Springer //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements functions concerned with optimizing transfer_read and
1099ef9eebSMatthias Springer // transfer_write ops.
1199ef9eebSMatthias Springer //
1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
136a8ba318SRiver Riddle
14eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1599ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1899ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1999ef9eebSMatthias Springer #include "mlir/IR/BuiltinOps.h"
2099ef9eebSMatthias Springer #include "mlir/IR/Dominance.h"
2199ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h"
2299ef9eebSMatthias Springer #include "llvm/ADT/StringRef.h"
2399ef9eebSMatthias Springer #include "llvm/Support/Debug.h"
2499ef9eebSMatthias Springer
2599ef9eebSMatthias Springer #define DEBUG_TYPE "vector-transfer-opt"
2699ef9eebSMatthias Springer
2799ef9eebSMatthias Springer #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
2899ef9eebSMatthias Springer
2999ef9eebSMatthias Springer using namespace mlir;
3099ef9eebSMatthias Springer
3199ef9eebSMatthias Springer /// Return the ancestor op in the region or nullptr if the region is not
3299ef9eebSMatthias Springer /// an ancestor of the op.
findAncestorOpInRegion(Region * region,Operation * op)3399ef9eebSMatthias Springer static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
3499ef9eebSMatthias Springer for (; op != nullptr && op->getParentRegion() != region;
3599ef9eebSMatthias Springer op = op->getParentOp())
3699ef9eebSMatthias Springer ;
3799ef9eebSMatthias Springer return op;
3899ef9eebSMatthias Springer }
3999ef9eebSMatthias Springer
4099ef9eebSMatthias Springer namespace {
4199ef9eebSMatthias Springer
4299ef9eebSMatthias Springer class TransferOptimization {
4399ef9eebSMatthias Springer public:
TransferOptimization(Operation * op)44171850c5SRiver Riddle TransferOptimization(Operation *op) : dominators(op), postDominators(op) {}
4599ef9eebSMatthias Springer void deadStoreOp(vector::TransferWriteOp);
4699ef9eebSMatthias Springer void storeToLoadForwarding(vector::TransferReadOp);
removeDeadOp()4799ef9eebSMatthias Springer void removeDeadOp() {
4899ef9eebSMatthias Springer for (Operation *op : opToErase)
4999ef9eebSMatthias Springer op->erase();
5099ef9eebSMatthias Springer opToErase.clear();
5199ef9eebSMatthias Springer }
5299ef9eebSMatthias Springer
5399ef9eebSMatthias Springer private:
5499ef9eebSMatthias Springer bool isReachable(Operation *start, Operation *dest);
5599ef9eebSMatthias Springer DominanceInfo dominators;
5699ef9eebSMatthias Springer PostDominanceInfo postDominators;
5799ef9eebSMatthias Springer std::vector<Operation *> opToErase;
5899ef9eebSMatthias Springer };
5999ef9eebSMatthias Springer
6099ef9eebSMatthias Springer /// Return true if there is a path from start operation to dest operation,
6199ef9eebSMatthias Springer /// otherwise return false. The operations have to be in the same region.
isReachable(Operation * start,Operation * dest)6299ef9eebSMatthias Springer bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
6399ef9eebSMatthias Springer assert(start->getParentRegion() == dest->getParentRegion() &&
6499ef9eebSMatthias Springer "This function only works for ops i the same region");
6599ef9eebSMatthias Springer // Simple case where the start op dominate the destination.
6699ef9eebSMatthias Springer if (dominators.dominates(start, dest))
6799ef9eebSMatthias Springer return true;
6899ef9eebSMatthias Springer Block *startBlock = start->getBlock();
6999ef9eebSMatthias Springer Block *destBlock = dest->getBlock();
7099ef9eebSMatthias Springer SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
7199ef9eebSMatthias Springer startBlock->succ_end());
7299ef9eebSMatthias Springer SmallPtrSet<Block *, 32> visited;
7399ef9eebSMatthias Springer while (!worklist.empty()) {
7499ef9eebSMatthias Springer Block *bb = worklist.pop_back_val();
7599ef9eebSMatthias Springer if (!visited.insert(bb).second)
7699ef9eebSMatthias Springer continue;
7799ef9eebSMatthias Springer if (dominators.dominates(bb, destBlock))
7899ef9eebSMatthias Springer return true;
7999ef9eebSMatthias Springer worklist.append(bb->succ_begin(), bb->succ_end());
8099ef9eebSMatthias Springer }
8199ef9eebSMatthias Springer return false;
8299ef9eebSMatthias Springer }
8399ef9eebSMatthias Springer
8499ef9eebSMatthias Springer /// For transfer_write to overwrite fully another transfer_write must:
8599ef9eebSMatthias Springer /// 1. Access the same memref with the same indices and vector type.
8699ef9eebSMatthias Springer /// 2. Post-dominate the other transfer_write operation.
8799ef9eebSMatthias Springer /// If several candidates are available, one must be post-dominated by all the
8899ef9eebSMatthias Springer /// others since they are all post-dominating the same transfer_write. We only
8999ef9eebSMatthias Springer /// consider the transfer_write post-dominated by all the other candidates as
9099ef9eebSMatthias Springer /// this will be the first transfer_write executed after the potentially dead
9199ef9eebSMatthias Springer /// transfer_write.
9299ef9eebSMatthias Springer /// If we found such an overwriting transfer_write we know that the original
9399ef9eebSMatthias Springer /// transfer_write is dead if all reads that can be reached from the potentially
9499ef9eebSMatthias Springer /// dead transfer_write are dominated by the overwriting transfer_write.
deadStoreOp(vector::TransferWriteOp write)9599ef9eebSMatthias Springer void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
9699ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
9799ef9eebSMatthias Springer << "\n");
9899ef9eebSMatthias Springer llvm::SmallVector<Operation *, 8> reads;
9999ef9eebSMatthias Springer Operation *firstOverwriteCandidate = nullptr;
1007c38fd60SJacques Pienaar for (auto *user : write.getSource().getUsers()) {
10199ef9eebSMatthias Springer if (user == write.getOperation())
10299ef9eebSMatthias Springer continue;
10399ef9eebSMatthias Springer if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
10499ef9eebSMatthias Springer // Check candidate that can override the store.
10599ef9eebSMatthias Springer if (checkSameValueWAW(nextWrite, write) &&
10699ef9eebSMatthias Springer postDominators.postDominates(nextWrite, write)) {
10799ef9eebSMatthias Springer if (firstOverwriteCandidate == nullptr ||
10899ef9eebSMatthias Springer postDominators.postDominates(firstOverwriteCandidate, nextWrite))
10999ef9eebSMatthias Springer firstOverwriteCandidate = nextWrite;
11099ef9eebSMatthias Springer else
11199ef9eebSMatthias Springer assert(
11299ef9eebSMatthias Springer postDominators.postDominates(nextWrite, firstOverwriteCandidate));
11399ef9eebSMatthias Springer }
11499ef9eebSMatthias Springer } else {
11599ef9eebSMatthias Springer if (auto read = dyn_cast<vector::TransferReadOp>(user)) {
11699ef9eebSMatthias Springer // Don't need to consider disjoint reads.
11799ef9eebSMatthias Springer if (vector::isDisjointTransferSet(
11899ef9eebSMatthias Springer cast<VectorTransferOpInterface>(write.getOperation()),
11999ef9eebSMatthias Springer cast<VectorTransferOpInterface>(read.getOperation())))
12099ef9eebSMatthias Springer continue;
12199ef9eebSMatthias Springer }
12299ef9eebSMatthias Springer reads.push_back(user);
12399ef9eebSMatthias Springer }
12499ef9eebSMatthias Springer }
12599ef9eebSMatthias Springer if (firstOverwriteCandidate == nullptr)
12699ef9eebSMatthias Springer return;
12799ef9eebSMatthias Springer Region *topRegion = firstOverwriteCandidate->getParentRegion();
12899ef9eebSMatthias Springer Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
12999ef9eebSMatthias Springer assert(writeAncestor &&
13099ef9eebSMatthias Springer "write op should be recursively part of the top region");
13199ef9eebSMatthias Springer
13299ef9eebSMatthias Springer for (Operation *read : reads) {
13399ef9eebSMatthias Springer Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
13499ef9eebSMatthias Springer // TODO: if the read and write have the same ancestor we could recurse in
13599ef9eebSMatthias Springer // the region to know if the read is reachable with more precision.
13699ef9eebSMatthias Springer if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
13799ef9eebSMatthias Springer continue;
13899ef9eebSMatthias Springer if (!dominators.dominates(firstOverwriteCandidate, read)) {
13999ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read
14099ef9eebSMatthias Springer << "\n");
14199ef9eebSMatthias Springer return;
14299ef9eebSMatthias Springer }
14399ef9eebSMatthias Springer }
14499ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
14599ef9eebSMatthias Springer << " overwritten by: " << *firstOverwriteCandidate << "\n");
14699ef9eebSMatthias Springer opToErase.push_back(write.getOperation());
14799ef9eebSMatthias Springer }
14899ef9eebSMatthias Springer
14999ef9eebSMatthias Springer /// A transfer_write candidate to storeToLoad forwarding must:
15099ef9eebSMatthias Springer /// 1. Access the same memref with the same indices and vector type as the
15199ef9eebSMatthias Springer /// transfer_read.
15299ef9eebSMatthias Springer /// 2. Dominate the transfer_read operation.
15399ef9eebSMatthias Springer /// If several candidates are available, one must be dominated by all the others
15499ef9eebSMatthias Springer /// since they are all dominating the same transfer_read. We only consider the
15599ef9eebSMatthias Springer /// transfer_write dominated by all the other candidates as this will be the
15699ef9eebSMatthias Springer /// last transfer_write executed before the transfer_read.
15799ef9eebSMatthias Springer /// If we found such a candidate we can do the forwarding if all the other
15899ef9eebSMatthias Springer /// potentially aliasing ops that may reach the transfer_read are post-dominated
15999ef9eebSMatthias Springer /// by the transfer_write.
storeToLoadForwarding(vector::TransferReadOp read)16099ef9eebSMatthias Springer void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
16199ef9eebSMatthias Springer if (read.hasOutOfBoundsDim())
16299ef9eebSMatthias Springer return;
16399ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
16499ef9eebSMatthias Springer << "\n");
16599ef9eebSMatthias Springer SmallVector<Operation *, 8> blockingWrites;
16699ef9eebSMatthias Springer vector::TransferWriteOp lastwrite = nullptr;
1677c38fd60SJacques Pienaar for (Operation *user : read.getSource().getUsers()) {
16899ef9eebSMatthias Springer if (isa<vector::TransferReadOp>(user))
16999ef9eebSMatthias Springer continue;
17099ef9eebSMatthias Springer if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
17199ef9eebSMatthias Springer // If there is a write, but we can prove that it is disjoint we can ignore
17299ef9eebSMatthias Springer // the write.
17399ef9eebSMatthias Springer if (vector::isDisjointTransferSet(
17499ef9eebSMatthias Springer cast<VectorTransferOpInterface>(write.getOperation()),
17599ef9eebSMatthias Springer cast<VectorTransferOpInterface>(read.getOperation())))
17699ef9eebSMatthias Springer continue;
17799ef9eebSMatthias Springer if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
17899ef9eebSMatthias Springer if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
17999ef9eebSMatthias Springer lastwrite = write;
18099ef9eebSMatthias Springer else
18199ef9eebSMatthias Springer assert(dominators.dominates(write, lastwrite));
18299ef9eebSMatthias Springer continue;
18399ef9eebSMatthias Springer }
18499ef9eebSMatthias Springer }
18599ef9eebSMatthias Springer blockingWrites.push_back(user);
18699ef9eebSMatthias Springer }
18799ef9eebSMatthias Springer
18899ef9eebSMatthias Springer if (lastwrite == nullptr)
18999ef9eebSMatthias Springer return;
19099ef9eebSMatthias Springer
19199ef9eebSMatthias Springer Region *topRegion = lastwrite->getParentRegion();
19299ef9eebSMatthias Springer Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
19399ef9eebSMatthias Springer assert(readAncestor &&
19499ef9eebSMatthias Springer "read op should be recursively part of the top region");
19599ef9eebSMatthias Springer
19699ef9eebSMatthias Springer for (Operation *write : blockingWrites) {
19799ef9eebSMatthias Springer Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
19899ef9eebSMatthias Springer // TODO: if the store and read have the same ancestor we could recurse in
19999ef9eebSMatthias Springer // the region to know if the read is reachable with more precision.
20099ef9eebSMatthias Springer if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
20199ef9eebSMatthias Springer continue;
20299ef9eebSMatthias Springer if (!postDominators.postDominates(lastwrite, write)) {
20399ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
20499ef9eebSMatthias Springer << *write << "\n");
20599ef9eebSMatthias Springer return;
20699ef9eebSMatthias Springer }
20799ef9eebSMatthias Springer }
20899ef9eebSMatthias Springer
20999ef9eebSMatthias Springer LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
21099ef9eebSMatthias Springer << " to: " << *read.getOperation() << "\n");
2117c38fd60SJacques Pienaar read.replaceAllUsesWith(lastwrite.getVector());
21299ef9eebSMatthias Springer opToErase.push_back(read.getOperation());
21399ef9eebSMatthias Springer }
21499ef9eebSMatthias Springer
21599ef9eebSMatthias Springer /// Drops unit dimensions from the input MemRefType.
dropUnitDims(MemRefType inputType,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)21699ef9eebSMatthias Springer static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
21799ef9eebSMatthias Springer ArrayRef<int64_t> sizes,
21899ef9eebSMatthias Springer ArrayRef<int64_t> strides) {
2196c3c5f80SMatthias Springer SmallVector<int64_t> targetShape = llvm::to_vector(
2206c3c5f80SMatthias Springer llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
22199ef9eebSMatthias Springer Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
2226c3c5f80SMatthias Springer targetShape, inputType, offsets, sizes, strides);
22399ef9eebSMatthias Springer return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
22499ef9eebSMatthias Springer }
22599ef9eebSMatthias Springer
22699ef9eebSMatthias Springer /// Creates a rank-reducing memref.subview op that drops unit dims from its
22799ef9eebSMatthias Springer /// input. Or just returns the input if it was already without unit dims.
rankReducingSubviewDroppingUnitDims(PatternRewriter & rewriter,mlir::Location loc,Value input)22899ef9eebSMatthias Springer static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
22999ef9eebSMatthias Springer mlir::Location loc,
23099ef9eebSMatthias Springer Value input) {
23199ef9eebSMatthias Springer MemRefType inputType = input.getType().cast<MemRefType>();
23299ef9eebSMatthias Springer assert(inputType.hasStaticShape());
23399ef9eebSMatthias Springer SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
23499ef9eebSMatthias Springer SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
23599ef9eebSMatthias Springer ArrayRef<int64_t> subViewSizes = inputType.getShape();
23699ef9eebSMatthias Springer MemRefType resultType =
23799ef9eebSMatthias Springer dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
23899ef9eebSMatthias Springer if (canonicalizeStridedLayout(resultType) ==
23999ef9eebSMatthias Springer canonicalizeStridedLayout(inputType))
24099ef9eebSMatthias Springer return input;
24199ef9eebSMatthias Springer return rewriter.create<memref::SubViewOp>(
24299ef9eebSMatthias Springer loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
24399ef9eebSMatthias Springer }
24499ef9eebSMatthias Springer
24599ef9eebSMatthias Springer /// Returns the number of dims that aren't unit dims.
getReducedRank(ArrayRef<int64_t> shape)24699ef9eebSMatthias Springer static int getReducedRank(ArrayRef<int64_t> shape) {
24799ef9eebSMatthias Springer return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
24899ef9eebSMatthias Springer }
24999ef9eebSMatthias Springer
25099ef9eebSMatthias Springer /// Returns true if all values are `arith.constant 0 : index`
isZero(Value v)25199ef9eebSMatthias Springer static bool isZero(Value v) {
25299ef9eebSMatthias Springer auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
25399ef9eebSMatthias Springer return cst && cst.value() == 0;
25499ef9eebSMatthias Springer }
25599ef9eebSMatthias Springer
25699ef9eebSMatthias Springer /// Rewrites vector.transfer_read ops where the source has unit dims, by
25799ef9eebSMatthias Springer /// inserting a memref.subview dropping those unit dims.
25899ef9eebSMatthias Springer class TransferReadDropUnitDimsPattern
25999ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferReadOp> {
26099ef9eebSMatthias Springer using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
26199ef9eebSMatthias Springer
matchAndRewrite(vector::TransferReadOp transferReadOp,PatternRewriter & rewriter) const26299ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
26399ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
26499ef9eebSMatthias Springer auto loc = transferReadOp.getLoc();
2657c38fd60SJacques Pienaar Value vector = transferReadOp.getVector();
26699ef9eebSMatthias Springer VectorType vectorType = vector.getType().cast<VectorType>();
2677c38fd60SJacques Pienaar Value source = transferReadOp.getSource();
26899ef9eebSMatthias Springer MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
26999ef9eebSMatthias Springer // TODO: support tensor types.
27099ef9eebSMatthias Springer if (!sourceType || !sourceType.hasStaticShape())
27199ef9eebSMatthias Springer return failure();
27299ef9eebSMatthias Springer if (sourceType.getNumElements() != vectorType.getNumElements())
27399ef9eebSMatthias Springer return failure();
27499ef9eebSMatthias Springer // TODO: generalize this pattern, relax the requirements here.
27599ef9eebSMatthias Springer if (transferReadOp.hasOutOfBoundsDim())
27699ef9eebSMatthias Springer return failure();
2777c38fd60SJacques Pienaar if (!transferReadOp.getPermutationMap().isMinorIdentity())
27899ef9eebSMatthias Springer return failure();
27999ef9eebSMatthias Springer int reducedRank = getReducedRank(sourceType.getShape());
28099ef9eebSMatthias Springer if (reducedRank == sourceType.getRank())
28199ef9eebSMatthias Springer return failure(); // The source shape can't be further reduced.
28299ef9eebSMatthias Springer if (reducedRank != vectorType.getRank())
28399ef9eebSMatthias Springer return failure(); // This pattern requires the vector shape to match the
28499ef9eebSMatthias Springer // reduced source shape.
2857c38fd60SJacques Pienaar if (llvm::any_of(transferReadOp.getIndices(),
28699ef9eebSMatthias Springer [](Value v) { return !isZero(v); }))
28799ef9eebSMatthias Springer return failure();
28899ef9eebSMatthias Springer Value reducedShapeSource =
28999ef9eebSMatthias Springer rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
29099ef9eebSMatthias Springer Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
29199ef9eebSMatthias Springer SmallVector<Value> zeros(reducedRank, c0);
29299ef9eebSMatthias Springer auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
29399ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
29499ef9eebSMatthias Springer transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
29599ef9eebSMatthias Springer return success();
29699ef9eebSMatthias Springer }
29799ef9eebSMatthias Springer };
29899ef9eebSMatthias Springer
29999ef9eebSMatthias Springer /// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
30099ef9eebSMatthias Springer /// unit dims, by inserting a memref.subview dropping those unit dims.
30199ef9eebSMatthias Springer class TransferWriteDropUnitDimsPattern
30299ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferWriteOp> {
30399ef9eebSMatthias Springer using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
30499ef9eebSMatthias Springer
matchAndRewrite(vector::TransferWriteOp transferWriteOp,PatternRewriter & rewriter) const30599ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
30699ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
30799ef9eebSMatthias Springer auto loc = transferWriteOp.getLoc();
3087c38fd60SJacques Pienaar Value vector = transferWriteOp.getVector();
30999ef9eebSMatthias Springer VectorType vectorType = vector.getType().cast<VectorType>();
3107c38fd60SJacques Pienaar Value source = transferWriteOp.getSource();
31199ef9eebSMatthias Springer MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
31299ef9eebSMatthias Springer // TODO: support tensor type.
31399ef9eebSMatthias Springer if (!sourceType || !sourceType.hasStaticShape())
31499ef9eebSMatthias Springer return failure();
31599ef9eebSMatthias Springer if (sourceType.getNumElements() != vectorType.getNumElements())
31699ef9eebSMatthias Springer return failure();
31799ef9eebSMatthias Springer // TODO: generalize this pattern, relax the requirements here.
31899ef9eebSMatthias Springer if (transferWriteOp.hasOutOfBoundsDim())
31999ef9eebSMatthias Springer return failure();
3207c38fd60SJacques Pienaar if (!transferWriteOp.getPermutationMap().isMinorIdentity())
32199ef9eebSMatthias Springer return failure();
32299ef9eebSMatthias Springer int reducedRank = getReducedRank(sourceType.getShape());
32399ef9eebSMatthias Springer if (reducedRank == sourceType.getRank())
32499ef9eebSMatthias Springer return failure(); // The source shape can't be further reduced.
32599ef9eebSMatthias Springer if (reducedRank != vectorType.getRank())
32699ef9eebSMatthias Springer return failure(); // This pattern requires the vector shape to match the
32799ef9eebSMatthias Springer // reduced source shape.
3287c38fd60SJacques Pienaar if (llvm::any_of(transferWriteOp.getIndices(),
32999ef9eebSMatthias Springer [](Value v) { return !isZero(v); }))
33099ef9eebSMatthias Springer return failure();
33199ef9eebSMatthias Springer Value reducedShapeSource =
33299ef9eebSMatthias Springer rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
33399ef9eebSMatthias Springer Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
33499ef9eebSMatthias Springer SmallVector<Value> zeros(reducedRank, c0);
33599ef9eebSMatthias Springer auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
33699ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
33799ef9eebSMatthias Springer transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
33899ef9eebSMatthias Springer return success();
33999ef9eebSMatthias Springer }
34099ef9eebSMatthias Springer };
34199ef9eebSMatthias Springer
342*f4ac9509SBenoit Jacob /// Returns the position of the first inner dimension that has contiguous layout
343*f4ac9509SBenoit Jacob /// with at least `requiredContiguousSize` contiguous elements.
344*f4ac9509SBenoit Jacob /// When such a dimension is found, the return value satisfies:
345*f4ac9509SBenoit Jacob /// 0 <= return_value <= memrefType.getRank() - 1.
346*f4ac9509SBenoit Jacob /// When no such dimension is found, the return value is memrefType.getRank().
getContiguousInnerDim(MemRefType memrefType,int64_t requiredContiguousSize)347*f4ac9509SBenoit Jacob static int64_t getContiguousInnerDim(MemRefType memrefType,
348*f4ac9509SBenoit Jacob int64_t requiredContiguousSize) {
349*f4ac9509SBenoit Jacob auto shape = memrefType.getShape();
350*f4ac9509SBenoit Jacob SmallVector<int64_t> strides;
351*f4ac9509SBenoit Jacob int64_t offset;
352*f4ac9509SBenoit Jacob int64_t innerDim = shape.size();
353*f4ac9509SBenoit Jacob if (succeeded(getStridesAndOffset(memrefType, strides, offset))) {
354*f4ac9509SBenoit Jacob int64_t innerSize = 1;
355*f4ac9509SBenoit Jacob while (true) {
356*f4ac9509SBenoit Jacob if (innerDim == 0)
357*f4ac9509SBenoit Jacob break;
358*f4ac9509SBenoit Jacob const int64_t nextDim = innerDim - 1;
359*f4ac9509SBenoit Jacob if (shape[nextDim] == ShapedType::kDynamicSize)
360*f4ac9509SBenoit Jacob break;
361*f4ac9509SBenoit Jacob if (strides[nextDim] != innerSize)
362*f4ac9509SBenoit Jacob break;
363*f4ac9509SBenoit Jacob innerSize *= shape[nextDim];
364*f4ac9509SBenoit Jacob innerDim = nextDim;
365*f4ac9509SBenoit Jacob if (innerSize >= requiredContiguousSize)
366*f4ac9509SBenoit Jacob break;
367*f4ac9509SBenoit Jacob }
368*f4ac9509SBenoit Jacob }
369*f4ac9509SBenoit Jacob return innerDim;
370*f4ac9509SBenoit Jacob }
371*f4ac9509SBenoit Jacob
372*f4ac9509SBenoit Jacob /// Creates a memref.collapse_shape collapsing all inner dimensions of the
373*f4ac9509SBenoit Jacob /// input starting at `firstDimToCollapse`.
collapseInnerDims(PatternRewriter & rewriter,mlir::Location loc,Value input,int64_t firstDimToCollapse)374*f4ac9509SBenoit Jacob static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
375*f4ac9509SBenoit Jacob Value input, int64_t firstDimToCollapse) {
376*f4ac9509SBenoit Jacob ShapedType inputType = input.getType().cast<ShapedType>();
377*f4ac9509SBenoit Jacob if (inputType.getRank() == 1)
378*f4ac9509SBenoit Jacob return input;
379*f4ac9509SBenoit Jacob SmallVector<ReassociationIndices> reassociation;
380*f4ac9509SBenoit Jacob for (int64_t i = 0; i < firstDimToCollapse; ++i)
381*f4ac9509SBenoit Jacob reassociation.push_back(ReassociationIndices{i});
382*f4ac9509SBenoit Jacob ReassociationIndices collapsedIndices;
383*f4ac9509SBenoit Jacob for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
384*f4ac9509SBenoit Jacob collapsedIndices.push_back(i);
385*f4ac9509SBenoit Jacob reassociation.push_back(collapsedIndices);
386*f4ac9509SBenoit Jacob return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
387*f4ac9509SBenoit Jacob }
388*f4ac9509SBenoit Jacob
389*f4ac9509SBenoit Jacob /// Checks that the indices corresponding to dimensions starting at
390*f4ac9509SBenoit Jacob /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
391*f4ac9509SBenoit Jacob /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
392*f4ac9509SBenoit Jacob static LogicalResult
checkAndCollapseInnerZeroIndices(ValueRange indices,int64_t firstDimToCollapse,SmallVector<Value> & outIndices)393*f4ac9509SBenoit Jacob checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
394*f4ac9509SBenoit Jacob SmallVector<Value> &outIndices) {
395*f4ac9509SBenoit Jacob int64_t rank = indices.size();
396*f4ac9509SBenoit Jacob if (firstDimToCollapse >= rank)
397*f4ac9509SBenoit Jacob return failure();
398*f4ac9509SBenoit Jacob for (int64_t i = firstDimToCollapse; i < rank; ++i) {
399*f4ac9509SBenoit Jacob arith::ConstantIndexOp cst =
400*f4ac9509SBenoit Jacob indices[i].getDefiningOp<arith::ConstantIndexOp>();
401*f4ac9509SBenoit Jacob if (!cst || cst.value() != 0)
402*f4ac9509SBenoit Jacob return failure();
403*f4ac9509SBenoit Jacob }
404*f4ac9509SBenoit Jacob outIndices = indices;
405*f4ac9509SBenoit Jacob outIndices.resize(firstDimToCollapse + 1);
406*f4ac9509SBenoit Jacob return success();
40799ef9eebSMatthias Springer }
40899ef9eebSMatthias Springer
40999ef9eebSMatthias Springer /// Rewrites contiguous row-major vector.transfer_read ops by inserting
41099ef9eebSMatthias Springer /// memref.collapse_shape on the source so that the resulting
41199ef9eebSMatthias Springer /// vector.transfer_read has a 1D source. Requires the source shape to be
41299ef9eebSMatthias Springer /// already reduced i.e. without unit dims.
41399ef9eebSMatthias Springer class FlattenContiguousRowMajorTransferReadPattern
41499ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferReadOp> {
41599ef9eebSMatthias Springer using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
41699ef9eebSMatthias Springer
matchAndRewrite(vector::TransferReadOp transferReadOp,PatternRewriter & rewriter) const41799ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
41899ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
41999ef9eebSMatthias Springer auto loc = transferReadOp.getLoc();
4207c38fd60SJacques Pienaar Value vector = transferReadOp.getVector();
42199ef9eebSMatthias Springer VectorType vectorType = vector.getType().cast<VectorType>();
4227c38fd60SJacques Pienaar Value source = transferReadOp.getSource();
42399ef9eebSMatthias Springer MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
42499ef9eebSMatthias Springer // Contiguity check is valid on tensors only.
42599ef9eebSMatthias Springer if (!sourceType)
42699ef9eebSMatthias Springer return failure();
4274a876b13Sharsh if (vectorType.getRank() <= 1)
4284a876b13Sharsh // Already 0D/1D, nothing to do.
42999ef9eebSMatthias Springer return failure();
430*f4ac9509SBenoit Jacob int64_t firstContiguousInnerDim =
431*f4ac9509SBenoit Jacob getContiguousInnerDim(sourceType, vectorType.getNumElements());
432*f4ac9509SBenoit Jacob if (firstContiguousInnerDim >= sourceType.getRank() - 1)
43399ef9eebSMatthias Springer return failure();
43499ef9eebSMatthias Springer // TODO: generalize this pattern, relax the requirements here.
43599ef9eebSMatthias Springer if (transferReadOp.hasOutOfBoundsDim())
43699ef9eebSMatthias Springer return failure();
4377c38fd60SJacques Pienaar if (!transferReadOp.getPermutationMap().isMinorIdentity())
43899ef9eebSMatthias Springer return failure();
4397c38fd60SJacques Pienaar if (transferReadOp.getMask())
44099ef9eebSMatthias Springer return failure();
441*f4ac9509SBenoit Jacob SmallVector<Value> collapsedIndices;
442*f4ac9509SBenoit Jacob if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
443*f4ac9509SBenoit Jacob firstContiguousInnerDim,
444*f4ac9509SBenoit Jacob collapsedIndices)))
44599ef9eebSMatthias Springer return failure();
446*f4ac9509SBenoit Jacob Value collapsedSource =
447*f4ac9509SBenoit Jacob collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
448*f4ac9509SBenoit Jacob MemRefType collapsedSourceType =
449*f4ac9509SBenoit Jacob collapsedSource.getType().dyn_cast<MemRefType>();
450*f4ac9509SBenoit Jacob int64_t collapsedRank = collapsedSourceType.getRank();
451*f4ac9509SBenoit Jacob assert(collapsedRank == firstContiguousInnerDim + 1);
452*f4ac9509SBenoit Jacob SmallVector<AffineExpr, 1> dimExprs{
453*f4ac9509SBenoit Jacob getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
454*f4ac9509SBenoit Jacob auto collapsedMap =
455*f4ac9509SBenoit Jacob AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
456*f4ac9509SBenoit Jacob VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
457*f4ac9509SBenoit Jacob vectorType.getElementType());
458*f4ac9509SBenoit Jacob vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
459*f4ac9509SBenoit Jacob loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
460*f4ac9509SBenoit Jacob flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
46199ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
462*f4ac9509SBenoit Jacob transferReadOp, vector.getType().cast<VectorType>(), flatRead);
46399ef9eebSMatthias Springer return success();
46499ef9eebSMatthias Springer }
46599ef9eebSMatthias Springer };
46699ef9eebSMatthias Springer
46799ef9eebSMatthias Springer /// Rewrites contiguous row-major vector.transfer_write ops by inserting
46899ef9eebSMatthias Springer /// memref.collapse_shape on the source so that the resulting
46999ef9eebSMatthias Springer /// vector.transfer_write has a 1D source. Requires the source shape to be
47099ef9eebSMatthias Springer /// already reduced i.e. without unit dims.
47199ef9eebSMatthias Springer class FlattenContiguousRowMajorTransferWritePattern
47299ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferWriteOp> {
47399ef9eebSMatthias Springer using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
47499ef9eebSMatthias Springer
matchAndRewrite(vector::TransferWriteOp transferWriteOp,PatternRewriter & rewriter) const47599ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
47699ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
47799ef9eebSMatthias Springer auto loc = transferWriteOp.getLoc();
4787c38fd60SJacques Pienaar Value vector = transferWriteOp.getVector();
47999ef9eebSMatthias Springer VectorType vectorType = vector.getType().cast<VectorType>();
4807c38fd60SJacques Pienaar Value source = transferWriteOp.getSource();
48199ef9eebSMatthias Springer MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
48299ef9eebSMatthias Springer // Contiguity check is valid on tensors only.
48399ef9eebSMatthias Springer if (!sourceType)
48499ef9eebSMatthias Springer return failure();
4854a876b13Sharsh if (vectorType.getRank() <= 1)
4864a876b13Sharsh // Already 0D/1D, nothing to do.
48799ef9eebSMatthias Springer return failure();
488*f4ac9509SBenoit Jacob int64_t firstContiguousInnerDim =
489*f4ac9509SBenoit Jacob getContiguousInnerDim(sourceType, vectorType.getNumElements());
490*f4ac9509SBenoit Jacob if (firstContiguousInnerDim >= sourceType.getRank() - 1)
49199ef9eebSMatthias Springer return failure();
49299ef9eebSMatthias Springer // TODO: generalize this pattern, relax the requirements here.
49399ef9eebSMatthias Springer if (transferWriteOp.hasOutOfBoundsDim())
49499ef9eebSMatthias Springer return failure();
4957c38fd60SJacques Pienaar if (!transferWriteOp.getPermutationMap().isMinorIdentity())
49699ef9eebSMatthias Springer return failure();
4977c38fd60SJacques Pienaar if (transferWriteOp.getMask())
49899ef9eebSMatthias Springer return failure();
499*f4ac9509SBenoit Jacob SmallVector<Value> collapsedIndices;
500*f4ac9509SBenoit Jacob if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
501*f4ac9509SBenoit Jacob firstContiguousInnerDim,
502*f4ac9509SBenoit Jacob collapsedIndices)))
50399ef9eebSMatthias Springer return failure();
504*f4ac9509SBenoit Jacob Value collapsedSource =
505*f4ac9509SBenoit Jacob collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
506*f4ac9509SBenoit Jacob MemRefType collapsedSourceType =
507*f4ac9509SBenoit Jacob collapsedSource.getType().cast<MemRefType>();
508*f4ac9509SBenoit Jacob int64_t collapsedRank = collapsedSourceType.getRank();
509*f4ac9509SBenoit Jacob assert(collapsedRank == firstContiguousInnerDim + 1);
510*f4ac9509SBenoit Jacob SmallVector<AffineExpr, 1> dimExprs{
511*f4ac9509SBenoit Jacob getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
512*f4ac9509SBenoit Jacob auto collapsedMap =
513*f4ac9509SBenoit Jacob AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
514*f4ac9509SBenoit Jacob VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
515*f4ac9509SBenoit Jacob vectorType.getElementType());
516*f4ac9509SBenoit Jacob Value flatVector =
517*f4ac9509SBenoit Jacob rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
518*f4ac9509SBenoit Jacob vector::TransferWriteOp flatWrite =
519*f4ac9509SBenoit Jacob rewriter.create<vector::TransferWriteOp>(
520*f4ac9509SBenoit Jacob loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
521*f4ac9509SBenoit Jacob flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
52299ef9eebSMatthias Springer rewriter.eraseOp(transferWriteOp);
52399ef9eebSMatthias Springer return success();
52499ef9eebSMatthias Springer }
52599ef9eebSMatthias Springer };
52699ef9eebSMatthias Springer
52799ef9eebSMatthias Springer } // namespace
52899ef9eebSMatthias Springer
transferOpflowOpt(Operation * rootOp)529171850c5SRiver Riddle void mlir::vector::transferOpflowOpt(Operation *rootOp) {
530171850c5SRiver Riddle TransferOptimization opt(rootOp);
53199ef9eebSMatthias Springer // Run store to load forwarding first since it can expose more dead store
53299ef9eebSMatthias Springer // opportunity.
533171850c5SRiver Riddle rootOp->walk([&](vector::TransferReadOp read) {
53499ef9eebSMatthias Springer if (read.getShapedType().isa<MemRefType>())
53599ef9eebSMatthias Springer opt.storeToLoadForwarding(read);
53699ef9eebSMatthias Springer });
53799ef9eebSMatthias Springer opt.removeDeadOp();
538171850c5SRiver Riddle rootOp->walk([&](vector::TransferWriteOp write) {
53999ef9eebSMatthias Springer if (write.getShapedType().isa<MemRefType>())
54099ef9eebSMatthias Springer opt.deadStoreOp(write);
54199ef9eebSMatthias Springer });
54299ef9eebSMatthias Springer opt.removeDeadOp();
54399ef9eebSMatthias Springer }
54499ef9eebSMatthias Springer
populateVectorTransferDropUnitDimsPatterns(RewritePatternSet & patterns)54599ef9eebSMatthias Springer void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
54699ef9eebSMatthias Springer RewritePatternSet &patterns) {
54799ef9eebSMatthias Springer patterns
54899ef9eebSMatthias Springer .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
54999ef9eebSMatthias Springer patterns.getContext());
55099ef9eebSMatthias Springer populateShapeCastFoldingPatterns(patterns);
55199ef9eebSMatthias Springer }
55299ef9eebSMatthias Springer
populateFlattenVectorTransferPatterns(RewritePatternSet & patterns)55399ef9eebSMatthias Springer void mlir::vector::populateFlattenVectorTransferPatterns(
55499ef9eebSMatthias Springer RewritePatternSet &patterns) {
55599ef9eebSMatthias Springer patterns.add<FlattenContiguousRowMajorTransferReadPattern,
55699ef9eebSMatthias Springer FlattenContiguousRowMajorTransferWritePattern>(
55799ef9eebSMatthias Springer patterns.getContext());
55899ef9eebSMatthias Springer populateShapeCastFoldingPatterns(patterns);
55999ef9eebSMatthias Springer }
560