1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Dialect/Vector/VectorUtils.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/ImplicitLocOpBuilder.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "mlir/Transforms/Passes.h"
30 
31 using namespace mlir;
32 using vector::TransferReadOp;
33 using vector::TransferWriteOp;
34 
35 namespace {
36 
37 /// Attribute name used for labeling transfer ops during progressive lowering.
38 static const char kPassLabel[] = "__vector_to_scf_lowering__";
39 
40 /// Patterns that inherit from this struct have access to
41 /// VectorTransferToSCFOptions.
42 template <typename OpTy>
43 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
44   explicit VectorToSCFPattern(MLIRContext *context,
45                               VectorTransferToSCFOptions opt)
46       : OpRewritePattern<OpTy>(context), options(opt) {}
47 
48   VectorTransferToSCFOptions options;
49 };
50 
51 /// Given a vector transfer op, calculate which dimension of the `source`
52 /// memref should be unpacked in the next application of TransferOpConversion.
53 /// A return value of None indicates a broadcast.
54 template <typename OpTy>
55 static Optional<int64_t> unpackedDim(OpTy xferOp) {
56   auto map = xferOp.permutation_map();
57   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
58     return expr.getPosition();
59   }
60   assert(xferOp.isBroadcastDim(0) &&
61          "Expected AffineDimExpr or AffineConstantExpr");
62   return None;
63 }
64 
65 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
66 /// map is identical to the current permutation map, but the first result is
67 /// omitted.
68 template <typename OpTy>
69 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
70   auto map = xferOp.permutation_map();
71   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
72                         b.getContext());
73 }
74 
75 /// Calculate the indices for the new vector transfer op.
76 ///
77 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
78 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
79 ///                                 ^^^^^^
80 ///              `iv` is the iteration variable of the (new) surrounding loop.
81 template <typename OpTy>
82 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
83                            SmallVector<Value, 8> &indices) {
84   typename OpTy::Adaptor adaptor(xferOp);
85   // Corresponding memref dim of the vector dim that is unpacked.
86   auto dim = unpackedDim(xferOp);
87   auto prevIndices = adaptor.indices();
88   indices.append(prevIndices.begin(), prevIndices.end());
89 
90   Location loc = xferOp.getLoc();
91   bool isBroadcast = !dim.hasValue();
92   if (!isBroadcast) {
93     AffineExpr d0, d1;
94     bindDims(xferOp.getContext(), d0, d1);
95     Value offset = adaptor.indices()[dim.getValue()];
96     indices[dim.getValue()] =
97         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
98   }
99 }
100 
101 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
102                             Value value) {
103   if (hasRetVal) {
104     assert(value && "Expected non-empty value");
105     b.create<scf::YieldOp>(loc, value);
106   } else {
107     b.create<scf::YieldOp>(loc);
108   }
109 }
110 
111 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
112 /// is set to true. No such check is generated under following circumstances:
113 /// * xferOp does not have a mask.
114 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
115 ///   computed and attached to the new transfer op in the pattern.)
116 /// * The to-be-unpacked dim of xferOp is a broadcast.
117 template <typename OpTy>
118 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
119   if (!xferOp.mask())
120     return Value();
121   if (xferOp.getMaskType().getRank() != 1)
122     return Value();
123   if (xferOp.isBroadcastDim(0))
124     return Value();
125 
126   Location loc = xferOp.getLoc();
127   Value ivI32 = b.create<arith::IndexCastOp>(
128       loc, IntegerType::get(b.getContext(), 32), iv);
129   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
130 }
131 
132 /// Helper function TransferOpConversion and TransferOp1dConversion.
133 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
134 /// specified dimension `dim` with the loop iteration variable `iv`.
135 /// E.g., when unpacking dimension 0 from:
136 /// ```
137 /// %vec = vector.transfer_read %A[%a, %b] %cst
138 ///     : vector<5x4xf32>, memref<?x?xf32>
139 /// ```
140 /// An if check similar to this will be generated inside the loop:
141 /// ```
142 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
143 /// if (%a + iv < %d) {
144 ///   (in-bounds case)
145 /// } else {
146 ///   (out-of-bounds case)
147 /// }
148 /// ```
149 ///
150 /// If the transfer is 1D and has a mask, this function generates a more complex
151 /// check also accounts for potentially masked out elements.
152 ///
153 /// This function variant returns the value returned by `inBoundsCase` or
154 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
155 /// `resultTypes`.
156 template <typename OpTy>
157 static Value generateInBoundsCheck(
158     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
159     TypeRange resultTypes,
160     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
161     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
162   bool hasRetVal = !resultTypes.empty();
163   Value cond; // Condition to be built...
164 
165   // Condition check 1: Access in-bounds?
166   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
167   Location loc = xferOp.getLoc();
168   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
169   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
170     Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim);
171     AffineExpr d0, d1;
172     bindDims(xferOp.getContext(), d0, d1);
173     Value base = xferOp.indices()[dim.getValue()];
174     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
175     cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
176                                     memrefIdx);
177   }
178 
179   // Condition check 2: Masked in?
180   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
181     if (cond)
182       cond = lb.create<arith::AndIOp>(cond, maskCond);
183     else
184       cond = maskCond;
185   }
186 
187   // If the condition is non-empty, generate an SCF::IfOp.
188   if (cond) {
189     auto check = lb.create<scf::IfOp>(
190         resultTypes, cond,
191         /*thenBuilder=*/
192         [&](OpBuilder &b, Location loc) {
193           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
194         },
195         /*elseBuilder=*/
196         [&](OpBuilder &b, Location loc) {
197           if (outOfBoundsCase) {
198             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
199           } else {
200             b.create<scf::YieldOp>(loc);
201           }
202         });
203 
204     return hasRetVal ? check.getResult(0) : Value();
205   }
206 
207   // Condition is empty, no need for an SCF::IfOp.
208   return inBoundsCase(b, loc);
209 }
210 
211 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
212 /// a return value. Consequently, this function does not have a return value.
213 template <typename OpTy>
214 static void generateInBoundsCheck(
215     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
216     function_ref<void(OpBuilder &, Location)> inBoundsCase,
217     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
218   generateInBoundsCheck(
219       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
220       /*inBoundsCase=*/
221       [&](OpBuilder &b, Location loc) {
222         inBoundsCase(b, loc);
223         return Value();
224       },
225       /*outOfBoundsCase=*/
226       [&](OpBuilder &b, Location loc) {
227         if (outOfBoundsCase)
228           outOfBoundsCase(b, loc);
229         return Value();
230       });
231 }
232 
233 /// Given an ArrayAttr, return a copy where the first element is dropped.
234 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
235   if (!attr)
236     return attr;
237   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
238 }
239 
240 /// Add the pass label to a vector transfer op if its rank is not the target
241 /// rank.
242 template <typename OpTy>
243 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
244                                 unsigned targetRank) {
245   if (newXferOp.getVectorType().getRank() > targetRank)
246     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
247 }
248 
249 /// Return true if this transfer op operates on a source tensor.
250 template <typename OpTy>
251 static bool isTensorOp(OpTy xferOp) {
252   if (xferOp.getShapedType().template isa<RankedTensorType>()) {
253     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
254       // TransferWriteOps on tensors have a result.
255       assert(xferOp->getNumResults() > 0);
256     }
257     return true;
258   }
259   return false;
260 }
261 
262 namespace lowering_n_d {
263 
264 /// Helper data structure for data and mask buffers.
265 struct BufferAllocs {
266   Value dataBuffer;
267   Value maskBuffer;
268 };
269 
270 /// Allocate temporary buffers for data (vector) and mask (if present).
271 /// TODO: Parallelism and threadlocal considerations.
272 template <typename OpTy>
273 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
274   Location loc = xferOp.getLoc();
275   OpBuilder::InsertionGuard guard(b);
276   Operation *scope =
277       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
278   assert(scope && "Expected op to be inside automatic allocation scope");
279   b.setInsertionPointToStart(&scope->getRegion(0).front());
280 
281   BufferAllocs result;
282   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
283   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
284 
285   if (xferOp.mask()) {
286     auto maskType = MemRefType::get({}, xferOp.mask().getType());
287     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
288     b.setInsertionPoint(xferOp);
289     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
290     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
291   }
292 
293   return result;
294 }
295 
296 /// Given a MemRefType with VectorType element type, unpack one dimension from
297 /// the VectorType into the MemRefType.
298 ///
299 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
300 static MemRefType unpackOneDim(MemRefType type) {
301   auto vectorType = type.getElementType().dyn_cast<VectorType>();
302   auto memrefShape = type.getShape();
303   SmallVector<int64_t, 8> newMemrefShape;
304   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
305   newMemrefShape.push_back(vectorType.getDimSize(0));
306   return MemRefType::get(newMemrefShape,
307                          VectorType::get(vectorType.getShape().drop_front(),
308                                          vectorType.getElementType()));
309 }
310 
311 /// Given a transfer op, find the memref from which the mask is loaded. This
312 /// is similar to Strategy<TransferWriteOp>::getBuffer.
313 template <typename OpTy>
314 static Value getMaskBuffer(OpTy xferOp) {
315   assert(xferOp.mask() && "Expected that transfer op has mask");
316   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
317   assert(loadOp && "Expected transfer op mask produced by LoadOp");
318   return loadOp.getMemRef();
319 }
320 
321 /// Codegen strategy, depending on the operation.
322 template <typename OpTy>
323 struct Strategy;
324 
325 /// Code strategy for vector TransferReadOp.
326 template <>
327 struct Strategy<TransferReadOp> {
328   /// Find the StoreOp that is used for writing the current TransferReadOp's
329   /// result to the temporary buffer allocation.
330   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
331     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
332     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
333     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
334     return storeOp;
335   }
336 
337   /// Find the temporary buffer allocation. All labeled TransferReadOps are
338   /// used like this, where %buf is either the buffer allocation or a type cast
339   /// of the buffer allocation:
340   /// ```
341   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
342   /// memref.store %vec, %buf[...] ...
343   /// ```
344   static Value getBuffer(TransferReadOp xferOp) {
345     return getStoreOp(xferOp).getMemRef();
346   }
347 
348   /// Retrieve the indices of the current StoreOp that stores into the buffer.
349   static void getBufferIndices(TransferReadOp xferOp,
350                                SmallVector<Value, 8> &indices) {
351     auto storeOp = getStoreOp(xferOp);
352     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
353     indices.append(prevIndices.begin(), prevIndices.end());
354   }
355 
356   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
357   /// accesses on the to-be-unpacked dimension.
358   ///
359   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
360   ///    variable `iv`.
361   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
362   ///
363   /// E.g.:
364   /// ```
365   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
366   ///     : memref<?x?x?xf32>, vector<4x3xf32>
367   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
368   /// ```
369   /// Is rewritten to:
370   /// ```
371   /// %casted = vector.type_cast %buf
372   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
373   /// for %j = 0 to 4 {
374   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
375   ///       : memref<?x?x?xf32>, vector<3xf32>
376   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
377   /// }
378   /// ```
379   ///
380   /// Note: The loop and type cast are generated in TransferOpConversion.
381   ///       The original TransferReadOp and store op are deleted in `cleanup`.
382   /// Note: The `mask` operand is set in TransferOpConversion.
383   static TransferReadOp rewriteOp(OpBuilder &b,
384                                   VectorTransferToSCFOptions options,
385                                   TransferReadOp xferOp, Value buffer, Value iv,
386                                   ValueRange /*loopState*/) {
387     SmallVector<Value, 8> storeIndices;
388     getBufferIndices(xferOp, storeIndices);
389     storeIndices.push_back(iv);
390 
391     SmallVector<Value, 8> xferIndices;
392     getXferIndices(b, xferOp, iv, xferIndices);
393 
394     Location loc = xferOp.getLoc();
395     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
396     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
397     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
398     auto newXferOp = b.create<vector::TransferReadOp>(
399         loc, vecType, xferOp.source(), xferIndices,
400         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
401         Value(), inBoundsAttr);
402 
403     maybeApplyPassLabel(b, newXferOp, options.targetRank);
404 
405     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
406     return newXferOp;
407   }
408 
409   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
410   /// padding value to the temporary buffer.
411   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
412                                     Value buffer, Value iv,
413                                     ValueRange /*loopState*/) {
414     SmallVector<Value, 8> storeIndices;
415     getBufferIndices(xferOp, storeIndices);
416     storeIndices.push_back(iv);
417 
418     Location loc = xferOp.getLoc();
419     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
420     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
421     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
422     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
423 
424     return Value();
425   }
426 
427   /// Cleanup after rewriting the op.
428   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
429                       scf::ForOp /*forOp*/) {
430     rewriter.eraseOp(getStoreOp(xferOp));
431     rewriter.eraseOp(xferOp);
432   }
433 
434   /// Return the initial loop state for the generated scf.for loop.
435   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
436 };
437 
438 /// Codegen strategy for vector TransferWriteOp.
439 template <>
440 struct Strategy<TransferWriteOp> {
441   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
442   /// used like this, where %buf is either the buffer allocation or a type cast
443   /// of the buffer allocation:
444   /// ```
445   /// %vec = memref.load %buf[...] ...
446   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
447   /// ```
448   static Value getBuffer(TransferWriteOp xferOp) {
449     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
450     assert(loadOp && "Expected transfer op vector produced by LoadOp");
451     return loadOp.getMemRef();
452   }
453 
454   /// Retrieve the indices of the current LoadOp that loads from the buffer.
455   static void getBufferIndices(TransferWriteOp xferOp,
456                                SmallVector<Value, 8> &indices) {
457     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
458     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
459     indices.append(prevIndices.begin(), prevIndices.end());
460   }
461 
462   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
463   /// accesses on the to-be-unpacked dimension.
464   ///
465   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
466   ///    using the loop iteration variable `iv`.
467   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
468   ///    to memory.
469   ///
470   /// Note: For more details, see comments on Strategy<TransferReadOp>.
471   static TransferWriteOp rewriteOp(OpBuilder &b,
472                                    VectorTransferToSCFOptions options,
473                                    TransferWriteOp xferOp, Value buffer,
474                                    Value iv, ValueRange loopState) {
475     SmallVector<Value, 8> loadIndices;
476     getBufferIndices(xferOp, loadIndices);
477     loadIndices.push_back(iv);
478 
479     SmallVector<Value, 8> xferIndices;
480     getXferIndices(b, xferOp, iv, xferIndices);
481 
482     Location loc = xferOp.getLoc();
483     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
484     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
485     auto source = loopState.empty() ? xferOp.source() : loopState[0];
486     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
487     auto newXferOp = b.create<vector::TransferWriteOp>(
488         loc, type, vec, source, xferIndices,
489         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
490         inBoundsAttr);
491 
492     maybeApplyPassLabel(b, newXferOp, options.targetRank);
493 
494     return newXferOp;
495   }
496 
497   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
498   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
499                                     Value buffer, Value iv,
500                                     ValueRange loopState) {
501     return isTensorOp(xferOp) ? loopState[0] : Value();
502   }
503 
504   /// Cleanup after rewriting the op.
505   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
506                       scf::ForOp forOp) {
507     if (isTensorOp(xferOp)) {
508       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
509       rewriter.replaceOp(xferOp, forOp->getResult(0));
510     } else {
511       rewriter.eraseOp(xferOp);
512     }
513   }
514 
515   /// Return the initial loop state for the generated scf.for loop.
516   static Value initialLoopState(TransferWriteOp xferOp) {
517     return isTensorOp(xferOp) ? xferOp.source() : Value();
518   }
519 };
520 
521 template <typename OpTy>
522 LogicalResult checkPrepareXferOp(OpTy xferOp,
523                                  VectorTransferToSCFOptions options) {
524   if (xferOp->hasAttr(kPassLabel))
525     return failure();
526   if (xferOp.getVectorType().getRank() <= options.targetRank)
527     return failure();
528   if (isTensorOp(xferOp) && !options.lowerTensors)
529     return failure();
530   // Transfer ops that modify the element type are not supported atm.
531   if (xferOp.getVectorType().getElementType() !=
532       xferOp.getShapedType().getElementType())
533     return failure();
534   return success();
535 }
536 
537 /// Prepare a TransferReadOp for progressive lowering.
538 ///
539 /// 1. Allocate a temporary buffer.
540 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
541 /// 3. Store the result of the TransferReadOp into the temporary buffer.
542 /// 4. Load the result from the temporary buffer and replace all uses of the
543 ///    original TransferReadOp with this load.
544 ///
545 /// E.g.:
546 /// ```
547 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
548 ///     : vector<5x4xf32>, memref<?x?x?xf32>
549 /// ```
550 /// is rewritten to:
551 /// ```
552 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
553 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
554 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
555 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
556 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
557 /// ```
558 ///
559 /// Note: A second temporary buffer may be allocated for the `mask` operand.
560 struct PrepareTransferReadConversion
561     : public VectorToSCFPattern<TransferReadOp> {
562   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
563 
564   LogicalResult matchAndRewrite(TransferReadOp xferOp,
565                                 PatternRewriter &rewriter) const override {
566     if (checkPrepareXferOp(xferOp, options).failed())
567       return failure();
568 
569     auto buffers = allocBuffers(rewriter, xferOp);
570     auto *newXfer = rewriter.clone(*xferOp.getOperation());
571     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
572     if (xferOp.mask()) {
573       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
574           buffers.maskBuffer);
575     }
576 
577     Location loc = xferOp.getLoc();
578     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
579                                      buffers.dataBuffer);
580     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
581 
582     return success();
583   }
584 };
585 
586 /// Prepare a TransferWriteOp for progressive lowering.
587 ///
588 /// 1. Allocate a temporary buffer.
589 /// 2. Store the vector into the buffer.
590 /// 3. Load the vector from the buffer again.
591 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
592 ///    marking it eligible for progressive lowering via TransferOpConversion.
593 ///
594 /// E.g.:
595 /// ```
596 /// vector.transfer_write %vec, %A[%a, %b, %c]
597 ///     : vector<5x4xf32>, memref<?x?x?xf32>
598 /// ```
599 /// is rewritten to:
600 /// ```
601 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
602 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
603 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
604 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
605 ///     : vector<5x4xf32>, memref<?x?x?xf32>
606 /// ```
607 ///
608 /// Note: A second temporary buffer may be allocated for the `mask` operand.
609 struct PrepareTransferWriteConversion
610     : public VectorToSCFPattern<TransferWriteOp> {
611   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
612 
613   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
614                                 PatternRewriter &rewriter) const override {
615     if (checkPrepareXferOp(xferOp, options).failed())
616       return failure();
617 
618     Location loc = xferOp.getLoc();
619     auto buffers = allocBuffers(rewriter, xferOp);
620     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
621     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
622     rewriter.updateRootInPlace(xferOp, [&]() {
623       xferOp.vectorMutable().assign(loadedVec);
624       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
625     });
626 
627     if (xferOp.mask()) {
628       rewriter.updateRootInPlace(
629           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
630     }
631 
632     return success();
633   }
634 };
635 
636 /// Progressive lowering of vector transfer ops: Unpack one dimension.
637 ///
638 /// 1. Unpack one dimension from the current buffer type and cast the buffer
639 ///    to that new type. E.g.:
640 ///    ```
641 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
642 ///    vector.transfer_write %vec ...
643 ///    ```
644 ///    The following cast is generated:
645 ///    ```
646 ///    %casted = vector.type_cast %0
647 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
648 ///    ```
649 /// 2. Generate a for loop and rewrite the transfer op according to the
650 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
651 ///    out-of-bounds, generate an if-check and handle both cases separately.
652 /// 3. Clean up according to the corresponding Strategy<OpTy>.
653 ///
654 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
655 /// source (as opposed to a memref source), then each iteration of the generated
656 /// scf.for loop yields the new tensor value. E.g.:
657 /// ```
658 /// %result = scf.for i = 0 to 5 {
659 ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
660 ///   %1 = vector.transfer_write %0, %source[...]
661 ///       : vector<4x3xf32>, tensor<5x4x3xf32>
662 ///   scf.yield %1 : tensor<5x4x3xf32>
663 /// }
664 /// ```
665 template <typename OpTy>
666 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
667   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
668 
669   void initialize() {
670     // This pattern recursively unpacks one dimension at a time. The recursion
671     // bounded as the rank is strictly decreasing.
672     this->setHasBoundedRewriteRecursion();
673   }
674 
675   LogicalResult matchAndRewrite(OpTy xferOp,
676                                 PatternRewriter &rewriter) const override {
677     if (!xferOp->hasAttr(kPassLabel))
678       return failure();
679 
680     // Find and cast data buffer. How the buffer can be found depends on OpTy.
681     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
682     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
683     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
684     auto castedDataType = unpackOneDim(dataBufferType);
685     auto castedDataBuffer =
686         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
687 
688     // If the xferOp has a mask: Find and cast mask buffer.
689     Value castedMaskBuffer;
690     if (xferOp.mask()) {
691       auto maskBuffer = getMaskBuffer(xferOp);
692       auto maskBufferType =
693           maskBuffer.getType().template dyn_cast<MemRefType>();
694       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
695         // Do not unpack a dimension of the mask, if:
696         // * To-be-unpacked transfer op dimension is a broadcast.
697         // * Mask is 1D, i.e., the mask cannot be further unpacked.
698         //   (That means that all remaining dimensions of the transfer op must
699         //   be broadcasted.)
700         castedMaskBuffer = maskBuffer;
701       } else {
702         auto castedMaskType = unpackOneDim(maskBufferType);
703         castedMaskBuffer =
704             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
705       }
706     }
707 
708     // Loop bounds and step.
709     auto lb = locB.create<arith::ConstantIndexOp>(0);
710     auto ub = locB.create<arith::ConstantIndexOp>(
711         castedDataType.getDimSize(castedDataType.getRank() - 1));
712     auto step = locB.create<arith::ConstantIndexOp>(1);
713     // TransferWriteOps that operate on tensors return the modified tensor and
714     // require a loop state.
715     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
716 
717     // Generate for loop.
718     auto result = locB.create<scf::ForOp>(
719         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
720         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
721           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
722 
723           auto result = generateInBoundsCheck(
724               b, xferOp, iv, unpackedDim(xferOp),
725               stateType ? TypeRange(stateType) : TypeRange(),
726               /*inBoundsCase=*/
727               [&](OpBuilder &b, Location loc) {
728                 // Create new transfer op.
729                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
730                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
731 
732                 // If old transfer op has a mask: Set mask on new transfer op.
733                 // Special case: If the mask of the old transfer op is 1D and
734                 // the
735                 //               unpacked dim is not a broadcast, no mask is
736                 //               needed on the new transfer op.
737                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
738                                       xferOp.getMaskType().getRank() > 1)) {
739                   OpBuilder::InsertionGuard guard(b);
740                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
741 
742                   SmallVector<Value, 8> loadIndices;
743                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
744                   // In case of broadcast: Use same indices to load from memref
745                   // as before.
746                   if (!xferOp.isBroadcastDim(0))
747                     loadIndices.push_back(iv);
748 
749                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
750                                                        loadIndices);
751                   rewriter.updateRootInPlace(
752                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
753                 }
754 
755                 return loopState.empty() ? Value() : newXfer->getResult(0);
756               },
757               /*outOfBoundsCase=*/
758               [&](OpBuilder &b, Location /*loc*/) {
759                 return Strategy<OpTy>::handleOutOfBoundsDim(
760                     b, xferOp, castedDataBuffer, iv, loopState);
761               });
762 
763           maybeYieldValue(b, loc, !loopState.empty(), result);
764         });
765 
766     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
767     return success();
768   }
769 };
770 
771 } // namespace lowering_n_d
772 
773 namespace lowering_n_d_unrolled {
774 
775 /// If the original transfer op has a mask, compute the mask of the new transfer
776 /// op (for the current iteration `i`) and assign it.
777 template <typename OpTy>
778 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
779                             int64_t i) {
780   if (!xferOp.mask())
781     return;
782 
783   if (xferOp.isBroadcastDim(0)) {
784     // To-be-unpacked dimension is a broadcast, which does not have a
785     // corresponding mask dimension. Mask attribute remains unchanged.
786     newXferOp.maskMutable().assign(xferOp.mask());
787     return;
788   }
789 
790   if (xferOp.getMaskType().getRank() > 1) {
791     // Unpack one dimension of the mask.
792     OpBuilder::InsertionGuard guard(b);
793     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
794 
795     llvm::SmallVector<int64_t, 1> indices({i});
796     Location loc = xferOp.getLoc();
797     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
798     newXferOp.maskMutable().assign(newMask);
799   }
800 
801   // If we end up here: The mask of the old transfer op is 1D and the unpacked
802   // dim is not a broadcast, so no mask is needed on the new transfer op.
803   // `generateInBoundsCheck` will have evaluated the mask already.
804 }
805 
806 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
807 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
808 /// memref buffer is allocated and the SCF loop is fully unrolled.
809 ///
810 /// ```
811 /// E.g.:
812 /// ```
813 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
814 ///     : memref<?x?x?xf32>, vector<5x4xf32>
815 /// ```
816 /// is rewritten to IR such as (simplified):
817 /// ```
818 /// %v_init = splat %padding : vector<5x4xf32>
819 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
820 ///     : memref<?x?x?xf32>, vector<4xf32>
821 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
822 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
823 ///     : memref<?x?x?xf32>, vector<4xf32>
824 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
825 /// ...
826 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
827 ///     : memref<?x?x?xf32>, vector<4xf32>
828 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
829 /// ```
830 ///
831 /// Note: As an optimization, if the result of the original TransferReadOp
832 /// was directly inserted into another vector, no new %v_init vector is created.
833 /// Instead, the new TransferReadOp results are inserted into that vector.
834 struct UnrollTransferReadConversion
835     : public VectorToSCFPattern<TransferReadOp> {
836   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
837 
838   void initialize() {
839     // This pattern recursively unpacks one dimension at a time. The recursion
840     // bounded as the rank is strictly decreasing.
841     setHasBoundedRewriteRecursion();
842   }
843 
844   /// Return the vector into which the newly created TransferReadOp results
845   /// are inserted.
846   Value getResultVector(TransferReadOp xferOp,
847                         PatternRewriter &rewriter) const {
848     if (auto insertOp = getInsertOp(xferOp))
849       return insertOp.dest();
850     Location loc = xferOp.getLoc();
851     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
852                                     xferOp.padding());
853   }
854 
855   /// If the result of the TransferReadOp has exactly one user, which is a
856   /// vector::InsertOp, return that operation.
857   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
858     if (xferOp->hasOneUse()) {
859       Operation *xferOpUser = *xferOp->getUsers().begin();
860       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
861         return insertOp;
862     }
863 
864     return vector::InsertOp();
865   }
866 
867   /// If the result of the TransferReadOp has exactly one user, which is a
868   /// vector::InsertOp, return that operation's indices.
869   void getInsertionIndices(TransferReadOp xferOp,
870                            SmallVector<int64_t, 8> &indices) const {
871     if (auto insertOp = getInsertOp(xferOp)) {
872       llvm::for_each(insertOp.position(), [&](Attribute attr) {
873         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
874       });
875     }
876   }
877 
878   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
879   /// accesses, and broadcasts and transposes in permutation maps.
880   LogicalResult matchAndRewrite(TransferReadOp xferOp,
881                                 PatternRewriter &rewriter) const override {
882     if (xferOp.getVectorType().getRank() <= options.targetRank)
883       return failure();
884     if (isTensorOp(xferOp) && !options.lowerTensors)
885       return failure();
886     // Transfer ops that modify the element type are not supported atm.
887     if (xferOp.getVectorType().getElementType() !=
888         xferOp.getShapedType().getElementType())
889       return failure();
890 
891     auto insertOp = getInsertOp(xferOp);
892     auto vec = getResultVector(xferOp, rewriter);
893     auto vecType = vec.getType().dyn_cast<VectorType>();
894     auto xferVecType = xferOp.getVectorType();
895     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
896                                           xferVecType.getElementType());
897     int64_t dimSize = xferVecType.getShape()[0];
898 
899     // Generate fully unrolled loop of transfer ops.
900     Location loc = xferOp.getLoc();
901     for (int64_t i = 0; i < dimSize; ++i) {
902       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
903 
904       vec = generateInBoundsCheck(
905           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
906           /*inBoundsCase=*/
907           [&](OpBuilder &b, Location loc) {
908             // Indices for the new transfer op.
909             SmallVector<Value, 8> xferIndices;
910             getXferIndices(b, xferOp, iv, xferIndices);
911 
912             // Indices for the new vector.insert op.
913             SmallVector<int64_t, 8> insertionIndices;
914             getInsertionIndices(xferOp, insertionIndices);
915             insertionIndices.push_back(i);
916 
917             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
918             auto newXferOp = b.create<vector::TransferReadOp>(
919                 loc, newXferVecType, xferOp.source(), xferIndices,
920                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
921                 xferOp.padding(), Value(), inBoundsAttr);
922             maybeAssignMask(b, xferOp, newXferOp, i);
923             return b.create<vector::InsertOp>(loc, newXferOp, vec,
924                                               insertionIndices);
925           },
926           /*outOfBoundsCase=*/
927           [&](OpBuilder &b, Location loc) {
928             // Loop through original (unmodified) vector.
929             return vec;
930           });
931     }
932 
933     if (insertOp) {
934       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
935       rewriter.replaceOp(insertOp, vec);
936       rewriter.eraseOp(xferOp);
937     } else {
938       rewriter.replaceOp(xferOp, vec);
939     }
940 
941     return success();
942   }
943 };
944 
945 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
946 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
947 /// memref buffer is allocated and the SCF loop is fully unrolled.
948 ///
949 /// ```
950 /// E.g.:
951 /// ```
952 /// vector.transfer_write %vec, %A[%a, %b, %c]
953 ///     : vector<5x4xf32>, memref<?x?x?xf32>
954 /// ```
955 /// is rewritten to IR such as (simplified):
956 /// ```
957 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
958 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
959 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
960 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
961 /// ...
962 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
963 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
964 /// ```
965 ///
966 /// Note: As an optimization, if the vector of the original TransferWriteOp
967 /// was directly extracted from another vector via an ExtractOp `a`, extract
968 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
969 /// doing so, `a` may become dead, and the number of ExtractOps generated during
970 /// recursive application of this pattern will be minimal.
971 struct UnrollTransferWriteConversion
972     : public VectorToSCFPattern<TransferWriteOp> {
973   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
974 
975   void initialize() {
976     // This pattern recursively unpacks one dimension at a time. The recursion
977     // bounded as the rank is strictly decreasing.
978     setHasBoundedRewriteRecursion();
979   }
980 
981   /// Return the vector from which newly generated ExtracOps will extract.
982   Value getDataVector(TransferWriteOp xferOp) const {
983     if (auto extractOp = getExtractOp(xferOp))
984       return extractOp.vector();
985     return xferOp.vector();
986   }
987 
988   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
989   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
990     if (auto *op = xferOp.vector().getDefiningOp())
991       return dyn_cast<vector::ExtractOp>(op);
992     return vector::ExtractOp();
993   }
994 
995   /// If the input of the given TransferWriteOp is an ExtractOp, return its
996   /// indices.
997   void getExtractionIndices(TransferWriteOp xferOp,
998                             SmallVector<int64_t, 8> &indices) const {
999     if (auto extractOp = getExtractOp(xferOp)) {
1000       llvm::for_each(extractOp.position(), [&](Attribute attr) {
1001         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
1002       });
1003     }
1004   }
1005 
1006   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1007   /// accesses, and broadcasts and transposes in permutation maps.
1008   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1009                                 PatternRewriter &rewriter) const override {
1010     if (xferOp.getVectorType().getRank() <= options.targetRank)
1011       return failure();
1012     if (isTensorOp(xferOp) && !options.lowerTensors)
1013       return failure();
1014     // Transfer ops that modify the element type are not supported atm.
1015     if (xferOp.getVectorType().getElementType() !=
1016         xferOp.getShapedType().getElementType())
1017       return failure();
1018 
1019     auto vec = getDataVector(xferOp);
1020     auto xferVecType = xferOp.getVectorType();
1021     int64_t dimSize = xferVecType.getShape()[0];
1022     auto source = xferOp.source(); // memref or tensor to be written to.
1023     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1024 
1025     // Generate fully unrolled loop of transfer ops.
1026     Location loc = xferOp.getLoc();
1027     for (int64_t i = 0; i < dimSize; ++i) {
1028       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1029 
1030       auto updatedSource = generateInBoundsCheck(
1031           rewriter, xferOp, iv, unpackedDim(xferOp),
1032           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1033           /*inBoundsCase=*/
1034           [&](OpBuilder &b, Location loc) {
1035             // Indices for the new transfer op.
1036             SmallVector<Value, 8> xferIndices;
1037             getXferIndices(b, xferOp, iv, xferIndices);
1038 
1039             // Indices for the new vector.extract op.
1040             SmallVector<int64_t, 8> extractionIndices;
1041             getExtractionIndices(xferOp, extractionIndices);
1042             extractionIndices.push_back(i);
1043 
1044             auto extracted =
1045                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1046             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
1047             auto newXferOp = b.create<vector::TransferWriteOp>(
1048                 loc, sourceType, extracted, source, xferIndices,
1049                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1050                 inBoundsAttr);
1051 
1052             maybeAssignMask(b, xferOp, newXferOp, i);
1053 
1054             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1055           },
1056           /*outOfBoundsCase=*/
1057           [&](OpBuilder &b, Location loc) {
1058             return isTensorOp(xferOp) ? source : Value();
1059           });
1060 
1061       if (isTensorOp(xferOp))
1062         source = updatedSource;
1063     }
1064 
1065     if (isTensorOp(xferOp))
1066       rewriter.replaceOp(xferOp, source);
1067     else
1068       rewriter.eraseOp(xferOp);
1069 
1070     return success();
1071   }
1072 };
1073 
1074 } // namespace lowering_n_d_unrolled
1075 
1076 namespace lowering_1_d {
1077 
1078 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1079 /// part of TransferOp1dConversion. Return the memref dimension on which
1080 /// the transfer is operating. A return value of None indicates a broadcast.
1081 template <typename OpTy>
1082 static Optional<int64_t>
1083 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1084                    SmallVector<Value, 8> &memrefIndices) {
1085   auto indices = xferOp.indices();
1086   auto map = xferOp.permutation_map();
1087 
1088   memrefIndices.append(indices.begin(), indices.end());
1089   assert(map.getNumResults() == 1 &&
1090          "Expected 1 permutation map result for 1D transfer");
1091   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1092     Location loc = xferOp.getLoc();
1093     auto dim = expr.getPosition();
1094     AffineExpr d0, d1;
1095     bindDims(xferOp.getContext(), d0, d1);
1096     Value offset = memrefIndices[dim];
1097     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1098     return dim;
1099   }
1100 
1101   assert(xferOp.isBroadcastDim(0) &&
1102          "Expected AffineDimExpr or AffineConstantExpr");
1103   return None;
1104 }
1105 
1106 /// Codegen strategy for TransferOp1dConversion, depending on the
1107 /// operation.
1108 template <typename OpTy>
1109 struct Strategy1d;
1110 
1111 /// Codegen strategy for TransferReadOp.
1112 template <>
1113 struct Strategy1d<TransferReadOp> {
1114   static void generateForLoopBody(OpBuilder &b, Location loc,
1115                                   TransferReadOp xferOp, Value iv,
1116                                   ValueRange loopState) {
1117     SmallVector<Value, 8> indices;
1118     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1119     Value ivI32 = b.create<arith::IndexCastOp>(
1120         loc, IntegerType::get(b.getContext(), 32), iv);
1121     auto vec = loopState[0];
1122 
1123     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1124     // padding value).
1125     auto nextVec = generateInBoundsCheck(
1126         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1127         /*inBoundsCase=*/
1128         [&](OpBuilder &b, Location loc) {
1129           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
1130           return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
1131         },
1132         /*outOfBoundsCase=*/
1133         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1134     b.create<scf::YieldOp>(loc, nextVec);
1135   }
1136 
1137   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1138     // Inititalize vector with padding value.
1139     Location loc = xferOp.getLoc();
1140     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
1141   }
1142 };
1143 
1144 /// Codegen strategy for TransferWriteOp.
1145 template <>
1146 struct Strategy1d<TransferWriteOp> {
1147   static void generateForLoopBody(OpBuilder &b, Location loc,
1148                                   TransferWriteOp xferOp, Value iv,
1149                                   ValueRange /*loopState*/) {
1150     SmallVector<Value, 8> indices;
1151     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1152     Value ivI32 = b.create<arith::IndexCastOp>(
1153         loc, IntegerType::get(b.getContext(), 32), iv);
1154 
1155     // Nothing to do in case of out-of-bounds access.
1156     generateInBoundsCheck(
1157         b, xferOp, iv, dim,
1158         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1159           auto val =
1160               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
1161           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
1162         });
1163     b.create<scf::YieldOp>(loc);
1164   }
1165 
1166   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1167     return Value();
1168   }
1169 };
1170 
1171 /// Return true if the last dimension of the MemRefType has unit stride.
1172 static bool isLastMemrefDimUnitStride(MemRefType type) {
1173   int64_t offset;
1174   SmallVector<int64_t, 4> strides;
1175   auto successStrides = getStridesAndOffset(type, strides, offset);
1176   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1177 }
1178 
1179 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1180 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1181 /// vector load/stores due to non-unit strides or broadcasts:
1182 ///
1183 /// * Transfer dimension is not the last memref dimension
1184 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1185 /// * Memref has a layout map with non-unit stride on the last dimension
1186 ///
1187 /// This pattern generates IR as follows:
1188 ///
1189 /// 1. Generate a for loop iterating over each vector element.
1190 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1191 ///    depending on OpTy.
1192 ///
1193 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1194 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1195 ///       to ConvertVectorToLLVM.
1196 ///
1197 /// E.g.:
1198 /// ```
1199 /// vector.transfer_write %vec, %A[%a, %b]
1200 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1201 ///    : vector<9xf32>, memref<?x?xf32>
1202 /// ```
1203 /// Is rewritten to approximately the following pseudo-IR:
1204 /// ```
1205 /// for i = 0 to 9 {
1206 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1207 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1208 /// }
1209 /// ```
1210 template <typename OpTy>
1211 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1212   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1213 
1214   LogicalResult matchAndRewrite(OpTy xferOp,
1215                                 PatternRewriter &rewriter) const override {
1216     auto map = xferOp.permutation_map();
1217     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1218 
1219     if (!memRefType)
1220       return failure();
1221     if (xferOp.getVectorType().getRank() != 1)
1222       return failure();
1223     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1224       return failure(); // Handled by ConvertVectorToLLVM
1225 
1226     // Loop bounds, step, state...
1227     Location loc = xferOp.getLoc();
1228     auto vecType = xferOp.getVectorType();
1229     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1230     auto ub =
1231         rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1232     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1233     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1234 
1235     // Generate for loop.
1236     rewriter.replaceOpWithNewOp<scf::ForOp>(
1237         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1238         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1239           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1240         });
1241 
1242     return success();
1243   }
1244 };
1245 
1246 } // namespace lowering_1_d
1247 } // namespace
1248 
1249 namespace mlir {
1250 
1251 void populateVectorToSCFConversionPatterns(
1252     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1253   if (options.unroll) {
1254     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1255                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1256         patterns.getContext(), options);
1257   } else {
1258     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1259                  lowering_n_d::PrepareTransferWriteConversion,
1260                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1261                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1262         patterns.getContext(), options);
1263   }
1264 
1265   if (options.targetRank == 1) {
1266     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1267                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1268         patterns.getContext(), options);
1269   }
1270 }
1271 
1272 } // namespace mlir
1273 
1274 namespace {
1275 
1276 struct ConvertVectorToSCFPass
1277     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1278   ConvertVectorToSCFPass() = default;
1279   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1280     this->fullUnroll = options.unroll;
1281     this->targetRank = options.targetRank;
1282     this->lowerPermutationMaps = options.lowerPermutationMaps;
1283     this->lowerTensors = options.lowerTensors;
1284   }
1285 
1286   void runOnFunction() override {
1287     VectorTransferToSCFOptions options;
1288     options.unroll = fullUnroll;
1289     options.targetRank = targetRank;
1290     options.lowerPermutationMaps = lowerPermutationMaps;
1291     options.lowerTensors = lowerTensors;
1292 
1293     // Lower permutation maps first.
1294     if (lowerPermutationMaps) {
1295       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1296       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1297           lowerTransferPatterns);
1298       (void)applyPatternsAndFoldGreedily(getFunction(),
1299                                          std::move(lowerTransferPatterns));
1300     }
1301 
1302     RewritePatternSet patterns(getFunction().getContext());
1303     populateVectorToSCFConversionPatterns(patterns, options);
1304     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
1305   }
1306 };
1307 
1308 } // namespace
1309 
1310 std::unique_ptr<Pass>
1311 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1312   return std::make_unique<ConvertVectorToSCFPass>(options);
1313 }
1314