1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/Vector/VectorTransforms.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/Passes.h"
29 
30 using namespace mlir;
31 using vector::TransferReadOp;
32 using vector::TransferWriteOp;
33 
34 namespace {
35 
36 /// Attribute name used for labeling transfer ops during progressive lowering.
37 static const char kPassLabel[] = "__vector_to_scf_lowering__";
38 
39 /// Patterns that inherit from this struct have access to
40 /// VectorTransferToSCFOptions.
41 template <typename OpTy>
42 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
43   explicit VectorToSCFPattern(MLIRContext *context,
44                               VectorTransferToSCFOptions opt)
45       : OpRewritePattern<OpTy>(context), options(opt) {}
46 
47   VectorTransferToSCFOptions options;
48 };
49 
50 /// Given a vector transfer op, calculate which dimension of the `source`
51 /// memref should be unpacked in the next application of TransferOpConversion.
52 /// A return value of None indicates a broadcast.
53 template <typename OpTy>
54 static Optional<int64_t> unpackedDim(OpTy xferOp) {
55   // TODO: support 0-d corner case.
56   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
57   auto map = xferOp.permutation_map();
58   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
59     return expr.getPosition();
60   }
61   assert(xferOp.isBroadcastDim(0) &&
62          "Expected AffineDimExpr or AffineConstantExpr");
63   return None;
64 }
65 
66 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
67 /// map is identical to the current permutation map, but the first result is
68 /// omitted.
69 template <typename OpTy>
70 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
71   // TODO: support 0-d corner case.
72   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
73   auto map = xferOp.permutation_map();
74   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
75                         b.getContext());
76 }
77 
78 /// Calculate the indices for the new vector transfer op.
79 ///
80 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
81 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
82 ///                                 ^^^^^^
83 ///              `iv` is the iteration variable of the (new) surrounding loop.
84 template <typename OpTy>
85 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
86                            SmallVector<Value, 8> &indices) {
87   typename OpTy::Adaptor adaptor(xferOp);
88   // Corresponding memref dim of the vector dim that is unpacked.
89   auto dim = unpackedDim(xferOp);
90   auto prevIndices = adaptor.indices();
91   indices.append(prevIndices.begin(), prevIndices.end());
92 
93   Location loc = xferOp.getLoc();
94   bool isBroadcast = !dim.hasValue();
95   if (!isBroadcast) {
96     AffineExpr d0, d1;
97     bindDims(xferOp.getContext(), d0, d1);
98     Value offset = adaptor.indices()[dim.getValue()];
99     indices[dim.getValue()] =
100         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
101   }
102 }
103 
104 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
105                             Value value) {
106   if (hasRetVal) {
107     assert(value && "Expected non-empty value");
108     b.create<scf::YieldOp>(loc, value);
109   } else {
110     b.create<scf::YieldOp>(loc);
111   }
112 }
113 
114 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
115 /// is set to true. No such check is generated under following circumstances:
116 /// * xferOp does not have a mask.
117 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
118 ///   computed and attached to the new transfer op in the pattern.)
119 /// * The to-be-unpacked dim of xferOp is a broadcast.
120 template <typename OpTy>
121 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
122   if (!xferOp.mask())
123     return Value();
124   if (xferOp.getMaskType().getRank() != 1)
125     return Value();
126   if (xferOp.isBroadcastDim(0))
127     return Value();
128 
129   Location loc = xferOp.getLoc();
130   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), iv);
131 }
132 
133 /// Helper function TransferOpConversion and TransferOp1dConversion.
134 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
135 /// specified dimension `dim` with the loop iteration variable `iv`.
136 /// E.g., when unpacking dimension 0 from:
137 /// ```
138 /// %vec = vector.transfer_read %A[%a, %b] %cst
139 ///     : vector<5x4xf32>, memref<?x?xf32>
140 /// ```
141 /// An if check similar to this will be generated inside the loop:
142 /// ```
143 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
144 /// if (%a + iv < %d) {
145 ///   (in-bounds case)
146 /// } else {
147 ///   (out-of-bounds case)
148 /// }
149 /// ```
150 ///
151 /// If the transfer is 1D and has a mask, this function generates a more complex
152 /// check also accounts for potentially masked out elements.
153 ///
154 /// This function variant returns the value returned by `inBoundsCase` or
155 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
156 /// `resultTypes`.
157 template <typename OpTy>
158 static Value generateInBoundsCheck(
159     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
160     TypeRange resultTypes,
161     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
162     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
163   bool hasRetVal = !resultTypes.empty();
164   Value cond; // Condition to be built...
165 
166   // Condition check 1: Access in-bounds?
167   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
168   Location loc = xferOp.getLoc();
169   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
170   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
171     Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim);
172     AffineExpr d0, d1;
173     bindDims(xferOp.getContext(), d0, d1);
174     Value base = xferOp.indices()[dim.getValue()];
175     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
176     cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
177                                     memrefIdx);
178   }
179 
180   // Condition check 2: Masked in?
181   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
182     if (cond)
183       cond = lb.create<arith::AndIOp>(cond, maskCond);
184     else
185       cond = maskCond;
186   }
187 
188   // If the condition is non-empty, generate an SCF::IfOp.
189   if (cond) {
190     auto check = lb.create<scf::IfOp>(
191         resultTypes, cond,
192         /*thenBuilder=*/
193         [&](OpBuilder &b, Location loc) {
194           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
195         },
196         /*elseBuilder=*/
197         [&](OpBuilder &b, Location loc) {
198           if (outOfBoundsCase) {
199             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
200           } else {
201             b.create<scf::YieldOp>(loc);
202           }
203         });
204 
205     return hasRetVal ? check.getResult(0) : Value();
206   }
207 
208   // Condition is empty, no need for an SCF::IfOp.
209   return inBoundsCase(b, loc);
210 }
211 
212 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
213 /// a return value. Consequently, this function does not have a return value.
214 template <typename OpTy>
215 static void generateInBoundsCheck(
216     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
217     function_ref<void(OpBuilder &, Location)> inBoundsCase,
218     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
219   generateInBoundsCheck(
220       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
221       /*inBoundsCase=*/
222       [&](OpBuilder &b, Location loc) {
223         inBoundsCase(b, loc);
224         return Value();
225       },
226       /*outOfBoundsCase=*/
227       [&](OpBuilder &b, Location loc) {
228         if (outOfBoundsCase)
229           outOfBoundsCase(b, loc);
230         return Value();
231       });
232 }
233 
234 /// Given an ArrayAttr, return a copy where the first element is dropped.
235 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
236   if (!attr)
237     return attr;
238   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
239 }
240 
241 /// Add the pass label to a vector transfer op if its rank is not the target
242 /// rank.
243 template <typename OpTy>
244 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
245                                 unsigned targetRank) {
246   if (newXferOp.getVectorType().getRank() > targetRank)
247     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
248 }
249 
250 /// Return true if this transfer op operates on a source tensor.
251 template <typename OpTy>
252 static bool isTensorOp(OpTy xferOp) {
253   if (xferOp.getShapedType().template isa<RankedTensorType>()) {
254     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
255       // TransferWriteOps on tensors have a result.
256       assert(xferOp->getNumResults() > 0);
257     }
258     return true;
259   }
260   return false;
261 }
262 
263 namespace lowering_n_d {
264 
265 /// Helper data structure for data and mask buffers.
266 struct BufferAllocs {
267   Value dataBuffer;
268   Value maskBuffer;
269 };
270 
271 /// Allocate temporary buffers for data (vector) and mask (if present).
272 /// TODO: Parallelism and threadlocal considerations.
273 template <typename OpTy>
274 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
275   Location loc = xferOp.getLoc();
276   OpBuilder::InsertionGuard guard(b);
277   Operation *scope =
278       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
279   assert(scope && "Expected op to be inside automatic allocation scope");
280   b.setInsertionPointToStart(&scope->getRegion(0).front());
281 
282   BufferAllocs result;
283   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
284   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
285 
286   if (xferOp.mask()) {
287     auto maskType = MemRefType::get({}, xferOp.mask().getType());
288     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
289     b.setInsertionPoint(xferOp);
290     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
291     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
292   }
293 
294   return result;
295 }
296 
297 /// Given a MemRefType with VectorType element type, unpack one dimension from
298 /// the VectorType into the MemRefType.
299 ///
300 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
301 static MemRefType unpackOneDim(MemRefType type) {
302   auto vectorType = type.getElementType().dyn_cast<VectorType>();
303   auto memrefShape = type.getShape();
304   SmallVector<int64_t, 8> newMemrefShape;
305   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
306   newMemrefShape.push_back(vectorType.getDimSize(0));
307   return MemRefType::get(newMemrefShape,
308                          VectorType::get(vectorType.getShape().drop_front(),
309                                          vectorType.getElementType()));
310 }
311 
312 /// Given a transfer op, find the memref from which the mask is loaded. This
313 /// is similar to Strategy<TransferWriteOp>::getBuffer.
314 template <typename OpTy>
315 static Value getMaskBuffer(OpTy xferOp) {
316   assert(xferOp.mask() && "Expected that transfer op has mask");
317   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
318   assert(loadOp && "Expected transfer op mask produced by LoadOp");
319   return loadOp.getMemRef();
320 }
321 
322 /// Codegen strategy, depending on the operation.
323 template <typename OpTy>
324 struct Strategy;
325 
326 /// Code strategy for vector TransferReadOp.
327 template <>
328 struct Strategy<TransferReadOp> {
329   /// Find the StoreOp that is used for writing the current TransferReadOp's
330   /// result to the temporary buffer allocation.
331   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
332     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
333     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
334     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
335     return storeOp;
336   }
337 
338   /// Find the temporary buffer allocation. All labeled TransferReadOps are
339   /// used like this, where %buf is either the buffer allocation or a type cast
340   /// of the buffer allocation:
341   /// ```
342   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
343   /// memref.store %vec, %buf[...] ...
344   /// ```
345   static Value getBuffer(TransferReadOp xferOp) {
346     return getStoreOp(xferOp).getMemRef();
347   }
348 
349   /// Retrieve the indices of the current StoreOp that stores into the buffer.
350   static void getBufferIndices(TransferReadOp xferOp,
351                                SmallVector<Value, 8> &indices) {
352     auto storeOp = getStoreOp(xferOp);
353     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
354     indices.append(prevIndices.begin(), prevIndices.end());
355   }
356 
357   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
358   /// accesses on the to-be-unpacked dimension.
359   ///
360   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
361   ///    variable `iv`.
362   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
363   ///
364   /// E.g.:
365   /// ```
366   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
367   ///     : memref<?x?x?xf32>, vector<4x3xf32>
368   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
369   /// ```
370   /// Is rewritten to:
371   /// ```
372   /// %casted = vector.type_cast %buf
373   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
374   /// for %j = 0 to 4 {
375   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
376   ///       : memref<?x?x?xf32>, vector<3xf32>
377   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
378   /// }
379   /// ```
380   ///
381   /// Note: The loop and type cast are generated in TransferOpConversion.
382   ///       The original TransferReadOp and store op are deleted in `cleanup`.
383   /// Note: The `mask` operand is set in TransferOpConversion.
384   static TransferReadOp rewriteOp(OpBuilder &b,
385                                   VectorTransferToSCFOptions options,
386                                   TransferReadOp xferOp, Value buffer, Value iv,
387                                   ValueRange /*loopState*/) {
388     SmallVector<Value, 8> storeIndices;
389     getBufferIndices(xferOp, storeIndices);
390     storeIndices.push_back(iv);
391 
392     SmallVector<Value, 8> xferIndices;
393     getXferIndices(b, xferOp, iv, xferIndices);
394 
395     Location loc = xferOp.getLoc();
396     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
397     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
398     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
399     auto newXferOp = b.create<vector::TransferReadOp>(
400         loc, vecType, xferOp.source(), xferIndices,
401         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
402         Value(), inBoundsAttr);
403 
404     maybeApplyPassLabel(b, newXferOp, options.targetRank);
405 
406     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
407     return newXferOp;
408   }
409 
410   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
411   /// padding value to the temporary buffer.
412   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
413                                     Value buffer, Value iv,
414                                     ValueRange /*loopState*/) {
415     SmallVector<Value, 8> storeIndices;
416     getBufferIndices(xferOp, storeIndices);
417     storeIndices.push_back(iv);
418 
419     Location loc = xferOp.getLoc();
420     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
421     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
422     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
423     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
424 
425     return Value();
426   }
427 
428   /// Cleanup after rewriting the op.
429   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
430                       scf::ForOp /*forOp*/) {
431     rewriter.eraseOp(getStoreOp(xferOp));
432     rewriter.eraseOp(xferOp);
433   }
434 
435   /// Return the initial loop state for the generated scf.for loop.
436   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
437 };
438 
439 /// Codegen strategy for vector TransferWriteOp.
440 template <>
441 struct Strategy<TransferWriteOp> {
442   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
443   /// used like this, where %buf is either the buffer allocation or a type cast
444   /// of the buffer allocation:
445   /// ```
446   /// %vec = memref.load %buf[...] ...
447   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
448   /// ```
449   static Value getBuffer(TransferWriteOp xferOp) {
450     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
451     assert(loadOp && "Expected transfer op vector produced by LoadOp");
452     return loadOp.getMemRef();
453   }
454 
455   /// Retrieve the indices of the current LoadOp that loads from the buffer.
456   static void getBufferIndices(TransferWriteOp xferOp,
457                                SmallVector<Value, 8> &indices) {
458     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
459     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
460     indices.append(prevIndices.begin(), prevIndices.end());
461   }
462 
463   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
464   /// accesses on the to-be-unpacked dimension.
465   ///
466   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
467   ///    using the loop iteration variable `iv`.
468   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
469   ///    to memory.
470   ///
471   /// Note: For more details, see comments on Strategy<TransferReadOp>.
472   static TransferWriteOp rewriteOp(OpBuilder &b,
473                                    VectorTransferToSCFOptions options,
474                                    TransferWriteOp xferOp, Value buffer,
475                                    Value iv, ValueRange loopState) {
476     SmallVector<Value, 8> loadIndices;
477     getBufferIndices(xferOp, loadIndices);
478     loadIndices.push_back(iv);
479 
480     SmallVector<Value, 8> xferIndices;
481     getXferIndices(b, xferOp, iv, xferIndices);
482 
483     Location loc = xferOp.getLoc();
484     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
485     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
486     auto source = loopState.empty() ? xferOp.source() : loopState[0];
487     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
488     auto newXferOp = b.create<vector::TransferWriteOp>(
489         loc, type, vec, source, xferIndices,
490         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
491         inBoundsAttr);
492 
493     maybeApplyPassLabel(b, newXferOp, options.targetRank);
494 
495     return newXferOp;
496   }
497 
498   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
499   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
500                                     Value buffer, Value iv,
501                                     ValueRange loopState) {
502     return isTensorOp(xferOp) ? loopState[0] : Value();
503   }
504 
505   /// Cleanup after rewriting the op.
506   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
507                       scf::ForOp forOp) {
508     if (isTensorOp(xferOp)) {
509       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
510       rewriter.replaceOp(xferOp, forOp->getResult(0));
511     } else {
512       rewriter.eraseOp(xferOp);
513     }
514   }
515 
516   /// Return the initial loop state for the generated scf.for loop.
517   static Value initialLoopState(TransferWriteOp xferOp) {
518     return isTensorOp(xferOp) ? xferOp.source() : Value();
519   }
520 };
521 
522 template <typename OpTy>
523 LogicalResult checkPrepareXferOp(OpTy xferOp,
524                                  VectorTransferToSCFOptions options) {
525   if (xferOp->hasAttr(kPassLabel))
526     return failure();
527   if (xferOp.getVectorType().getRank() <= options.targetRank)
528     return failure();
529   if (isTensorOp(xferOp) && !options.lowerTensors)
530     return failure();
531   // Transfer ops that modify the element type are not supported atm.
532   if (xferOp.getVectorType().getElementType() !=
533       xferOp.getShapedType().getElementType())
534     return failure();
535   return success();
536 }
537 
538 /// Prepare a TransferReadOp for progressive lowering.
539 ///
540 /// 1. Allocate a temporary buffer.
541 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
542 /// 3. Store the result of the TransferReadOp into the temporary buffer.
543 /// 4. Load the result from the temporary buffer and replace all uses of the
544 ///    original TransferReadOp with this load.
545 ///
546 /// E.g.:
547 /// ```
548 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
549 ///     : vector<5x4xf32>, memref<?x?x?xf32>
550 /// ```
551 /// is rewritten to:
552 /// ```
553 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
554 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
555 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
556 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
557 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
558 /// ```
559 ///
560 /// Note: A second temporary buffer may be allocated for the `mask` operand.
561 struct PrepareTransferReadConversion
562     : public VectorToSCFPattern<TransferReadOp> {
563   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
564 
565   LogicalResult matchAndRewrite(TransferReadOp xferOp,
566                                 PatternRewriter &rewriter) const override {
567     if (checkPrepareXferOp(xferOp, options).failed())
568       return failure();
569 
570     auto buffers = allocBuffers(rewriter, xferOp);
571     auto *newXfer = rewriter.clone(*xferOp.getOperation());
572     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
573     if (xferOp.mask()) {
574       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
575           buffers.maskBuffer);
576     }
577 
578     Location loc = xferOp.getLoc();
579     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
580                                      buffers.dataBuffer);
581     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
582 
583     return success();
584   }
585 };
586 
587 /// Prepare a TransferWriteOp for progressive lowering.
588 ///
589 /// 1. Allocate a temporary buffer.
590 /// 2. Store the vector into the buffer.
591 /// 3. Load the vector from the buffer again.
592 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
593 ///    marking it eligible for progressive lowering via TransferOpConversion.
594 ///
595 /// E.g.:
596 /// ```
597 /// vector.transfer_write %vec, %A[%a, %b, %c]
598 ///     : vector<5x4xf32>, memref<?x?x?xf32>
599 /// ```
600 /// is rewritten to:
601 /// ```
602 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
603 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
604 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
605 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
606 ///     : vector<5x4xf32>, memref<?x?x?xf32>
607 /// ```
608 ///
609 /// Note: A second temporary buffer may be allocated for the `mask` operand.
610 struct PrepareTransferWriteConversion
611     : public VectorToSCFPattern<TransferWriteOp> {
612   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
613 
614   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
615                                 PatternRewriter &rewriter) const override {
616     if (checkPrepareXferOp(xferOp, options).failed())
617       return failure();
618 
619     Location loc = xferOp.getLoc();
620     auto buffers = allocBuffers(rewriter, xferOp);
621     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
622     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
623     rewriter.updateRootInPlace(xferOp, [&]() {
624       xferOp.vectorMutable().assign(loadedVec);
625       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
626     });
627 
628     if (xferOp.mask()) {
629       rewriter.updateRootInPlace(
630           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
631     }
632 
633     return success();
634   }
635 };
636 
637 /// Progressive lowering of vector transfer ops: Unpack one dimension.
638 ///
639 /// 1. Unpack one dimension from the current buffer type and cast the buffer
640 ///    to that new type. E.g.:
641 ///    ```
642 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
643 ///    vector.transfer_write %vec ...
644 ///    ```
645 ///    The following cast is generated:
646 ///    ```
647 ///    %casted = vector.type_cast %0
648 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
649 ///    ```
650 /// 2. Generate a for loop and rewrite the transfer op according to the
651 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
652 ///    out-of-bounds, generate an if-check and handle both cases separately.
653 /// 3. Clean up according to the corresponding Strategy<OpTy>.
654 ///
655 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
656 /// source (as opposed to a memref source), then each iteration of the generated
657 /// scf.for loop yields the new tensor value. E.g.:
658 /// ```
659 /// %result = scf.for i = 0 to 5 {
660 ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
661 ///   %1 = vector.transfer_write %0, %source[...]
662 ///       : vector<4x3xf32>, tensor<5x4x3xf32>
663 ///   scf.yield %1 : tensor<5x4x3xf32>
664 /// }
665 /// ```
666 template <typename OpTy>
667 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
668   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
669 
670   void initialize() {
671     // This pattern recursively unpacks one dimension at a time. The recursion
672     // bounded as the rank is strictly decreasing.
673     this->setHasBoundedRewriteRecursion();
674   }
675 
676   LogicalResult matchAndRewrite(OpTy xferOp,
677                                 PatternRewriter &rewriter) const override {
678     if (!xferOp->hasAttr(kPassLabel))
679       return failure();
680 
681     // Find and cast data buffer. How the buffer can be found depends on OpTy.
682     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
683     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
684     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
685     auto castedDataType = unpackOneDim(dataBufferType);
686     auto castedDataBuffer =
687         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
688 
689     // If the xferOp has a mask: Find and cast mask buffer.
690     Value castedMaskBuffer;
691     if (xferOp.mask()) {
692       auto maskBuffer = getMaskBuffer(xferOp);
693       auto maskBufferType =
694           maskBuffer.getType().template dyn_cast<MemRefType>();
695       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
696         // Do not unpack a dimension of the mask, if:
697         // * To-be-unpacked transfer op dimension is a broadcast.
698         // * Mask is 1D, i.e., the mask cannot be further unpacked.
699         //   (That means that all remaining dimensions of the transfer op must
700         //   be broadcasted.)
701         castedMaskBuffer = maskBuffer;
702       } else {
703         auto castedMaskType = unpackOneDim(maskBufferType);
704         castedMaskBuffer =
705             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
706       }
707     }
708 
709     // Loop bounds and step.
710     auto lb = locB.create<arith::ConstantIndexOp>(0);
711     auto ub = locB.create<arith::ConstantIndexOp>(
712         castedDataType.getDimSize(castedDataType.getRank() - 1));
713     auto step = locB.create<arith::ConstantIndexOp>(1);
714     // TransferWriteOps that operate on tensors return the modified tensor and
715     // require a loop state.
716     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
717 
718     // Generate for loop.
719     auto result = locB.create<scf::ForOp>(
720         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
721         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
722           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
723 
724           auto result = generateInBoundsCheck(
725               b, xferOp, iv, unpackedDim(xferOp),
726               stateType ? TypeRange(stateType) : TypeRange(),
727               /*inBoundsCase=*/
728               [&](OpBuilder &b, Location loc) {
729                 // Create new transfer op.
730                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
731                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
732 
733                 // If old transfer op has a mask: Set mask on new transfer op.
734                 // Special case: If the mask of the old transfer op is 1D and
735                 // the
736                 //               unpacked dim is not a broadcast, no mask is
737                 //               needed on the new transfer op.
738                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
739                                       xferOp.getMaskType().getRank() > 1)) {
740                   OpBuilder::InsertionGuard guard(b);
741                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
742 
743                   SmallVector<Value, 8> loadIndices;
744                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
745                   // In case of broadcast: Use same indices to load from memref
746                   // as before.
747                   if (!xferOp.isBroadcastDim(0))
748                     loadIndices.push_back(iv);
749 
750                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
751                                                        loadIndices);
752                   rewriter.updateRootInPlace(
753                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
754                 }
755 
756                 return loopState.empty() ? Value() : newXfer->getResult(0);
757               },
758               /*outOfBoundsCase=*/
759               [&](OpBuilder &b, Location /*loc*/) {
760                 return Strategy<OpTy>::handleOutOfBoundsDim(
761                     b, xferOp, castedDataBuffer, iv, loopState);
762               });
763 
764           maybeYieldValue(b, loc, !loopState.empty(), result);
765         });
766 
767     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
768     return success();
769   }
770 };
771 
772 } // namespace lowering_n_d
773 
774 namespace lowering_n_d_unrolled {
775 
776 /// If the original transfer op has a mask, compute the mask of the new transfer
777 /// op (for the current iteration `i`) and assign it.
778 template <typename OpTy>
779 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
780                             int64_t i) {
781   if (!xferOp.mask())
782     return;
783 
784   if (xferOp.isBroadcastDim(0)) {
785     // To-be-unpacked dimension is a broadcast, which does not have a
786     // corresponding mask dimension. Mask attribute remains unchanged.
787     newXferOp.maskMutable().assign(xferOp.mask());
788     return;
789   }
790 
791   if (xferOp.getMaskType().getRank() > 1) {
792     // Unpack one dimension of the mask.
793     OpBuilder::InsertionGuard guard(b);
794     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
795 
796     llvm::SmallVector<int64_t, 1> indices({i});
797     Location loc = xferOp.getLoc();
798     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
799     newXferOp.maskMutable().assign(newMask);
800   }
801 
802   // If we end up here: The mask of the old transfer op is 1D and the unpacked
803   // dim is not a broadcast, so no mask is needed on the new transfer op.
804   // `generateInBoundsCheck` will have evaluated the mask already.
805 }
806 
807 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
808 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
809 /// memref buffer is allocated and the SCF loop is fully unrolled.
810 ///
811 /// ```
812 /// E.g.:
813 /// ```
814 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
815 ///     : memref<?x?x?xf32>, vector<5x4xf32>
816 /// ```
817 /// is rewritten to IR such as (simplified):
818 /// ```
819 /// %v_init = splat %padding : vector<5x4xf32>
820 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
821 ///     : memref<?x?x?xf32>, vector<4xf32>
822 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
823 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
824 ///     : memref<?x?x?xf32>, vector<4xf32>
825 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
826 /// ...
827 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
828 ///     : memref<?x?x?xf32>, vector<4xf32>
829 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
830 /// ```
831 ///
832 /// Note: As an optimization, if the result of the original TransferReadOp
833 /// was directly inserted into another vector, no new %v_init vector is created.
834 /// Instead, the new TransferReadOp results are inserted into that vector.
835 struct UnrollTransferReadConversion
836     : public VectorToSCFPattern<TransferReadOp> {
837   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
838 
839   void initialize() {
840     // This pattern recursively unpacks one dimension at a time. The recursion
841     // bounded as the rank is strictly decreasing.
842     setHasBoundedRewriteRecursion();
843   }
844 
845   /// Return the vector into which the newly created TransferReadOp results
846   /// are inserted.
847   Value getResultVector(TransferReadOp xferOp,
848                         PatternRewriter &rewriter) const {
849     if (auto insertOp = getInsertOp(xferOp))
850       return insertOp.dest();
851     Location loc = xferOp.getLoc();
852     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
853                                     xferOp.padding());
854   }
855 
856   /// If the result of the TransferReadOp has exactly one user, which is a
857   /// vector::InsertOp, return that operation.
858   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
859     if (xferOp->hasOneUse()) {
860       Operation *xferOpUser = *xferOp->getUsers().begin();
861       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
862         return insertOp;
863     }
864 
865     return vector::InsertOp();
866   }
867 
868   /// If the result of the TransferReadOp has exactly one user, which is a
869   /// vector::InsertOp, return that operation's indices.
870   void getInsertionIndices(TransferReadOp xferOp,
871                            SmallVector<int64_t, 8> &indices) const {
872     if (auto insertOp = getInsertOp(xferOp)) {
873       llvm::for_each(insertOp.position(), [&](Attribute attr) {
874         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
875       });
876     }
877   }
878 
879   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
880   /// accesses, and broadcasts and transposes in permutation maps.
881   LogicalResult matchAndRewrite(TransferReadOp xferOp,
882                                 PatternRewriter &rewriter) const override {
883     if (xferOp.getVectorType().getRank() <= options.targetRank)
884       return failure();
885     if (isTensorOp(xferOp) && !options.lowerTensors)
886       return failure();
887     // Transfer ops that modify the element type are not supported atm.
888     if (xferOp.getVectorType().getElementType() !=
889         xferOp.getShapedType().getElementType())
890       return failure();
891 
892     auto insertOp = getInsertOp(xferOp);
893     auto vec = getResultVector(xferOp, rewriter);
894     auto vecType = vec.getType().dyn_cast<VectorType>();
895     auto xferVecType = xferOp.getVectorType();
896     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
897                                           xferVecType.getElementType());
898     int64_t dimSize = xferVecType.getShape()[0];
899 
900     // Generate fully unrolled loop of transfer ops.
901     Location loc = xferOp.getLoc();
902     for (int64_t i = 0; i < dimSize; ++i) {
903       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
904 
905       vec = generateInBoundsCheck(
906           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
907           /*inBoundsCase=*/
908           [&](OpBuilder &b, Location loc) {
909             // Indices for the new transfer op.
910             SmallVector<Value, 8> xferIndices;
911             getXferIndices(b, xferOp, iv, xferIndices);
912 
913             // Indices for the new vector.insert op.
914             SmallVector<int64_t, 8> insertionIndices;
915             getInsertionIndices(xferOp, insertionIndices);
916             insertionIndices.push_back(i);
917 
918             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
919             auto newXferOp = b.create<vector::TransferReadOp>(
920                 loc, newXferVecType, xferOp.source(), xferIndices,
921                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
922                 xferOp.padding(), Value(), inBoundsAttr);
923             maybeAssignMask(b, xferOp, newXferOp, i);
924             return b.create<vector::InsertOp>(loc, newXferOp, vec,
925                                               insertionIndices);
926           },
927           /*outOfBoundsCase=*/
928           [&](OpBuilder &b, Location loc) {
929             // Loop through original (unmodified) vector.
930             return vec;
931           });
932     }
933 
934     if (insertOp) {
935       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
936       rewriter.replaceOp(insertOp, vec);
937       rewriter.eraseOp(xferOp);
938     } else {
939       rewriter.replaceOp(xferOp, vec);
940     }
941 
942     return success();
943   }
944 };
945 
946 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
947 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
948 /// memref buffer is allocated and the SCF loop is fully unrolled.
949 ///
950 /// ```
951 /// E.g.:
952 /// ```
953 /// vector.transfer_write %vec, %A[%a, %b, %c]
954 ///     : vector<5x4xf32>, memref<?x?x?xf32>
955 /// ```
956 /// is rewritten to IR such as (simplified):
957 /// ```
958 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
959 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
960 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
961 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
962 /// ...
963 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
964 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
965 /// ```
966 ///
967 /// Note: As an optimization, if the vector of the original TransferWriteOp
968 /// was directly extracted from another vector via an ExtractOp `a`, extract
969 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
970 /// doing so, `a` may become dead, and the number of ExtractOps generated during
971 /// recursive application of this pattern will be minimal.
972 struct UnrollTransferWriteConversion
973     : public VectorToSCFPattern<TransferWriteOp> {
974   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
975 
976   void initialize() {
977     // This pattern recursively unpacks one dimension at a time. The recursion
978     // bounded as the rank is strictly decreasing.
979     setHasBoundedRewriteRecursion();
980   }
981 
982   /// Return the vector from which newly generated ExtracOps will extract.
983   Value getDataVector(TransferWriteOp xferOp) const {
984     if (auto extractOp = getExtractOp(xferOp))
985       return extractOp.vector();
986     return xferOp.vector();
987   }
988 
989   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
990   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
991     if (auto *op = xferOp.vector().getDefiningOp())
992       return dyn_cast<vector::ExtractOp>(op);
993     return vector::ExtractOp();
994   }
995 
996   /// If the input of the given TransferWriteOp is an ExtractOp, return its
997   /// indices.
998   void getExtractionIndices(TransferWriteOp xferOp,
999                             SmallVector<int64_t, 8> &indices) const {
1000     if (auto extractOp = getExtractOp(xferOp)) {
1001       llvm::for_each(extractOp.position(), [&](Attribute attr) {
1002         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
1003       });
1004     }
1005   }
1006 
1007   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1008   /// accesses, and broadcasts and transposes in permutation maps.
1009   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1010                                 PatternRewriter &rewriter) const override {
1011     if (xferOp.getVectorType().getRank() <= options.targetRank)
1012       return failure();
1013     if (isTensorOp(xferOp) && !options.lowerTensors)
1014       return failure();
1015     // Transfer ops that modify the element type are not supported atm.
1016     if (xferOp.getVectorType().getElementType() !=
1017         xferOp.getShapedType().getElementType())
1018       return failure();
1019 
1020     auto vec = getDataVector(xferOp);
1021     auto xferVecType = xferOp.getVectorType();
1022     int64_t dimSize = xferVecType.getShape()[0];
1023     auto source = xferOp.source(); // memref or tensor to be written to.
1024     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1025 
1026     // Generate fully unrolled loop of transfer ops.
1027     Location loc = xferOp.getLoc();
1028     for (int64_t i = 0; i < dimSize; ++i) {
1029       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1030 
1031       auto updatedSource = generateInBoundsCheck(
1032           rewriter, xferOp, iv, unpackedDim(xferOp),
1033           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1034           /*inBoundsCase=*/
1035           [&](OpBuilder &b, Location loc) {
1036             // Indices for the new transfer op.
1037             SmallVector<Value, 8> xferIndices;
1038             getXferIndices(b, xferOp, iv, xferIndices);
1039 
1040             // Indices for the new vector.extract op.
1041             SmallVector<int64_t, 8> extractionIndices;
1042             getExtractionIndices(xferOp, extractionIndices);
1043             extractionIndices.push_back(i);
1044 
1045             auto extracted =
1046                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1047             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
1048             auto newXferOp = b.create<vector::TransferWriteOp>(
1049                 loc, sourceType, extracted, source, xferIndices,
1050                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1051                 inBoundsAttr);
1052 
1053             maybeAssignMask(b, xferOp, newXferOp, i);
1054 
1055             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1056           },
1057           /*outOfBoundsCase=*/
1058           [&](OpBuilder &b, Location loc) {
1059             return isTensorOp(xferOp) ? source : Value();
1060           });
1061 
1062       if (isTensorOp(xferOp))
1063         source = updatedSource;
1064     }
1065 
1066     if (isTensorOp(xferOp))
1067       rewriter.replaceOp(xferOp, source);
1068     else
1069       rewriter.eraseOp(xferOp);
1070 
1071     return success();
1072   }
1073 };
1074 
1075 } // namespace lowering_n_d_unrolled
1076 
1077 namespace lowering_1_d {
1078 
1079 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1080 /// part of TransferOp1dConversion. Return the memref dimension on which
1081 /// the transfer is operating. A return value of None indicates a broadcast.
1082 template <typename OpTy>
1083 static Optional<int64_t>
1084 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1085                    SmallVector<Value, 8> &memrefIndices) {
1086   auto indices = xferOp.indices();
1087   auto map = xferOp.permutation_map();
1088   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1089 
1090   memrefIndices.append(indices.begin(), indices.end());
1091   assert(map.getNumResults() == 1 &&
1092          "Expected 1 permutation map result for 1D transfer");
1093   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1094     Location loc = xferOp.getLoc();
1095     auto dim = expr.getPosition();
1096     AffineExpr d0, d1;
1097     bindDims(xferOp.getContext(), d0, d1);
1098     Value offset = memrefIndices[dim];
1099     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1100     return dim;
1101   }
1102 
1103   assert(xferOp.isBroadcastDim(0) &&
1104          "Expected AffineDimExpr or AffineConstantExpr");
1105   return None;
1106 }
1107 
1108 /// Codegen strategy for TransferOp1dConversion, depending on the
1109 /// operation.
1110 template <typename OpTy>
1111 struct Strategy1d;
1112 
1113 /// Codegen strategy for TransferReadOp.
1114 template <>
1115 struct Strategy1d<TransferReadOp> {
1116   static void generateForLoopBody(OpBuilder &b, Location loc,
1117                                   TransferReadOp xferOp, Value iv,
1118                                   ValueRange loopState) {
1119     SmallVector<Value, 8> indices;
1120     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1121     auto vec = loopState[0];
1122 
1123     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1124     // padding value).
1125     auto nextVec = generateInBoundsCheck(
1126         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1127         /*inBoundsCase=*/
1128         [&](OpBuilder &b, Location loc) {
1129           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
1130           return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1131         },
1132         /*outOfBoundsCase=*/
1133         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1134     b.create<scf::YieldOp>(loc, nextVec);
1135   }
1136 
1137   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1138     // Inititalize vector with padding value.
1139     Location loc = xferOp.getLoc();
1140     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
1141   }
1142 };
1143 
1144 /// Codegen strategy for TransferWriteOp.
1145 template <>
1146 struct Strategy1d<TransferWriteOp> {
1147   static void generateForLoopBody(OpBuilder &b, Location loc,
1148                                   TransferWriteOp xferOp, Value iv,
1149                                   ValueRange /*loopState*/) {
1150     SmallVector<Value, 8> indices;
1151     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1152 
1153     // Nothing to do in case of out-of-bounds access.
1154     generateInBoundsCheck(
1155         b, xferOp, iv, dim,
1156         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1157           auto val =
1158               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), iv);
1159           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
1160         });
1161     b.create<scf::YieldOp>(loc);
1162   }
1163 
1164   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1165     return Value();
1166   }
1167 };
1168 
1169 /// Return true if the last dimension of the MemRefType has unit stride.
1170 static bool isLastMemrefDimUnitStride(MemRefType type) {
1171   int64_t offset;
1172   SmallVector<int64_t, 4> strides;
1173   auto successStrides = getStridesAndOffset(type, strides, offset);
1174   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1175 }
1176 
1177 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1178 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1179 /// vector load/stores due to non-unit strides or broadcasts:
1180 ///
1181 /// * Transfer dimension is not the last memref dimension
1182 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1183 /// * Memref has a layout map with non-unit stride on the last dimension
1184 ///
1185 /// This pattern generates IR as follows:
1186 ///
1187 /// 1. Generate a for loop iterating over each vector element.
1188 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1189 ///    depending on OpTy.
1190 ///
1191 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1192 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1193 ///       to ConvertVectorToLLVM.
1194 ///
1195 /// E.g.:
1196 /// ```
1197 /// vector.transfer_write %vec, %A[%a, %b]
1198 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1199 ///    : vector<9xf32>, memref<?x?xf32>
1200 /// ```
1201 /// Is rewritten to approximately the following pseudo-IR:
1202 /// ```
1203 /// for i = 0 to 9 {
1204 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1205 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1206 /// }
1207 /// ```
1208 template <typename OpTy>
1209 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1210   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1211 
1212   LogicalResult matchAndRewrite(OpTy xferOp,
1213                                 PatternRewriter &rewriter) const override {
1214     // TODO: support 0-d corner case.
1215     if (xferOp.getTransferRank() == 0)
1216       return failure();
1217     auto map = xferOp.permutation_map();
1218     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1219 
1220     if (!memRefType)
1221       return failure();
1222     if (xferOp.getVectorType().getRank() != 1)
1223       return failure();
1224     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1225       return failure(); // Handled by ConvertVectorToLLVM
1226 
1227     // Loop bounds, step, state...
1228     Location loc = xferOp.getLoc();
1229     auto vecType = xferOp.getVectorType();
1230     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1231     auto ub =
1232         rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1233     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1234     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1235 
1236     // Generate for loop.
1237     rewriter.replaceOpWithNewOp<scf::ForOp>(
1238         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1239         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1240           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1241         });
1242 
1243     return success();
1244   }
1245 };
1246 
1247 } // namespace lowering_1_d
1248 } // namespace
1249 
1250 namespace mlir {
1251 
1252 void populateVectorToSCFConversionPatterns(
1253     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1254   if (options.unroll) {
1255     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1256                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1257         patterns.getContext(), options);
1258   } else {
1259     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1260                  lowering_n_d::PrepareTransferWriteConversion,
1261                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1262                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1263         patterns.getContext(), options);
1264   }
1265 
1266   if (options.targetRank == 1) {
1267     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1268                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1269         patterns.getContext(), options);
1270   }
1271 }
1272 
1273 } // namespace mlir
1274 
1275 namespace {
1276 
1277 struct ConvertVectorToSCFPass
1278     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1279   ConvertVectorToSCFPass() = default;
1280   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1281     this->fullUnroll = options.unroll;
1282     this->targetRank = options.targetRank;
1283     this->lowerPermutationMaps = options.lowerPermutationMaps;
1284     this->lowerTensors = options.lowerTensors;
1285   }
1286 
1287   void runOnFunction() override {
1288     VectorTransferToSCFOptions options;
1289     options.unroll = fullUnroll;
1290     options.targetRank = targetRank;
1291     options.lowerPermutationMaps = lowerPermutationMaps;
1292     options.lowerTensors = lowerTensors;
1293 
1294     // Lower permutation maps first.
1295     if (lowerPermutationMaps) {
1296       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1297       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1298           lowerTransferPatterns);
1299       (void)applyPatternsAndFoldGreedily(getFunction(),
1300                                          std::move(lowerTransferPatterns));
1301     }
1302 
1303     RewritePatternSet patterns(getFunction().getContext());
1304     populateVectorToSCFConversionPatterns(patterns, options);
1305     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
1306   }
1307 };
1308 
1309 } // namespace
1310 
1311 std::unique_ptr<Pass>
1312 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1313   return std::make_unique<ConvertVectorToSCFPass>(options);
1314 }
1315