1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <type_traits>
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22 
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/VectorInterfaces.h"
27 
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 
35 #define DEBUG_TYPE "vector-transfer-split"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
extractConstantIndex(Value v)40 static Optional<int64_t> extractConstantIndex(Value v) {
41   if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
42     return cstOp.value();
43   if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
44     if (affineApplyOp.getAffineMap().isSingleConstant())
45       return affineApplyOp.getAffineMap().getSingleConstantResult();
46   return None;
47 }
48 
49 // Missing foldings of scf.if make it necessary to perform poor man's folding
50 // eagerly, especially in the case of unrolling. In the future, this should go
51 // away once scf.if folds properly.
createFoldedSLE(RewriterBase & b,Value v,Value ub)52 static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) {
53   auto maybeCstV = extractConstantIndex(v);
54   auto maybeCstUb = extractConstantIndex(ub);
55   if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
56     return Value();
57   return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
58 }
59 
60 /// Build the condition to ensure that a particular VectorTransferOpInterface
61 /// is in-bounds.
createInBoundsCond(RewriterBase & b,VectorTransferOpInterface xferOp)62 static Value createInBoundsCond(RewriterBase &b,
63                                 VectorTransferOpInterface xferOp) {
64   assert(xferOp.permutation_map().isMinorIdentity() &&
65          "Expected minor identity map");
66   Value inBoundsCond;
67   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
68     // Zip over the resulting vector shape and memref indices.
69     // If the dimension is known to be in-bounds, it does not participate in
70     // the construction of `inBoundsCond`.
71     if (xferOp.isDimInBounds(resultIdx))
72       return;
73     // Fold or create the check that `index + vector_size` <= `memref_size`.
74     Location loc = xferOp.getLoc();
75     int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
76     auto d0 = getAffineDimExpr(0, xferOp.getContext());
77     auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
78     Value sum =
79         makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
80     Value cond = createFoldedSLE(
81         b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
82     if (!cond)
83       return;
84     // Conjunction over all dims for which we are in-bounds.
85     if (inBoundsCond)
86       inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
87     else
88       inBoundsCond = cond;
89   });
90   return inBoundsCond;
91 }
92 
93 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
94 /// masking) fastpath and a slowpath.
95 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
96 /// newly created conditional upon function return.
97 /// To accomodate for the fact that the original vector.transfer indexing may be
98 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
99 /// scf.if op returns a view and values of type index.
100 /// At this time, only vector.transfer_read case is implemented.
101 ///
102 /// Example (a 2-D vector.transfer_read):
103 /// ```
104 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
105 /// ```
106 /// is transformed into:
107 /// ```
108 ///    %1:3 = scf.if (%inBounds) {
109 ///      // fastpath, direct cast
110 ///      memref.cast %A: memref<A...> to compatibleMemRefType
111 ///      scf.yield %view : compatibleMemRefType, index, index
112 ///    } else {
113 ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
114 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
115 ///      scf.yield %4 : compatibleMemRefType, index, index
116 //     }
117 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
118 /// ```
119 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
120 ///
121 /// Preconditions:
122 ///  1. `xferOp.permutation_map()` must be a minor identity map
123 ///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
124 ///  must be equal. This will be relaxed in the future but requires
125 ///  rank-reducing subviews.
126 static LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)127 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
128   // TODO: support 0-d corner case.
129   if (xferOp.getTransferRank() == 0)
130     return failure();
131 
132   // TODO: expand support to these 2 cases.
133   if (!xferOp.permutation_map().isMinorIdentity())
134     return failure();
135   // Must have some out-of-bounds dimension to be a candidate for splitting.
136   if (!xferOp.hasOutOfBoundsDim())
137     return failure();
138   // Don't split transfer operations directly under IfOp, this avoids applying
139   // the pattern recursively.
140   // TODO: improve the filtering condition to make it more applicable.
141   if (isa<scf::IfOp>(xferOp->getParentOp()))
142     return failure();
143   return success();
144 }
145 
146 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
147 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
148 /// return null; otherwise:
149 ///   1. if `aT` and `bT` are cast-compatible, return `aT`.
150 ///   2. else return a new MemRefType obtained by iterating over the shape and
151 ///   strides and:
152 ///     a. keeping the ones that are static and equal across `aT` and `bT`.
153 ///     b. using a dynamic shape and/or stride for the dimensions that don't
154 ///        agree.
getCastCompatibleMemRefType(MemRefType aT,MemRefType bT)155 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
156   if (memref::CastOp::areCastCompatible(aT, bT))
157     return aT;
158   if (aT.getRank() != bT.getRank())
159     return MemRefType();
160   int64_t aOffset, bOffset;
161   SmallVector<int64_t, 4> aStrides, bStrides;
162   if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
163       failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
164       aStrides.size() != bStrides.size())
165     return MemRefType();
166 
167   ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
168   int64_t resOffset;
169   SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
170       resStrides(bT.getRank(), 0);
171   for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
172     resShape[idx] =
173         (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize;
174     resStrides[idx] = (aStrides[idx] == bStrides[idx])
175                           ? aStrides[idx]
176                           : ShapedType::kDynamicStrideOrOffset;
177   }
178   resOffset =
179       (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset;
180   return MemRefType::get(
181       resShape, aT.getElementType(),
182       makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
183 }
184 
185 /// Operates under a scoped context to build the intersection between the
186 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
187 // TODO: view intersection/union/differences should be a proper std op.
188 static std::pair<Value, Value>
createSubViewIntersection(RewriterBase & b,VectorTransferOpInterface xferOp,Value alloc)189 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
190                           Value alloc) {
191   Location loc = xferOp.getLoc();
192   int64_t memrefRank = xferOp.getShapedType().getRank();
193   // TODO: relax this precondition, will require rank-reducing subviews.
194   assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
195          "Expected memref rank to match the alloc rank");
196   ValueRange leadingIndices =
197       xferOp.indices().take_front(xferOp.getLeadingShapedRank());
198   SmallVector<OpFoldResult, 4> sizes;
199   sizes.append(leadingIndices.begin(), leadingIndices.end());
200   auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
201   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
202     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
203     Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
204                                                 xferOp.source(), indicesIdx);
205     Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
206     Value index = xferOp.indices()[indicesIdx];
207     AffineExpr i, j, k;
208     bindDims(xferOp.getContext(), i, j, k);
209     SmallVector<AffineMap, 4> maps =
210         AffineMap::inferFromExprList(MapList{{i - j, k}});
211     // affine_min(%dimMemRef - %index, %dimAlloc)
212     Value affineMin = b.create<AffineMinOp>(
213         loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
214     sizes.push_back(affineMin);
215   });
216 
217   SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
218       xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
219   SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
220   SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
221   auto copySrc = b.create<memref::SubViewOp>(
222       loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
223   auto copyDest = b.create<memref::SubViewOp>(
224       loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
225   return std::make_pair(copySrc, copyDest);
226 }
227 
228 /// Given an `xferOp` for which:
229 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
230 ///   2. a memref of single vector `alloc` has been allocated.
231 /// Produce IR resembling:
232 /// ```
233 ///    %1:3 = scf.if (%inBounds) {
234 ///      %view = memref.cast %A: memref<A...> to compatibleMemRefType
235 ///      scf.yield %view, ... : compatibleMemRefType, index, index
236 ///    } else {
237 ///      %2 = linalg.fill(%pad, %alloc)
238 ///      %3 = subview %view [...][...][...]
239 ///      %4 = subview %alloc [0, 0] [...] [...]
240 ///      linalg.copy(%3, %4)
241 ///      %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
242 ///      scf.yield %5, ... : compatibleMemRefType, index, index
243 ///   }
244 /// ```
245 /// Return the produced scf::IfOp.
246 static scf::IfOp
createFullPartialLinalgCopy(RewriterBase & b,vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)247 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
248                             TypeRange returnTypes, Value inBoundsCond,
249                             MemRefType compatibleMemRefType, Value alloc) {
250   Location loc = xferOp.getLoc();
251   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
252   Value memref = xferOp.getSource();
253   return b.create<scf::IfOp>(
254       loc, returnTypes, inBoundsCond,
255       [&](OpBuilder &b, Location loc) {
256         Value res = memref;
257         if (compatibleMemRefType != xferOp.getShapedType())
258           res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
259         scf::ValueVector viewAndIndices{res};
260         viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
261                               xferOp.getIndices().end());
262         b.create<scf::YieldOp>(loc, viewAndIndices);
263       },
264       [&](OpBuilder &b, Location loc) {
265         b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
266                                  ValueRange{alloc});
267         // Take partial subview of memref which guarantees no dimension
268         // overflows.
269         IRRewriter rewriter(b);
270         std::pair<Value, Value> copyArgs = createSubViewIntersection(
271             rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
272             alloc);
273         b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
274         Value casted =
275             b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
276         scf::ValueVector viewAndIndices{casted};
277         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
278                               zero);
279         b.create<scf::YieldOp>(loc, viewAndIndices);
280       });
281 }
282 
283 /// Given an `xferOp` for which:
284 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
285 ///   2. a memref of single vector `alloc` has been allocated.
286 /// Produce IR resembling:
287 /// ```
288 ///    %1:3 = scf.if (%inBounds) {
289 ///      memref.cast %A: memref<A...> to compatibleMemRefType
290 ///      scf.yield %view, ... : compatibleMemRefType, index, index
291 ///    } else {
292 ///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
293 ///      %3 = vector.type_cast %extra_alloc :
294 ///        memref<...> to memref<vector<...>>
295 ///      store %2, %3[] : memref<vector<...>>
296 ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
297 ///      scf.yield %4, ... : compatibleMemRefType, index, index
298 ///   }
299 /// ```
300 /// Return the produced scf::IfOp.
createFullPartialVectorTransferRead(RewriterBase & b,vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)301 static scf::IfOp createFullPartialVectorTransferRead(
302     RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
303     Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
304   Location loc = xferOp.getLoc();
305   scf::IfOp fullPartialIfOp;
306   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
307   Value memref = xferOp.getSource();
308   return b.create<scf::IfOp>(
309       loc, returnTypes, inBoundsCond,
310       [&](OpBuilder &b, Location loc) {
311         Value res = memref;
312         if (compatibleMemRefType != xferOp.getShapedType())
313           res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
314         scf::ValueVector viewAndIndices{res};
315         viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
316                               xferOp.getIndices().end());
317         b.create<scf::YieldOp>(loc, viewAndIndices);
318       },
319       [&](OpBuilder &b, Location loc) {
320         Operation *newXfer = b.clone(*xferOp.getOperation());
321         Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
322         b.create<memref::StoreOp>(
323             loc, vector,
324             b.create<vector::TypeCastOp>(
325                 loc, MemRefType::get({}, vector.getType()), alloc));
326 
327         Value casted =
328             b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
329         scf::ValueVector viewAndIndices{casted};
330         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
331                               zero);
332         b.create<scf::YieldOp>(loc, viewAndIndices);
333       });
334 }
335 
336 /// Given an `xferOp` for which:
337 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
338 ///   2. a memref of single vector `alloc` has been allocated.
339 /// Produce IR resembling:
340 /// ```
341 ///    %1:3 = scf.if (%inBounds) {
342 ///      memref.cast %A: memref<A...> to compatibleMemRefType
343 ///      scf.yield %view, ... : compatibleMemRefType, index, index
344 ///    } else {
345 ///      %3 = vector.type_cast %extra_alloc :
346 ///        memref<...> to memref<vector<...>>
347 ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
348 ///      scf.yield %4, ... : compatibleMemRefType, index, index
349 ///   }
350 /// ```
351 static ValueRange
getLocationToWriteFullVec(RewriterBase & b,vector::TransferWriteOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
353                           TypeRange returnTypes, Value inBoundsCond,
354                           MemRefType compatibleMemRefType, Value alloc) {
355   Location loc = xferOp.getLoc();
356   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
357   Value memref = xferOp.getSource();
358   return b
359       .create<scf::IfOp>(
360           loc, returnTypes, inBoundsCond,
361           [&](OpBuilder &b, Location loc) {
362             Value res = memref;
363             if (compatibleMemRefType != xferOp.getShapedType())
364               res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
365             scf::ValueVector viewAndIndices{res};
366             viewAndIndices.insert(viewAndIndices.end(),
367                                   xferOp.getIndices().begin(),
368                                   xferOp.getIndices().end());
369             b.create<scf::YieldOp>(loc, viewAndIndices);
370           },
371           [&](OpBuilder &b, Location loc) {
372             Value casted =
373                 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
374             scf::ValueVector viewAndIndices{casted};
375             viewAndIndices.insert(viewAndIndices.end(),
376                                   xferOp.getTransferRank(), zero);
377             b.create<scf::YieldOp>(loc, viewAndIndices);
378           })
379       ->getResults();
380 }
381 
382 /// Given an `xferOp` for which:
383 ///   1. `inBoundsCond` has been computed.
384 ///   2. a memref of single vector `alloc` has been allocated.
385 ///   3. it originally wrote to %view
386 /// Produce IR resembling:
387 /// ```
388 ///    %notInBounds = arith.xori %inBounds, %true
389 ///    scf.if (%notInBounds) {
390 ///      %3 = subview %alloc [...][...][...]
391 ///      %4 = subview %view [0, 0][...][...]
392 ///      linalg.copy(%3, %4)
393 ///   }
394 /// ```
createFullPartialLinalgCopy(RewriterBase & b,vector::TransferWriteOp xferOp,Value inBoundsCond,Value alloc)395 static void createFullPartialLinalgCopy(RewriterBase &b,
396                                         vector::TransferWriteOp xferOp,
397                                         Value inBoundsCond, Value alloc) {
398   Location loc = xferOp.getLoc();
399   auto notInBounds = b.create<arith::XOrIOp>(
400       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
401   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
402     IRRewriter rewriter(b);
403     std::pair<Value, Value> copyArgs = createSubViewIntersection(
404         rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
405         alloc);
406     b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
407     b.create<scf::YieldOp>(loc, ValueRange{});
408   });
409 }
410 
411 /// Given an `xferOp` for which:
412 ///   1. `inBoundsCond` has been computed.
413 ///   2. a memref of single vector `alloc` has been allocated.
414 ///   3. it originally wrote to %view
415 /// Produce IR resembling:
416 /// ```
417 ///    %notInBounds = arith.xori %inBounds, %true
418 ///    scf.if (%notInBounds) {
419 ///      %2 = load %alloc : memref<vector<...>>
420 ///      vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
421 ///   }
422 /// ```
createFullPartialVectorTransferWrite(RewriterBase & b,vector::TransferWriteOp xferOp,Value inBoundsCond,Value alloc)423 static void createFullPartialVectorTransferWrite(RewriterBase &b,
424                                                  vector::TransferWriteOp xferOp,
425                                                  Value inBoundsCond,
426                                                  Value alloc) {
427   Location loc = xferOp.getLoc();
428   auto notInBounds = b.create<arith::XOrIOp>(
429       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
430   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
431     BlockAndValueMapping mapping;
432     Value load = b.create<memref::LoadOp>(
433         loc,
434         b.create<vector::TypeCastOp>(
435             loc, MemRefType::get({}, xferOp.getVector().getType()), alloc));
436     mapping.map(xferOp.getVector(), load);
437     b.clone(*xferOp.getOperation(), mapping);
438     b.create<scf::YieldOp>(loc, ValueRange{});
439   });
440 }
441 
442 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
getAutomaticAllocationScope(Operation * op)443 static Operation *getAutomaticAllocationScope(Operation *op) {
444   // Find the closest surrounding allocation scope that is not a known looping
445   // construct (putting alloca's in loops doesn't always lower to deallocation
446   // until the end of the loop).
447   Operation *scope = nullptr;
448   for (Operation *parent = op->getParentOp(); parent != nullptr;
449        parent = parent->getParentOp()) {
450     if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
451       scope = parent;
452     if (!isa<scf::ForOp, AffineForOp>(parent))
453       break;
454   }
455   assert(scope && "Expected op to be inside automatic allocation scope");
456   return scope;
457 }
458 
459 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
460 /// masking) fastpath and a slowpath.
461 ///
462 /// For vector.transfer_read:
463 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
464 /// newly created conditional upon function return.
465 /// To accomodate for the fact that the original vector.transfer indexing may be
466 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
467 /// scf.if op returns a view and values of type index.
468 ///
469 /// Example (a 2-D vector.transfer_read):
470 /// ```
471 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
472 /// ```
473 /// is transformed into:
474 /// ```
475 ///    %1:3 = scf.if (%inBounds) {
476 ///      // fastpath, direct cast
477 ///      memref.cast %A: memref<A...> to compatibleMemRefType
478 ///      scf.yield %view : compatibleMemRefType, index, index
479 ///    } else {
480 ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
481 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
482 ///      scf.yield %4 : compatibleMemRefType, index, index
483 //     }
484 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
485 /// ```
486 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
487 ///
488 /// For vector.transfer_write:
489 /// There are 2 conditional blocks. First a block to decide which memref and
490 /// indices to use for an unmasked, inbounds write. Then a conditional block to
491 /// further copy a partial buffer into the final result in the slow path case.
492 ///
493 /// Example (a 2-D vector.transfer_write):
494 /// ```
495 ///    vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
496 /// ```
497 /// is transformed into:
498 /// ```
499 ///    %1:3 = scf.if (%inBounds) {
500 ///      memref.cast %A: memref<A...> to compatibleMemRefType
501 ///      scf.yield %view : compatibleMemRefType, index, index
502 ///    } else {
503 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
504 ///      scf.yield %4 : compatibleMemRefType, index, index
505 ///     }
506 ///    %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
507 ///                                                                    true]}
508 ///    scf.if (%notInBounds) {
509 ///      // slowpath: not in-bounds vector.transfer or linalg.copy.
510 ///    }
511 /// ```
512 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
513 ///
514 /// Preconditions:
515 ///  1. `xferOp.permutation_map()` must be a minor identity map
516 ///  2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
517 ///  must be equal. This will be relaxed in the future but requires
518 ///  rank-reducing subviews.
splitFullAndPartialTransfer(RewriterBase & b,VectorTransferOpInterface xferOp,VectorTransformsOptions options,scf::IfOp * ifOp)519 LogicalResult mlir::vector::splitFullAndPartialTransfer(
520     RewriterBase &b, VectorTransferOpInterface xferOp,
521     VectorTransformsOptions options, scf::IfOp *ifOp) {
522   if (options.vectorTransferSplit == VectorTransferSplit::None)
523     return failure();
524 
525   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
526   auto inBoundsAttr = b.getBoolArrayAttr(bools);
527   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
528     xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
529     return success();
530   }
531 
532   // Assert preconditions. Additionally, keep the variables in an inner scope to
533   // ensure they aren't used in the wrong scopes further down.
534   {
535     assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
536            "Expected splitFullAndPartialTransferPrecondition to hold");
537 
538     auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
539     auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
540 
541     if (!(xferReadOp || xferWriteOp))
542       return failure();
543     if (xferWriteOp && xferWriteOp.getMask())
544       return failure();
545     if (xferReadOp && xferReadOp.getMask())
546       return failure();
547   }
548 
549   RewriterBase::InsertionGuard guard(b);
550   b.setInsertionPoint(xferOp);
551   Value inBoundsCond = createInBoundsCond(
552       b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
553   if (!inBoundsCond)
554     return failure();
555 
556   // Top of the function `alloc` for transient storage.
557   Value alloc;
558   {
559     RewriterBase::InsertionGuard guard(b);
560     Operation *scope = getAutomaticAllocationScope(xferOp);
561     assert(scope->getNumRegions() == 1 &&
562            "AutomaticAllocationScope with >1 regions");
563     b.setInsertionPointToStart(&scope->getRegion(0).front());
564     auto shape = xferOp.getVectorType().getShape();
565     Type elementType = xferOp.getVectorType().getElementType();
566     alloc = b.create<memref::AllocaOp>(scope->getLoc(),
567                                        MemRefType::get(shape, elementType),
568                                        ValueRange{}, b.getI64IntegerAttr(32));
569   }
570 
571   MemRefType compatibleMemRefType =
572       getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
573                                   alloc.getType().cast<MemRefType>());
574   if (!compatibleMemRefType)
575     return failure();
576 
577   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
578                                    b.getIndexType());
579   returnTypes[0] = compatibleMemRefType;
580 
581   if (auto xferReadOp =
582           dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
583     // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
584     scf::IfOp fullPartialIfOp =
585         options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
586             ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
587                                                   inBoundsCond,
588                                                   compatibleMemRefType, alloc)
589             : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
590                                           inBoundsCond, compatibleMemRefType,
591                                           alloc);
592     if (ifOp)
593       *ifOp = fullPartialIfOp;
594 
595     // Set existing read op to in-bounds, it always reads from a full buffer.
596     for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
597       xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
598 
599     xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
600 
601     return success();
602   }
603 
604   auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
605 
606   // Decide which location to write the entire vector to.
607   auto memrefAndIndices = getLocationToWriteFullVec(
608       b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
609 
610   // Do an in bounds write to either the output or the extra allocated buffer.
611   // The operation is cloned to prevent deleting information needed for the
612   // later IR creation.
613   BlockAndValueMapping mapping;
614   mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
615   mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
616   auto *clone = b.clone(*xferWriteOp, mapping);
617   clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
618 
619   // Create a potential copy from the allocated buffer to the final output in
620   // the slow path case.
621   if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
622     createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
623   else
624     createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
625 
626   xferOp->erase();
627 
628   return success();
629 }
630 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const631 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
632     Operation *op, PatternRewriter &rewriter) const {
633   auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
634   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
635       failed(filter(xferOp)))
636     return failure();
637   rewriter.startRootUpdate(xferOp);
638   if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
639     rewriter.finalizeRootUpdate(xferOp);
640     return success();
641   }
642   rewriter.cancelRootUpdate(xferOp);
643   return failure();
644 }
645