1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <type_traits>
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 
24 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Interfaces/VectorInterfaces.h"
28 
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 #define DEBUG_TYPE "vector-transfer-split"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 static Optional<int64_t> extractConstantIndex(Value v) {
42   if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
43     return cstOp.value();
44   if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
45     if (affineApplyOp.getAffineMap().isSingleConstant())
46       return affineApplyOp.getAffineMap().getSingleConstantResult();
47   return None;
48 }
49 
50 // Missing foldings of scf.if make it necessary to perform poor man's folding
51 // eagerly, especially in the case of unrolling. In the future, this should go
52 // away once scf.if folds properly.
53 static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) {
54   auto maybeCstV = extractConstantIndex(v);
55   auto maybeCstUb = extractConstantIndex(ub);
56   if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
57     return Value();
58   return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
59 }
60 
61 /// Build the condition to ensure that a particular VectorTransferOpInterface
62 /// is in-bounds.
63 static Value createInBoundsCond(RewriterBase &b,
64                                 VectorTransferOpInterface xferOp) {
65   assert(xferOp.permutation_map().isMinorIdentity() &&
66          "Expected minor identity map");
67   Value inBoundsCond;
68   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
69     // Zip over the resulting vector shape and memref indices.
70     // If the dimension is known to be in-bounds, it does not participate in
71     // the construction of `inBoundsCond`.
72     if (xferOp.isDimInBounds(resultIdx))
73       return;
74     // Fold or create the check that `index + vector_size` <= `memref_size`.
75     Location loc = xferOp.getLoc();
76     int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
77     auto d0 = getAffineDimExpr(0, xferOp.getContext());
78     auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
79     Value sum =
80         makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
81     Value cond = createFoldedSLE(
82         b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
83     if (!cond)
84       return;
85     // Conjunction over all dims for which we are in-bounds.
86     if (inBoundsCond)
87       inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
88     else
89       inBoundsCond = cond;
90   });
91   return inBoundsCond;
92 }
93 
94 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
95 /// masking) fastpath and a slowpath.
96 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
97 /// newly created conditional upon function return.
98 /// To accomodate for the fact that the original vector.transfer indexing may be
99 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
100 /// scf.if op returns a view and values of type index.
101 /// At this time, only vector.transfer_read case is implemented.
102 ///
103 /// Example (a 2-D vector.transfer_read):
104 /// ```
105 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
106 /// ```
107 /// is transformed into:
108 /// ```
109 ///    %1:3 = scf.if (%inBounds) {
110 ///      // fastpath, direct cast
111 ///      memref.cast %A: memref<A...> to compatibleMemRefType
112 ///      scf.yield %view : compatibleMemRefType, index, index
113 ///    } else {
114 ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
115 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
116 ///      scf.yield %4 : compatibleMemRefType, index, index
117 //     }
118 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
119 /// ```
120 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
121 ///
122 /// Preconditions:
123 ///  1. `xferOp.permutation_map()` must be a minor identity map
124 ///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
125 ///  must be equal. This will be relaxed in the future but requires
126 ///  rank-reducing subviews.
127 static LogicalResult
128 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
129   // TODO: support 0-d corner case.
130   if (xferOp.getTransferRank() == 0)
131     return failure();
132 
133   // TODO: expand support to these 2 cases.
134   if (!xferOp.permutation_map().isMinorIdentity())
135     return failure();
136   // Must have some out-of-bounds dimension to be a candidate for splitting.
137   if (!xferOp.hasOutOfBoundsDim())
138     return failure();
139   // Don't split transfer operations directly under IfOp, this avoids applying
140   // the pattern recursively.
141   // TODO: improve the filtering condition to make it more applicable.
142   if (isa<scf::IfOp>(xferOp->getParentOp()))
143     return failure();
144   return success();
145 }
146 
147 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
148 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
149 /// return null; otherwise:
150 ///   1. if `aT` and `bT` are cast-compatible, return `aT`.
151 ///   2. else return a new MemRefType obtained by iterating over the shape and
152 ///   strides and:
153 ///     a. keeping the ones that are static and equal across `aT` and `bT`.
154 ///     b. using a dynamic shape and/or stride for the dimensions that don't
155 ///        agree.
156 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
157   if (memref::CastOp::areCastCompatible(aT, bT))
158     return aT;
159   if (aT.getRank() != bT.getRank())
160     return MemRefType();
161   int64_t aOffset, bOffset;
162   SmallVector<int64_t, 4> aStrides, bStrides;
163   if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
164       failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
165       aStrides.size() != bStrides.size())
166     return MemRefType();
167 
168   ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
169   int64_t resOffset;
170   SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
171       resStrides(bT.getRank(), 0);
172   for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
173     resShape[idx] =
174         (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize;
175     resStrides[idx] = (aStrides[idx] == bStrides[idx])
176                           ? aStrides[idx]
177                           : ShapedType::kDynamicStrideOrOffset;
178   }
179   resOffset =
180       (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset;
181   return MemRefType::get(
182       resShape, aT.getElementType(),
183       makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
184 }
185 
186 /// Operates under a scoped context to build the intersection between the
187 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
188 // TODO: view intersection/union/differences should be a proper std op.
189 static std::pair<Value, Value>
190 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
191                           Value alloc) {
192   Location loc = xferOp.getLoc();
193   int64_t memrefRank = xferOp.getShapedType().getRank();
194   // TODO: relax this precondition, will require rank-reducing subviews.
195   assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
196          "Expected memref rank to match the alloc rank");
197   ValueRange leadingIndices =
198       xferOp.indices().take_front(xferOp.getLeadingShapedRank());
199   SmallVector<OpFoldResult, 4> sizes;
200   sizes.append(leadingIndices.begin(), leadingIndices.end());
201   auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
202   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
203     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
204     Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
205                                                 xferOp.source(), indicesIdx);
206     Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
207     Value index = xferOp.indices()[indicesIdx];
208     AffineExpr i, j, k;
209     bindDims(xferOp.getContext(), i, j, k);
210     SmallVector<AffineMap, 4> maps =
211         AffineMap::inferFromExprList(MapList{{i - j, k}});
212     // affine_min(%dimMemRef - %index, %dimAlloc)
213     Value affineMin = b.create<AffineMinOp>(
214         loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
215     sizes.push_back(affineMin);
216   });
217 
218   SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
219       xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
220   SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
221   SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
222   auto copySrc = b.create<memref::SubViewOp>(
223       loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
224   auto copyDest = b.create<memref::SubViewOp>(
225       loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
226   return std::make_pair(copySrc, copyDest);
227 }
228 
229 /// Given an `xferOp` for which:
230 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
231 ///   2. a memref of single vector `alloc` has been allocated.
232 /// Produce IR resembling:
233 /// ```
234 ///    %1:3 = scf.if (%inBounds) {
235 ///      %view = memref.cast %A: memref<A...> to compatibleMemRefType
236 ///      scf.yield %view, ... : compatibleMemRefType, index, index
237 ///    } else {
238 ///      %2 = linalg.fill(%pad, %alloc)
239 ///      %3 = subview %view [...][...][...]
240 ///      %4 = subview %alloc [0, 0] [...] [...]
241 ///      linalg.copy(%3, %4)
242 ///      %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
243 ///      scf.yield %5, ... : compatibleMemRefType, index, index
244 ///   }
245 /// ```
246 /// Return the produced scf::IfOp.
247 static scf::IfOp
248 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
249                             TypeRange returnTypes, Value inBoundsCond,
250                             MemRefType compatibleMemRefType, Value alloc) {
251   Location loc = xferOp.getLoc();
252   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
253   Value memref = xferOp.source();
254   return b.create<scf::IfOp>(
255       loc, returnTypes, inBoundsCond,
256       [&](OpBuilder &b, Location loc) {
257         Value res = memref;
258         if (compatibleMemRefType != xferOp.getShapedType())
259           res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
260         scf::ValueVector viewAndIndices{res};
261         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
262                               xferOp.indices().end());
263         b.create<scf::YieldOp>(loc, viewAndIndices);
264       },
265       [&](OpBuilder &b, Location loc) {
266         b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
267         // Take partial subview of memref which guarantees no dimension
268         // overflows.
269         IRRewriter rewriter(b);
270         std::pair<Value, Value> copyArgs = createSubViewIntersection(
271             rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
272             alloc);
273         b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
274         Value casted =
275             b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
276         scf::ValueVector viewAndIndices{casted};
277         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
278                               zero);
279         b.create<scf::YieldOp>(loc, viewAndIndices);
280       });
281 }
282 
283 /// Given an `xferOp` for which:
284 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
285 ///   2. a memref of single vector `alloc` has been allocated.
286 /// Produce IR resembling:
287 /// ```
288 ///    %1:3 = scf.if (%inBounds) {
289 ///      memref.cast %A: memref<A...> to compatibleMemRefType
290 ///      scf.yield %view, ... : compatibleMemRefType, index, index
291 ///    } else {
292 ///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
293 ///      %3 = vector.type_cast %extra_alloc :
294 ///        memref<...> to memref<vector<...>>
295 ///      store %2, %3[] : memref<vector<...>>
296 ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
297 ///      scf.yield %4, ... : compatibleMemRefType, index, index
298 ///   }
299 /// ```
300 /// Return the produced scf::IfOp.
301 static scf::IfOp createFullPartialVectorTransferRead(
302     RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
303     Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
304   Location loc = xferOp.getLoc();
305   scf::IfOp fullPartialIfOp;
306   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
307   Value memref = xferOp.source();
308   return b.create<scf::IfOp>(
309       loc, returnTypes, inBoundsCond,
310       [&](OpBuilder &b, Location loc) {
311         Value res = memref;
312         if (compatibleMemRefType != xferOp.getShapedType())
313           res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
314         scf::ValueVector viewAndIndices{res};
315         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
316                               xferOp.indices().end());
317         b.create<scf::YieldOp>(loc, viewAndIndices);
318       },
319       [&](OpBuilder &b, Location loc) {
320         Operation *newXfer = b.clone(*xferOp.getOperation());
321         Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
322         b.create<memref::StoreOp>(
323             loc, vector,
324             b.create<vector::TypeCastOp>(
325                 loc, MemRefType::get({}, vector.getType()), alloc));
326 
327         Value casted =
328             b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
329         scf::ValueVector viewAndIndices{casted};
330         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
331                               zero);
332         b.create<scf::YieldOp>(loc, viewAndIndices);
333       });
334 }
335 
336 /// Given an `xferOp` for which:
337 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
338 ///   2. a memref of single vector `alloc` has been allocated.
339 /// Produce IR resembling:
340 /// ```
341 ///    %1:3 = scf.if (%inBounds) {
342 ///      memref.cast %A: memref<A...> to compatibleMemRefType
343 ///      scf.yield %view, ... : compatibleMemRefType, index, index
344 ///    } else {
345 ///      %3 = vector.type_cast %extra_alloc :
346 ///        memref<...> to memref<vector<...>>
347 ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
348 ///      scf.yield %4, ... : compatibleMemRefType, index, index
349 ///   }
350 /// ```
351 static ValueRange
352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
353                           TypeRange returnTypes, Value inBoundsCond,
354                           MemRefType compatibleMemRefType, Value alloc) {
355   Location loc = xferOp.getLoc();
356   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
357   Value memref = xferOp.source();
358   return b
359       .create<scf::IfOp>(
360           loc, returnTypes, inBoundsCond,
361           [&](OpBuilder &b, Location loc) {
362             Value res = memref;
363             if (compatibleMemRefType != xferOp.getShapedType())
364               res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
365             scf::ValueVector viewAndIndices{res};
366             viewAndIndices.insert(viewAndIndices.end(),
367                                   xferOp.indices().begin(),
368                                   xferOp.indices().end());
369             b.create<scf::YieldOp>(loc, viewAndIndices);
370           },
371           [&](OpBuilder &b, Location loc) {
372             Value casted =
373                 b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
374             scf::ValueVector viewAndIndices{casted};
375             viewAndIndices.insert(viewAndIndices.end(),
376                                   xferOp.getTransferRank(), zero);
377             b.create<scf::YieldOp>(loc, viewAndIndices);
378           })
379       ->getResults();
380 }
381 
382 /// Given an `xferOp` for which:
383 ///   1. `inBoundsCond` has been computed.
384 ///   2. a memref of single vector `alloc` has been allocated.
385 ///   3. it originally wrote to %view
386 /// Produce IR resembling:
387 /// ```
388 ///    %notInBounds = arith.xori %inBounds, %true
389 ///    scf.if (%notInBounds) {
390 ///      %3 = subview %alloc [...][...][...]
391 ///      %4 = subview %view [0, 0][...][...]
392 ///      linalg.copy(%3, %4)
393 ///   }
394 /// ```
395 static void createFullPartialLinalgCopy(RewriterBase &b,
396                                         vector::TransferWriteOp xferOp,
397                                         Value inBoundsCond, Value alloc) {
398   Location loc = xferOp.getLoc();
399   auto notInBounds = b.create<arith::XOrIOp>(
400       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
401   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
402     IRRewriter rewriter(b);
403     std::pair<Value, Value> copyArgs = createSubViewIntersection(
404         rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
405         alloc);
406     b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
407     b.create<scf::YieldOp>(loc, ValueRange{});
408   });
409 }
410 
411 /// Given an `xferOp` for which:
412 ///   1. `inBoundsCond` has been computed.
413 ///   2. a memref of single vector `alloc` has been allocated.
414 ///   3. it originally wrote to %view
415 /// Produce IR resembling:
416 /// ```
417 ///    %notInBounds = arith.xori %inBounds, %true
418 ///    scf.if (%notInBounds) {
419 ///      %2 = load %alloc : memref<vector<...>>
420 ///      vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
421 ///   }
422 /// ```
423 static void createFullPartialVectorTransferWrite(RewriterBase &b,
424                                                  vector::TransferWriteOp xferOp,
425                                                  Value inBoundsCond,
426                                                  Value alloc) {
427   Location loc = xferOp.getLoc();
428   auto notInBounds = b.create<arith::XOrIOp>(
429       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
430   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
431     BlockAndValueMapping mapping;
432     Value load = b.create<memref::LoadOp>(
433         loc, b.create<vector::TypeCastOp>(
434                  loc, MemRefType::get({}, xferOp.vector().getType()), alloc));
435     mapping.map(xferOp.vector(), load);
436     b.clone(*xferOp.getOperation(), mapping);
437     b.create<scf::YieldOp>(loc, ValueRange{});
438   });
439 }
440 
441 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
442 /// masking) fastpath and a slowpath.
443 ///
444 /// For vector.transfer_read:
445 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
446 /// newly created conditional upon function return.
447 /// To accomodate for the fact that the original vector.transfer indexing may be
448 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
449 /// scf.if op returns a view and values of type index.
450 ///
451 /// Example (a 2-D vector.transfer_read):
452 /// ```
453 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
454 /// ```
455 /// is transformed into:
456 /// ```
457 ///    %1:3 = scf.if (%inBounds) {
458 ///      // fastpath, direct cast
459 ///      memref.cast %A: memref<A...> to compatibleMemRefType
460 ///      scf.yield %view : compatibleMemRefType, index, index
461 ///    } else {
462 ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
463 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
464 ///      scf.yield %4 : compatibleMemRefType, index, index
465 //     }
466 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
467 /// ```
468 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
469 ///
470 /// For vector.transfer_write:
471 /// There are 2 conditional blocks. First a block to decide which memref and
472 /// indices to use for an unmasked, inbounds write. Then a conditional block to
473 /// further copy a partial buffer into the final result in the slow path case.
474 ///
475 /// Example (a 2-D vector.transfer_write):
476 /// ```
477 ///    vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
478 /// ```
479 /// is transformed into:
480 /// ```
481 ///    %1:3 = scf.if (%inBounds) {
482 ///      memref.cast %A: memref<A...> to compatibleMemRefType
483 ///      scf.yield %view : compatibleMemRefType, index, index
484 ///    } else {
485 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
486 ///      scf.yield %4 : compatibleMemRefType, index, index
487 ///     }
488 ///    %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
489 ///                                                                    true]}
490 ///    scf.if (%notInBounds) {
491 ///      // slowpath: not in-bounds vector.transfer or linalg.copy.
492 ///    }
493 /// ```
494 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
495 ///
496 /// Preconditions:
497 ///  1. `xferOp.permutation_map()` must be a minor identity map
498 ///  2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
499 ///  must be equal. This will be relaxed in the future but requires
500 ///  rank-reducing subviews.
501 LogicalResult mlir::vector::splitFullAndPartialTransfer(
502     RewriterBase &b, VectorTransferOpInterface xferOp,
503     VectorTransformsOptions options, scf::IfOp *ifOp) {
504   if (options.vectorTransferSplit == VectorTransferSplit::None)
505     return failure();
506 
507   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
508   auto inBoundsAttr = b.getBoolArrayAttr(bools);
509   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
510     xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
511     return success();
512   }
513 
514   // Assert preconditions. Additionally, keep the variables in an inner scope to
515   // ensure they aren't used in the wrong scopes further down.
516   {
517     assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
518            "Expected splitFullAndPartialTransferPrecondition to hold");
519 
520     auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
521     auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
522 
523     if (!(xferReadOp || xferWriteOp))
524       return failure();
525     if (xferWriteOp && xferWriteOp.mask())
526       return failure();
527     if (xferReadOp && xferReadOp.mask())
528       return failure();
529   }
530 
531   RewriterBase::InsertionGuard guard(b);
532   b.setInsertionPoint(xferOp);
533   Value inBoundsCond = createInBoundsCond(
534       b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
535   if (!inBoundsCond)
536     return failure();
537 
538   // Top of the function `alloc` for transient storage.
539   Value alloc;
540   {
541     FuncOp funcOp = xferOp->getParentOfType<FuncOp>();
542     RewriterBase::InsertionGuard guard(b);
543     b.setInsertionPointToStart(&funcOp.getRegion().front());
544     auto shape = xferOp.getVectorType().getShape();
545     Type elementType = xferOp.getVectorType().getElementType();
546     alloc = b.create<memref::AllocaOp>(funcOp.getLoc(),
547                                        MemRefType::get(shape, elementType),
548                                        ValueRange{}, b.getI64IntegerAttr(32));
549   }
550 
551   MemRefType compatibleMemRefType =
552       getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
553                                   alloc.getType().cast<MemRefType>());
554   if (!compatibleMemRefType)
555     return failure();
556 
557   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
558                                    b.getIndexType());
559   returnTypes[0] = compatibleMemRefType;
560 
561   if (auto xferReadOp =
562           dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
563     // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
564     scf::IfOp fullPartialIfOp =
565         options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
566             ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
567                                                   inBoundsCond,
568                                                   compatibleMemRefType, alloc)
569             : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
570                                           inBoundsCond, compatibleMemRefType,
571                                           alloc);
572     if (ifOp)
573       *ifOp = fullPartialIfOp;
574 
575     // Set existing read op to in-bounds, it always reads from a full buffer.
576     for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
577       xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
578 
579     xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
580 
581     return success();
582   }
583 
584   auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
585 
586   // Decide which location to write the entire vector to.
587   auto memrefAndIndices = getLocationToWriteFullVec(
588       b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
589 
590   // Do an in bounds write to either the output or the extra allocated buffer.
591   // The operation is cloned to prevent deleting information needed for the
592   // later IR creation.
593   BlockAndValueMapping mapping;
594   mapping.map(xferWriteOp.source(), memrefAndIndices.front());
595   mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front());
596   auto *clone = b.clone(*xferWriteOp, mapping);
597   clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
598 
599   // Create a potential copy from the allocated buffer to the final output in
600   // the slow path case.
601   if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
602     createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
603   else
604     createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
605 
606   xferOp->erase();
607 
608   return success();
609 }
610 
611 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
612     Operation *op, PatternRewriter &rewriter) const {
613   auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
614   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
615       failed(filter(xferOp)))
616     return failure();
617   rewriter.startRootUpdate(xferOp);
618   if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
619     rewriter.finalizeRootUpdate(xferOp);
620     return success();
621   }
622   rewriter.cancelRootUpdate(xferOp);
623   return failure();
624 }
625