1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include <type_traits>
14
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "mlir/Transforms/Passes.h"
28
29 using namespace mlir;
30 using vector::TransferReadOp;
31 using vector::TransferWriteOp;
32
33 namespace {
34
35 /// Attribute name used for labeling transfer ops during progressive lowering.
36 static const char kPassLabel[] = "__vector_to_scf_lowering__";
37
38 /// Patterns that inherit from this struct have access to
39 /// VectorTransferToSCFOptions.
40 template <typename OpTy>
41 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
VectorToSCFPattern__anon4d9edda10111::VectorToSCFPattern42 explicit VectorToSCFPattern(MLIRContext *context,
43 VectorTransferToSCFOptions opt)
44 : OpRewritePattern<OpTy>(context), options(opt) {}
45
46 VectorTransferToSCFOptions options;
47 };
48
49 /// Given a vector transfer op, calculate which dimension of the `source`
50 /// memref should be unpacked in the next application of TransferOpConversion.
51 /// A return value of None indicates a broadcast.
52 template <typename OpTy>
unpackedDim(OpTy xferOp)53 static Optional<int64_t> unpackedDim(OpTy xferOp) {
54 // TODO: support 0-d corner case.
55 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
56 auto map = xferOp.getPermutationMap();
57 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
58 return expr.getPosition();
59 }
60 assert(xferOp.isBroadcastDim(0) &&
61 "Expected AffineDimExpr or AffineConstantExpr");
62 return None;
63 }
64
65 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
66 /// map is identical to the current permutation map, but the first result is
67 /// omitted.
68 template <typename OpTy>
unpackedPermutationMap(OpBuilder & b,OpTy xferOp)69 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
70 // TODO: support 0-d corner case.
71 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
72 auto map = xferOp.getPermutationMap();
73 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
74 b.getContext());
75 }
76
77 /// Calculate the indices for the new vector transfer op.
78 ///
79 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
80 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
81 /// ^^^^^^
82 /// `iv` is the iteration variable of the (new) surrounding loop.
83 template <typename OpTy>
getXferIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & indices)84 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
85 SmallVector<Value, 8> &indices) {
86 typename OpTy::Adaptor adaptor(xferOp);
87 // Corresponding memref dim of the vector dim that is unpacked.
88 auto dim = unpackedDim(xferOp);
89 auto prevIndices = adaptor.getIndices();
90 indices.append(prevIndices.begin(), prevIndices.end());
91
92 Location loc = xferOp.getLoc();
93 bool isBroadcast = !dim.has_value();
94 if (!isBroadcast) {
95 AffineExpr d0, d1;
96 bindDims(xferOp.getContext(), d0, d1);
97 Value offset = adaptor.getIndices()[dim.value()];
98 indices[dim.value()] =
99 makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
100 }
101 }
102
maybeYieldValue(OpBuilder & b,Location loc,bool hasRetVal,Value value)103 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
104 Value value) {
105 if (hasRetVal) {
106 assert(value && "Expected non-empty value");
107 b.create<scf::YieldOp>(loc, value);
108 } else {
109 b.create<scf::YieldOp>(loc);
110 }
111 }
112
113 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
114 /// is set to true. No such check is generated under following circumstances:
115 /// * xferOp does not have a mask.
116 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
117 /// computed and attached to the new transfer op in the pattern.)
118 /// * The to-be-unpacked dim of xferOp is a broadcast.
119 template <typename OpTy>
generateMaskCheck(OpBuilder & b,OpTy xferOp,Value iv)120 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
121 if (!xferOp.getMask())
122 return Value();
123 if (xferOp.getMaskType().getRank() != 1)
124 return Value();
125 if (xferOp.isBroadcastDim(0))
126 return Value();
127
128 Location loc = xferOp.getLoc();
129 return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
130 }
131
132 /// Helper function TransferOpConversion and TransferOp1dConversion.
133 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
134 /// specified dimension `dim` with the loop iteration variable `iv`.
135 /// E.g., when unpacking dimension 0 from:
136 /// ```
137 /// %vec = vector.transfer_read %A[%a, %b] %cst
138 /// : vector<5x4xf32>, memref<?x?xf32>
139 /// ```
140 /// An if check similar to this will be generated inside the loop:
141 /// ```
142 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
143 /// if (%a + iv < %d) {
144 /// (in-bounds case)
145 /// } else {
146 /// (out-of-bounds case)
147 /// }
148 /// ```
149 ///
150 /// If the transfer is 1D and has a mask, this function generates a more complex
151 /// check also accounts for potentially masked out elements.
152 ///
153 /// This function variant returns the value returned by `inBoundsCase` or
154 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
155 /// `resultTypes`.
156 template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,TypeRange resultTypes,function_ref<Value (OpBuilder &,Location)> inBoundsCase,function_ref<Value (OpBuilder &,Location)> outOfBoundsCase=nullptr)157 static Value generateInBoundsCheck(
158 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
159 TypeRange resultTypes,
160 function_ref<Value(OpBuilder &, Location)> inBoundsCase,
161 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
162 bool hasRetVal = !resultTypes.empty();
163 Value cond; // Condition to be built...
164
165 // Condition check 1: Access in-bounds?
166 bool isBroadcast = !dim; // No in-bounds check for broadcasts.
167 Location loc = xferOp.getLoc();
168 ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
169 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
170 Value memrefDim =
171 vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
172 AffineExpr d0, d1;
173 bindDims(xferOp.getContext(), d0, d1);
174 Value base = xferOp.getIndices()[*dim];
175 Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
176 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
177 memrefIdx);
178 }
179
180 // Condition check 2: Masked in?
181 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
182 if (cond)
183 cond = lb.create<arith::AndIOp>(cond, maskCond);
184 else
185 cond = maskCond;
186 }
187
188 // If the condition is non-empty, generate an SCF::IfOp.
189 if (cond) {
190 auto check = lb.create<scf::IfOp>(
191 resultTypes, cond,
192 /*thenBuilder=*/
193 [&](OpBuilder &b, Location loc) {
194 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
195 },
196 /*elseBuilder=*/
197 [&](OpBuilder &b, Location loc) {
198 if (outOfBoundsCase) {
199 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
200 } else {
201 b.create<scf::YieldOp>(loc);
202 }
203 });
204
205 return hasRetVal ? check.getResult(0) : Value();
206 }
207
208 // Condition is empty, no need for an SCF::IfOp.
209 return inBoundsCase(b, loc);
210 }
211
212 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
213 /// a return value. Consequently, this function does not have a return value.
214 template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,function_ref<void (OpBuilder &,Location)> inBoundsCase,function_ref<void (OpBuilder &,Location)> outOfBoundsCase=nullptr)215 static void generateInBoundsCheck(
216 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
217 function_ref<void(OpBuilder &, Location)> inBoundsCase,
218 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
219 generateInBoundsCheck(
220 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
221 /*inBoundsCase=*/
222 [&](OpBuilder &b, Location loc) {
223 inBoundsCase(b, loc);
224 return Value();
225 },
226 /*outOfBoundsCase=*/
227 [&](OpBuilder &b, Location loc) {
228 if (outOfBoundsCase)
229 outOfBoundsCase(b, loc);
230 return Value();
231 });
232 }
233
234 /// Given an ArrayAttr, return a copy where the first element is dropped.
dropFirstElem(OpBuilder & b,ArrayAttr attr)235 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
236 if (!attr)
237 return attr;
238 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
239 }
240
241 /// Add the pass label to a vector transfer op if its rank is not the target
242 /// rank.
243 template <typename OpTy>
maybeApplyPassLabel(OpBuilder & b,OpTy newXferOp,unsigned targetRank)244 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
245 unsigned targetRank) {
246 if (newXferOp.getVectorType().getRank() > targetRank)
247 newXferOp->setAttr(kPassLabel, b.getUnitAttr());
248 }
249
250 /// Return true if this transfer op operates on a source tensor.
251 template <typename OpTy>
isTensorOp(OpTy xferOp)252 static bool isTensorOp(OpTy xferOp) {
253 if (xferOp.getShapedType().template isa<RankedTensorType>()) {
254 if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
255 // TransferWriteOps on tensors have a result.
256 assert(xferOp->getNumResults() > 0);
257 }
258 return true;
259 }
260 return false;
261 }
262
263 namespace lowering_n_d {
264
265 /// Helper data structure for data and mask buffers.
266 struct BufferAllocs {
267 Value dataBuffer;
268 Value maskBuffer;
269 };
270
271 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
getAutomaticAllocationScope(Operation * op)272 static Operation *getAutomaticAllocationScope(Operation *op) {
273 Operation *scope =
274 op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
275 assert(scope && "Expected op to be inside automatic allocation scope");
276 return scope;
277 }
278
279 /// Allocate temporary buffers for data (vector) and mask (if present).
280 template <typename OpTy>
allocBuffers(OpBuilder & b,OpTy xferOp)281 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
282 Location loc = xferOp.getLoc();
283 OpBuilder::InsertionGuard guard(b);
284 Operation *scope = getAutomaticAllocationScope(xferOp);
285 assert(scope->getNumRegions() == 1 &&
286 "AutomaticAllocationScope with >1 regions");
287 b.setInsertionPointToStart(&scope->getRegion(0).front());
288
289 BufferAllocs result;
290 auto bufferType = MemRefType::get({}, xferOp.getVectorType());
291 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
292
293 if (xferOp.getMask()) {
294 auto maskType = MemRefType::get({}, xferOp.getMask().getType());
295 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
296 b.setInsertionPoint(xferOp);
297 b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
298 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
299 }
300
301 return result;
302 }
303
304 /// Given a MemRefType with VectorType element type, unpack one dimension from
305 /// the VectorType into the MemRefType.
306 ///
307 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
unpackOneDim(MemRefType type)308 static MemRefType unpackOneDim(MemRefType type) {
309 auto vectorType = type.getElementType().dyn_cast<VectorType>();
310 auto memrefShape = type.getShape();
311 SmallVector<int64_t, 8> newMemrefShape;
312 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
313 newMemrefShape.push_back(vectorType.getDimSize(0));
314 return MemRefType::get(newMemrefShape,
315 VectorType::get(vectorType.getShape().drop_front(),
316 vectorType.getElementType()));
317 }
318
319 /// Given a transfer op, find the memref from which the mask is loaded. This
320 /// is similar to Strategy<TransferWriteOp>::getBuffer.
321 template <typename OpTy>
getMaskBuffer(OpTy xferOp)322 static Value getMaskBuffer(OpTy xferOp) {
323 assert(xferOp.getMask() && "Expected that transfer op has mask");
324 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
325 assert(loadOp && "Expected transfer op mask produced by LoadOp");
326 return loadOp.getMemRef();
327 }
328
329 /// Codegen strategy, depending on the operation.
330 template <typename OpTy>
331 struct Strategy;
332
333 /// Code strategy for vector TransferReadOp.
334 template <>
335 struct Strategy<TransferReadOp> {
336 /// Find the StoreOp that is used for writing the current TransferReadOp's
337 /// result to the temporary buffer allocation.
getStoreOp__anon4d9edda10111::lowering_n_d::Strategy338 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
339 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
340 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
341 assert(storeOp && "Expected TransferReadOp result used by StoreOp");
342 return storeOp;
343 }
344
345 /// Find the temporary buffer allocation. All labeled TransferReadOps are
346 /// used like this, where %buf is either the buffer allocation or a type cast
347 /// of the buffer allocation:
348 /// ```
349 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
350 /// memref.store %vec, %buf[...] ...
351 /// ```
getBuffer__anon4d9edda10111::lowering_n_d::Strategy352 static Value getBuffer(TransferReadOp xferOp) {
353 return getStoreOp(xferOp).getMemRef();
354 }
355
356 /// Retrieve the indices of the current StoreOp that stores into the buffer.
getBufferIndices__anon4d9edda10111::lowering_n_d::Strategy357 static void getBufferIndices(TransferReadOp xferOp,
358 SmallVector<Value, 8> &indices) {
359 auto storeOp = getStoreOp(xferOp);
360 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
361 indices.append(prevIndices.begin(), prevIndices.end());
362 }
363
364 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
365 /// accesses on the to-be-unpacked dimension.
366 ///
367 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
368 /// variable `iv`.
369 /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
370 ///
371 /// E.g.:
372 /// ```
373 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
374 /// : memref<?x?x?xf32>, vector<4x3xf32>
375 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
376 /// ```
377 /// Is rewritten to:
378 /// ```
379 /// %casted = vector.type_cast %buf
380 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
381 /// for %j = 0 to 4 {
382 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
383 /// : memref<?x?x?xf32>, vector<3xf32>
384 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
385 /// }
386 /// ```
387 ///
388 /// Note: The loop and type cast are generated in TransferOpConversion.
389 /// The original TransferReadOp and store op are deleted in `cleanup`.
390 /// Note: The `mask` operand is set in TransferOpConversion.
rewriteOp__anon4d9edda10111::lowering_n_d::Strategy391 static TransferReadOp rewriteOp(OpBuilder &b,
392 VectorTransferToSCFOptions options,
393 TransferReadOp xferOp, Value buffer, Value iv,
394 ValueRange /*loopState*/) {
395 SmallVector<Value, 8> storeIndices;
396 getBufferIndices(xferOp, storeIndices);
397 storeIndices.push_back(iv);
398
399 SmallVector<Value, 8> xferIndices;
400 getXferIndices(b, xferOp, iv, xferIndices);
401
402 Location loc = xferOp.getLoc();
403 auto bufferType = buffer.getType().dyn_cast<ShapedType>();
404 auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
405 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
406 auto newXferOp = b.create<vector::TransferReadOp>(
407 loc, vecType, xferOp.getSource(), xferIndices,
408 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
409 xferOp.getPadding(), Value(), inBoundsAttr);
410
411 maybeApplyPassLabel(b, newXferOp, options.targetRank);
412
413 b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
414 return newXferOp;
415 }
416
417 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
418 /// padding value to the temporary buffer.
handleOutOfBoundsDim__anon4d9edda10111::lowering_n_d::Strategy419 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
420 Value buffer, Value iv,
421 ValueRange /*loopState*/) {
422 SmallVector<Value, 8> storeIndices;
423 getBufferIndices(xferOp, storeIndices);
424 storeIndices.push_back(iv);
425
426 Location loc = xferOp.getLoc();
427 auto bufferType = buffer.getType().dyn_cast<ShapedType>();
428 auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
429 auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
430 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
431
432 return Value();
433 }
434
435 /// Cleanup after rewriting the op.
cleanup__anon4d9edda10111::lowering_n_d::Strategy436 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
437 scf::ForOp /*forOp*/) {
438 rewriter.eraseOp(getStoreOp(xferOp));
439 rewriter.eraseOp(xferOp);
440 }
441
442 /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon4d9edda10111::lowering_n_d::Strategy443 static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
444 };
445
446 /// Codegen strategy for vector TransferWriteOp.
447 template <>
448 struct Strategy<TransferWriteOp> {
449 /// Find the temporary buffer allocation. All labeled TransferWriteOps are
450 /// used like this, where %buf is either the buffer allocation or a type cast
451 /// of the buffer allocation:
452 /// ```
453 /// %vec = memref.load %buf[...] ...
454 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
455 /// ```
getBuffer__anon4d9edda10111::lowering_n_d::Strategy456 static Value getBuffer(TransferWriteOp xferOp) {
457 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
458 assert(loadOp && "Expected transfer op vector produced by LoadOp");
459 return loadOp.getMemRef();
460 }
461
462 /// Retrieve the indices of the current LoadOp that loads from the buffer.
getBufferIndices__anon4d9edda10111::lowering_n_d::Strategy463 static void getBufferIndices(TransferWriteOp xferOp,
464 SmallVector<Value, 8> &indices) {
465 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
466 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
467 indices.append(prevIndices.begin(), prevIndices.end());
468 }
469
470 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
471 /// accesses on the to-be-unpacked dimension.
472 ///
473 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
474 /// using the loop iteration variable `iv`.
475 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
476 /// to memory.
477 ///
478 /// Note: For more details, see comments on Strategy<TransferReadOp>.
rewriteOp__anon4d9edda10111::lowering_n_d::Strategy479 static TransferWriteOp rewriteOp(OpBuilder &b,
480 VectorTransferToSCFOptions options,
481 TransferWriteOp xferOp, Value buffer,
482 Value iv, ValueRange loopState) {
483 SmallVector<Value, 8> loadIndices;
484 getBufferIndices(xferOp, loadIndices);
485 loadIndices.push_back(iv);
486
487 SmallVector<Value, 8> xferIndices;
488 getXferIndices(b, xferOp, iv, xferIndices);
489
490 Location loc = xferOp.getLoc();
491 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
492 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
493 auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
494 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
495 auto newXferOp = b.create<vector::TransferWriteOp>(
496 loc, type, vec, source, xferIndices,
497 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
498 inBoundsAttr);
499
500 maybeApplyPassLabel(b, newXferOp, options.targetRank);
501
502 return newXferOp;
503 }
504
505 /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
handleOutOfBoundsDim__anon4d9edda10111::lowering_n_d::Strategy506 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
507 Value buffer, Value iv,
508 ValueRange loopState) {
509 return isTensorOp(xferOp) ? loopState[0] : Value();
510 }
511
512 /// Cleanup after rewriting the op.
cleanup__anon4d9edda10111::lowering_n_d::Strategy513 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
514 scf::ForOp forOp) {
515 if (isTensorOp(xferOp)) {
516 assert(forOp->getNumResults() == 1 && "Expected one for loop result");
517 rewriter.replaceOp(xferOp, forOp->getResult(0));
518 } else {
519 rewriter.eraseOp(xferOp);
520 }
521 }
522
523 /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon4d9edda10111::lowering_n_d::Strategy524 static Value initialLoopState(TransferWriteOp xferOp) {
525 return isTensorOp(xferOp) ? xferOp.getSource() : Value();
526 }
527 };
528
529 template <typename OpTy>
checkPrepareXferOp(OpTy xferOp,VectorTransferToSCFOptions options)530 LogicalResult checkPrepareXferOp(OpTy xferOp,
531 VectorTransferToSCFOptions options) {
532 if (xferOp->hasAttr(kPassLabel))
533 return failure();
534 if (xferOp.getVectorType().getRank() <= options.targetRank)
535 return failure();
536 if (isTensorOp(xferOp) && !options.lowerTensors)
537 return failure();
538 // Transfer ops that modify the element type are not supported atm.
539 if (xferOp.getVectorType().getElementType() !=
540 xferOp.getShapedType().getElementType())
541 return failure();
542 return success();
543 }
544
545 /// Prepare a TransferReadOp for progressive lowering.
546 ///
547 /// 1. Allocate a temporary buffer.
548 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
549 /// 3. Store the result of the TransferReadOp into the temporary buffer.
550 /// 4. Load the result from the temporary buffer and replace all uses of the
551 /// original TransferReadOp with this load.
552 ///
553 /// E.g.:
554 /// ```
555 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
556 /// : vector<5x4xf32>, memref<?x?x?xf32>
557 /// ```
558 /// is rewritten to:
559 /// ```
560 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
561 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
562 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
563 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
564 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
565 /// ```
566 ///
567 /// Note: A second temporary buffer may be allocated for the `mask` operand.
568 struct PrepareTransferReadConversion
569 : public VectorToSCFPattern<TransferReadOp> {
570 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
571
matchAndRewrite__anon4d9edda10111::lowering_n_d::PrepareTransferReadConversion572 LogicalResult matchAndRewrite(TransferReadOp xferOp,
573 PatternRewriter &rewriter) const override {
574 if (checkPrepareXferOp(xferOp, options).failed())
575 return failure();
576
577 auto buffers = allocBuffers(rewriter, xferOp);
578 auto *newXfer = rewriter.clone(*xferOp.getOperation());
579 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
580 if (xferOp.getMask()) {
581 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
582 buffers.maskBuffer);
583 }
584
585 Location loc = xferOp.getLoc();
586 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
587 buffers.dataBuffer);
588 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
589
590 return success();
591 }
592 };
593
594 /// Prepare a TransferWriteOp for progressive lowering.
595 ///
596 /// 1. Allocate a temporary buffer.
597 /// 2. Store the vector into the buffer.
598 /// 3. Load the vector from the buffer again.
599 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
600 /// marking it eligible for progressive lowering via TransferOpConversion.
601 ///
602 /// E.g.:
603 /// ```
604 /// vector.transfer_write %vec, %A[%a, %b, %c]
605 /// : vector<5x4xf32>, memref<?x?x?xf32>
606 /// ```
607 /// is rewritten to:
608 /// ```
609 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
610 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
611 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
612 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
613 /// : vector<5x4xf32>, memref<?x?x?xf32>
614 /// ```
615 ///
616 /// Note: A second temporary buffer may be allocated for the `mask` operand.
617 struct PrepareTransferWriteConversion
618 : public VectorToSCFPattern<TransferWriteOp> {
619 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
620
matchAndRewrite__anon4d9edda10111::lowering_n_d::PrepareTransferWriteConversion621 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
622 PatternRewriter &rewriter) const override {
623 if (checkPrepareXferOp(xferOp, options).failed())
624 return failure();
625
626 Location loc = xferOp.getLoc();
627 auto buffers = allocBuffers(rewriter, xferOp);
628 rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
629 buffers.dataBuffer);
630 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
631 rewriter.updateRootInPlace(xferOp, [&]() {
632 xferOp.getVectorMutable().assign(loadedVec);
633 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
634 });
635
636 if (xferOp.getMask()) {
637 rewriter.updateRootInPlace(xferOp, [&]() {
638 xferOp.getMaskMutable().assign(buffers.maskBuffer);
639 });
640 }
641
642 return success();
643 }
644 };
645
646 /// Progressive lowering of vector transfer ops: Unpack one dimension.
647 ///
648 /// 1. Unpack one dimension from the current buffer type and cast the buffer
649 /// to that new type. E.g.:
650 /// ```
651 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
652 /// vector.transfer_write %vec ...
653 /// ```
654 /// The following cast is generated:
655 /// ```
656 /// %casted = vector.type_cast %0
657 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
658 /// ```
659 /// 2. Generate a for loop and rewrite the transfer op according to the
660 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
661 /// out-of-bounds, generate an if-check and handle both cases separately.
662 /// 3. Clean up according to the corresponding Strategy<OpTy>.
663 ///
664 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
665 /// source (as opposed to a memref source), then each iteration of the generated
666 /// scf.for loop yields the new tensor value. E.g.:
667 /// ```
668 /// %result = scf.for i = 0 to 5 {
669 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
670 /// %1 = vector.transfer_write %0, %source[...]
671 /// : vector<4x3xf32>, tensor<5x4x3xf32>
672 /// scf.yield %1 : tensor<5x4x3xf32>
673 /// }
674 /// ```
675 template <typename OpTy>
676 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
677 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
678
initialize__anon4d9edda10111::lowering_n_d::TransferOpConversion679 void initialize() {
680 // This pattern recursively unpacks one dimension at a time. The recursion
681 // bounded as the rank is strictly decreasing.
682 this->setHasBoundedRewriteRecursion();
683 }
684
matchAndRewrite__anon4d9edda10111::lowering_n_d::TransferOpConversion685 LogicalResult matchAndRewrite(OpTy xferOp,
686 PatternRewriter &rewriter) const override {
687 if (!xferOp->hasAttr(kPassLabel))
688 return failure();
689
690 // Find and cast data buffer. How the buffer can be found depends on OpTy.
691 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
692 auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
693 auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
694 auto castedDataType = unpackOneDim(dataBufferType);
695 auto castedDataBuffer =
696 locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
697
698 // If the xferOp has a mask: Find and cast mask buffer.
699 Value castedMaskBuffer;
700 if (xferOp.getMask()) {
701 auto maskBuffer = getMaskBuffer(xferOp);
702 auto maskBufferType =
703 maskBuffer.getType().template dyn_cast<MemRefType>();
704 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
705 // Do not unpack a dimension of the mask, if:
706 // * To-be-unpacked transfer op dimension is a broadcast.
707 // * Mask is 1D, i.e., the mask cannot be further unpacked.
708 // (That means that all remaining dimensions of the transfer op must
709 // be broadcasted.)
710 castedMaskBuffer = maskBuffer;
711 } else {
712 auto castedMaskType = unpackOneDim(maskBufferType);
713 castedMaskBuffer =
714 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
715 }
716 }
717
718 // Loop bounds and step.
719 auto lb = locB.create<arith::ConstantIndexOp>(0);
720 auto ub = locB.create<arith::ConstantIndexOp>(
721 castedDataType.getDimSize(castedDataType.getRank() - 1));
722 auto step = locB.create<arith::ConstantIndexOp>(1);
723 // TransferWriteOps that operate on tensors return the modified tensor and
724 // require a loop state.
725 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
726
727 // Generate for loop.
728 auto result = locB.create<scf::ForOp>(
729 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
730 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
731 Type stateType = loopState.empty() ? Type() : loopState[0].getType();
732
733 auto result = generateInBoundsCheck(
734 b, xferOp, iv, unpackedDim(xferOp),
735 stateType ? TypeRange(stateType) : TypeRange(),
736 /*inBoundsCase=*/
737 [&](OpBuilder &b, Location loc) {
738 // Create new transfer op.
739 OpTy newXfer = Strategy<OpTy>::rewriteOp(
740 b, this->options, xferOp, castedDataBuffer, iv, loopState);
741
742 // If old transfer op has a mask: Set mask on new transfer op.
743 // Special case: If the mask of the old transfer op is 1D and
744 // the
745 // unpacked dim is not a broadcast, no mask is
746 // needed on the new transfer op.
747 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
748 xferOp.getMaskType().getRank() > 1)) {
749 OpBuilder::InsertionGuard guard(b);
750 b.setInsertionPoint(newXfer); // Insert load before newXfer.
751
752 SmallVector<Value, 8> loadIndices;
753 Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
754 // In case of broadcast: Use same indices to load from memref
755 // as before.
756 if (!xferOp.isBroadcastDim(0))
757 loadIndices.push_back(iv);
758
759 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
760 loadIndices);
761 rewriter.updateRootInPlace(newXfer, [&]() {
762 newXfer.getMaskMutable().assign(mask);
763 });
764 }
765
766 return loopState.empty() ? Value() : newXfer->getResult(0);
767 },
768 /*outOfBoundsCase=*/
769 [&](OpBuilder &b, Location /*loc*/) {
770 return Strategy<OpTy>::handleOutOfBoundsDim(
771 b, xferOp, castedDataBuffer, iv, loopState);
772 });
773
774 maybeYieldValue(b, loc, !loopState.empty(), result);
775 });
776
777 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
778 return success();
779 }
780 };
781
782 } // namespace lowering_n_d
783
784 namespace lowering_n_d_unrolled {
785
786 /// If the original transfer op has a mask, compute the mask of the new transfer
787 /// op (for the current iteration `i`) and assign it.
788 template <typename OpTy>
maybeAssignMask(OpBuilder & b,OpTy xferOp,OpTy newXferOp,int64_t i)789 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
790 int64_t i) {
791 if (!xferOp.getMask())
792 return;
793
794 if (xferOp.isBroadcastDim(0)) {
795 // To-be-unpacked dimension is a broadcast, which does not have a
796 // corresponding mask dimension. Mask attribute remains unchanged.
797 newXferOp.getMaskMutable().assign(xferOp.getMask());
798 return;
799 }
800
801 if (xferOp.getMaskType().getRank() > 1) {
802 // Unpack one dimension of the mask.
803 OpBuilder::InsertionGuard guard(b);
804 b.setInsertionPoint(newXferOp); // Insert load before newXfer.
805
806 llvm::SmallVector<int64_t, 1> indices({i});
807 Location loc = xferOp.getLoc();
808 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
809 newXferOp.getMaskMutable().assign(newMask);
810 }
811
812 // If we end up here: The mask of the old transfer op is 1D and the unpacked
813 // dim is not a broadcast, so no mask is needed on the new transfer op.
814 // `generateInBoundsCheck` will have evaluated the mask already.
815 }
816
817 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
818 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
819 /// memref buffer is allocated and the SCF loop is fully unrolled.
820 ///
821 /// ```
822 /// E.g.:
823 /// ```
824 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
825 /// : memref<?x?x?xf32>, vector<5x4xf32>
826 /// ```
827 /// is rewritten to IR such as (simplified):
828 /// ```
829 /// %v_init = splat %padding : vector<5x4xf32>
830 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
831 /// : memref<?x?x?xf32>, vector<4xf32>
832 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
833 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
834 /// : memref<?x?x?xf32>, vector<4xf32>
835 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
836 /// ...
837 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
838 /// : memref<?x?x?xf32>, vector<4xf32>
839 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
840 /// ```
841 ///
842 /// Note: As an optimization, if the result of the original TransferReadOp
843 /// was directly inserted into another vector, no new %v_init vector is created.
844 /// Instead, the new TransferReadOp results are inserted into that vector.
845 struct UnrollTransferReadConversion
846 : public VectorToSCFPattern<TransferReadOp> {
847 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
848
initialize__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion849 void initialize() {
850 // This pattern recursively unpacks one dimension at a time. The recursion
851 // bounded as the rank is strictly decreasing.
852 setHasBoundedRewriteRecursion();
853 }
854
855 /// Return the vector into which the newly created TransferReadOp results
856 /// are inserted.
getResultVector__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion857 Value getResultVector(TransferReadOp xferOp,
858 PatternRewriter &rewriter) const {
859 if (auto insertOp = getInsertOp(xferOp))
860 return insertOp.getDest();
861 Location loc = xferOp.getLoc();
862 return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
863 xferOp.getPadding());
864 }
865
866 /// If the result of the TransferReadOp has exactly one user, which is a
867 /// vector::InsertOp, return that operation.
getInsertOp__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion868 vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
869 if (xferOp->hasOneUse()) {
870 Operation *xferOpUser = *xferOp->getUsers().begin();
871 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
872 return insertOp;
873 }
874
875 return vector::InsertOp();
876 }
877
878 /// If the result of the TransferReadOp has exactly one user, which is a
879 /// vector::InsertOp, return that operation's indices.
getInsertionIndices__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion880 void getInsertionIndices(TransferReadOp xferOp,
881 SmallVector<int64_t, 8> &indices) const {
882 if (auto insertOp = getInsertOp(xferOp)) {
883 for (Attribute attr : insertOp.getPosition())
884 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
885 }
886 }
887
888 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
889 /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion890 LogicalResult matchAndRewrite(TransferReadOp xferOp,
891 PatternRewriter &rewriter) const override {
892 if (xferOp.getVectorType().getRank() <= options.targetRank)
893 return failure();
894 if (isTensorOp(xferOp) && !options.lowerTensors)
895 return failure();
896 // Transfer ops that modify the element type are not supported atm.
897 if (xferOp.getVectorType().getElementType() !=
898 xferOp.getShapedType().getElementType())
899 return failure();
900
901 auto insertOp = getInsertOp(xferOp);
902 auto vec = getResultVector(xferOp, rewriter);
903 auto vecType = vec.getType().dyn_cast<VectorType>();
904 auto xferVecType = xferOp.getVectorType();
905 auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
906 xferVecType.getElementType());
907 int64_t dimSize = xferVecType.getShape()[0];
908
909 // Generate fully unrolled loop of transfer ops.
910 Location loc = xferOp.getLoc();
911 for (int64_t i = 0; i < dimSize; ++i) {
912 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
913
914 vec = generateInBoundsCheck(
915 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
916 /*inBoundsCase=*/
917 [&](OpBuilder &b, Location loc) {
918 // Indices for the new transfer op.
919 SmallVector<Value, 8> xferIndices;
920 getXferIndices(b, xferOp, iv, xferIndices);
921
922 // Indices for the new vector.insert op.
923 SmallVector<int64_t, 8> insertionIndices;
924 getInsertionIndices(xferOp, insertionIndices);
925 insertionIndices.push_back(i);
926
927 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
928 auto newXferOp = b.create<vector::TransferReadOp>(
929 loc, newXferVecType, xferOp.getSource(), xferIndices,
930 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
931 xferOp.getPadding(), Value(), inBoundsAttr);
932 maybeAssignMask(b, xferOp, newXferOp, i);
933 return b.create<vector::InsertOp>(loc, newXferOp, vec,
934 insertionIndices);
935 },
936 /*outOfBoundsCase=*/
937 [&](OpBuilder &b, Location loc) {
938 // Loop through original (unmodified) vector.
939 return vec;
940 });
941 }
942
943 if (insertOp) {
944 // Rewrite single user of the old TransferReadOp, which was an InsertOp.
945 rewriter.replaceOp(insertOp, vec);
946 rewriter.eraseOp(xferOp);
947 } else {
948 rewriter.replaceOp(xferOp, vec);
949 }
950
951 return success();
952 }
953 };
954
955 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
956 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
957 /// memref buffer is allocated and the SCF loop is fully unrolled.
958 ///
959 /// ```
960 /// E.g.:
961 /// ```
962 /// vector.transfer_write %vec, %A[%a, %b, %c]
963 /// : vector<5x4xf32>, memref<?x?x?xf32>
964 /// ```
965 /// is rewritten to IR such as (simplified):
966 /// ```
967 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
968 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
969 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
970 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
971 /// ...
972 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
973 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
974 /// ```
975 ///
976 /// Note: As an optimization, if the vector of the original TransferWriteOp
977 /// was directly extracted from another vector via an ExtractOp `a`, extract
978 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
979 /// doing so, `a` may become dead, and the number of ExtractOps generated during
980 /// recursive application of this pattern will be minimal.
981 struct UnrollTransferWriteConversion
982 : public VectorToSCFPattern<TransferWriteOp> {
983 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
984
initialize__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion985 void initialize() {
986 // This pattern recursively unpacks one dimension at a time. The recursion
987 // bounded as the rank is strictly decreasing.
988 setHasBoundedRewriteRecursion();
989 }
990
991 /// Return the vector from which newly generated ExtracOps will extract.
getDataVector__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion992 Value getDataVector(TransferWriteOp xferOp) const {
993 if (auto extractOp = getExtractOp(xferOp))
994 return extractOp.getVector();
995 return xferOp.getVector();
996 }
997
998 /// If the input of the given TransferWriteOp is an ExtractOp, return it.
getExtractOp__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion999 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1000 if (auto *op = xferOp.getVector().getDefiningOp())
1001 return dyn_cast<vector::ExtractOp>(op);
1002 return vector::ExtractOp();
1003 }
1004
1005 /// If the input of the given TransferWriteOp is an ExtractOp, return its
1006 /// indices.
getExtractionIndices__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion1007 void getExtractionIndices(TransferWriteOp xferOp,
1008 SmallVector<int64_t, 8> &indices) const {
1009 if (auto extractOp = getExtractOp(xferOp)) {
1010 for (Attribute attr : extractOp.getPosition())
1011 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
1012 }
1013 }
1014
1015 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1016 /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion1017 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1018 PatternRewriter &rewriter) const override {
1019 if (xferOp.getVectorType().getRank() <= options.targetRank)
1020 return failure();
1021 if (isTensorOp(xferOp) && !options.lowerTensors)
1022 return failure();
1023 // Transfer ops that modify the element type are not supported atm.
1024 if (xferOp.getVectorType().getElementType() !=
1025 xferOp.getShapedType().getElementType())
1026 return failure();
1027
1028 auto vec = getDataVector(xferOp);
1029 auto xferVecType = xferOp.getVectorType();
1030 int64_t dimSize = xferVecType.getShape()[0];
1031 auto source = xferOp.getSource(); // memref or tensor to be written to.
1032 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1033
1034 // Generate fully unrolled loop of transfer ops.
1035 Location loc = xferOp.getLoc();
1036 for (int64_t i = 0; i < dimSize; ++i) {
1037 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1038
1039 auto updatedSource = generateInBoundsCheck(
1040 rewriter, xferOp, iv, unpackedDim(xferOp),
1041 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1042 /*inBoundsCase=*/
1043 [&](OpBuilder &b, Location loc) {
1044 // Indices for the new transfer op.
1045 SmallVector<Value, 8> xferIndices;
1046 getXferIndices(b, xferOp, iv, xferIndices);
1047
1048 // Indices for the new vector.extract op.
1049 SmallVector<int64_t, 8> extractionIndices;
1050 getExtractionIndices(xferOp, extractionIndices);
1051 extractionIndices.push_back(i);
1052
1053 auto extracted =
1054 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1055 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1056 auto newXferOp = b.create<vector::TransferWriteOp>(
1057 loc, sourceType, extracted, source, xferIndices,
1058 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1059 inBoundsAttr);
1060
1061 maybeAssignMask(b, xferOp, newXferOp, i);
1062
1063 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1064 },
1065 /*outOfBoundsCase=*/
1066 [&](OpBuilder &b, Location loc) {
1067 return isTensorOp(xferOp) ? source : Value();
1068 });
1069
1070 if (isTensorOp(xferOp))
1071 source = updatedSource;
1072 }
1073
1074 if (isTensorOp(xferOp))
1075 rewriter.replaceOp(xferOp, source);
1076 else
1077 rewriter.eraseOp(xferOp);
1078
1079 return success();
1080 }
1081 };
1082
1083 } // namespace lowering_n_d_unrolled
1084
1085 namespace lowering_1_d {
1086
1087 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1088 /// part of TransferOp1dConversion. Return the memref dimension on which
1089 /// the transfer is operating. A return value of None indicates a broadcast.
1090 template <typename OpTy>
1091 static Optional<int64_t>
get1dMemrefIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & memrefIndices)1092 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1093 SmallVector<Value, 8> &memrefIndices) {
1094 auto indices = xferOp.getIndices();
1095 auto map = xferOp.getPermutationMap();
1096 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1097
1098 memrefIndices.append(indices.begin(), indices.end());
1099 assert(map.getNumResults() == 1 &&
1100 "Expected 1 permutation map result for 1D transfer");
1101 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1102 Location loc = xferOp.getLoc();
1103 auto dim = expr.getPosition();
1104 AffineExpr d0, d1;
1105 bindDims(xferOp.getContext(), d0, d1);
1106 Value offset = memrefIndices[dim];
1107 memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1108 return dim;
1109 }
1110
1111 assert(xferOp.isBroadcastDim(0) &&
1112 "Expected AffineDimExpr or AffineConstantExpr");
1113 return None;
1114 }
1115
1116 /// Codegen strategy for TransferOp1dConversion, depending on the
1117 /// operation.
1118 template <typename OpTy>
1119 struct Strategy1d;
1120
1121 /// Codegen strategy for TransferReadOp.
1122 template <>
1123 struct Strategy1d<TransferReadOp> {
generateForLoopBody__anon4d9edda10111::lowering_1_d::Strategy1d1124 static void generateForLoopBody(OpBuilder &b, Location loc,
1125 TransferReadOp xferOp, Value iv,
1126 ValueRange loopState) {
1127 SmallVector<Value, 8> indices;
1128 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1129 auto vec = loopState[0];
1130
1131 // In case of out-of-bounds access, leave `vec` as is (was initialized with
1132 // padding value).
1133 auto nextVec = generateInBoundsCheck(
1134 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1135 /*inBoundsCase=*/
1136 [&](OpBuilder &b, Location loc) {
1137 Value val =
1138 b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
1139 return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1140 },
1141 /*outOfBoundsCase=*/
1142 [&](OpBuilder & /*b*/, Location loc) { return vec; });
1143 b.create<scf::YieldOp>(loc, nextVec);
1144 }
1145
initialLoopState__anon4d9edda10111::lowering_1_d::Strategy1d1146 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1147 // Inititalize vector with padding value.
1148 Location loc = xferOp.getLoc();
1149 return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1150 xferOp.getPadding());
1151 }
1152 };
1153
1154 /// Codegen strategy for TransferWriteOp.
1155 template <>
1156 struct Strategy1d<TransferWriteOp> {
generateForLoopBody__anon4d9edda10111::lowering_1_d::Strategy1d1157 static void generateForLoopBody(OpBuilder &b, Location loc,
1158 TransferWriteOp xferOp, Value iv,
1159 ValueRange /*loopState*/) {
1160 SmallVector<Value, 8> indices;
1161 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1162
1163 // Nothing to do in case of out-of-bounds access.
1164 generateInBoundsCheck(
1165 b, xferOp, iv, dim,
1166 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1167 auto val =
1168 b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1169 b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
1170 });
1171 b.create<scf::YieldOp>(loc);
1172 }
1173
initialLoopState__anon4d9edda10111::lowering_1_d::Strategy1d1174 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1175 return Value();
1176 }
1177 };
1178
1179 /// Return true if the last dimension of the MemRefType has unit stride.
isLastMemrefDimUnitStride(MemRefType type)1180 static bool isLastMemrefDimUnitStride(MemRefType type) {
1181 int64_t offset;
1182 SmallVector<int64_t, 4> strides;
1183 auto successStrides = getStridesAndOffset(type, strides, offset);
1184 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1185 }
1186
1187 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1188 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1189 /// vector load/stores due to non-unit strides or broadcasts:
1190 ///
1191 /// * Transfer dimension is not the last memref dimension
1192 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1193 /// * Memref has a layout map with non-unit stride on the last dimension
1194 ///
1195 /// This pattern generates IR as follows:
1196 ///
1197 /// 1. Generate a for loop iterating over each vector element.
1198 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1199 /// depending on OpTy.
1200 ///
1201 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1202 /// can be generated instead of TransferOp1dConversion. Add such a pattern
1203 /// to ConvertVectorToLLVM.
1204 ///
1205 /// E.g.:
1206 /// ```
1207 /// vector.transfer_write %vec, %A[%a, %b]
1208 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1209 /// : vector<9xf32>, memref<?x?xf32>
1210 /// ```
1211 /// Is rewritten to approximately the following pseudo-IR:
1212 /// ```
1213 /// for i = 0 to 9 {
1214 /// %t = vector.extractelement %vec[i] : vector<9xf32>
1215 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1216 /// }
1217 /// ```
1218 template <typename OpTy>
1219 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1220 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1221
matchAndRewrite__anon4d9edda10111::lowering_1_d::TransferOp1dConversion1222 LogicalResult matchAndRewrite(OpTy xferOp,
1223 PatternRewriter &rewriter) const override {
1224 // TODO: support 0-d corner case.
1225 if (xferOp.getTransferRank() == 0)
1226 return failure();
1227 auto map = xferOp.getPermutationMap();
1228 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1229
1230 if (!memRefType)
1231 return failure();
1232 if (xferOp.getVectorType().getRank() != 1)
1233 return failure();
1234 if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1235 return failure(); // Handled by ConvertVectorToLLVM
1236
1237 // Loop bounds, step, state...
1238 Location loc = xferOp.getLoc();
1239 auto vecType = xferOp.getVectorType();
1240 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1241 auto ub =
1242 rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1243 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1244 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1245
1246 // Generate for loop.
1247 rewriter.replaceOpWithNewOp<scf::ForOp>(
1248 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1249 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1250 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1251 });
1252
1253 return success();
1254 }
1255 };
1256
1257 } // namespace lowering_1_d
1258 } // namespace
1259
populateVectorToSCFConversionPatterns(RewritePatternSet & patterns,const VectorTransferToSCFOptions & options)1260 void mlir::populateVectorToSCFConversionPatterns(
1261 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1262 if (options.unroll) {
1263 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1264 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1265 patterns.getContext(), options);
1266 } else {
1267 patterns.add<lowering_n_d::PrepareTransferReadConversion,
1268 lowering_n_d::PrepareTransferWriteConversion,
1269 lowering_n_d::TransferOpConversion<TransferReadOp>,
1270 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1271 patterns.getContext(), options);
1272 }
1273
1274 if (options.targetRank == 1) {
1275 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1276 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1277 patterns.getContext(), options);
1278 }
1279 }
1280
1281 namespace {
1282
1283 struct ConvertVectorToSCFPass
1284 : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1285 ConvertVectorToSCFPass() = default;
ConvertVectorToSCFPass__anon4d9edda11411::ConvertVectorToSCFPass1286 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1287 this->fullUnroll = options.unroll;
1288 this->targetRank = options.targetRank;
1289 this->lowerPermutationMaps = options.lowerPermutationMaps;
1290 this->lowerTensors = options.lowerTensors;
1291 }
1292
runOnOperation__anon4d9edda11411::ConvertVectorToSCFPass1293 void runOnOperation() override {
1294 VectorTransferToSCFOptions options;
1295 options.unroll = fullUnroll;
1296 options.targetRank = targetRank;
1297 options.lowerPermutationMaps = lowerPermutationMaps;
1298 options.lowerTensors = lowerTensors;
1299
1300 // Lower permutation maps first.
1301 if (lowerPermutationMaps) {
1302 RewritePatternSet lowerTransferPatterns(&getContext());
1303 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1304 lowerTransferPatterns);
1305 (void)applyPatternsAndFoldGreedily(getOperation(),
1306 std::move(lowerTransferPatterns));
1307 }
1308
1309 RewritePatternSet patterns(&getContext());
1310 populateVectorToSCFConversionPatterns(patterns, options);
1311 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1312 }
1313 };
1314
1315 } // namespace
1316
1317 std::unique_ptr<Pass>
createConvertVectorToSCFPass(const VectorTransferToSCFOptions & options)1318 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1319 return std::make_unique<ConvertVectorToSCFPass>(options);
1320 }
1321