1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include <type_traits>
15
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/VectorInterfaces.h"
27
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34
35 #define DEBUG_TYPE "vector-transfer-split"
36
37 using namespace mlir;
38 using namespace mlir::vector;
39
extractConstantIndex(Value v)40 static Optional<int64_t> extractConstantIndex(Value v) {
41 if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
42 return cstOp.value();
43 if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
44 if (affineApplyOp.getAffineMap().isSingleConstant())
45 return affineApplyOp.getAffineMap().getSingleConstantResult();
46 return None;
47 }
48
49 // Missing foldings of scf.if make it necessary to perform poor man's folding
50 // eagerly, especially in the case of unrolling. In the future, this should go
51 // away once scf.if folds properly.
createFoldedSLE(RewriterBase & b,Value v,Value ub)52 static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) {
53 auto maybeCstV = extractConstantIndex(v);
54 auto maybeCstUb = extractConstantIndex(ub);
55 if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
56 return Value();
57 return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
58 }
59
60 /// Build the condition to ensure that a particular VectorTransferOpInterface
61 /// is in-bounds.
createInBoundsCond(RewriterBase & b,VectorTransferOpInterface xferOp)62 static Value createInBoundsCond(RewriterBase &b,
63 VectorTransferOpInterface xferOp) {
64 assert(xferOp.permutation_map().isMinorIdentity() &&
65 "Expected minor identity map");
66 Value inBoundsCond;
67 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
68 // Zip over the resulting vector shape and memref indices.
69 // If the dimension is known to be in-bounds, it does not participate in
70 // the construction of `inBoundsCond`.
71 if (xferOp.isDimInBounds(resultIdx))
72 return;
73 // Fold or create the check that `index + vector_size` <= `memref_size`.
74 Location loc = xferOp.getLoc();
75 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
76 auto d0 = getAffineDimExpr(0, xferOp.getContext());
77 auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
78 Value sum =
79 makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]);
80 Value cond = createFoldedSLE(
81 b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
82 if (!cond)
83 return;
84 // Conjunction over all dims for which we are in-bounds.
85 if (inBoundsCond)
86 inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
87 else
88 inBoundsCond = cond;
89 });
90 return inBoundsCond;
91 }
92
93 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
94 /// masking) fastpath and a slowpath.
95 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
96 /// newly created conditional upon function return.
97 /// To accomodate for the fact that the original vector.transfer indexing may be
98 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
99 /// scf.if op returns a view and values of type index.
100 /// At this time, only vector.transfer_read case is implemented.
101 ///
102 /// Example (a 2-D vector.transfer_read):
103 /// ```
104 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
105 /// ```
106 /// is transformed into:
107 /// ```
108 /// %1:3 = scf.if (%inBounds) {
109 /// // fastpath, direct cast
110 /// memref.cast %A: memref<A...> to compatibleMemRefType
111 /// scf.yield %view : compatibleMemRefType, index, index
112 /// } else {
113 /// // slowpath, not in-bounds vector.transfer or linalg.copy.
114 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
115 /// scf.yield %4 : compatibleMemRefType, index, index
116 // }
117 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
118 /// ```
119 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
120 ///
121 /// Preconditions:
122 /// 1. `xferOp.permutation_map()` must be a minor identity map
123 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
124 /// must be equal. This will be relaxed in the future but requires
125 /// rank-reducing subviews.
126 static LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)127 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
128 // TODO: support 0-d corner case.
129 if (xferOp.getTransferRank() == 0)
130 return failure();
131
132 // TODO: expand support to these 2 cases.
133 if (!xferOp.permutation_map().isMinorIdentity())
134 return failure();
135 // Must have some out-of-bounds dimension to be a candidate for splitting.
136 if (!xferOp.hasOutOfBoundsDim())
137 return failure();
138 // Don't split transfer operations directly under IfOp, this avoids applying
139 // the pattern recursively.
140 // TODO: improve the filtering condition to make it more applicable.
141 if (isa<scf::IfOp>(xferOp->getParentOp()))
142 return failure();
143 return success();
144 }
145
146 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
147 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
148 /// return null; otherwise:
149 /// 1. if `aT` and `bT` are cast-compatible, return `aT`.
150 /// 2. else return a new MemRefType obtained by iterating over the shape and
151 /// strides and:
152 /// a. keeping the ones that are static and equal across `aT` and `bT`.
153 /// b. using a dynamic shape and/or stride for the dimensions that don't
154 /// agree.
getCastCompatibleMemRefType(MemRefType aT,MemRefType bT)155 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
156 if (memref::CastOp::areCastCompatible(aT, bT))
157 return aT;
158 if (aT.getRank() != bT.getRank())
159 return MemRefType();
160 int64_t aOffset, bOffset;
161 SmallVector<int64_t, 4> aStrides, bStrides;
162 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
163 failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
164 aStrides.size() != bStrides.size())
165 return MemRefType();
166
167 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
168 int64_t resOffset;
169 SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
170 resStrides(bT.getRank(), 0);
171 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
172 resShape[idx] =
173 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize;
174 resStrides[idx] = (aStrides[idx] == bStrides[idx])
175 ? aStrides[idx]
176 : ShapedType::kDynamicStrideOrOffset;
177 }
178 resOffset =
179 (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset;
180 return MemRefType::get(
181 resShape, aT.getElementType(),
182 makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
183 }
184
185 /// Operates under a scoped context to build the intersection between the
186 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
187 // TODO: view intersection/union/differences should be a proper std op.
188 static std::pair<Value, Value>
createSubViewIntersection(RewriterBase & b,VectorTransferOpInterface xferOp,Value alloc)189 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
190 Value alloc) {
191 Location loc = xferOp.getLoc();
192 int64_t memrefRank = xferOp.getShapedType().getRank();
193 // TODO: relax this precondition, will require rank-reducing subviews.
194 assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
195 "Expected memref rank to match the alloc rank");
196 ValueRange leadingIndices =
197 xferOp.indices().take_front(xferOp.getLeadingShapedRank());
198 SmallVector<OpFoldResult, 4> sizes;
199 sizes.append(leadingIndices.begin(), leadingIndices.end());
200 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
201 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
202 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
203 Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
204 xferOp.source(), indicesIdx);
205 Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
206 Value index = xferOp.indices()[indicesIdx];
207 AffineExpr i, j, k;
208 bindDims(xferOp.getContext(), i, j, k);
209 SmallVector<AffineMap, 4> maps =
210 AffineMap::inferFromExprList(MapList{{i - j, k}});
211 // affine_min(%dimMemRef - %index, %dimAlloc)
212 Value affineMin = b.create<AffineMinOp>(
213 loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
214 sizes.push_back(affineMin);
215 });
216
217 SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
218 xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
219 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
220 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
221 auto copySrc = b.create<memref::SubViewOp>(
222 loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
223 auto copyDest = b.create<memref::SubViewOp>(
224 loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
225 return std::make_pair(copySrc, copyDest);
226 }
227
228 /// Given an `xferOp` for which:
229 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
230 /// 2. a memref of single vector `alloc` has been allocated.
231 /// Produce IR resembling:
232 /// ```
233 /// %1:3 = scf.if (%inBounds) {
234 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType
235 /// scf.yield %view, ... : compatibleMemRefType, index, index
236 /// } else {
237 /// %2 = linalg.fill(%pad, %alloc)
238 /// %3 = subview %view [...][...][...]
239 /// %4 = subview %alloc [0, 0] [...] [...]
240 /// linalg.copy(%3, %4)
241 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
242 /// scf.yield %5, ... : compatibleMemRefType, index, index
243 /// }
244 /// ```
245 /// Return the produced scf::IfOp.
246 static scf::IfOp
createFullPartialLinalgCopy(RewriterBase & b,vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)247 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
248 TypeRange returnTypes, Value inBoundsCond,
249 MemRefType compatibleMemRefType, Value alloc) {
250 Location loc = xferOp.getLoc();
251 Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
252 Value memref = xferOp.getSource();
253 return b.create<scf::IfOp>(
254 loc, returnTypes, inBoundsCond,
255 [&](OpBuilder &b, Location loc) {
256 Value res = memref;
257 if (compatibleMemRefType != xferOp.getShapedType())
258 res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
259 scf::ValueVector viewAndIndices{res};
260 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
261 xferOp.getIndices().end());
262 b.create<scf::YieldOp>(loc, viewAndIndices);
263 },
264 [&](OpBuilder &b, Location loc) {
265 b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
266 ValueRange{alloc});
267 // Take partial subview of memref which guarantees no dimension
268 // overflows.
269 IRRewriter rewriter(b);
270 std::pair<Value, Value> copyArgs = createSubViewIntersection(
271 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
272 alloc);
273 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
274 Value casted =
275 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
276 scf::ValueVector viewAndIndices{casted};
277 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
278 zero);
279 b.create<scf::YieldOp>(loc, viewAndIndices);
280 });
281 }
282
283 /// Given an `xferOp` for which:
284 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
285 /// 2. a memref of single vector `alloc` has been allocated.
286 /// Produce IR resembling:
287 /// ```
288 /// %1:3 = scf.if (%inBounds) {
289 /// memref.cast %A: memref<A...> to compatibleMemRefType
290 /// scf.yield %view, ... : compatibleMemRefType, index, index
291 /// } else {
292 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
293 /// %3 = vector.type_cast %extra_alloc :
294 /// memref<...> to memref<vector<...>>
295 /// store %2, %3[] : memref<vector<...>>
296 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
297 /// scf.yield %4, ... : compatibleMemRefType, index, index
298 /// }
299 /// ```
300 /// Return the produced scf::IfOp.
createFullPartialVectorTransferRead(RewriterBase & b,vector::TransferReadOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)301 static scf::IfOp createFullPartialVectorTransferRead(
302 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
303 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
304 Location loc = xferOp.getLoc();
305 scf::IfOp fullPartialIfOp;
306 Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
307 Value memref = xferOp.getSource();
308 return b.create<scf::IfOp>(
309 loc, returnTypes, inBoundsCond,
310 [&](OpBuilder &b, Location loc) {
311 Value res = memref;
312 if (compatibleMemRefType != xferOp.getShapedType())
313 res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
314 scf::ValueVector viewAndIndices{res};
315 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
316 xferOp.getIndices().end());
317 b.create<scf::YieldOp>(loc, viewAndIndices);
318 },
319 [&](OpBuilder &b, Location loc) {
320 Operation *newXfer = b.clone(*xferOp.getOperation());
321 Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
322 b.create<memref::StoreOp>(
323 loc, vector,
324 b.create<vector::TypeCastOp>(
325 loc, MemRefType::get({}, vector.getType()), alloc));
326
327 Value casted =
328 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
329 scf::ValueVector viewAndIndices{casted};
330 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
331 zero);
332 b.create<scf::YieldOp>(loc, viewAndIndices);
333 });
334 }
335
336 /// Given an `xferOp` for which:
337 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
338 /// 2. a memref of single vector `alloc` has been allocated.
339 /// Produce IR resembling:
340 /// ```
341 /// %1:3 = scf.if (%inBounds) {
342 /// memref.cast %A: memref<A...> to compatibleMemRefType
343 /// scf.yield %view, ... : compatibleMemRefType, index, index
344 /// } else {
345 /// %3 = vector.type_cast %extra_alloc :
346 /// memref<...> to memref<vector<...>>
347 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
348 /// scf.yield %4, ... : compatibleMemRefType, index, index
349 /// }
350 /// ```
351 static ValueRange
getLocationToWriteFullVec(RewriterBase & b,vector::TransferWriteOp xferOp,TypeRange returnTypes,Value inBoundsCond,MemRefType compatibleMemRefType,Value alloc)352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
353 TypeRange returnTypes, Value inBoundsCond,
354 MemRefType compatibleMemRefType, Value alloc) {
355 Location loc = xferOp.getLoc();
356 Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
357 Value memref = xferOp.getSource();
358 return b
359 .create<scf::IfOp>(
360 loc, returnTypes, inBoundsCond,
361 [&](OpBuilder &b, Location loc) {
362 Value res = memref;
363 if (compatibleMemRefType != xferOp.getShapedType())
364 res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
365 scf::ValueVector viewAndIndices{res};
366 viewAndIndices.insert(viewAndIndices.end(),
367 xferOp.getIndices().begin(),
368 xferOp.getIndices().end());
369 b.create<scf::YieldOp>(loc, viewAndIndices);
370 },
371 [&](OpBuilder &b, Location loc) {
372 Value casted =
373 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
374 scf::ValueVector viewAndIndices{casted};
375 viewAndIndices.insert(viewAndIndices.end(),
376 xferOp.getTransferRank(), zero);
377 b.create<scf::YieldOp>(loc, viewAndIndices);
378 })
379 ->getResults();
380 }
381
382 /// Given an `xferOp` for which:
383 /// 1. `inBoundsCond` has been computed.
384 /// 2. a memref of single vector `alloc` has been allocated.
385 /// 3. it originally wrote to %view
386 /// Produce IR resembling:
387 /// ```
388 /// %notInBounds = arith.xori %inBounds, %true
389 /// scf.if (%notInBounds) {
390 /// %3 = subview %alloc [...][...][...]
391 /// %4 = subview %view [0, 0][...][...]
392 /// linalg.copy(%3, %4)
393 /// }
394 /// ```
createFullPartialLinalgCopy(RewriterBase & b,vector::TransferWriteOp xferOp,Value inBoundsCond,Value alloc)395 static void createFullPartialLinalgCopy(RewriterBase &b,
396 vector::TransferWriteOp xferOp,
397 Value inBoundsCond, Value alloc) {
398 Location loc = xferOp.getLoc();
399 auto notInBounds = b.create<arith::XOrIOp>(
400 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
401 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
402 IRRewriter rewriter(b);
403 std::pair<Value, Value> copyArgs = createSubViewIntersection(
404 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
405 alloc);
406 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
407 b.create<scf::YieldOp>(loc, ValueRange{});
408 });
409 }
410
411 /// Given an `xferOp` for which:
412 /// 1. `inBoundsCond` has been computed.
413 /// 2. a memref of single vector `alloc` has been allocated.
414 /// 3. it originally wrote to %view
415 /// Produce IR resembling:
416 /// ```
417 /// %notInBounds = arith.xori %inBounds, %true
418 /// scf.if (%notInBounds) {
419 /// %2 = load %alloc : memref<vector<...>>
420 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
421 /// }
422 /// ```
createFullPartialVectorTransferWrite(RewriterBase & b,vector::TransferWriteOp xferOp,Value inBoundsCond,Value alloc)423 static void createFullPartialVectorTransferWrite(RewriterBase &b,
424 vector::TransferWriteOp xferOp,
425 Value inBoundsCond,
426 Value alloc) {
427 Location loc = xferOp.getLoc();
428 auto notInBounds = b.create<arith::XOrIOp>(
429 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
430 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
431 BlockAndValueMapping mapping;
432 Value load = b.create<memref::LoadOp>(
433 loc,
434 b.create<vector::TypeCastOp>(
435 loc, MemRefType::get({}, xferOp.getVector().getType()), alloc));
436 mapping.map(xferOp.getVector(), load);
437 b.clone(*xferOp.getOperation(), mapping);
438 b.create<scf::YieldOp>(loc, ValueRange{});
439 });
440 }
441
442 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
getAutomaticAllocationScope(Operation * op)443 static Operation *getAutomaticAllocationScope(Operation *op) {
444 // Find the closest surrounding allocation scope that is not a known looping
445 // construct (putting alloca's in loops doesn't always lower to deallocation
446 // until the end of the loop).
447 Operation *scope = nullptr;
448 for (Operation *parent = op->getParentOp(); parent != nullptr;
449 parent = parent->getParentOp()) {
450 if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
451 scope = parent;
452 if (!isa<scf::ForOp, AffineForOp>(parent))
453 break;
454 }
455 assert(scope && "Expected op to be inside automatic allocation scope");
456 return scope;
457 }
458
459 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
460 /// masking) fastpath and a slowpath.
461 ///
462 /// For vector.transfer_read:
463 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
464 /// newly created conditional upon function return.
465 /// To accomodate for the fact that the original vector.transfer indexing may be
466 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
467 /// scf.if op returns a view and values of type index.
468 ///
469 /// Example (a 2-D vector.transfer_read):
470 /// ```
471 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
472 /// ```
473 /// is transformed into:
474 /// ```
475 /// %1:3 = scf.if (%inBounds) {
476 /// // fastpath, direct cast
477 /// memref.cast %A: memref<A...> to compatibleMemRefType
478 /// scf.yield %view : compatibleMemRefType, index, index
479 /// } else {
480 /// // slowpath, not in-bounds vector.transfer or linalg.copy.
481 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
482 /// scf.yield %4 : compatibleMemRefType, index, index
483 // }
484 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
485 /// ```
486 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
487 ///
488 /// For vector.transfer_write:
489 /// There are 2 conditional blocks. First a block to decide which memref and
490 /// indices to use for an unmasked, inbounds write. Then a conditional block to
491 /// further copy a partial buffer into the final result in the slow path case.
492 ///
493 /// Example (a 2-D vector.transfer_write):
494 /// ```
495 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
496 /// ```
497 /// is transformed into:
498 /// ```
499 /// %1:3 = scf.if (%inBounds) {
500 /// memref.cast %A: memref<A...> to compatibleMemRefType
501 /// scf.yield %view : compatibleMemRefType, index, index
502 /// } else {
503 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
504 /// scf.yield %4 : compatibleMemRefType, index, index
505 /// }
506 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
507 /// true]}
508 /// scf.if (%notInBounds) {
509 /// // slowpath: not in-bounds vector.transfer or linalg.copy.
510 /// }
511 /// ```
512 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
513 ///
514 /// Preconditions:
515 /// 1. `xferOp.permutation_map()` must be a minor identity map
516 /// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
517 /// must be equal. This will be relaxed in the future but requires
518 /// rank-reducing subviews.
splitFullAndPartialTransfer(RewriterBase & b,VectorTransferOpInterface xferOp,VectorTransformsOptions options,scf::IfOp * ifOp)519 LogicalResult mlir::vector::splitFullAndPartialTransfer(
520 RewriterBase &b, VectorTransferOpInterface xferOp,
521 VectorTransformsOptions options, scf::IfOp *ifOp) {
522 if (options.vectorTransferSplit == VectorTransferSplit::None)
523 return failure();
524
525 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
526 auto inBoundsAttr = b.getBoolArrayAttr(bools);
527 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
528 xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
529 return success();
530 }
531
532 // Assert preconditions. Additionally, keep the variables in an inner scope to
533 // ensure they aren't used in the wrong scopes further down.
534 {
535 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
536 "Expected splitFullAndPartialTransferPrecondition to hold");
537
538 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
539 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
540
541 if (!(xferReadOp || xferWriteOp))
542 return failure();
543 if (xferWriteOp && xferWriteOp.getMask())
544 return failure();
545 if (xferReadOp && xferReadOp.getMask())
546 return failure();
547 }
548
549 RewriterBase::InsertionGuard guard(b);
550 b.setInsertionPoint(xferOp);
551 Value inBoundsCond = createInBoundsCond(
552 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
553 if (!inBoundsCond)
554 return failure();
555
556 // Top of the function `alloc` for transient storage.
557 Value alloc;
558 {
559 RewriterBase::InsertionGuard guard(b);
560 Operation *scope = getAutomaticAllocationScope(xferOp);
561 assert(scope->getNumRegions() == 1 &&
562 "AutomaticAllocationScope with >1 regions");
563 b.setInsertionPointToStart(&scope->getRegion(0).front());
564 auto shape = xferOp.getVectorType().getShape();
565 Type elementType = xferOp.getVectorType().getElementType();
566 alloc = b.create<memref::AllocaOp>(scope->getLoc(),
567 MemRefType::get(shape, elementType),
568 ValueRange{}, b.getI64IntegerAttr(32));
569 }
570
571 MemRefType compatibleMemRefType =
572 getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
573 alloc.getType().cast<MemRefType>());
574 if (!compatibleMemRefType)
575 return failure();
576
577 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
578 b.getIndexType());
579 returnTypes[0] = compatibleMemRefType;
580
581 if (auto xferReadOp =
582 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
583 // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
584 scf::IfOp fullPartialIfOp =
585 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
586 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
587 inBoundsCond,
588 compatibleMemRefType, alloc)
589 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
590 inBoundsCond, compatibleMemRefType,
591 alloc);
592 if (ifOp)
593 *ifOp = fullPartialIfOp;
594
595 // Set existing read op to in-bounds, it always reads from a full buffer.
596 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
597 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
598
599 xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
600
601 return success();
602 }
603
604 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
605
606 // Decide which location to write the entire vector to.
607 auto memrefAndIndices = getLocationToWriteFullVec(
608 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
609
610 // Do an in bounds write to either the output or the extra allocated buffer.
611 // The operation is cloned to prevent deleting information needed for the
612 // later IR creation.
613 BlockAndValueMapping mapping;
614 mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
615 mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
616 auto *clone = b.clone(*xferWriteOp, mapping);
617 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
618
619 // Create a potential copy from the allocated buffer to the final output in
620 // the slow path case.
621 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
622 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
623 else
624 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
625
626 xferOp->erase();
627
628 return success();
629 }
630
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const631 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
632 Operation *op, PatternRewriter &rewriter) const {
633 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
634 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
635 failed(filter(xferOp)))
636 return failure();
637 rewriter.startRootUpdate(xferOp);
638 if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
639 rewriter.finalizeRootUpdate(xferOp);
640 return success();
641 }
642 rewriter.cancelRootUpdate(xferOp);
643 return failure();
644 }
645