1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <type_traits>
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22 
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/VectorInterfaces.h"
27 
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 
35 #define DEBUG_TYPE "vector-transfer-split"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 static Optional<int64_t> extractConstantIndex(Value v) {
41   if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
42     return cstOp.value();
43   if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
44     if (affineApplyOp.getAffineMap().isSingleConstant())
45       return affineApplyOp.getAffineMap().getSingleConstantResult();
46   return None;
47 }
48 
49 // Missing foldings of scf.if make it necessary to perform poor man's folding
50 // eagerly, especially in the case of unrolling. In the future, this should go
51 // away once scf.if folds properly.
52 static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) {
53   auto maybeCstV = extractConstantIndex(v);
54   auto maybeCstUb = extractConstantIndex(ub);
55   if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
56     return Value();
57   return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
58 }
59 
60 /// Build the condition to ensure that a particular VectorTransferOpInterface
61 /// is in-bounds.
62 static Value createInBoundsCond(RewriterBase &b,
63                                 VectorTransferOpInterface xferOp) {
64   assert(xferOp.permutation_map().isMinorIdentity() &&
65          "Expected minor identity map");
66   Value inBoundsCond;
67   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
68     // Zip over the resulting vector shape and memref indices.
69     // If the dimension is known to be in-bounds, it does not participate in
70     // the construction of `inBoundsCond`.
71     if (xferOp.isDimInBounds(resultIdx))
72       return;
73     // Fold or create the check that `index + vector_size` <= `memref_size`.
74     Location loc = xferOp.getLoc();
75     int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
76     auto d0 = getAffineDimExpr(0, xferOp.getContext());
77     auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
78     Value sum =
79         makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
80     Value cond = createFoldedSLE(
81         b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
82     if (!cond)
83       return;
84     // Conjunction over all dims for which we are in-bounds.
85     if (inBoundsCond)
86       inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
87     else
88       inBoundsCond = cond;
89   });
90   return inBoundsCond;
91 }
92 
93 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
94 /// masking) fastpath and a slowpath.
95 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
96 /// newly created conditional upon function return.
97 /// To accomodate for the fact that the original vector.transfer indexing may be
98 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
99 /// scf.if op returns a view and values of type index.
100 /// At this time, only vector.transfer_read case is implemented.
101 ///
102 /// Example (a 2-D vector.transfer_read):
103 /// ```
104 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
105 /// ```
106 /// is transformed into:
107 /// ```
108 ///    %1:3 = scf.if (%inBounds) {
109 ///      // fastpath, direct cast
110 ///      memref.cast %A: memref<A...> to compatibleMemRefType
111 ///      scf.yield %view : compatibleMemRefType, index, index
112 ///    } else {
113 ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
114 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
115 ///      scf.yield %4 : compatibleMemRefType, index, index
116 //     }
117 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
118 /// ```
119 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
120 ///
121 /// Preconditions:
122 ///  1. `xferOp.permutation_map()` must be a minor identity map
123 ///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
124 ///  must be equal. This will be relaxed in the future but requires
125 ///  rank-reducing subviews.
126 static LogicalResult
127 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
128   // TODO: support 0-d corner case.
129   if (xferOp.getTransferRank() == 0)
130     return failure();
131 
132   // TODO: expand support to these 2 cases.
133   if (!xferOp.permutation_map().isMinorIdentity())
134     return failure();
135   // Must have some out-of-bounds dimension to be a candidate for splitting.
136   if (!xferOp.hasOutOfBoundsDim())
137     return failure();
138   // Don't split transfer operations directly under IfOp, this avoids applying
139   // the pattern recursively.
140   // TODO: improve the filtering condition to make it more applicable.
141   if (isa<scf::IfOp>(xferOp->getParentOp()))
142     return failure();
143   return success();
144 }
145 
146 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
147 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
148 /// return null; otherwise:
149 ///   1. if `aT` and `bT` are cast-compatible, return `aT`.
150 ///   2. else return a new MemRefType obtained by iterating over the shape and
151 ///   strides and:
152 ///     a. keeping the ones that are static and equal across `aT` and `bT`.
153 ///     b. using a dynamic shape and/or stride for the dimensions that don't
154 ///        agree.
155 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
156   if (memref::CastOp::areCastCompatible(aT, bT))
157     return aT;
158   if (aT.getRank() != bT.getRank())
159     return MemRefType();
160   int64_t aOffset, bOffset;
161   SmallVector<int64_t, 4> aStrides, bStrides;
162   if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
163       failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
164       aStrides.size() != bStrides.size())
165     return MemRefType();
166 
167   ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
168   int64_t resOffset;
169   SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
170       resStrides(bT.getRank(), 0);
171   for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
172     resShape[idx] =
173         (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize;
174     resStrides[idx] = (aStrides[idx] == bStrides[idx])
175                           ? aStrides[idx]
176                           : ShapedType::kDynamicStrideOrOffset;
177   }
178   resOffset =
179       (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset;
180   return MemRefType::get(
181       resShape, aT.getElementType(),
182       makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
183 }
184 
185 /// Operates under a scoped context to build the intersection between the
186 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
187 // TODO: view intersection/union/differences should be a proper std op.
188 static std::pair<Value, Value>
189 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
190                           Value alloc) {
191   Location loc = xferOp.getLoc();
192   int64_t memrefRank = xferOp.getShapedType().getRank();
193   // TODO: relax this precondition, will require rank-reducing subviews.
194   assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
195          "Expected memref rank to match the alloc rank");
196   ValueRange leadingIndices =
197       xferOp.indices().take_front(xferOp.getLeadingShapedRank());
198   SmallVector<OpFoldResult, 4> sizes;
199   sizes.append(leadingIndices.begin(), leadingIndices.end());
200   auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
201   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
202     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
203     Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
204                                                 xferOp.source(), indicesIdx);
205     Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
206     Value index = xferOp.indices()[indicesIdx];
207     AffineExpr i, j, k;
208     bindDims(xferOp.getContext(), i, j, k);
209     SmallVector<AffineMap, 4> maps =
210         AffineMap::inferFromExprList(MapList{{i - j, k}});
211     // affine_min(%dimMemRef - %index, %dimAlloc)
212     Value affineMin = b.create<AffineMinOp>(
213         loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
214     sizes.push_back(affineMin);
215   });
216 
217   SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
218       xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
219   SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
220   SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
221   auto copySrc = b.create<memref::SubViewOp>(
222       loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
223   auto copyDest = b.create<memref::SubViewOp>(
224       loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
225   return std::make_pair(copySrc, copyDest);
226 }
227 
228 /// Given an `xferOp` for which:
229 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
230 ///   2. a memref of single vector `alloc` has been allocated.
231 /// Produce IR resembling:
232 /// ```
233 ///    %1:3 = scf.if (%inBounds) {
234 ///      %view = memref.cast %A: memref<A...> to compatibleMemRefType
235 ///      scf.yield %view, ... : compatibleMemRefType, index, index
236 ///    } else {
237 ///      %2 = linalg.fill(%pad, %alloc)
238 ///      %3 = subview %view [...][...][...]
239 ///      %4 = subview %alloc [0, 0] [...] [...]
240 ///      linalg.copy(%3, %4)
241 ///      %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
242 ///      scf.yield %5, ... : compatibleMemRefType, index, index
243 ///   }
244 /// ```
245 /// Return the produced scf::IfOp.
246 static scf::IfOp
247 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
248                             TypeRange returnTypes, Value inBoundsCond,
249                             MemRefType compatibleMemRefType, Value alloc) {
250   Location loc = xferOp.getLoc();
251   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
252   Value memref = xferOp.source();
253   return b.create<scf::IfOp>(
254       loc, returnTypes, inBoundsCond,
255       [&](OpBuilder &b, Location loc) {
256         Value res = memref;
257         if (compatibleMemRefType != xferOp.getShapedType())
258           res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
259         scf::ValueVector viewAndIndices{res};
260         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
261                               xferOp.indices().end());
262         b.create<scf::YieldOp>(loc, viewAndIndices);
263       },
264       [&](OpBuilder &b, Location loc) {
265         b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
266         // Take partial subview of memref which guarantees no dimension
267         // overflows.
268         IRRewriter rewriter(b);
269         std::pair<Value, Value> copyArgs = createSubViewIntersection(
270             rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
271             alloc);
272         b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
273         Value casted =
274             b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
275         scf::ValueVector viewAndIndices{casted};
276         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
277                               zero);
278         b.create<scf::YieldOp>(loc, viewAndIndices);
279       });
280 }
281 
282 /// Given an `xferOp` for which:
283 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
284 ///   2. a memref of single vector `alloc` has been allocated.
285 /// Produce IR resembling:
286 /// ```
287 ///    %1:3 = scf.if (%inBounds) {
288 ///      memref.cast %A: memref<A...> to compatibleMemRefType
289 ///      scf.yield %view, ... : compatibleMemRefType, index, index
290 ///    } else {
291 ///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
292 ///      %3 = vector.type_cast %extra_alloc :
293 ///        memref<...> to memref<vector<...>>
294 ///      store %2, %3[] : memref<vector<...>>
295 ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
296 ///      scf.yield %4, ... : compatibleMemRefType, index, index
297 ///   }
298 /// ```
299 /// Return the produced scf::IfOp.
300 static scf::IfOp createFullPartialVectorTransferRead(
301     RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
302     Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
303   Location loc = xferOp.getLoc();
304   scf::IfOp fullPartialIfOp;
305   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
306   Value memref = xferOp.source();
307   return b.create<scf::IfOp>(
308       loc, returnTypes, inBoundsCond,
309       [&](OpBuilder &b, Location loc) {
310         Value res = memref;
311         if (compatibleMemRefType != xferOp.getShapedType())
312           res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
313         scf::ValueVector viewAndIndices{res};
314         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
315                               xferOp.indices().end());
316         b.create<scf::YieldOp>(loc, viewAndIndices);
317       },
318       [&](OpBuilder &b, Location loc) {
319         Operation *newXfer = b.clone(*xferOp.getOperation());
320         Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
321         b.create<memref::StoreOp>(
322             loc, vector,
323             b.create<vector::TypeCastOp>(
324                 loc, MemRefType::get({}, vector.getType()), alloc));
325 
326         Value casted =
327             b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
328         scf::ValueVector viewAndIndices{casted};
329         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
330                               zero);
331         b.create<scf::YieldOp>(loc, viewAndIndices);
332       });
333 }
334 
335 /// Given an `xferOp` for which:
336 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
337 ///   2. a memref of single vector `alloc` has been allocated.
338 /// Produce IR resembling:
339 /// ```
340 ///    %1:3 = scf.if (%inBounds) {
341 ///      memref.cast %A: memref<A...> to compatibleMemRefType
342 ///      scf.yield %view, ... : compatibleMemRefType, index, index
343 ///    } else {
344 ///      %3 = vector.type_cast %extra_alloc :
345 ///        memref<...> to memref<vector<...>>
346 ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
347 ///      scf.yield %4, ... : compatibleMemRefType, index, index
348 ///   }
349 /// ```
350 static ValueRange
351 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
352                           TypeRange returnTypes, Value inBoundsCond,
353                           MemRefType compatibleMemRefType, Value alloc) {
354   Location loc = xferOp.getLoc();
355   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
356   Value memref = xferOp.source();
357   return b
358       .create<scf::IfOp>(
359           loc, returnTypes, inBoundsCond,
360           [&](OpBuilder &b, Location loc) {
361             Value res = memref;
362             if (compatibleMemRefType != xferOp.getShapedType())
363               res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
364             scf::ValueVector viewAndIndices{res};
365             viewAndIndices.insert(viewAndIndices.end(),
366                                   xferOp.indices().begin(),
367                                   xferOp.indices().end());
368             b.create<scf::YieldOp>(loc, viewAndIndices);
369           },
370           [&](OpBuilder &b, Location loc) {
371             Value casted =
372                 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
373             scf::ValueVector viewAndIndices{casted};
374             viewAndIndices.insert(viewAndIndices.end(),
375                                   xferOp.getTransferRank(), zero);
376             b.create<scf::YieldOp>(loc, viewAndIndices);
377           })
378       ->getResults();
379 }
380 
381 /// Given an `xferOp` for which:
382 ///   1. `inBoundsCond` has been computed.
383 ///   2. a memref of single vector `alloc` has been allocated.
384 ///   3. it originally wrote to %view
385 /// Produce IR resembling:
386 /// ```
387 ///    %notInBounds = arith.xori %inBounds, %true
388 ///    scf.if (%notInBounds) {
389 ///      %3 = subview %alloc [...][...][...]
390 ///      %4 = subview %view [0, 0][...][...]
391 ///      linalg.copy(%3, %4)
392 ///   }
393 /// ```
394 static void createFullPartialLinalgCopy(RewriterBase &b,
395                                         vector::TransferWriteOp xferOp,
396                                         Value inBoundsCond, Value alloc) {
397   Location loc = xferOp.getLoc();
398   auto notInBounds = b.create<arith::XOrIOp>(
399       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
400   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
401     IRRewriter rewriter(b);
402     std::pair<Value, Value> copyArgs = createSubViewIntersection(
403         rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
404         alloc);
405     b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
406     b.create<scf::YieldOp>(loc, ValueRange{});
407   });
408 }
409 
410 /// Given an `xferOp` for which:
411 ///   1. `inBoundsCond` has been computed.
412 ///   2. a memref of single vector `alloc` has been allocated.
413 ///   3. it originally wrote to %view
414 /// Produce IR resembling:
415 /// ```
416 ///    %notInBounds = arith.xori %inBounds, %true
417 ///    scf.if (%notInBounds) {
418 ///      %2 = load %alloc : memref<vector<...>>
419 ///      vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
420 ///   }
421 /// ```
422 static void createFullPartialVectorTransferWrite(RewriterBase &b,
423                                                  vector::TransferWriteOp xferOp,
424                                                  Value inBoundsCond,
425                                                  Value alloc) {
426   Location loc = xferOp.getLoc();
427   auto notInBounds = b.create<arith::XOrIOp>(
428       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
429   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
430     BlockAndValueMapping mapping;
431     Value load = b.create<memref::LoadOp>(
432         loc, b.create<vector::TypeCastOp>(
433                  loc, MemRefType::get({}, xferOp.vector().getType()), alloc));
434     mapping.map(xferOp.vector(), load);
435     b.clone(*xferOp.getOperation(), mapping);
436     b.create<scf::YieldOp>(loc, ValueRange{});
437   });
438 }
439 
440 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
441 static Operation *getAutomaticAllocationScope(Operation *op) {
442   Operation *scope =
443       op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
444   assert(scope && "Expected op to be inside automatic allocation scope");
445   return scope;
446 }
447 
448 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
449 /// masking) fastpath and a slowpath.
450 ///
451 /// For vector.transfer_read:
452 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
453 /// newly created conditional upon function return.
454 /// To accomodate for the fact that the original vector.transfer indexing may be
455 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
456 /// scf.if op returns a view and values of type index.
457 ///
458 /// Example (a 2-D vector.transfer_read):
459 /// ```
460 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
461 /// ```
462 /// is transformed into:
463 /// ```
464 ///    %1:3 = scf.if (%inBounds) {
465 ///      // fastpath, direct cast
466 ///      memref.cast %A: memref<A...> to compatibleMemRefType
467 ///      scf.yield %view : compatibleMemRefType, index, index
468 ///    } else {
469 ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
470 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
471 ///      scf.yield %4 : compatibleMemRefType, index, index
472 //     }
473 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
474 /// ```
475 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
476 ///
477 /// For vector.transfer_write:
478 /// There are 2 conditional blocks. First a block to decide which memref and
479 /// indices to use for an unmasked, inbounds write. Then a conditional block to
480 /// further copy a partial buffer into the final result in the slow path case.
481 ///
482 /// Example (a 2-D vector.transfer_write):
483 /// ```
484 ///    vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
485 /// ```
486 /// is transformed into:
487 /// ```
488 ///    %1:3 = scf.if (%inBounds) {
489 ///      memref.cast %A: memref<A...> to compatibleMemRefType
490 ///      scf.yield %view : compatibleMemRefType, index, index
491 ///    } else {
492 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
493 ///      scf.yield %4 : compatibleMemRefType, index, index
494 ///     }
495 ///    %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
496 ///                                                                    true]}
497 ///    scf.if (%notInBounds) {
498 ///      // slowpath: not in-bounds vector.transfer or linalg.copy.
499 ///    }
500 /// ```
501 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
502 ///
503 /// Preconditions:
504 ///  1. `xferOp.permutation_map()` must be a minor identity map
505 ///  2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
506 ///  must be equal. This will be relaxed in the future but requires
507 ///  rank-reducing subviews.
508 LogicalResult mlir::vector::splitFullAndPartialTransfer(
509     RewriterBase &b, VectorTransferOpInterface xferOp,
510     VectorTransformsOptions options, scf::IfOp *ifOp) {
511   if (options.vectorTransferSplit == VectorTransferSplit::None)
512     return failure();
513 
514   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
515   auto inBoundsAttr = b.getBoolArrayAttr(bools);
516   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
517     xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
518     return success();
519   }
520 
521   // Assert preconditions. Additionally, keep the variables in an inner scope to
522   // ensure they aren't used in the wrong scopes further down.
523   {
524     assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
525            "Expected splitFullAndPartialTransferPrecondition to hold");
526 
527     auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
528     auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
529 
530     if (!(xferReadOp || xferWriteOp))
531       return failure();
532     if (xferWriteOp && xferWriteOp.mask())
533       return failure();
534     if (xferReadOp && xferReadOp.mask())
535       return failure();
536   }
537 
538   RewriterBase::InsertionGuard guard(b);
539   b.setInsertionPoint(xferOp);
540   Value inBoundsCond = createInBoundsCond(
541       b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
542   if (!inBoundsCond)
543     return failure();
544 
545   // Top of the function `alloc` for transient storage.
546   Value alloc;
547   {
548     RewriterBase::InsertionGuard guard(b);
549     Operation *scope = getAutomaticAllocationScope(xferOp);
550     assert(scope->getNumRegions() == 1 &&
551            "AutomaticAllocationScope with >1 regions");
552     b.setInsertionPointToStart(&scope->getRegion(0).front());
553     auto shape = xferOp.getVectorType().getShape();
554     Type elementType = xferOp.getVectorType().getElementType();
555     alloc = b.create<memref::AllocaOp>(scope->getLoc(),
556                                        MemRefType::get(shape, elementType),
557                                        ValueRange{}, b.getI64IntegerAttr(32));
558   }
559 
560   MemRefType compatibleMemRefType =
561       getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
562                                   alloc.getType().cast<MemRefType>());
563   if (!compatibleMemRefType)
564     return failure();
565 
566   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
567                                    b.getIndexType());
568   returnTypes[0] = compatibleMemRefType;
569 
570   if (auto xferReadOp =
571           dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
572     // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
573     scf::IfOp fullPartialIfOp =
574         options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
575             ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
576                                                   inBoundsCond,
577                                                   compatibleMemRefType, alloc)
578             : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
579                                           inBoundsCond, compatibleMemRefType,
580                                           alloc);
581     if (ifOp)
582       *ifOp = fullPartialIfOp;
583 
584     // Set existing read op to in-bounds, it always reads from a full buffer.
585     for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
586       xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
587 
588     xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
589 
590     return success();
591   }
592 
593   auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
594 
595   // Decide which location to write the entire vector to.
596   auto memrefAndIndices = getLocationToWriteFullVec(
597       b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
598 
599   // Do an in bounds write to either the output or the extra allocated buffer.
600   // The operation is cloned to prevent deleting information needed for the
601   // later IR creation.
602   BlockAndValueMapping mapping;
603   mapping.map(xferWriteOp.source(), memrefAndIndices.front());
604   mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front());
605   auto *clone = b.clone(*xferWriteOp, mapping);
606   clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
607 
608   // Create a potential copy from the allocated buffer to the final output in
609   // the slow path case.
610   if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
611     createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
612   else
613     createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
614 
615   xferOp->erase();
616 
617   return success();
618 }
619 
620 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
621     Operation *op, PatternRewriter &rewriter) const {
622   auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
623   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
624       failed(filter(xferOp)))
625     return failure();
626   rewriter.startRootUpdate(xferOp);
627   if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
628     rewriter.finalizeRootUpdate(xferOp);
629     return success();
630   }
631   rewriter.cancelRootUpdate(xferOp);
632   return failure();
633 }
634