199ef9eebSMatthias Springer //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements target-independent patterns to rewrite a vector.transfer
1099ef9eebSMatthias Springer // op into a fully in-bounds part and a partial part.
1199ef9eebSMatthias Springer //
1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1399ef9eebSMatthias Springer 
1499ef9eebSMatthias Springer #include <type_traits>
1599ef9eebSMatthias Springer 
1699ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
1799ef9eebSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1899ef9eebSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
1999ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
20*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
2199ef9eebSMatthias Springer #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2299ef9eebSMatthias Springer 
2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
2499ef9eebSMatthias Springer #include "mlir/IR/Matchers.h"
2599ef9eebSMatthias Springer #include "mlir/IR/PatternMatch.h"
2699ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h"
2799ef9eebSMatthias Springer 
2899ef9eebSMatthias Springer #include "llvm/ADT/DenseSet.h"
2999ef9eebSMatthias Springer #include "llvm/ADT/MapVector.h"
3099ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h"
3199ef9eebSMatthias Springer #include "llvm/Support/CommandLine.h"
3299ef9eebSMatthias Springer #include "llvm/Support/Debug.h"
3399ef9eebSMatthias Springer #include "llvm/Support/raw_ostream.h"
3499ef9eebSMatthias Springer 
3599ef9eebSMatthias Springer #define DEBUG_TYPE "vector-transfer-split"
3699ef9eebSMatthias Springer 
3799ef9eebSMatthias Springer using namespace mlir;
3899ef9eebSMatthias Springer using namespace mlir::vector;
3999ef9eebSMatthias Springer 
extractConstantIndex(Value v)4099ef9eebSMatthias Springer static Optional<int64_t> extractConstantIndex(Value v) {
4199ef9eebSMatthias Springer   if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
4299ef9eebSMatthias Springer     return cstOp.value();
4399ef9eebSMatthias Springer   if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
4499ef9eebSMatthias Springer     if (affineApplyOp.getAffineMap().isSingleConstant())
4599ef9eebSMatthias Springer       return affineApplyOp.getAffineMap().getSingleConstantResult();
4699ef9eebSMatthias Springer   return None;
4799ef9eebSMatthias Springer }
4899ef9eebSMatthias Springer 
4999ef9eebSMatthias Springer // Missing foldings of scf.if make it necessary to perform poor man's folding
5099ef9eebSMatthias Springer // eagerly, especially in the case of unrolling. In the future, this should go
5199ef9eebSMatthias Springer // away once scf.if folds properly.
createFoldedSLE(RewriterBase & b,Value v,Value ub)5299ef9eebSMatthias Springer static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) {
5399ef9eebSMatthias Springer   auto maybeCstV = extractConstantIndex(v);
5499ef9eebSMatthias Springer   auto maybeCstUb = extractConstantIndex(ub);
5599ef9eebSMatthias Springer   if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
5699ef9eebSMatthias Springer     return Value();
5799ef9eebSMatthias Springer   return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
5899ef9eebSMatthias Springer }
5999ef9eebSMatthias Springer 
6099ef9eebSMatthias Springer /// Build the condition to ensure that a particular VectorTransferOpInterface
6199ef9eebSMatthias Springer /// is in-bounds.
createInBoundsCond(RewriterBase & b,VectorTransferOpInterface xferOp)6299ef9eebSMatthias Springer static Value createInBoundsCond(RewriterBase &b,
6399ef9eebSMatthias Springer                                 VectorTransferOpInterface xferOp) {
6499ef9eebSMatthias Springer   assert(xferOp.permutation_map().isMinorIdentity() &&
6599ef9eebSMatthias Springer          "Expected minor identity map");
6699ef9eebSMatthias Springer   Value inBoundsCond;
6799ef9eebSMatthias Springer   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
6899ef9eebSMatthias Springer     // Zip over the resulting vector shape and memref indices.
6999ef9eebSMatthias Springer     // If the dimension is known to be in-bounds, it does not participate in
7099ef9eebSMatthias Springer     // the construction of `inBoundsCond`.
7199ef9eebSMatthias Springer     if (xferOp.isDimInBounds(resultIdx))
7299ef9eebSMatthias Springer       return;
7399ef9eebSMatthias Springer     // Fold or create the check that `index + vector_size` <= `memref_size`.
7499ef9eebSMatthias Springer     Location loc = xferOp.getLoc();
7599ef9eebSMatthias Springer     int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
7699ef9eebSMatthias Springer     auto d0 = getAffineDimExpr(0, xferOp.getContext());
7799ef9eebSMatthias Springer     auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
7899ef9eebSMatthias Springer     Value sum =
7999ef9eebSMatthias Springer         makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
8099ef9eebSMatthias Springer     Value cond = createFoldedSLE(
8199ef9eebSMatthias Springer         b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
8299ef9eebSMatthias Springer     if (!cond)
8399ef9eebSMatthias Springer       return;
8499ef9eebSMatthias Springer     // Conjunction over all dims for which we are in-bounds.
8599ef9eebSMatthias Springer     if (inBoundsCond)
8699ef9eebSMatthias Springer       inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
8799ef9eebSMatthias Springer     else
8899ef9eebSMatthias Springer       inBoundsCond = cond;
8999ef9eebSMatthias Springer   });
9099ef9eebSMatthias Springer   return inBoundsCond;
9199ef9eebSMatthias Springer }
9299ef9eebSMatthias Springer 
9399ef9eebSMatthias Springer /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
9499ef9eebSMatthias Springer /// masking) fastpath and a slowpath.
9599ef9eebSMatthias Springer /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
9699ef9eebSMatthias Springer /// newly created conditional upon function return.
9799ef9eebSMatthias Springer /// To accomodate for the fact that the original vector.transfer indexing may be
9899ef9eebSMatthias Springer /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
9999ef9eebSMatthias Springer /// scf.if op returns a view and values of type index.
10099ef9eebSMatthias Springer /// At this time, only vector.transfer_read case is implemented.
10199ef9eebSMatthias Springer ///
10299ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_read):
10399ef9eebSMatthias Springer /// ```
10499ef9eebSMatthias Springer ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
10599ef9eebSMatthias Springer /// ```
10699ef9eebSMatthias Springer /// is transformed into:
10799ef9eebSMatthias Springer /// ```
10899ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
10999ef9eebSMatthias Springer ///      // fastpath, direct cast
11099ef9eebSMatthias Springer ///      memref.cast %A: memref<A...> to compatibleMemRefType
11199ef9eebSMatthias Springer ///      scf.yield %view : compatibleMemRefType, index, index
11299ef9eebSMatthias Springer ///    } else {
11399ef9eebSMatthias Springer ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
11499ef9eebSMatthias Springer ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
11599ef9eebSMatthias Springer ///      scf.yield %4 : compatibleMemRefType, index, index
11699ef9eebSMatthias Springer //     }
11799ef9eebSMatthias Springer ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
11899ef9eebSMatthias Springer /// ```
11999ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector.
12099ef9eebSMatthias Springer ///
12199ef9eebSMatthias Springer /// Preconditions:
12299ef9eebSMatthias Springer ///  1. `xferOp.permutation_map()` must be a minor identity map
12399ef9eebSMatthias Springer ///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
12499ef9eebSMatthias Springer ///  must be equal. This will be relaxed in the future but requires
12599ef9eebSMatthias Springer ///  rank-reducing subviews.
12699ef9eebSMatthias Springer static LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)12799ef9eebSMatthias Springer splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
12899ef9eebSMatthias Springer   // TODO: support 0-d corner case.
12999ef9eebSMatthias Springer   if (xferOp.getTransferRank() == 0)
13099ef9eebSMatthias Springer     return failure();
13199ef9eebSMatthias Springer 
13299ef9eebSMatthias Springer   // TODO: expand support to these 2 cases.
13399ef9eebSMatthias Springer   if (!xferOp.permutation_map().isMinorIdentity())
13499ef9eebSMatthias Springer     return failure();
13599ef9eebSMatthias Springer   // Must have some out-of-bounds dimension to be a candidate for splitting.
13699ef9eebSMatthias Springer   if (!xferOp.hasOutOfBoundsDim())
13799ef9eebSMatthias Springer     return failure();
13899ef9eebSMatthias Springer   // Don't split transfer operations directly under IfOp, this avoids applying
13999ef9eebSMatthias Springer   // the pattern recursively.
14099ef9eebSMatthias Springer   // TODO: improve the filtering condition to make it more applicable.
14199ef9eebSMatthias Springer   if (isa<scf::IfOp>(xferOp->getParentOp()))
14299ef9eebSMatthias Springer     return failure();
14399ef9eebSMatthias Springer   return success();
14499ef9eebSMatthias Springer }
14599ef9eebSMatthias Springer 
14699ef9eebSMatthias Springer /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
14799ef9eebSMatthias Springer /// be cast. If the MemRefTypes don't have the same rank or are not strided,
14899ef9eebSMatthias Springer /// return null; otherwise:
14999ef9eebSMatthias Springer ///   1. if `aT` and `bT` are cast-compatible, return `aT`.
15099ef9eebSMatthias Springer ///   2. else return a new MemRefType obtained by iterating over the shape and
15199ef9eebSMatthias Springer ///   strides and:
15299ef9eebSMatthias Springer ///     a. keeping the ones that are static and equal across `aT` and `bT`.
15399ef9eebSMatthias Springer ///     b. using a dynamic shape and/or stride for the dimensions that don't
15499ef9eebSMatthias Springer ///        agree.
getCastCompatibleMemRefType(MemRefType aT,MemRefType bT)15599ef9eebSMatthias Springer static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
15699ef9eebSMatthias Springer   if (memref::CastOp::areCastCompatible(aT, bT))
15799ef9eebSMatthias Springer     return aT;
15899ef9eebSMatthias Springer   if (aT.getRank() != bT.getRank())
15999ef9eebSMatthias Springer     return MemRefType();
16099ef9eebSMatthias Springer   int64_t aOffset, bOffset;
16199ef9eebSMatthias Springer   SmallVector<int64_t, 4> aStrides, bStrides;
16299ef9eebSMatthias Springer   if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
16399ef9eebSMatthias Springer       failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
16499ef9eebSMatthias Springer       aStrides.size() != bStrides.size())
16599ef9eebSMatthias Springer     return MemRefType();
16699ef9eebSMatthias Springer 
16799ef9eebSMatthias Springer   ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
16899ef9eebSMatthias Springer   int64_t resOffset;
16999ef9eebSMatthias Springer   SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
17099ef9eebSMatthias Springer       resStrides(bT.getRank(), 0);
17199ef9eebSMatthias Springer   for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
17299ef9eebSMatthias Springer     resShape[idx] =
17399ef9eebSMatthias Springer         (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize;
17499ef9eebSMatthias Springer     resStrides[idx] = (aStrides[idx] == bStrides[idx])
17599ef9eebSMatthias Springer                           ? aStrides[idx]
17699ef9eebSMatthias Springer                           : ShapedType::kDynamicStrideOrOffset;
17799ef9eebSMatthias Springer   }
17899ef9eebSMatthias Springer   resOffset =
17999ef9eebSMatthias Springer       (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset;
18099ef9eebSMatthias Springer   return MemRefType::get(
18199ef9eebSMatthias Springer       resShape, aT.getElementType(),
18299ef9eebSMatthias Springer       makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
18399ef9eebSMatthias Springer }
18499ef9eebSMatthias Springer 
18599ef9eebSMatthias Springer /// Operates under a scoped context to build the intersection between the
18699ef9eebSMatthias Springer /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
18799ef9eebSMatthias Springer // TODO: view intersection/union/differences should be a proper std op.
18899ef9eebSMatthias Springer static std::pair<Value, Value>
createSubViewIntersection(RewriterBase & b,VectorTransferOpInterface xferOp,Value alloc)18999ef9eebSMatthias Springer createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
19099ef9eebSMatthias Springer                           Value alloc) {
19199ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
19299ef9eebSMatthias Springer   int64_t memrefRank = xferOp.getShapedType().getRank();
19399ef9eebSMatthias Springer   // TODO: relax this precondition, will require rank-reducing subviews.
19499ef9eebSMatthias Springer   assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
19599ef9eebSMatthias Springer          "Expected memref rank to match the alloc rank");
19699ef9eebSMatthias Springer   ValueRange leadingIndices =
19799ef9eebSMatthias Springer       xferOp.indices().take_front(xferOp.getLeadingShapedRank());
19899ef9eebSMatthias Springer   SmallVector<OpFoldResult, 4> sizes;
19999ef9eebSMatthias Springer   sizes.append(leadingIndices.begin(), leadingIndices.end());
20099ef9eebSMatthias Springer   auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
20199ef9eebSMatthias Springer   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
20299ef9eebSMatthias Springer     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
20399ef9eebSMatthias Springer     Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
20499ef9eebSMatthias Springer                                                 xferOp.source(), indicesIdx);
20599ef9eebSMatthias Springer     Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
20699ef9eebSMatthias Springer     Value index = xferOp.indices()[indicesIdx];
20799ef9eebSMatthias Springer     AffineExpr i, j, k;
20899ef9eebSMatthias Springer     bindDims(xferOp.getContext(), i, j, k);
20999ef9eebSMatthias Springer     SmallVector<AffineMap, 4> maps =
21099ef9eebSMatthias Springer         AffineMap::inferFromExprList(MapList{{i - j, k}});
21199ef9eebSMatthias Springer     // affine_min(%dimMemRef - %index, %dimAlloc)
21299ef9eebSMatthias Springer     Value affineMin = b.create<AffineMinOp>(
21399ef9eebSMatthias Springer         loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
21499ef9eebSMatthias Springer     sizes.push_back(affineMin);
21599ef9eebSMatthias Springer   });
21699ef9eebSMatthias Springer 
21799ef9eebSMatthias Springer   SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
21899ef9eebSMatthias Springer       xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
21999ef9eebSMatthias Springer   SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
22099ef9eebSMatthias Springer   SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
22199ef9eebSMatthias Springer   auto copySrc = b.create<memref::SubViewOp>(
22299ef9eebSMatthias Springer       loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
22399ef9eebSMatthias Springer   auto copyDest = b.create<memref::SubViewOp>(
22499ef9eebSMatthias Springer       loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
22599ef9eebSMatthias Springer   return std::make_pair(copySrc, copyDest);
22699ef9eebSMatthias Springer }
22799ef9eebSMatthias Springer 
22899ef9eebSMatthias Springer /// Given an `xferOp` for which:
22999ef9eebSMatthias Springer ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
23099ef9eebSMatthias Springer ///   2. a memref of single vector `alloc` has been allocated.
23199ef9eebSMatthias Springer /// Produce IR resembling:
23299ef9eebSMatthias Springer /// ```
23399ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
23499ef9eebSMatthias Springer ///      %view = memref.cast %A: memref<A...> to compatibleMemRefType
23599ef9eebSMatthias Springer ///      scf.yield %view, ... : compatibleMemRefType, index, index
23699ef9eebSMatthias Springer ///    } else {
23799ef9eebSMatthias Springer ///      %2 = linalg.fill(%pad, %alloc)
23899ef9eebSMatthias Springer ///      %3 = subview %view [...][...][...]
23999ef9eebSMatthias Springer ///      %4 = subview %alloc [0, 0] [...] [...]
24099ef9eebSMatthias Springer ///      linalg.copy(%3, %4)
24199ef9eebSMatthias Springer ///      %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
24299ef9eebSMatthias Springer ///      scf.yield %5, ... : compatibleMemRefType, index, index
24399ef9eebSMatthias Springer ///   }
24499ef9eebSMatthias Springer /// ```
24599ef9eebSMatthias Springer /// Return the produced scf::IfOp.
24699ef9eebSMatthias Springer static scf::IfOp
createFullPartialLinalgCopy(RewriterBase & b,vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)24799ef9eebSMatthias Springer createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
24899ef9eebSMatthias Springer                             TypeRange returnTypes, Value inBoundsCond,
24999ef9eebSMatthias Springer                             MemRefType compatibleMemRefType, Value alloc) {
25099ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
25199ef9eebSMatthias Springer   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
2527c38fd60SJacques Pienaar   Value memref = xferOp.getSource();
25399ef9eebSMatthias Springer   return b.create<scf::IfOp>(
25499ef9eebSMatthias Springer       loc, returnTypes, inBoundsCond,
25599ef9eebSMatthias Springer       [&](OpBuilder &b, Location loc) {
25699ef9eebSMatthias Springer         Value res = memref;
25799ef9eebSMatthias Springer         if (compatibleMemRefType != xferOp.getShapedType())
2583c69bc4dSRiver Riddle           res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
25999ef9eebSMatthias Springer         scf::ValueVector viewAndIndices{res};
2607c38fd60SJacques Pienaar         viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
2617c38fd60SJacques Pienaar                               xferOp.getIndices().end());
26299ef9eebSMatthias Springer         b.create<scf::YieldOp>(loc, viewAndIndices);
26399ef9eebSMatthias Springer       },
26499ef9eebSMatthias Springer       [&](OpBuilder &b, Location loc) {
2657c38fd60SJacques Pienaar         b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
2667294be2bSgysit                                  ValueRange{alloc});
26799ef9eebSMatthias Springer         // Take partial subview of memref which guarantees no dimension
26899ef9eebSMatthias Springer         // overflows.
26999ef9eebSMatthias Springer         IRRewriter rewriter(b);
27099ef9eebSMatthias Springer         std::pair<Value, Value> copyArgs = createSubViewIntersection(
27199ef9eebSMatthias Springer             rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
27299ef9eebSMatthias Springer             alloc);
273ebc81537SAlexander Belyaev         b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
27499ef9eebSMatthias Springer         Value casted =
2753c69bc4dSRiver Riddle             b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
27699ef9eebSMatthias Springer         scf::ValueVector viewAndIndices{casted};
27799ef9eebSMatthias Springer         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
27899ef9eebSMatthias Springer                               zero);
27999ef9eebSMatthias Springer         b.create<scf::YieldOp>(loc, viewAndIndices);
28099ef9eebSMatthias Springer       });
28199ef9eebSMatthias Springer }
28299ef9eebSMatthias Springer 
28399ef9eebSMatthias Springer /// Given an `xferOp` for which:
28499ef9eebSMatthias Springer ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
28599ef9eebSMatthias Springer ///   2. a memref of single vector `alloc` has been allocated.
28699ef9eebSMatthias Springer /// Produce IR resembling:
28799ef9eebSMatthias Springer /// ```
28899ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
28999ef9eebSMatthias Springer ///      memref.cast %A: memref<A...> to compatibleMemRefType
29099ef9eebSMatthias Springer ///      scf.yield %view, ... : compatibleMemRefType, index, index
29199ef9eebSMatthias Springer ///    } else {
29299ef9eebSMatthias Springer ///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
29399ef9eebSMatthias Springer ///      %3 = vector.type_cast %extra_alloc :
29499ef9eebSMatthias Springer ///        memref<...> to memref<vector<...>>
29599ef9eebSMatthias Springer ///      store %2, %3[] : memref<vector<...>>
29699ef9eebSMatthias Springer ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
29799ef9eebSMatthias Springer ///      scf.yield %4, ... : compatibleMemRefType, index, index
29899ef9eebSMatthias Springer ///   }
29999ef9eebSMatthias Springer /// ```
30099ef9eebSMatthias Springer /// Return the produced scf::IfOp.
createFullPartialVectorTransferRead(RewriterBase & b,vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)30199ef9eebSMatthias Springer static scf::IfOp createFullPartialVectorTransferRead(
30299ef9eebSMatthias Springer     RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
30399ef9eebSMatthias Springer     Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
30499ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
30599ef9eebSMatthias Springer   scf::IfOp fullPartialIfOp;
30699ef9eebSMatthias Springer   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
3077c38fd60SJacques Pienaar   Value memref = xferOp.getSource();
30899ef9eebSMatthias Springer   return b.create<scf::IfOp>(
30999ef9eebSMatthias Springer       loc, returnTypes, inBoundsCond,
31099ef9eebSMatthias Springer       [&](OpBuilder &b, Location loc) {
31199ef9eebSMatthias Springer         Value res = memref;
31299ef9eebSMatthias Springer         if (compatibleMemRefType != xferOp.getShapedType())
3133c69bc4dSRiver Riddle           res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
31499ef9eebSMatthias Springer         scf::ValueVector viewAndIndices{res};
3157c38fd60SJacques Pienaar         viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
3167c38fd60SJacques Pienaar                               xferOp.getIndices().end());
31799ef9eebSMatthias Springer         b.create<scf::YieldOp>(loc, viewAndIndices);
31899ef9eebSMatthias Springer       },
31999ef9eebSMatthias Springer       [&](OpBuilder &b, Location loc) {
32099ef9eebSMatthias Springer         Operation *newXfer = b.clone(*xferOp.getOperation());
32199ef9eebSMatthias Springer         Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
32299ef9eebSMatthias Springer         b.create<memref::StoreOp>(
32399ef9eebSMatthias Springer             loc, vector,
32499ef9eebSMatthias Springer             b.create<vector::TypeCastOp>(
32599ef9eebSMatthias Springer                 loc, MemRefType::get({}, vector.getType()), alloc));
32699ef9eebSMatthias Springer 
32799ef9eebSMatthias Springer         Value casted =
3283c69bc4dSRiver Riddle             b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
32999ef9eebSMatthias Springer         scf::ValueVector viewAndIndices{casted};
33099ef9eebSMatthias Springer         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
33199ef9eebSMatthias Springer                               zero);
33299ef9eebSMatthias Springer         b.create<scf::YieldOp>(loc, viewAndIndices);
33399ef9eebSMatthias Springer       });
33499ef9eebSMatthias Springer }
33599ef9eebSMatthias Springer 
33699ef9eebSMatthias Springer /// Given an `xferOp` for which:
33799ef9eebSMatthias Springer ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
33899ef9eebSMatthias Springer ///   2. a memref of single vector `alloc` has been allocated.
33999ef9eebSMatthias Springer /// Produce IR resembling:
34099ef9eebSMatthias Springer /// ```
34199ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
34299ef9eebSMatthias Springer ///      memref.cast %A: memref<A...> to compatibleMemRefType
34399ef9eebSMatthias Springer ///      scf.yield %view, ... : compatibleMemRefType, index, index
34499ef9eebSMatthias Springer ///    } else {
34599ef9eebSMatthias Springer ///      %3 = vector.type_cast %extra_alloc :
34699ef9eebSMatthias Springer ///        memref<...> to memref<vector<...>>
34799ef9eebSMatthias Springer ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
34899ef9eebSMatthias Springer ///      scf.yield %4, ... : compatibleMemRefType, index, index
34999ef9eebSMatthias Springer ///   }
35099ef9eebSMatthias Springer /// ```
35199ef9eebSMatthias Springer static ValueRange
getLocationToWriteFullVec(RewriterBase & b,vector::TransferWriteOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)35299ef9eebSMatthias Springer getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
35399ef9eebSMatthias Springer                           TypeRange returnTypes, Value inBoundsCond,
35499ef9eebSMatthias Springer                           MemRefType compatibleMemRefType, Value alloc) {
35599ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
35699ef9eebSMatthias Springer   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
3577c38fd60SJacques Pienaar   Value memref = xferOp.getSource();
35899ef9eebSMatthias Springer   return b
35999ef9eebSMatthias Springer       .create<scf::IfOp>(
36099ef9eebSMatthias Springer           loc, returnTypes, inBoundsCond,
36199ef9eebSMatthias Springer           [&](OpBuilder &b, Location loc) {
36299ef9eebSMatthias Springer             Value res = memref;
36399ef9eebSMatthias Springer             if (compatibleMemRefType != xferOp.getShapedType())
3643c69bc4dSRiver Riddle               res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
36599ef9eebSMatthias Springer             scf::ValueVector viewAndIndices{res};
36699ef9eebSMatthias Springer             viewAndIndices.insert(viewAndIndices.end(),
3677c38fd60SJacques Pienaar                                   xferOp.getIndices().begin(),
3687c38fd60SJacques Pienaar                                   xferOp.getIndices().end());
36999ef9eebSMatthias Springer             b.create<scf::YieldOp>(loc, viewAndIndices);
37099ef9eebSMatthias Springer           },
37199ef9eebSMatthias Springer           [&](OpBuilder &b, Location loc) {
37299ef9eebSMatthias Springer             Value casted =
3733c69bc4dSRiver Riddle                 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
37499ef9eebSMatthias Springer             scf::ValueVector viewAndIndices{casted};
37599ef9eebSMatthias Springer             viewAndIndices.insert(viewAndIndices.end(),
37699ef9eebSMatthias Springer                                   xferOp.getTransferRank(), zero);
37799ef9eebSMatthias Springer             b.create<scf::YieldOp>(loc, viewAndIndices);
37899ef9eebSMatthias Springer           })
37999ef9eebSMatthias Springer       ->getResults();
38099ef9eebSMatthias Springer }
38199ef9eebSMatthias Springer 
38299ef9eebSMatthias Springer /// Given an `xferOp` for which:
38399ef9eebSMatthias Springer ///   1. `inBoundsCond` has been computed.
38499ef9eebSMatthias Springer ///   2. a memref of single vector `alloc` has been allocated.
38599ef9eebSMatthias Springer ///   3. it originally wrote to %view
38699ef9eebSMatthias Springer /// Produce IR resembling:
38799ef9eebSMatthias Springer /// ```
38899ef9eebSMatthias Springer ///    %notInBounds = arith.xori %inBounds, %true
38999ef9eebSMatthias Springer ///    scf.if (%notInBounds) {
39099ef9eebSMatthias Springer ///      %3 = subview %alloc [...][...][...]
39199ef9eebSMatthias Springer ///      %4 = subview %view [0, 0][...][...]
39299ef9eebSMatthias Springer ///      linalg.copy(%3, %4)
39399ef9eebSMatthias Springer ///   }
39499ef9eebSMatthias Springer /// ```
createFullPartialLinalgCopy(RewriterBase & b,vector::TransferWriteOp xferOp,Value inBoundsCond,Value alloc)39599ef9eebSMatthias Springer static void createFullPartialLinalgCopy(RewriterBase &b,
39699ef9eebSMatthias Springer                                         vector::TransferWriteOp xferOp,
39799ef9eebSMatthias Springer                                         Value inBoundsCond, Value alloc) {
39899ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
39999ef9eebSMatthias Springer   auto notInBounds = b.create<arith::XOrIOp>(
40099ef9eebSMatthias Springer       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
40199ef9eebSMatthias Springer   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
40299ef9eebSMatthias Springer     IRRewriter rewriter(b);
40399ef9eebSMatthias Springer     std::pair<Value, Value> copyArgs = createSubViewIntersection(
40499ef9eebSMatthias Springer         rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
40599ef9eebSMatthias Springer         alloc);
406ebc81537SAlexander Belyaev     b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
40799ef9eebSMatthias Springer     b.create<scf::YieldOp>(loc, ValueRange{});
40899ef9eebSMatthias Springer   });
40999ef9eebSMatthias Springer }
41099ef9eebSMatthias Springer 
41199ef9eebSMatthias Springer /// Given an `xferOp` for which:
41299ef9eebSMatthias Springer ///   1. `inBoundsCond` has been computed.
41399ef9eebSMatthias Springer ///   2. a memref of single vector `alloc` has been allocated.
41499ef9eebSMatthias Springer ///   3. it originally wrote to %view
41599ef9eebSMatthias Springer /// Produce IR resembling:
41699ef9eebSMatthias Springer /// ```
41799ef9eebSMatthias Springer ///    %notInBounds = arith.xori %inBounds, %true
41899ef9eebSMatthias Springer ///    scf.if (%notInBounds) {
41999ef9eebSMatthias Springer ///      %2 = load %alloc : memref<vector<...>>
42099ef9eebSMatthias Springer ///      vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
42199ef9eebSMatthias Springer ///   }
42299ef9eebSMatthias Springer /// ```
createFullPartialVectorTransferWrite(RewriterBase & b,vector::TransferWriteOp xferOp,Value inBoundsCond,Value alloc)42399ef9eebSMatthias Springer static void createFullPartialVectorTransferWrite(RewriterBase &b,
42499ef9eebSMatthias Springer                                                  vector::TransferWriteOp xferOp,
42599ef9eebSMatthias Springer                                                  Value inBoundsCond,
42699ef9eebSMatthias Springer                                                  Value alloc) {
42799ef9eebSMatthias Springer   Location loc = xferOp.getLoc();
42899ef9eebSMatthias Springer   auto notInBounds = b.create<arith::XOrIOp>(
42999ef9eebSMatthias Springer       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
43099ef9eebSMatthias Springer   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
43199ef9eebSMatthias Springer     BlockAndValueMapping mapping;
43299ef9eebSMatthias Springer     Value load = b.create<memref::LoadOp>(
4337c38fd60SJacques Pienaar         loc,
4347c38fd60SJacques Pienaar         b.create<vector::TypeCastOp>(
4357c38fd60SJacques Pienaar             loc, MemRefType::get({}, xferOp.getVector().getType()), alloc));
4367c38fd60SJacques Pienaar     mapping.map(xferOp.getVector(), load);
43799ef9eebSMatthias Springer     b.clone(*xferOp.getOperation(), mapping);
43899ef9eebSMatthias Springer     b.create<scf::YieldOp>(loc, ValueRange{});
43999ef9eebSMatthias Springer   });
44099ef9eebSMatthias Springer }
44199ef9eebSMatthias Springer 
4423c3810e7SNicolas Vasilache // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
getAutomaticAllocationScope(Operation * op)4433c3810e7SNicolas Vasilache static Operation *getAutomaticAllocationScope(Operation *op) {
4444c807f2fSAlex Zinenko   // Find the closest surrounding allocation scope that is not a known looping
4454c807f2fSAlex Zinenko   // construct (putting alloca's in loops doesn't always lower to deallocation
4464c807f2fSAlex Zinenko   // until the end of the loop).
4474c807f2fSAlex Zinenko   Operation *scope = nullptr;
4484c807f2fSAlex Zinenko   for (Operation *parent = op->getParentOp(); parent != nullptr;
4494c807f2fSAlex Zinenko        parent = parent->getParentOp()) {
4504c807f2fSAlex Zinenko     if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
4514c807f2fSAlex Zinenko       scope = parent;
4524c807f2fSAlex Zinenko     if (!isa<scf::ForOp, AffineForOp>(parent))
4534c807f2fSAlex Zinenko       break;
4544c807f2fSAlex Zinenko   }
4553c3810e7SNicolas Vasilache   assert(scope && "Expected op to be inside automatic allocation scope");
4563c3810e7SNicolas Vasilache   return scope;
4573c3810e7SNicolas Vasilache }
4583c3810e7SNicolas Vasilache 
45999ef9eebSMatthias Springer /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
46099ef9eebSMatthias Springer /// masking) fastpath and a slowpath.
46199ef9eebSMatthias Springer ///
46299ef9eebSMatthias Springer /// For vector.transfer_read:
46399ef9eebSMatthias Springer /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
46499ef9eebSMatthias Springer /// newly created conditional upon function return.
46599ef9eebSMatthias Springer /// To accomodate for the fact that the original vector.transfer indexing may be
46699ef9eebSMatthias Springer /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
46799ef9eebSMatthias Springer /// scf.if op returns a view and values of type index.
46899ef9eebSMatthias Springer ///
46999ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_read):
47099ef9eebSMatthias Springer /// ```
47199ef9eebSMatthias Springer ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
47299ef9eebSMatthias Springer /// ```
47399ef9eebSMatthias Springer /// is transformed into:
47499ef9eebSMatthias Springer /// ```
47599ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
47699ef9eebSMatthias Springer ///      // fastpath, direct cast
47799ef9eebSMatthias Springer ///      memref.cast %A: memref<A...> to compatibleMemRefType
47899ef9eebSMatthias Springer ///      scf.yield %view : compatibleMemRefType, index, index
47999ef9eebSMatthias Springer ///    } else {
48099ef9eebSMatthias Springer ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
48199ef9eebSMatthias Springer ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
48299ef9eebSMatthias Springer ///      scf.yield %4 : compatibleMemRefType, index, index
48399ef9eebSMatthias Springer //     }
48499ef9eebSMatthias Springer ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
48599ef9eebSMatthias Springer /// ```
48699ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector.
48799ef9eebSMatthias Springer ///
48899ef9eebSMatthias Springer /// For vector.transfer_write:
48999ef9eebSMatthias Springer /// There are 2 conditional blocks. First a block to decide which memref and
49099ef9eebSMatthias Springer /// indices to use for an unmasked, inbounds write. Then a conditional block to
49199ef9eebSMatthias Springer /// further copy a partial buffer into the final result in the slow path case.
49299ef9eebSMatthias Springer ///
49399ef9eebSMatthias Springer /// Example (a 2-D vector.transfer_write):
49499ef9eebSMatthias Springer /// ```
49599ef9eebSMatthias Springer ///    vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
49699ef9eebSMatthias Springer /// ```
49799ef9eebSMatthias Springer /// is transformed into:
49899ef9eebSMatthias Springer /// ```
49999ef9eebSMatthias Springer ///    %1:3 = scf.if (%inBounds) {
50099ef9eebSMatthias Springer ///      memref.cast %A: memref<A...> to compatibleMemRefType
50199ef9eebSMatthias Springer ///      scf.yield %view : compatibleMemRefType, index, index
50299ef9eebSMatthias Springer ///    } else {
50399ef9eebSMatthias Springer ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
50499ef9eebSMatthias Springer ///      scf.yield %4 : compatibleMemRefType, index, index
50599ef9eebSMatthias Springer ///     }
50699ef9eebSMatthias Springer ///    %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
50799ef9eebSMatthias Springer ///                                                                    true]}
50899ef9eebSMatthias Springer ///    scf.if (%notInBounds) {
50999ef9eebSMatthias Springer ///      // slowpath: not in-bounds vector.transfer or linalg.copy.
51099ef9eebSMatthias Springer ///    }
51199ef9eebSMatthias Springer /// ```
51299ef9eebSMatthias Springer /// where `alloc` is a top of the function alloca'ed buffer of one vector.
51399ef9eebSMatthias Springer ///
51499ef9eebSMatthias Springer /// Preconditions:
51599ef9eebSMatthias Springer ///  1. `xferOp.permutation_map()` must be a minor identity map
51699ef9eebSMatthias Springer ///  2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
51799ef9eebSMatthias Springer ///  must be equal. This will be relaxed in the future but requires
51899ef9eebSMatthias Springer ///  rank-reducing subviews.
splitFullAndPartialTransfer(RewriterBase & b,VectorTransferOpInterface xferOp,VectorTransformsOptions options,scf::IfOp * ifOp)51999ef9eebSMatthias Springer LogicalResult mlir::vector::splitFullAndPartialTransfer(
52099ef9eebSMatthias Springer     RewriterBase &b, VectorTransferOpInterface xferOp,
52199ef9eebSMatthias Springer     VectorTransformsOptions options, scf::IfOp *ifOp) {
52299ef9eebSMatthias Springer   if (options.vectorTransferSplit == VectorTransferSplit::None)
52399ef9eebSMatthias Springer     return failure();
52499ef9eebSMatthias Springer 
52599ef9eebSMatthias Springer   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
52699ef9eebSMatthias Springer   auto inBoundsAttr = b.getBoolArrayAttr(bools);
52799ef9eebSMatthias Springer   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
52875044e9bSJacques Pienaar     xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
52999ef9eebSMatthias Springer     return success();
53099ef9eebSMatthias Springer   }
53199ef9eebSMatthias Springer 
53299ef9eebSMatthias Springer   // Assert preconditions. Additionally, keep the variables in an inner scope to
53399ef9eebSMatthias Springer   // ensure they aren't used in the wrong scopes further down.
53499ef9eebSMatthias Springer   {
53599ef9eebSMatthias Springer     assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
53699ef9eebSMatthias Springer            "Expected splitFullAndPartialTransferPrecondition to hold");
53799ef9eebSMatthias Springer 
53899ef9eebSMatthias Springer     auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
53999ef9eebSMatthias Springer     auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
54099ef9eebSMatthias Springer 
54199ef9eebSMatthias Springer     if (!(xferReadOp || xferWriteOp))
54299ef9eebSMatthias Springer       return failure();
5437c38fd60SJacques Pienaar     if (xferWriteOp && xferWriteOp.getMask())
54499ef9eebSMatthias Springer       return failure();
5457c38fd60SJacques Pienaar     if (xferReadOp && xferReadOp.getMask())
54699ef9eebSMatthias Springer       return failure();
54799ef9eebSMatthias Springer   }
54899ef9eebSMatthias Springer 
54999ef9eebSMatthias Springer   RewriterBase::InsertionGuard guard(b);
55099ef9eebSMatthias Springer   b.setInsertionPoint(xferOp);
55199ef9eebSMatthias Springer   Value inBoundsCond = createInBoundsCond(
55299ef9eebSMatthias Springer       b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
55399ef9eebSMatthias Springer   if (!inBoundsCond)
55499ef9eebSMatthias Springer     return failure();
55599ef9eebSMatthias Springer 
55699ef9eebSMatthias Springer   // Top of the function `alloc` for transient storage.
55799ef9eebSMatthias Springer   Value alloc;
55899ef9eebSMatthias Springer   {
55999ef9eebSMatthias Springer     RewriterBase::InsertionGuard guard(b);
5603c3810e7SNicolas Vasilache     Operation *scope = getAutomaticAllocationScope(xferOp);
5613c3810e7SNicolas Vasilache     assert(scope->getNumRegions() == 1 &&
5623c3810e7SNicolas Vasilache            "AutomaticAllocationScope with >1 regions");
5633c3810e7SNicolas Vasilache     b.setInsertionPointToStart(&scope->getRegion(0).front());
56499ef9eebSMatthias Springer     auto shape = xferOp.getVectorType().getShape();
56599ef9eebSMatthias Springer     Type elementType = xferOp.getVectorType().getElementType();
5663c3810e7SNicolas Vasilache     alloc = b.create<memref::AllocaOp>(scope->getLoc(),
56799ef9eebSMatthias Springer                                        MemRefType::get(shape, elementType),
56899ef9eebSMatthias Springer                                        ValueRange{}, b.getI64IntegerAttr(32));
56999ef9eebSMatthias Springer   }
57099ef9eebSMatthias Springer 
57199ef9eebSMatthias Springer   MemRefType compatibleMemRefType =
57299ef9eebSMatthias Springer       getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
57399ef9eebSMatthias Springer                                   alloc.getType().cast<MemRefType>());
57499ef9eebSMatthias Springer   if (!compatibleMemRefType)
57599ef9eebSMatthias Springer     return failure();
57699ef9eebSMatthias Springer 
57799ef9eebSMatthias Springer   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
57899ef9eebSMatthias Springer                                    b.getIndexType());
57999ef9eebSMatthias Springer   returnTypes[0] = compatibleMemRefType;
58099ef9eebSMatthias Springer 
58199ef9eebSMatthias Springer   if (auto xferReadOp =
58299ef9eebSMatthias Springer           dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
58399ef9eebSMatthias Springer     // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
58499ef9eebSMatthias Springer     scf::IfOp fullPartialIfOp =
58599ef9eebSMatthias Springer         options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
58699ef9eebSMatthias Springer             ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
58799ef9eebSMatthias Springer                                                   inBoundsCond,
58899ef9eebSMatthias Springer                                                   compatibleMemRefType, alloc)
58999ef9eebSMatthias Springer             : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
59099ef9eebSMatthias Springer                                           inBoundsCond, compatibleMemRefType,
59199ef9eebSMatthias Springer                                           alloc);
59299ef9eebSMatthias Springer     if (ifOp)
59399ef9eebSMatthias Springer       *ifOp = fullPartialIfOp;
59499ef9eebSMatthias Springer 
59599ef9eebSMatthias Springer     // Set existing read op to in-bounds, it always reads from a full buffer.
59699ef9eebSMatthias Springer     for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
59799ef9eebSMatthias Springer       xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
59899ef9eebSMatthias Springer 
59975044e9bSJacques Pienaar     xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
60099ef9eebSMatthias Springer 
60199ef9eebSMatthias Springer     return success();
60299ef9eebSMatthias Springer   }
60399ef9eebSMatthias Springer 
60499ef9eebSMatthias Springer   auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
60599ef9eebSMatthias Springer 
60699ef9eebSMatthias Springer   // Decide which location to write the entire vector to.
60799ef9eebSMatthias Springer   auto memrefAndIndices = getLocationToWriteFullVec(
60899ef9eebSMatthias Springer       b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
60999ef9eebSMatthias Springer 
61099ef9eebSMatthias Springer   // Do an in bounds write to either the output or the extra allocated buffer.
61199ef9eebSMatthias Springer   // The operation is cloned to prevent deleting information needed for the
61299ef9eebSMatthias Springer   // later IR creation.
61399ef9eebSMatthias Springer   BlockAndValueMapping mapping;
6147c38fd60SJacques Pienaar   mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
6157c38fd60SJacques Pienaar   mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
61699ef9eebSMatthias Springer   auto *clone = b.clone(*xferWriteOp, mapping);
61799ef9eebSMatthias Springer   clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
61899ef9eebSMatthias Springer 
61999ef9eebSMatthias Springer   // Create a potential copy from the allocated buffer to the final output in
62099ef9eebSMatthias Springer   // the slow path case.
62199ef9eebSMatthias Springer   if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
62299ef9eebSMatthias Springer     createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
62399ef9eebSMatthias Springer   else
62499ef9eebSMatthias Springer     createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
62599ef9eebSMatthias Springer 
62699ef9eebSMatthias Springer   xferOp->erase();
62799ef9eebSMatthias Springer 
62899ef9eebSMatthias Springer   return success();
62999ef9eebSMatthias Springer }
63099ef9eebSMatthias Springer 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const63199ef9eebSMatthias Springer LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
63299ef9eebSMatthias Springer     Operation *op, PatternRewriter &rewriter) const {
63399ef9eebSMatthias Springer   auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
63499ef9eebSMatthias Springer   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
63599ef9eebSMatthias Springer       failed(filter(xferOp)))
63699ef9eebSMatthias Springer     return failure();
63799ef9eebSMatthias Springer   rewriter.startRootUpdate(xferOp);
63899ef9eebSMatthias Springer   if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
63999ef9eebSMatthias Springer     rewriter.finalizeRootUpdate(xferOp);
64099ef9eebSMatthias Springer     return success();
64199ef9eebSMatthias Springer   }
64299ef9eebSMatthias Springer   rewriter.cancelRootUpdate(xferOp);
64399ef9eebSMatthias Springer   return failure();
64499ef9eebSMatthias Springer }
645