199ef9eebSMatthias Springer //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements target-independent patterns to rewrite a vector.transfer
1099ef9eebSMatthias Springer // op into a fully in-bounds part and a partial part.
1199ef9eebSMatthias Springer //
1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1399ef9eebSMatthias Springer
1499ef9eebSMatthias Springer #include <type_traits>
1599ef9eebSMatthias Springer
1699ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
1799ef9eebSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1899ef9eebSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
1999ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
20*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
2199ef9eebSMatthias Springer #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2299ef9eebSMatthias Springer
2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
2499ef9eebSMatthias Springer #include "mlir/IR/Matchers.h"
2599ef9eebSMatthias Springer #include "mlir/IR/PatternMatch.h"
2699ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h"
2799ef9eebSMatthias Springer
2899ef9eebSMatthias Springer #include "llvm/ADT/DenseSet.h"
2999ef9eebSMatthias Springer #include "llvm/ADT/MapVector.h"
3099ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h"
3199ef9eebSMatthias Springer #include "llvm/Support/CommandLine.h"
3299ef9eebSMatthias Springer #include "llvm/Support/Debug.h"
3399ef9eebSMatthias Springer #include "llvm/Support/raw_ostream.h"
3499ef9eebSMatthias Springer
3599ef9eebSMatthias Springer #define DEBUG_TYPE "vector-transfer-split"
3699ef9eebSMatthias Springer
3799ef9eebSMatthias Springer using namespace mlir;
3899ef9eebSMatthias Springer using namespace mlir::vector;
3999ef9eebSMatthias Springer
extractConstantIndex(Value v)4099ef9eebSMatthias Springer static Optional<int64_t> extractConstantIndex(Value v) {
4199ef9eebSMatthias Springer if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
4299ef9eebSMatthias Springer return cstOp.value();
4399ef9eebSMatthias Springer if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
4499ef9eebSMatthias Springer if (affineApplyOp.getAffineMap().isSingleConstant())
4599ef9eebSMatthias Springer return affineApplyOp.getAffineMap().getSingleConstantResult();
4699ef9eebSMatthias Springer return None;
4799ef9eebSMatthias Springer }
4899ef9eebSMatthias Springer
4999ef9eebSMatthias Springer // Missing foldings of scf.if make it necessary to perform poor man's folding
5099ef9eebSMatthias Springer // eagerly, especially in the case of unrolling. In the future, this should go
5199ef9eebSMatthias Springer // away once scf.if folds properly.
createFoldedSLE(RewriterBase & b,Value v,Value ub)5299ef9eebSMatthias Springer static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) {
5399ef9eebSMatthias Springer auto maybeCstV = extractConstantIndex(v);
5499ef9eebSMatthias Springer auto maybeCstUb = extractConstantIndex(ub);
5599ef9eebSMatthias Springer if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
5699ef9eebSMatthias Springer return Value();
5799ef9eebSMatthias Springer return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
5899ef9eebSMatthias Springer }
5999ef9eebSMatthias Springer
6099ef9eebSMatthias Springer /// Build the condition to ensure that a particular VectorTransferOpInterface
6199ef9eebSMatthias Springer /// is in-bounds.
createInBoundsCond(RewriterBase & b,VectorTransferOpInterface xferOp)6299ef9eebSMatthias Springer static Value createInBoundsCond(RewriterBase &b,
6399ef9eebSMatthias Springer VectorTransferOpInterface xferOp) {
6499ef9eebSMatthias Springer assert(xferOp.permutation_map().isMinorIdentity() &&
6599ef9eebSMatthias Springer "Expected minor identity map");
6699ef9eebSMatthias Springer Value inBoundsCond;
6799ef9eebSMatthias Springer xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
6899ef9eebSMatthias Springer // Zip over the resulting vector shape and memref indices.
6999ef9eebSMatthias Springer // If the dimension is known to be in-bounds, it does not participate in
7099ef9eebSMatthias Springer // the construction of `inBoundsCond`.
7199ef9eebSMatthias Springer if (xferOp.isDimInBounds(resultIdx))
7299ef9eebSMatthias Springer return;
7399ef9eebSMatthias Springer // Fold or create the check that `index + vector_size` <= `memref_size`.
7499ef9eebSMatthias Springer Location loc = xferOp.getLoc();
7599ef9eebSMatthias Springer int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
7699ef9eebSMatthias Springer auto d0 = getAffineDimExpr(0, xferOp.getContext());
7799ef9eebSMatthias Springer auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
7899ef9eebSMatthias Springer Value sum =
7999ef9eebSMatthias Springer makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
8099ef9eebSMatthias Springer Value cond = createFoldedSLE(
8199ef9eebSMatthias Springer b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
8299ef9eebSMatthias Springer if (!cond)
8399ef9eebSMatthias Springer return;
8499ef9eebSMatthias Springer // Conjunction over all dims for which we are in-bounds.
8599ef9eebSMatthias Springer if (inBoundsCond)
8699ef9eebSMatthias Springer inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
8799ef9eebSMatthias Springer else
8899ef9eebSMatthias Springer inBoundsCond = cond;
8999ef9eebSMatthias Springer });
9099ef9eebSMatthias Springer return inBoundsCond;
9199ef9eebSMatthias Springer }
9299ef9eebSMatthias Springer
9399ef9eebSMatthias Springer /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
9499ef9eebSMatthias Springer /// masking) fastpath and a slowpath.
9599ef9eebSMatthias Springer /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
9699ef9eebSMatthias Springer /// newly created conditional upon function return.
9799ef9eebSMatthias Springer /// To accomodate for the fact that the original vector.transfer indexing may be
9899ef9eebSMatthias Springer /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
9999ef9eebSMatthias Springer /// scf.if op returns a view and values of type index.
10099ef9eebSMatthias Springer /// At this time, only vector.transfer_read case is implemented.
10199ef9eebSMatthias Springer ///
10299ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_read):
10399ef9eebSMatthias Springer /// ```
10499ef9eebSMatthias Springer /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
10599ef9eebSMatthias Springer /// ```
10699ef9eebSMatthias Springer /// is transformed into:
10799ef9eebSMatthias Springer /// ```
10899ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) {
10999ef9eebSMatthias Springer /// // fastpath, direct cast
11099ef9eebSMatthias Springer /// memref.cast %A: memref<A...> to compatibleMemRefType
11199ef9eebSMatthias Springer /// scf.yield %view : compatibleMemRefType, index, index
11299ef9eebSMatthias Springer /// } else {
11399ef9eebSMatthias Springer /// // slowpath, not in-bounds vector.transfer or linalg.copy.
11499ef9eebSMatthias Springer /// memref.cast %alloc: memref<B...> to compatibleMemRefType
11599ef9eebSMatthias Springer /// scf.yield %4 : compatibleMemRefType, index, index
11699ef9eebSMatthias Springer // }
11799ef9eebSMatthias Springer /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
11899ef9eebSMatthias Springer /// ```
11999ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector.
12099ef9eebSMatthias Springer ///
12199ef9eebSMatthias Springer /// Preconditions:
12299ef9eebSMatthias Springer /// 1. `xferOp.permutation_map()` must be a minor identity map
12399ef9eebSMatthias Springer /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
12499ef9eebSMatthias Springer /// must be equal. This will be relaxed in the future but requires
12599ef9eebSMatthias Springer /// rank-reducing subviews.
12699ef9eebSMatthias Springer static LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)12799ef9eebSMatthias Springer splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
12899ef9eebSMatthias Springer // TODO: support 0-d corner case.
12999ef9eebSMatthias Springer if (xferOp.getTransferRank() == 0)
13099ef9eebSMatthias Springer return failure();
13199ef9eebSMatthias Springer
13299ef9eebSMatthias Springer // TODO: expand support to these 2 cases.
13399ef9eebSMatthias Springer if (!xferOp.permutation_map().isMinorIdentity())
13499ef9eebSMatthias Springer return failure();
13599ef9eebSMatthias Springer // Must have some out-of-bounds dimension to be a candidate for splitting.
13699ef9eebSMatthias Springer if (!xferOp.hasOutOfBoundsDim())
13799ef9eebSMatthias Springer return failure();
13899ef9eebSMatthias Springer // Don't split transfer operations directly under IfOp, this avoids applying
13999ef9eebSMatthias Springer // the pattern recursively.
14099ef9eebSMatthias Springer // TODO: improve the filtering condition to make it more applicable.
14199ef9eebSMatthias Springer if (isa<scf::IfOp>(xferOp->getParentOp()))
14299ef9eebSMatthias Springer return failure();
14399ef9eebSMatthias Springer return success();
14499ef9eebSMatthias Springer }
14599ef9eebSMatthias Springer
14699ef9eebSMatthias Springer /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
14799ef9eebSMatthias Springer /// be cast. If the MemRefTypes don't have the same rank or are not strided,
14899ef9eebSMatthias Springer /// return null; otherwise:
14999ef9eebSMatthias Springer /// 1. if `aT` and `bT` are cast-compatible, return `aT`.
15099ef9eebSMatthias Springer /// 2. else return a new MemRefType obtained by iterating over the shape and
15199ef9eebSMatthias Springer /// strides and:
15299ef9eebSMatthias Springer /// a. keeping the ones that are static and equal across `aT` and `bT`.
15399ef9eebSMatthias Springer /// b. using a dynamic shape and/or stride for the dimensions that don't
15499ef9eebSMatthias Springer /// agree.
getCastCompatibleMemRefType(MemRefType aT,MemRefType bT)15599ef9eebSMatthias Springer static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
15699ef9eebSMatthias Springer if (memref::CastOp::areCastCompatible(aT, bT))
15799ef9eebSMatthias Springer return aT;
15899ef9eebSMatthias Springer if (aT.getRank() != bT.getRank())
15999ef9eebSMatthias Springer return MemRefType();
16099ef9eebSMatthias Springer int64_t aOffset, bOffset;
16199ef9eebSMatthias Springer SmallVector<int64_t, 4> aStrides, bStrides;
16299ef9eebSMatthias Springer if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
16399ef9eebSMatthias Springer failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
16499ef9eebSMatthias Springer aStrides.size() != bStrides.size())
16599ef9eebSMatthias Springer return MemRefType();
16699ef9eebSMatthias Springer
16799ef9eebSMatthias Springer ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
16899ef9eebSMatthias Springer int64_t resOffset;
16999ef9eebSMatthias Springer SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
17099ef9eebSMatthias Springer resStrides(bT.getRank(), 0);
17199ef9eebSMatthias Springer for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
17299ef9eebSMatthias Springer resShape[idx] =
17399ef9eebSMatthias Springer (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize;
17499ef9eebSMatthias Springer resStrides[idx] = (aStrides[idx] == bStrides[idx])
17599ef9eebSMatthias Springer ? aStrides[idx]
17699ef9eebSMatthias Springer : ShapedType::kDynamicStrideOrOffset;
17799ef9eebSMatthias Springer }
17899ef9eebSMatthias Springer resOffset =
17999ef9eebSMatthias Springer (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset;
18099ef9eebSMatthias Springer return MemRefType::get(
18199ef9eebSMatthias Springer resShape, aT.getElementType(),
18299ef9eebSMatthias Springer makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
18399ef9eebSMatthias Springer }
18499ef9eebSMatthias Springer
18599ef9eebSMatthias Springer /// Operates under a scoped context to build the intersection between the
18699ef9eebSMatthias Springer /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
18799ef9eebSMatthias Springer // TODO: view intersection/union/differences should be a proper std op.
18899ef9eebSMatthias Springer static std::pair<Value, Value>
createSubViewIntersection(RewriterBase & b,VectorTransferOpInterface xferOp,Value alloc)18999ef9eebSMatthias Springer createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
19099ef9eebSMatthias Springer Value alloc) {
19199ef9eebSMatthias Springer Location loc = xferOp.getLoc();
19299ef9eebSMatthias Springer int64_t memrefRank = xferOp.getShapedType().getRank();
19399ef9eebSMatthias Springer // TODO: relax this precondition, will require rank-reducing subviews.
19499ef9eebSMatthias Springer assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
19599ef9eebSMatthias Springer "Expected memref rank to match the alloc rank");
19699ef9eebSMatthias Springer ValueRange leadingIndices =
19799ef9eebSMatthias Springer xferOp.indices().take_front(xferOp.getLeadingShapedRank());
19899ef9eebSMatthias Springer SmallVector<OpFoldResult, 4> sizes;
19999ef9eebSMatthias Springer sizes.append(leadingIndices.begin(), leadingIndices.end());
20099ef9eebSMatthias Springer auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
20199ef9eebSMatthias Springer xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
20299ef9eebSMatthias Springer using MapList = ArrayRef<ArrayRef<AffineExpr>>;
20399ef9eebSMatthias Springer Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
20499ef9eebSMatthias Springer xferOp.source(), indicesIdx);
20599ef9eebSMatthias Springer Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
20699ef9eebSMatthias Springer Value index = xferOp.indices()[indicesIdx];
20799ef9eebSMatthias Springer AffineExpr i, j, k;
20899ef9eebSMatthias Springer bindDims(xferOp.getContext(), i, j, k);
20999ef9eebSMatthias Springer SmallVector<AffineMap, 4> maps =
21099ef9eebSMatthias Springer AffineMap::inferFromExprList(MapList{{i - j, k}});
21199ef9eebSMatthias Springer // affine_min(%dimMemRef - %index, %dimAlloc)
21299ef9eebSMatthias Springer Value affineMin = b.create<AffineMinOp>(
21399ef9eebSMatthias Springer loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
21499ef9eebSMatthias Springer sizes.push_back(affineMin);
21599ef9eebSMatthias Springer });
21699ef9eebSMatthias Springer
21799ef9eebSMatthias Springer SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
21899ef9eebSMatthias Springer xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
21999ef9eebSMatthias Springer SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
22099ef9eebSMatthias Springer SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
22199ef9eebSMatthias Springer auto copySrc = b.create<memref::SubViewOp>(
22299ef9eebSMatthias Springer loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
22399ef9eebSMatthias Springer auto copyDest = b.create<memref::SubViewOp>(
22499ef9eebSMatthias Springer loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
22599ef9eebSMatthias Springer return std::make_pair(copySrc, copyDest);
22699ef9eebSMatthias Springer }
22799ef9eebSMatthias Springer
22899ef9eebSMatthias Springer /// Given an `xferOp` for which:
22999ef9eebSMatthias Springer /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
23099ef9eebSMatthias Springer /// 2. a memref of single vector `alloc` has been allocated.
23199ef9eebSMatthias Springer /// Produce IR resembling:
23299ef9eebSMatthias Springer /// ```
23399ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) {
23499ef9eebSMatthias Springer /// %view = memref.cast %A: memref<A...> to compatibleMemRefType
23599ef9eebSMatthias Springer /// scf.yield %view, ... : compatibleMemRefType, index, index
23699ef9eebSMatthias Springer /// } else {
23799ef9eebSMatthias Springer /// %2 = linalg.fill(%pad, %alloc)
23899ef9eebSMatthias Springer /// %3 = subview %view [...][...][...]
23999ef9eebSMatthias Springer /// %4 = subview %alloc [0, 0] [...] [...]
24099ef9eebSMatthias Springer /// linalg.copy(%3, %4)
24199ef9eebSMatthias Springer /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
24299ef9eebSMatthias Springer /// scf.yield %5, ... : compatibleMemRefType, index, index
24399ef9eebSMatthias Springer /// }
24499ef9eebSMatthias Springer /// ```
24599ef9eebSMatthias Springer /// Return the produced scf::IfOp.
24699ef9eebSMatthias Springer static scf::IfOp
createFullPartialLinalgCopy(RewriterBase & b,vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)24799ef9eebSMatthias Springer createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
24899ef9eebSMatthias Springer TypeRange returnTypes, Value inBoundsCond,
24999ef9eebSMatthias Springer MemRefType compatibleMemRefType, Value alloc) {
25099ef9eebSMatthias Springer Location loc = xferOp.getLoc();
25199ef9eebSMatthias Springer Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
2527c38fd60SJacques Pienaar Value memref = xferOp.getSource();
25399ef9eebSMatthias Springer return b.create<scf::IfOp>(
25499ef9eebSMatthias Springer loc, returnTypes, inBoundsCond,
25599ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) {
25699ef9eebSMatthias Springer Value res = memref;
25799ef9eebSMatthias Springer if (compatibleMemRefType != xferOp.getShapedType())
2583c69bc4dSRiver Riddle res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
25999ef9eebSMatthias Springer scf::ValueVector viewAndIndices{res};
2607c38fd60SJacques Pienaar viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
2617c38fd60SJacques Pienaar xferOp.getIndices().end());
26299ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices);
26399ef9eebSMatthias Springer },
26499ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) {
2657c38fd60SJacques Pienaar b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
2667294be2bSgysit ValueRange{alloc});
26799ef9eebSMatthias Springer // Take partial subview of memref which guarantees no dimension
26899ef9eebSMatthias Springer // overflows.
26999ef9eebSMatthias Springer IRRewriter rewriter(b);
27099ef9eebSMatthias Springer std::pair<Value, Value> copyArgs = createSubViewIntersection(
27199ef9eebSMatthias Springer rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
27299ef9eebSMatthias Springer alloc);
273ebc81537SAlexander Belyaev b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
27499ef9eebSMatthias Springer Value casted =
2753c69bc4dSRiver Riddle b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
27699ef9eebSMatthias Springer scf::ValueVector viewAndIndices{casted};
27799ef9eebSMatthias Springer viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
27899ef9eebSMatthias Springer zero);
27999ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices);
28099ef9eebSMatthias Springer });
28199ef9eebSMatthias Springer }
28299ef9eebSMatthias Springer
28399ef9eebSMatthias Springer /// Given an `xferOp` for which:
28499ef9eebSMatthias Springer /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
28599ef9eebSMatthias Springer /// 2. a memref of single vector `alloc` has been allocated.
28699ef9eebSMatthias Springer /// Produce IR resembling:
28799ef9eebSMatthias Springer /// ```
28899ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) {
28999ef9eebSMatthias Springer /// memref.cast %A: memref<A...> to compatibleMemRefType
29099ef9eebSMatthias Springer /// scf.yield %view, ... : compatibleMemRefType, index, index
29199ef9eebSMatthias Springer /// } else {
29299ef9eebSMatthias Springer /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
29399ef9eebSMatthias Springer /// %3 = vector.type_cast %extra_alloc :
29499ef9eebSMatthias Springer /// memref<...> to memref<vector<...>>
29599ef9eebSMatthias Springer /// store %2, %3[] : memref<vector<...>>
29699ef9eebSMatthias Springer /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
29799ef9eebSMatthias Springer /// scf.yield %4, ... : compatibleMemRefType, index, index
29899ef9eebSMatthias Springer /// }
29999ef9eebSMatthias Springer /// ```
30099ef9eebSMatthias Springer /// Return the produced scf::IfOp.
createFullPartialVectorTransferRead(RewriterBase & b,vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)30199ef9eebSMatthias Springer static scf::IfOp createFullPartialVectorTransferRead(
30299ef9eebSMatthias Springer RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
30399ef9eebSMatthias Springer Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
30499ef9eebSMatthias Springer Location loc = xferOp.getLoc();
30599ef9eebSMatthias Springer scf::IfOp fullPartialIfOp;
30699ef9eebSMatthias Springer Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
3077c38fd60SJacques Pienaar Value memref = xferOp.getSource();
30899ef9eebSMatthias Springer return b.create<scf::IfOp>(
30999ef9eebSMatthias Springer loc, returnTypes, inBoundsCond,
31099ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) {
31199ef9eebSMatthias Springer Value res = memref;
31299ef9eebSMatthias Springer if (compatibleMemRefType != xferOp.getShapedType())
3133c69bc4dSRiver Riddle res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
31499ef9eebSMatthias Springer scf::ValueVector viewAndIndices{res};
3157c38fd60SJacques Pienaar viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
3167c38fd60SJacques Pienaar xferOp.getIndices().end());
31799ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices);
31899ef9eebSMatthias Springer },
31999ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) {
32099ef9eebSMatthias Springer Operation *newXfer = b.clone(*xferOp.getOperation());
32199ef9eebSMatthias Springer Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
32299ef9eebSMatthias Springer b.create<memref::StoreOp>(
32399ef9eebSMatthias Springer loc, vector,
32499ef9eebSMatthias Springer b.create<vector::TypeCastOp>(
32599ef9eebSMatthias Springer loc, MemRefType::get({}, vector.getType()), alloc));
32699ef9eebSMatthias Springer
32799ef9eebSMatthias Springer Value casted =
3283c69bc4dSRiver Riddle b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
32999ef9eebSMatthias Springer scf::ValueVector viewAndIndices{casted};
33099ef9eebSMatthias Springer viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
33199ef9eebSMatthias Springer zero);
33299ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices);
33399ef9eebSMatthias Springer });
33499ef9eebSMatthias Springer }
33599ef9eebSMatthias Springer
33699ef9eebSMatthias Springer /// Given an `xferOp` for which:
33799ef9eebSMatthias Springer /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
33899ef9eebSMatthias Springer /// 2. a memref of single vector `alloc` has been allocated.
33999ef9eebSMatthias Springer /// Produce IR resembling:
34099ef9eebSMatthias Springer /// ```
34199ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) {
34299ef9eebSMatthias Springer /// memref.cast %A: memref<A...> to compatibleMemRefType
34399ef9eebSMatthias Springer /// scf.yield %view, ... : compatibleMemRefType, index, index
34499ef9eebSMatthias Springer /// } else {
34599ef9eebSMatthias Springer /// %3 = vector.type_cast %extra_alloc :
34699ef9eebSMatthias Springer /// memref<...> to memref<vector<...>>
34799ef9eebSMatthias Springer /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
34899ef9eebSMatthias Springer /// scf.yield %4, ... : compatibleMemRefType, index, index
34999ef9eebSMatthias Springer /// }
35099ef9eebSMatthias Springer /// ```
35199ef9eebSMatthias Springer static ValueRange
getLocationToWriteFullVec(RewriterBase & b,vector::TransferWriteOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)35299ef9eebSMatthias Springer getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
35399ef9eebSMatthias Springer TypeRange returnTypes, Value inBoundsCond,
35499ef9eebSMatthias Springer MemRefType compatibleMemRefType, Value alloc) {
35599ef9eebSMatthias Springer Location loc = xferOp.getLoc();
35699ef9eebSMatthias Springer Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
3577c38fd60SJacques Pienaar Value memref = xferOp.getSource();
35899ef9eebSMatthias Springer return b
35999ef9eebSMatthias Springer .create<scf::IfOp>(
36099ef9eebSMatthias Springer loc, returnTypes, inBoundsCond,
36199ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) {
36299ef9eebSMatthias Springer Value res = memref;
36399ef9eebSMatthias Springer if (compatibleMemRefType != xferOp.getShapedType())
3643c69bc4dSRiver Riddle res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
36599ef9eebSMatthias Springer scf::ValueVector viewAndIndices{res};
36699ef9eebSMatthias Springer viewAndIndices.insert(viewAndIndices.end(),
3677c38fd60SJacques Pienaar xferOp.getIndices().begin(),
3687c38fd60SJacques Pienaar xferOp.getIndices().end());
36999ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices);
37099ef9eebSMatthias Springer },
37199ef9eebSMatthias Springer [&](OpBuilder &b, Location loc) {
37299ef9eebSMatthias Springer Value casted =
3733c69bc4dSRiver Riddle b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
37499ef9eebSMatthias Springer scf::ValueVector viewAndIndices{casted};
37599ef9eebSMatthias Springer viewAndIndices.insert(viewAndIndices.end(),
37699ef9eebSMatthias Springer xferOp.getTransferRank(), zero);
37799ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, viewAndIndices);
37899ef9eebSMatthias Springer })
37999ef9eebSMatthias Springer ->getResults();
38099ef9eebSMatthias Springer }
38199ef9eebSMatthias Springer
38299ef9eebSMatthias Springer /// Given an `xferOp` for which:
38399ef9eebSMatthias Springer /// 1. `inBoundsCond` has been computed.
38499ef9eebSMatthias Springer /// 2. a memref of single vector `alloc` has been allocated.
38599ef9eebSMatthias Springer /// 3. it originally wrote to %view
38699ef9eebSMatthias Springer /// Produce IR resembling:
38799ef9eebSMatthias Springer /// ```
38899ef9eebSMatthias Springer /// %notInBounds = arith.xori %inBounds, %true
38999ef9eebSMatthias Springer /// scf.if (%notInBounds) {
39099ef9eebSMatthias Springer /// %3 = subview %alloc [...][...][...]
39199ef9eebSMatthias Springer /// %4 = subview %view [0, 0][...][...]
39299ef9eebSMatthias Springer /// linalg.copy(%3, %4)
39399ef9eebSMatthias Springer /// }
39499ef9eebSMatthias Springer /// ```
createFullPartialLinalgCopy(RewriterBase & b,vector::TransferWriteOp xferOp,Value inBoundsCond,Value alloc)39599ef9eebSMatthias Springer static void createFullPartialLinalgCopy(RewriterBase &b,
39699ef9eebSMatthias Springer vector::TransferWriteOp xferOp,
39799ef9eebSMatthias Springer Value inBoundsCond, Value alloc) {
39899ef9eebSMatthias Springer Location loc = xferOp.getLoc();
39999ef9eebSMatthias Springer auto notInBounds = b.create<arith::XOrIOp>(
40099ef9eebSMatthias Springer loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
40199ef9eebSMatthias Springer b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
40299ef9eebSMatthias Springer IRRewriter rewriter(b);
40399ef9eebSMatthias Springer std::pair<Value, Value> copyArgs = createSubViewIntersection(
40499ef9eebSMatthias Springer rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
40599ef9eebSMatthias Springer alloc);
406ebc81537SAlexander Belyaev b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
40799ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, ValueRange{});
40899ef9eebSMatthias Springer });
40999ef9eebSMatthias Springer }
41099ef9eebSMatthias Springer
41199ef9eebSMatthias Springer /// Given an `xferOp` for which:
41299ef9eebSMatthias Springer /// 1. `inBoundsCond` has been computed.
41399ef9eebSMatthias Springer /// 2. a memref of single vector `alloc` has been allocated.
41499ef9eebSMatthias Springer /// 3. it originally wrote to %view
41599ef9eebSMatthias Springer /// Produce IR resembling:
41699ef9eebSMatthias Springer /// ```
41799ef9eebSMatthias Springer /// %notInBounds = arith.xori %inBounds, %true
41899ef9eebSMatthias Springer /// scf.if (%notInBounds) {
41999ef9eebSMatthias Springer /// %2 = load %alloc : memref<vector<...>>
42099ef9eebSMatthias Springer /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
42199ef9eebSMatthias Springer /// }
42299ef9eebSMatthias Springer /// ```
createFullPartialVectorTransferWrite(RewriterBase & b,vector::TransferWriteOp xferOp,Value inBoundsCond,Value alloc)42399ef9eebSMatthias Springer static void createFullPartialVectorTransferWrite(RewriterBase &b,
42499ef9eebSMatthias Springer vector::TransferWriteOp xferOp,
42599ef9eebSMatthias Springer Value inBoundsCond,
42699ef9eebSMatthias Springer Value alloc) {
42799ef9eebSMatthias Springer Location loc = xferOp.getLoc();
42899ef9eebSMatthias Springer auto notInBounds = b.create<arith::XOrIOp>(
42999ef9eebSMatthias Springer loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
43099ef9eebSMatthias Springer b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
43199ef9eebSMatthias Springer BlockAndValueMapping mapping;
43299ef9eebSMatthias Springer Value load = b.create<memref::LoadOp>(
4337c38fd60SJacques Pienaar loc,
4347c38fd60SJacques Pienaar b.create<vector::TypeCastOp>(
4357c38fd60SJacques Pienaar loc, MemRefType::get({}, xferOp.getVector().getType()), alloc));
4367c38fd60SJacques Pienaar mapping.map(xferOp.getVector(), load);
43799ef9eebSMatthias Springer b.clone(*xferOp.getOperation(), mapping);
43899ef9eebSMatthias Springer b.create<scf::YieldOp>(loc, ValueRange{});
43999ef9eebSMatthias Springer });
44099ef9eebSMatthias Springer }
44199ef9eebSMatthias Springer
4423c3810e7SNicolas Vasilache // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
getAutomaticAllocationScope(Operation * op)4433c3810e7SNicolas Vasilache static Operation *getAutomaticAllocationScope(Operation *op) {
4444c807f2fSAlex Zinenko // Find the closest surrounding allocation scope that is not a known looping
4454c807f2fSAlex Zinenko // construct (putting alloca's in loops doesn't always lower to deallocation
4464c807f2fSAlex Zinenko // until the end of the loop).
4474c807f2fSAlex Zinenko Operation *scope = nullptr;
4484c807f2fSAlex Zinenko for (Operation *parent = op->getParentOp(); parent != nullptr;
4494c807f2fSAlex Zinenko parent = parent->getParentOp()) {
4504c807f2fSAlex Zinenko if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
4514c807f2fSAlex Zinenko scope = parent;
4524c807f2fSAlex Zinenko if (!isa<scf::ForOp, AffineForOp>(parent))
4534c807f2fSAlex Zinenko break;
4544c807f2fSAlex Zinenko }
4553c3810e7SNicolas Vasilache assert(scope && "Expected op to be inside automatic allocation scope");
4563c3810e7SNicolas Vasilache return scope;
4573c3810e7SNicolas Vasilache }
4583c3810e7SNicolas Vasilache
45999ef9eebSMatthias Springer /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
46099ef9eebSMatthias Springer /// masking) fastpath and a slowpath.
46199ef9eebSMatthias Springer ///
46299ef9eebSMatthias Springer /// For vector.transfer_read:
46399ef9eebSMatthias Springer /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
46499ef9eebSMatthias Springer /// newly created conditional upon function return.
46599ef9eebSMatthias Springer /// To accomodate for the fact that the original vector.transfer indexing may be
46699ef9eebSMatthias Springer /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
46799ef9eebSMatthias Springer /// scf.if op returns a view and values of type index.
46899ef9eebSMatthias Springer ///
46999ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_read):
47099ef9eebSMatthias Springer /// ```
47199ef9eebSMatthias Springer /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
47299ef9eebSMatthias Springer /// ```
47399ef9eebSMatthias Springer /// is transformed into:
47499ef9eebSMatthias Springer /// ```
47599ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) {
47699ef9eebSMatthias Springer /// // fastpath, direct cast
47799ef9eebSMatthias Springer /// memref.cast %A: memref<A...> to compatibleMemRefType
47899ef9eebSMatthias Springer /// scf.yield %view : compatibleMemRefType, index, index
47999ef9eebSMatthias Springer /// } else {
48099ef9eebSMatthias Springer /// // slowpath, not in-bounds vector.transfer or linalg.copy.
48199ef9eebSMatthias Springer /// memref.cast %alloc: memref<B...> to compatibleMemRefType
48299ef9eebSMatthias Springer /// scf.yield %4 : compatibleMemRefType, index, index
48399ef9eebSMatthias Springer // }
48499ef9eebSMatthias Springer /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
48599ef9eebSMatthias Springer /// ```
48699ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector.
48799ef9eebSMatthias Springer ///
48899ef9eebSMatthias Springer /// For vector.transfer_write:
48999ef9eebSMatthias Springer /// There are 2 conditional blocks. First a block to decide which memref and
49099ef9eebSMatthias Springer /// indices to use for an unmasked, inbounds write. Then a conditional block to
49199ef9eebSMatthias Springer /// further copy a partial buffer into the final result in the slow path case.
49299ef9eebSMatthias Springer ///
49399ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_write):
49499ef9eebSMatthias Springer /// ```
49599ef9eebSMatthias Springer /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
49699ef9eebSMatthias Springer /// ```
49799ef9eebSMatthias Springer /// is transformed into:
49899ef9eebSMatthias Springer /// ```
49999ef9eebSMatthias Springer /// %1:3 = scf.if (%inBounds) {
50099ef9eebSMatthias Springer /// memref.cast %A: memref<A...> to compatibleMemRefType
50199ef9eebSMatthias Springer /// scf.yield %view : compatibleMemRefType, index, index
50299ef9eebSMatthias Springer /// } else {
50399ef9eebSMatthias Springer /// memref.cast %alloc: memref<B...> to compatibleMemRefType
50499ef9eebSMatthias Springer /// scf.yield %4 : compatibleMemRefType, index, index
50599ef9eebSMatthias Springer /// }
50699ef9eebSMatthias Springer /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
50799ef9eebSMatthias Springer /// true]}
50899ef9eebSMatthias Springer /// scf.if (%notInBounds) {
50999ef9eebSMatthias Springer /// // slowpath: not in-bounds vector.transfer or linalg.copy.
51099ef9eebSMatthias Springer /// }
51199ef9eebSMatthias Springer /// ```
51299ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector.
51399ef9eebSMatthias Springer ///
51499ef9eebSMatthias Springer /// Preconditions:
51599ef9eebSMatthias Springer /// 1. `xferOp.permutation_map()` must be a minor identity map
51699ef9eebSMatthias Springer /// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
51799ef9eebSMatthias Springer /// must be equal. This will be relaxed in the future but requires
51899ef9eebSMatthias Springer /// rank-reducing subviews.
splitFullAndPartialTransfer(RewriterBase & b,VectorTransferOpInterface xferOp,VectorTransformsOptions options,scf::IfOp * ifOp)51999ef9eebSMatthias Springer LogicalResult mlir::vector::splitFullAndPartialTransfer(
52099ef9eebSMatthias Springer RewriterBase &b, VectorTransferOpInterface xferOp,
52199ef9eebSMatthias Springer VectorTransformsOptions options, scf::IfOp *ifOp) {
52299ef9eebSMatthias Springer if (options.vectorTransferSplit == VectorTransferSplit::None)
52399ef9eebSMatthias Springer return failure();
52499ef9eebSMatthias Springer
52599ef9eebSMatthias Springer SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
52699ef9eebSMatthias Springer auto inBoundsAttr = b.getBoolArrayAttr(bools);
52799ef9eebSMatthias Springer if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
52875044e9bSJacques Pienaar xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
52999ef9eebSMatthias Springer return success();
53099ef9eebSMatthias Springer }
53199ef9eebSMatthias Springer
53299ef9eebSMatthias Springer // Assert preconditions. Additionally, keep the variables in an inner scope to
53399ef9eebSMatthias Springer // ensure they aren't used in the wrong scopes further down.
53499ef9eebSMatthias Springer {
53599ef9eebSMatthias Springer assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
53699ef9eebSMatthias Springer "Expected splitFullAndPartialTransferPrecondition to hold");
53799ef9eebSMatthias Springer
53899ef9eebSMatthias Springer auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
53999ef9eebSMatthias Springer auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
54099ef9eebSMatthias Springer
54199ef9eebSMatthias Springer if (!(xferReadOp || xferWriteOp))
54299ef9eebSMatthias Springer return failure();
5437c38fd60SJacques Pienaar if (xferWriteOp && xferWriteOp.getMask())
54499ef9eebSMatthias Springer return failure();
5457c38fd60SJacques Pienaar if (xferReadOp && xferReadOp.getMask())
54699ef9eebSMatthias Springer return failure();
54799ef9eebSMatthias Springer }
54899ef9eebSMatthias Springer
54999ef9eebSMatthias Springer RewriterBase::InsertionGuard guard(b);
55099ef9eebSMatthias Springer b.setInsertionPoint(xferOp);
55199ef9eebSMatthias Springer Value inBoundsCond = createInBoundsCond(
55299ef9eebSMatthias Springer b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
55399ef9eebSMatthias Springer if (!inBoundsCond)
55499ef9eebSMatthias Springer return failure();
55599ef9eebSMatthias Springer
55699ef9eebSMatthias Springer // Top of the function `alloc` for transient storage.
55799ef9eebSMatthias Springer Value alloc;
55899ef9eebSMatthias Springer {
55999ef9eebSMatthias Springer RewriterBase::InsertionGuard guard(b);
5603c3810e7SNicolas Vasilache Operation *scope = getAutomaticAllocationScope(xferOp);
5613c3810e7SNicolas Vasilache assert(scope->getNumRegions() == 1 &&
5623c3810e7SNicolas Vasilache "AutomaticAllocationScope with >1 regions");
5633c3810e7SNicolas Vasilache b.setInsertionPointToStart(&scope->getRegion(0).front());
56499ef9eebSMatthias Springer auto shape = xferOp.getVectorType().getShape();
56599ef9eebSMatthias Springer Type elementType = xferOp.getVectorType().getElementType();
5663c3810e7SNicolas Vasilache alloc = b.create<memref::AllocaOp>(scope->getLoc(),
56799ef9eebSMatthias Springer MemRefType::get(shape, elementType),
56899ef9eebSMatthias Springer ValueRange{}, b.getI64IntegerAttr(32));
56999ef9eebSMatthias Springer }
57099ef9eebSMatthias Springer
57199ef9eebSMatthias Springer MemRefType compatibleMemRefType =
57299ef9eebSMatthias Springer getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
57399ef9eebSMatthias Springer alloc.getType().cast<MemRefType>());
57499ef9eebSMatthias Springer if (!compatibleMemRefType)
57599ef9eebSMatthias Springer return failure();
57699ef9eebSMatthias Springer
57799ef9eebSMatthias Springer SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
57899ef9eebSMatthias Springer b.getIndexType());
57999ef9eebSMatthias Springer returnTypes[0] = compatibleMemRefType;
58099ef9eebSMatthias Springer
58199ef9eebSMatthias Springer if (auto xferReadOp =
58299ef9eebSMatthias Springer dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
58399ef9eebSMatthias Springer // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
58499ef9eebSMatthias Springer scf::IfOp fullPartialIfOp =
58599ef9eebSMatthias Springer options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
58699ef9eebSMatthias Springer ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
58799ef9eebSMatthias Springer inBoundsCond,
58899ef9eebSMatthias Springer compatibleMemRefType, alloc)
58999ef9eebSMatthias Springer : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
59099ef9eebSMatthias Springer inBoundsCond, compatibleMemRefType,
59199ef9eebSMatthias Springer alloc);
59299ef9eebSMatthias Springer if (ifOp)
59399ef9eebSMatthias Springer *ifOp = fullPartialIfOp;
59499ef9eebSMatthias Springer
59599ef9eebSMatthias Springer // Set existing read op to in-bounds, it always reads from a full buffer.
59699ef9eebSMatthias Springer for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
59799ef9eebSMatthias Springer xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
59899ef9eebSMatthias Springer
59975044e9bSJacques Pienaar xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
60099ef9eebSMatthias Springer
60199ef9eebSMatthias Springer return success();
60299ef9eebSMatthias Springer }
60399ef9eebSMatthias Springer
60499ef9eebSMatthias Springer auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
60599ef9eebSMatthias Springer
60699ef9eebSMatthias Springer // Decide which location to write the entire vector to.
60799ef9eebSMatthias Springer auto memrefAndIndices = getLocationToWriteFullVec(
60899ef9eebSMatthias Springer b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
60999ef9eebSMatthias Springer
61099ef9eebSMatthias Springer // Do an in bounds write to either the output or the extra allocated buffer.
61199ef9eebSMatthias Springer // The operation is cloned to prevent deleting information needed for the
61299ef9eebSMatthias Springer // later IR creation.
61399ef9eebSMatthias Springer BlockAndValueMapping mapping;
6147c38fd60SJacques Pienaar mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
6157c38fd60SJacques Pienaar mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
61699ef9eebSMatthias Springer auto *clone = b.clone(*xferWriteOp, mapping);
61799ef9eebSMatthias Springer clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
61899ef9eebSMatthias Springer
61999ef9eebSMatthias Springer // Create a potential copy from the allocated buffer to the final output in
62099ef9eebSMatthias Springer // the slow path case.
62199ef9eebSMatthias Springer if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
62299ef9eebSMatthias Springer createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
62399ef9eebSMatthias Springer else
62499ef9eebSMatthias Springer createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
62599ef9eebSMatthias Springer
62699ef9eebSMatthias Springer xferOp->erase();
62799ef9eebSMatthias Springer
62899ef9eebSMatthias Springer return success();
62999ef9eebSMatthias Springer }
63099ef9eebSMatthias Springer
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const63199ef9eebSMatthias Springer LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
63299ef9eebSMatthias Springer Operation *op, PatternRewriter &rewriter) const {
63399ef9eebSMatthias Springer auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
63499ef9eebSMatthias Springer if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
63599ef9eebSMatthias Springer failed(filter(xferOp)))
63699ef9eebSMatthias Springer return failure();
63799ef9eebSMatthias Springer rewriter.startRootUpdate(xferOp);
63899ef9eebSMatthias Springer if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
63999ef9eebSMatthias Springer rewriter.finalizeRootUpdate(xferOp);
64099ef9eebSMatthias Springer return success();
64199ef9eebSMatthias Springer }
64299ef9eebSMatthias Springer rewriter.cancelRootUpdate(xferOp);
64399ef9eebSMatthias Springer return failure();
64499ef9eebSMatthias Springer }
645