1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/Dialect/Vector/VectorUtils.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 using namespace mlir;
30 using vector::TransferReadOp;
31 using vector::TransferWriteOp;
32 
33 namespace {
34 
35 /// Attribute name used for labeling transfer ops during progressive lowering.
36 static const char kPassLabel[] = "__vector_to_scf_lowering__";
37 
38 /// Patterns that inherit from this struct have access to
39 /// VectorTransferToSCFOptions.
40 template <typename OpTy>
41 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
42   explicit VectorToSCFPattern(MLIRContext *context,
43                               VectorTransferToSCFOptions opt)
44       : OpRewritePattern<OpTy>(context), options(opt) {}
45 
46   VectorTransferToSCFOptions options;
47 };
48 
49 /// Given a vector transfer op, calculate which dimension of the `source`
50 /// memref should be unpacked in the next application of TransferOpConversion.
51 /// A return value of None indicates a broadcast.
52 template <typename OpTy>
53 static Optional<int64_t> unpackedDim(OpTy xferOp) {
54   auto map = xferOp.permutation_map();
55   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
56     return expr.getPosition();
57   }
58   assert(xferOp.isBroadcastDim(0) &&
59          "Expected AffineDimExpr or AffineConstantExpr");
60   return None;
61 }
62 
63 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
64 /// map is identical to the current permutation map, but the first result is
65 /// omitted.
66 template <typename OpTy>
67 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
68   auto map = xferOp.permutation_map();
69   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
70                         b.getContext());
71 }
72 
73 /// Calculate the indices for the new vector transfer op.
74 ///
75 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
76 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
77 ///                                 ^^^^^^
78 ///              `iv` is the iteration variable of the (new) surrounding loop.
79 template <typename OpTy>
80 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
81                            SmallVector<Value, 8> &indices) {
82   typename OpTy::Adaptor adaptor(xferOp);
83   // Corresponding memref dim of the vector dim that is unpacked.
84   auto dim = unpackedDim(xferOp);
85   auto prevIndices = adaptor.indices();
86   indices.append(prevIndices.begin(), prevIndices.end());
87 
88   Location loc = xferOp.getLoc();
89   bool isBroadcast = !dim.hasValue();
90   if (!isBroadcast) {
91     AffineExpr d0, d1;
92     bindDims(xferOp.getContext(), d0, d1);
93     Value offset = adaptor.indices()[dim.getValue()];
94     indices[dim.getValue()] =
95         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
96   }
97 }
98 
99 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
100                             Value value) {
101   if (hasRetVal) {
102     assert(value && "Expected non-empty value");
103     b.create<scf::YieldOp>(loc, value);
104   } else {
105     b.create<scf::YieldOp>(loc);
106   }
107 }
108 
109 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
110 /// is set to true. No such check is generated under following circumstances:
111 /// * xferOp does not have a mask.
112 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
113 ///   computed and attached to the new transfer op in the pattern.)
114 /// * The to-be-unpacked dim of xferOp is a broadcast.
115 template <typename OpTy>
116 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
117   if (!xferOp.mask())
118     return Value();
119   if (xferOp.getMaskType().getRank() != 1)
120     return Value();
121   if (xferOp.isBroadcastDim(0))
122     return Value();
123 
124   Location loc = xferOp.getLoc();
125   Value ivI32 =
126       b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
127   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
128 }
129 
130 /// Helper function TransferOpConversion and TransferOp1dConversion.
131 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
132 /// specified dimension `dim` with the loop iteration variable `iv`.
133 /// E.g., when unpacking dimension 0 from:
134 /// ```
135 /// %vec = vector.transfer_read %A[%a, %b] %cst
136 ///     : vector<5x4xf32>, memref<?x?xf32>
137 /// ```
138 /// An if check similar to this will be generated inside the loop:
139 /// ```
140 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
141 /// if (%a + iv < %d) {
142 ///   (in-bounds case)
143 /// } else {
144 ///   (out-of-bounds case)
145 /// }
146 /// ```
147 ///
148 /// If the transfer is 1D and has a mask, this function generates a more complex
149 /// check also accounts for potentially masked out elements.
150 ///
151 /// This function variant returns the value returned by `inBoundsCase` or
152 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
153 /// `resultTypes`.
154 template <typename OpTy>
155 static Value generateInBoundsCheck(
156     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
157     TypeRange resultTypes,
158     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
159     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
160   bool hasRetVal = !resultTypes.empty();
161   Value cond; // Condition to be built...
162 
163   // Condition check 1: Access in-bounds?
164   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
165   Location loc = xferOp.getLoc();
166   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
167   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
168     Value memrefDim = lb.create<memref::DimOp>(xferOp.source(), *dim);
169     AffineExpr d0, d1;
170     bindDims(xferOp.getContext(), d0, d1);
171     Value base = xferOp.indices()[dim.getValue()];
172     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
173     cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx);
174   }
175 
176   // Condition check 2: Masked in?
177   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
178     if (cond)
179       cond = lb.create<AndOp>(cond, maskCond);
180     else
181       cond = maskCond;
182   }
183 
184   // If the condition is non-empty, generate an SCF::IfOp.
185   if (cond) {
186     auto check = lb.create<scf::IfOp>(
187         resultTypes, cond,
188         /*thenBuilder=*/
189         [&](OpBuilder &b, Location loc) {
190           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
191         },
192         /*elseBuilder=*/
193         [&](OpBuilder &b, Location loc) {
194           if (outOfBoundsCase) {
195             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
196           } else {
197             b.create<scf::YieldOp>(loc);
198           }
199         });
200 
201     return hasRetVal ? check.getResult(0) : Value();
202   }
203 
204   // Condition is empty, no need for an SCF::IfOp.
205   return inBoundsCase(b, loc);
206 }
207 
208 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
209 /// a return value. Consequently, this function does not have a return value.
210 template <typename OpTy>
211 static void generateInBoundsCheck(
212     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
213     function_ref<void(OpBuilder &, Location)> inBoundsCase,
214     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
215   generateInBoundsCheck(
216       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
217       /*inBoundsCase=*/
218       [&](OpBuilder &b, Location loc) {
219         inBoundsCase(b, loc);
220         return Value();
221       },
222       /*outOfBoundsCase=*/
223       [&](OpBuilder &b, Location loc) {
224         if (outOfBoundsCase)
225           outOfBoundsCase(b, loc);
226         return Value();
227       });
228 }
229 
230 /// Given an ArrayAttr, return a copy where the first element is dropped.
231 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
232   if (!attr)
233     return attr;
234   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
235 }
236 
237 /// Add the pass label to a vector transfer op if its rank is not the target
238 /// rank.
239 template <typename OpTy>
240 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
241                                 unsigned targetRank) {
242   if (newXferOp.getVectorType().getRank() > targetRank)
243     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
244 }
245 
246 /// Return true if this transfer op operates on a source tensor.
247 template <typename OpTy>
248 static bool isTensorOp(OpTy xferOp) {
249   if (xferOp.getShapedType().template isa<RankedTensorType>()) {
250     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
251       // TransferWriteOps on tensors have a result.
252       assert(xferOp->getNumResults() > 0);
253     }
254     return true;
255   }
256   return false;
257 }
258 
259 namespace lowering_n_d {
260 
261 /// Helper data structure for data and mask buffers.
262 struct BufferAllocs {
263   Value dataBuffer;
264   Value maskBuffer;
265 };
266 
267 /// Allocate temporary buffers for data (vector) and mask (if present).
268 /// TODO: Parallelism and threadlocal considerations.
269 template <typename OpTy>
270 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
271   Location loc = xferOp.getLoc();
272   OpBuilder::InsertionGuard guard(b);
273   Operation *scope =
274       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
275   assert(scope && "Expected op to be inside automatic allocation scope");
276   b.setInsertionPointToStart(&scope->getRegion(0).front());
277 
278   BufferAllocs result;
279   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
280   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
281 
282   if (xferOp.mask()) {
283     auto maskType = MemRefType::get({}, xferOp.mask().getType());
284     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
285     b.setInsertionPoint(xferOp);
286     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
287     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
288   }
289 
290   return result;
291 }
292 
293 /// Given a MemRefType with VectorType element type, unpack one dimension from
294 /// the VectorType into the MemRefType.
295 ///
296 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
297 static MemRefType unpackOneDim(MemRefType type) {
298   auto vectorType = type.getElementType().dyn_cast<VectorType>();
299   auto memrefShape = type.getShape();
300   SmallVector<int64_t, 8> newMemrefShape;
301   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
302   newMemrefShape.push_back(vectorType.getDimSize(0));
303   return MemRefType::get(newMemrefShape,
304                          VectorType::get(vectorType.getShape().drop_front(),
305                                          vectorType.getElementType()));
306 }
307 
308 /// Given a transfer op, find the memref from which the mask is loaded. This
309 /// is similar to Strategy<TransferWriteOp>::getBuffer.
310 template <typename OpTy>
311 static Value getMaskBuffer(OpTy xferOp) {
312   assert(xferOp.mask() && "Expected that transfer op has mask");
313   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
314   assert(loadOp && "Expected transfer op mask produced by LoadOp");
315   return loadOp.getMemRef();
316 }
317 
318 /// Codegen strategy, depending on the operation.
319 template <typename OpTy>
320 struct Strategy;
321 
322 /// Code strategy for vector TransferReadOp.
323 template <>
324 struct Strategy<TransferReadOp> {
325   /// Find the StoreOp that is used for writing the current TransferReadOp's
326   /// result to the temporary buffer allocation.
327   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
328     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
329     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
330     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
331     return storeOp;
332   }
333 
334   /// Find the temporary buffer allocation. All labeled TransferReadOps are
335   /// used like this, where %buf is either the buffer allocation or a type cast
336   /// of the buffer allocation:
337   /// ```
338   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
339   /// memref.store %vec, %buf[...] ...
340   /// ```
341   static Value getBuffer(TransferReadOp xferOp) {
342     return getStoreOp(xferOp).getMemRef();
343   }
344 
345   /// Retrieve the indices of the current StoreOp that stores into the buffer.
346   static void getBufferIndices(TransferReadOp xferOp,
347                                SmallVector<Value, 8> &indices) {
348     auto storeOp = getStoreOp(xferOp);
349     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
350     indices.append(prevIndices.begin(), prevIndices.end());
351   }
352 
353   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
354   /// accesses on the to-be-unpacked dimension.
355   ///
356   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
357   ///    variable `iv`.
358   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
359   ///
360   /// E.g.:
361   /// ```
362   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
363   ///     : memref<?x?x?xf32>, vector<4x3xf32>
364   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
365   /// ```
366   /// Is rewritten to:
367   /// ```
368   /// %casted = vector.type_cast %buf
369   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
370   /// for %j = 0 to 4 {
371   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
372   ///       : memref<?x?x?xf32>, vector<3xf32>
373   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
374   /// }
375   /// ```
376   ///
377   /// Note: The loop and type cast are generated in TransferOpConversion.
378   ///       The original TransferReadOp and store op are deleted in `cleanup`.
379   /// Note: The `mask` operand is set in TransferOpConversion.
380   static TransferReadOp rewriteOp(OpBuilder &b,
381                                   VectorTransferToSCFOptions options,
382                                   TransferReadOp xferOp, Value buffer, Value iv,
383                                   ValueRange /*loopState*/) {
384     SmallVector<Value, 8> storeIndices;
385     getBufferIndices(xferOp, storeIndices);
386     storeIndices.push_back(iv);
387 
388     SmallVector<Value, 8> xferIndices;
389     getXferIndices(b, xferOp, iv, xferIndices);
390 
391     Location loc = xferOp.getLoc();
392     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
393     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
394     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
395     auto newXferOp = b.create<vector::TransferReadOp>(
396         loc, vecType, xferOp.source(), xferIndices,
397         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
398         Value(), inBoundsAttr);
399 
400     maybeApplyPassLabel(b, newXferOp, options.targetRank);
401 
402     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
403     return newXferOp;
404   }
405 
406   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
407   /// padding value to the temporary buffer.
408   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
409                                     Value buffer, Value iv,
410                                     ValueRange /*loopState*/) {
411     SmallVector<Value, 8> storeIndices;
412     getBufferIndices(xferOp, storeIndices);
413     storeIndices.push_back(iv);
414 
415     Location loc = xferOp.getLoc();
416     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
417     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
418     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
419     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
420 
421     return Value();
422   }
423 
424   /// Cleanup after rewriting the op.
425   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
426                       scf::ForOp /*forOp*/) {
427     rewriter.eraseOp(getStoreOp(xferOp));
428     rewriter.eraseOp(xferOp);
429   }
430 
431   /// Return the initial loop state for the generated scf.for loop.
432   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
433 };
434 
435 /// Codegen strategy for vector TransferWriteOp.
436 template <>
437 struct Strategy<TransferWriteOp> {
438   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
439   /// used like this, where %buf is either the buffer allocation or a type cast
440   /// of the buffer allocation:
441   /// ```
442   /// %vec = memref.load %buf[...] ...
443   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
444   /// ```
445   static Value getBuffer(TransferWriteOp xferOp) {
446     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
447     assert(loadOp && "Expected transfer op vector produced by LoadOp");
448     return loadOp.getMemRef();
449   }
450 
451   /// Retrieve the indices of the current LoadOp that loads from the buffer.
452   static void getBufferIndices(TransferWriteOp xferOp,
453                                SmallVector<Value, 8> &indices) {
454     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
455     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
456     indices.append(prevIndices.begin(), prevIndices.end());
457   }
458 
459   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
460   /// accesses on the to-be-unpacked dimension.
461   ///
462   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
463   ///    using the loop iteration variable `iv`.
464   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
465   ///    to memory.
466   ///
467   /// Note: For more details, see comments on Strategy<TransferReadOp>.
468   static TransferWriteOp rewriteOp(OpBuilder &b,
469                                    VectorTransferToSCFOptions options,
470                                    TransferWriteOp xferOp, Value buffer,
471                                    Value iv, ValueRange loopState) {
472     SmallVector<Value, 8> loadIndices;
473     getBufferIndices(xferOp, loadIndices);
474     loadIndices.push_back(iv);
475 
476     SmallVector<Value, 8> xferIndices;
477     getXferIndices(b, xferOp, iv, xferIndices);
478 
479     Location loc = xferOp.getLoc();
480     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
481     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
482     auto source = loopState.empty() ? xferOp.source() : loopState[0];
483     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
484     auto newXferOp = b.create<vector::TransferWriteOp>(
485         loc, type, vec, source, xferIndices,
486         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
487         inBoundsAttr);
488 
489     maybeApplyPassLabel(b, newXferOp, options.targetRank);
490 
491     return newXferOp;
492   }
493 
494   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
495   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
496                                     Value buffer, Value iv,
497                                     ValueRange loopState) {
498     return isTensorOp(xferOp) ? loopState[0] : Value();
499   }
500 
501   /// Cleanup after rewriting the op.
502   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
503                       scf::ForOp forOp) {
504     if (isTensorOp(xferOp)) {
505       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
506       rewriter.replaceOp(xferOp, forOp->getResult(0));
507     } else {
508       rewriter.eraseOp(xferOp);
509     }
510   }
511 
512   /// Return the initial loop state for the generated scf.for loop.
513   static Value initialLoopState(TransferWriteOp xferOp) {
514     return isTensorOp(xferOp) ? xferOp.source() : Value();
515   }
516 };
517 
518 template <typename OpTy>
519 LogicalResult checkPrepareXferOp(OpTy xferOp,
520                                  VectorTransferToSCFOptions options) {
521   if (xferOp->hasAttr(kPassLabel))
522     return failure();
523   if (xferOp.getVectorType().getRank() <= options.targetRank)
524     return failure();
525   if (isTensorOp(xferOp) && !options.lowerTensors)
526     return failure();
527   // Transfer ops that modify the element type are not supported atm.
528   if (xferOp.getVectorType().getElementType() !=
529       xferOp.getShapedType().getElementType())
530     return failure();
531   return success();
532 }
533 
534 /// Prepare a TransferReadOp for progressive lowering.
535 ///
536 /// 1. Allocate a temporary buffer.
537 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
538 /// 3. Store the result of the TransferReadOp into the temporary buffer.
539 /// 4. Load the result from the temporary buffer and replace all uses of the
540 ///    original TransferReadOp with this load.
541 ///
542 /// E.g.:
543 /// ```
544 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
545 ///     : vector<5x4xf32>, memref<?x?x?xf32>
546 /// ```
547 /// is rewritten to:
548 /// ```
549 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
550 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
551 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
552 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
553 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
554 /// ```
555 ///
556 /// Note: A second temporary buffer may be allocated for the `mask` operand.
557 struct PrepareTransferReadConversion
558     : public VectorToSCFPattern<TransferReadOp> {
559   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
560 
561   LogicalResult matchAndRewrite(TransferReadOp xferOp,
562                                 PatternRewriter &rewriter) const override {
563     if (checkPrepareXferOp(xferOp, options).failed())
564       return failure();
565 
566     auto buffers = allocBuffers(rewriter, xferOp);
567     auto *newXfer = rewriter.clone(*xferOp.getOperation());
568     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
569     if (xferOp.mask()) {
570       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
571           buffers.maskBuffer);
572     }
573 
574     Location loc = xferOp.getLoc();
575     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
576                                      buffers.dataBuffer);
577     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
578 
579     return success();
580   }
581 };
582 
583 /// Prepare a TransferWriteOp for progressive lowering.
584 ///
585 /// 1. Allocate a temporary buffer.
586 /// 2. Store the vector into the buffer.
587 /// 3. Load the vector from the buffer again.
588 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
589 ///    marking it eligible for progressive lowering via TransferOpConversion.
590 ///
591 /// E.g.:
592 /// ```
593 /// vector.transfer_write %vec, %A[%a, %b, %c]
594 ///     : vector<5x4xf32>, memref<?x?x?xf32>
595 /// ```
596 /// is rewritten to:
597 /// ```
598 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
599 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
600 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
601 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
602 ///     : vector<5x4xf32>, memref<?x?x?xf32>
603 /// ```
604 ///
605 /// Note: A second temporary buffer may be allocated for the `mask` operand.
606 struct PrepareTransferWriteConversion
607     : public VectorToSCFPattern<TransferWriteOp> {
608   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
609 
610   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
611                                 PatternRewriter &rewriter) const override {
612     if (checkPrepareXferOp(xferOp, options).failed())
613       return failure();
614 
615     Location loc = xferOp.getLoc();
616     auto buffers = allocBuffers(rewriter, xferOp);
617     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
618     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
619     rewriter.updateRootInPlace(xferOp, [&]() {
620       xferOp.vectorMutable().assign(loadedVec);
621       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
622     });
623 
624     if (xferOp.mask()) {
625       rewriter.updateRootInPlace(
626           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
627     }
628 
629     return success();
630   }
631 };
632 
633 /// Progressive lowering of vector transfer ops: Unpack one dimension.
634 ///
635 /// 1. Unpack one dimension from the current buffer type and cast the buffer
636 ///    to that new type. E.g.:
637 ///    ```
638 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
639 ///    vector.transfer_write %vec ...
640 ///    ```
641 ///    The following cast is generated:
642 ///    ```
643 ///    %casted = vector.type_cast %0
644 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
645 ///    ```
646 /// 2. Generate a for loop and rewrite the transfer op according to the
647 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
648 ///    out-of-bounds, generate an if-check and handle both cases separately.
649 /// 3. Clean up according to the corresponding Strategy<OpTy>.
650 ///
651 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
652 /// source (as opposed to a memref source), then each iteration of the generated
653 /// scf.for loop yields the new tensor value. E.g.:
654 /// ```
655 /// %result = scf.for i = 0 to 5 {
656 ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
657 ///   %1 = vector.transfer_write %0, %source[...]
658 ///       : vector<4x3xf32>, tensor<5x4x3xf32>
659 ///   scf.yield %1 : tensor<5x4x3xf32>
660 /// }
661 /// ```
662 template <typename OpTy>
663 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
664   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
665 
666   LogicalResult matchAndRewrite(OpTy xferOp,
667                                 PatternRewriter &rewriter) const override {
668     if (!xferOp->hasAttr(kPassLabel))
669       return failure();
670 
671     // Find and cast data buffer. How the buffer can be found depends on OpTy.
672     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
673     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
674     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
675     auto castedDataType = unpackOneDim(dataBufferType);
676     auto castedDataBuffer =
677         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
678 
679     // If the xferOp has a mask: Find and cast mask buffer.
680     Value castedMaskBuffer;
681     if (xferOp.mask()) {
682       auto maskBuffer = getMaskBuffer(xferOp);
683       auto maskBufferType =
684           maskBuffer.getType().template dyn_cast<MemRefType>();
685       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
686         // Do not unpack a dimension of the mask, if:
687         // * To-be-unpacked transfer op dimension is a broadcast.
688         // * Mask is 1D, i.e., the mask cannot be further unpacked.
689         //   (That means that all remaining dimensions of the transfer op must
690         //   be broadcasted.)
691         castedMaskBuffer = maskBuffer;
692       } else {
693         auto castedMaskType = unpackOneDim(maskBufferType);
694         castedMaskBuffer =
695             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
696       }
697     }
698 
699     // Loop bounds and step.
700     auto lb = locB.create<ConstantIndexOp>(0);
701     auto ub = locB.create<ConstantIndexOp>(
702         castedDataType.getDimSize(castedDataType.getRank() - 1));
703     auto step = locB.create<ConstantIndexOp>(1);
704     // TransferWriteOps that operate on tensors return the modified tensor and
705     // require a loop state.
706     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
707 
708     // Generate for loop.
709     auto result = locB.create<scf::ForOp>(
710         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
711         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
712           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
713 
714           auto result = generateInBoundsCheck(
715               b, xferOp, iv, unpackedDim(xferOp),
716               stateType ? TypeRange(stateType) : TypeRange(),
717               /*inBoundsCase=*/
718               [&](OpBuilder &b, Location loc) {
719                 // Create new transfer op.
720                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
721                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
722 
723                 // If old transfer op has a mask: Set mask on new transfer op.
724                 // Special case: If the mask of the old transfer op is 1D and
725                 // the
726                 //               unpacked dim is not a broadcast, no mask is
727                 //               needed on the new transfer op.
728                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
729                                       xferOp.getMaskType().getRank() > 1)) {
730                   OpBuilder::InsertionGuard guard(b);
731                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
732 
733                   SmallVector<Value, 8> loadIndices;
734                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
735                   // In case of broadcast: Use same indices to load from memref
736                   // as before.
737                   if (!xferOp.isBroadcastDim(0))
738                     loadIndices.push_back(iv);
739 
740                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
741                                                        loadIndices);
742                   rewriter.updateRootInPlace(
743                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
744                 }
745 
746                 return loopState.empty() ? Value() : newXfer->getResult(0);
747               },
748               /*outOfBoundsCase=*/
749               [&](OpBuilder &b, Location /*loc*/) {
750                 return Strategy<OpTy>::handleOutOfBoundsDim(
751                     b, xferOp, castedDataBuffer, iv, loopState);
752               });
753 
754           maybeYieldValue(b, loc, !loopState.empty(), result);
755         });
756 
757     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
758     return success();
759   }
760 };
761 
762 } // namespace lowering_n_d
763 
764 namespace lowering_n_d_unrolled {
765 
766 /// If the original transfer op has a mask, compute the mask of the new transfer
767 /// op (for the current iteration `i`) and assign it.
768 template <typename OpTy>
769 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
770                             int64_t i) {
771   if (!xferOp.mask())
772     return;
773 
774   if (xferOp.isBroadcastDim(0)) {
775     // To-be-unpacked dimension is a broadcast, which does not have a
776     // corresponding mask dimension. Mask attribute remains unchanged.
777     newXferOp.maskMutable().assign(xferOp.mask());
778     return;
779   }
780 
781   if (xferOp.getMaskType().getRank() > 1) {
782     // Unpack one dimension of the mask.
783     OpBuilder::InsertionGuard guard(b);
784     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
785 
786     llvm::SmallVector<int64_t, 1> indices({i});
787     Location loc = xferOp.getLoc();
788     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
789     newXferOp.maskMutable().assign(newMask);
790   }
791 
792   // If we end up here: The mask of the old transfer op is 1D and the unpacked
793   // dim is not a broadcast, so no mask is needed on the new transfer op.
794   // `generateInBoundsCheck` will have evaluated the mask already.
795 }
796 
797 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
798 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
799 /// memref buffer is allocated and the SCF loop is fully unrolled.
800 ///
801 /// ```
802 /// E.g.:
803 /// ```
804 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
805 ///     : memref<?x?x?xf32>, vector<5x4xf32>
806 /// ```
807 /// is rewritten to IR such as (simplified):
808 /// ```
809 /// %v_init = splat %padding : vector<5x4xf32>
810 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
811 ///     : memref<?x?x?xf32>, vector<4xf32>
812 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
813 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
814 ///     : memref<?x?x?xf32>, vector<4xf32>
815 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
816 /// ...
817 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
818 ///     : memref<?x?x?xf32>, vector<4xf32>
819 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
820 /// ```
821 ///
822 /// Note: As an optimization, if the result of the original TransferReadOp
823 /// was directly inserted into another vector, no new %v_init vector is created.
824 /// Instead, the new TransferReadOp results are inserted into that vector.
825 struct UnrollTransferReadConversion
826     : public VectorToSCFPattern<TransferReadOp> {
827   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
828 
829   /// Return the vector into which the newly created TransferReadOp results
830   /// are inserted.
831   Value getResultVector(TransferReadOp xferOp,
832                         PatternRewriter &rewriter) const {
833     if (auto insertOp = getInsertOp(xferOp))
834       return insertOp.dest();
835     Location loc = xferOp.getLoc();
836     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
837                                     xferOp.padding());
838   }
839 
840   /// If the result of the TransferReadOp has exactly one user, which is a
841   /// vector::InsertOp, return that operation.
842   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
843     if (xferOp->hasOneUse()) {
844       Operation *xferOpUser = *xferOp->getUsers().begin();
845       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
846         return insertOp;
847     }
848 
849     return vector::InsertOp();
850   }
851 
852   /// If the result of the TransferReadOp has exactly one user, which is a
853   /// vector::InsertOp, return that operation's indices.
854   void getInsertionIndices(TransferReadOp xferOp,
855                            SmallVector<int64_t, 8> &indices) const {
856     if (auto insertOp = getInsertOp(xferOp)) {
857       llvm::for_each(insertOp.position(), [&](Attribute attr) {
858         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
859       });
860     }
861   }
862 
863   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
864   /// accesses, and broadcasts and transposes in permutation maps.
865   LogicalResult matchAndRewrite(TransferReadOp xferOp,
866                                 PatternRewriter &rewriter) const override {
867     if (xferOp.getVectorType().getRank() <= options.targetRank)
868       return failure();
869     if (isTensorOp(xferOp) && !options.lowerTensors)
870       return failure();
871     // Transfer ops that modify the element type are not supported atm.
872     if (xferOp.getVectorType().getElementType() !=
873         xferOp.getShapedType().getElementType())
874       return failure();
875 
876     auto insertOp = getInsertOp(xferOp);
877     auto vec = getResultVector(xferOp, rewriter);
878     auto vecType = vec.getType().dyn_cast<VectorType>();
879     auto xferVecType = xferOp.getVectorType();
880     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
881                                           xferVecType.getElementType());
882     int64_t dimSize = xferVecType.getShape()[0];
883 
884     // Generate fully unrolled loop of transfer ops.
885     Location loc = xferOp.getLoc();
886     for (int64_t i = 0; i < dimSize; ++i) {
887       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
888 
889       vec = generateInBoundsCheck(
890           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
891           /*inBoundsCase=*/
892           [&](OpBuilder &b, Location loc) {
893             // Indices for the new transfer op.
894             SmallVector<Value, 8> xferIndices;
895             getXferIndices(b, xferOp, iv, xferIndices);
896 
897             // Indices for the new vector.insert op.
898             SmallVector<int64_t, 8> insertionIndices;
899             getInsertionIndices(xferOp, insertionIndices);
900             insertionIndices.push_back(i);
901 
902             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
903             auto newXferOp = b.create<vector::TransferReadOp>(
904                 loc, newXferVecType, xferOp.source(), xferIndices,
905                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
906                 xferOp.padding(), Value(), inBoundsAttr);
907             maybeAssignMask(b, xferOp, newXferOp, i);
908             return b.create<vector::InsertOp>(loc, newXferOp, vec,
909                                               insertionIndices);
910           },
911           /*outOfBoundsCase=*/
912           [&](OpBuilder &b, Location loc) {
913             // Loop through original (unmodified) vector.
914             return vec;
915           });
916     }
917 
918     if (insertOp) {
919       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
920       rewriter.replaceOp(insertOp, vec);
921       rewriter.eraseOp(xferOp);
922     } else {
923       rewriter.replaceOp(xferOp, vec);
924     }
925 
926     return success();
927   }
928 };
929 
930 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
931 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
932 /// memref buffer is allocated and the SCF loop is fully unrolled.
933 ///
934 /// ```
935 /// E.g.:
936 /// ```
937 /// vector.transfer_write %vec, %A[%a, %b, %c]
938 ///     : vector<5x4xf32>, memref<?x?x?xf32>
939 /// ```
940 /// is rewritten to IR such as (simplified):
941 /// ```
942 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
943 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
944 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
945 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
946 /// ...
947 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
948 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
949 /// ```
950 ///
951 /// Note: As an optimization, if the vector of the original TransferWriteOp
952 /// was directly extracted from another vector via an ExtractOp `a`, extract
953 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
954 /// doing so, `a` may become dead, and the number of ExtractOps generated during
955 /// recursive application of this pattern will be minimal.
956 struct UnrollTransferWriteConversion
957     : public VectorToSCFPattern<TransferWriteOp> {
958   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
959 
960   /// Return the vector from which newly generated ExtracOps will extract.
961   Value getDataVector(TransferWriteOp xferOp) const {
962     if (auto extractOp = getExtractOp(xferOp))
963       return extractOp.vector();
964     return xferOp.vector();
965   }
966 
967   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
968   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
969     if (auto *op = xferOp.vector().getDefiningOp())
970       return dyn_cast<vector::ExtractOp>(op);
971     return vector::ExtractOp();
972   }
973 
974   /// If the input of the given TransferWriteOp is an ExtractOp, return its
975   /// indices.
976   void getExtractionIndices(TransferWriteOp xferOp,
977                             SmallVector<int64_t, 8> &indices) const {
978     if (auto extractOp = getExtractOp(xferOp)) {
979       llvm::for_each(extractOp.position(), [&](Attribute attr) {
980         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
981       });
982     }
983   }
984 
985   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
986   /// accesses, and broadcasts and transposes in permutation maps.
987   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
988                                 PatternRewriter &rewriter) const override {
989     if (xferOp.getVectorType().getRank() <= options.targetRank)
990       return failure();
991     if (isTensorOp(xferOp) && !options.lowerTensors)
992       return failure();
993     // Transfer ops that modify the element type are not supported atm.
994     if (xferOp.getVectorType().getElementType() !=
995         xferOp.getShapedType().getElementType())
996       return failure();
997 
998     auto vec = getDataVector(xferOp);
999     auto xferVecType = xferOp.getVectorType();
1000     int64_t dimSize = xferVecType.getShape()[0];
1001     auto source = xferOp.source(); // memref or tensor to be written to.
1002     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1003 
1004     // Generate fully unrolled loop of transfer ops.
1005     Location loc = xferOp.getLoc();
1006     for (int64_t i = 0; i < dimSize; ++i) {
1007       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
1008 
1009       auto updatedSource = generateInBoundsCheck(
1010           rewriter, xferOp, iv, unpackedDim(xferOp),
1011           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1012           /*inBoundsCase=*/
1013           [&](OpBuilder &b, Location loc) {
1014             // Indices for the new transfer op.
1015             SmallVector<Value, 8> xferIndices;
1016             getXferIndices(b, xferOp, iv, xferIndices);
1017 
1018             // Indices for the new vector.extract op.
1019             SmallVector<int64_t, 8> extractionIndices;
1020             getExtractionIndices(xferOp, extractionIndices);
1021             extractionIndices.push_back(i);
1022 
1023             auto extracted =
1024                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1025             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
1026             auto newXferOp = b.create<vector::TransferWriteOp>(
1027                 loc, sourceType, extracted, source, xferIndices,
1028                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1029                 inBoundsAttr);
1030 
1031             maybeAssignMask(b, xferOp, newXferOp, i);
1032 
1033             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1034           },
1035           /*outOfBoundsCase=*/
1036           [&](OpBuilder &b, Location loc) {
1037             return isTensorOp(xferOp) ? source : Value();
1038           });
1039 
1040       if (isTensorOp(xferOp))
1041         source = updatedSource;
1042     }
1043 
1044     if (isTensorOp(xferOp))
1045       rewriter.replaceOp(xferOp, source);
1046     else
1047       rewriter.eraseOp(xferOp);
1048 
1049     return success();
1050   }
1051 };
1052 
1053 } // namespace lowering_n_d_unrolled
1054 
1055 namespace lowering_1_d {
1056 
1057 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1058 /// part of TransferOp1dConversion. Return the memref dimension on which
1059 /// the transfer is operating. A return value of None indicates a broadcast.
1060 template <typename OpTy>
1061 static Optional<int64_t>
1062 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1063                    SmallVector<Value, 8> &memrefIndices) {
1064   auto indices = xferOp.indices();
1065   auto map = xferOp.permutation_map();
1066 
1067   memrefIndices.append(indices.begin(), indices.end());
1068   assert(map.getNumResults() == 1 &&
1069          "Expected 1 permutation map result for 1D transfer");
1070   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1071     Location loc = xferOp.getLoc();
1072     auto dim = expr.getPosition();
1073     AffineExpr d0, d1;
1074     bindDims(xferOp.getContext(), d0, d1);
1075     Value offset = memrefIndices[dim];
1076     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1077     return dim;
1078   }
1079 
1080   assert(xferOp.isBroadcastDim(0) &&
1081          "Expected AffineDimExpr or AffineConstantExpr");
1082   return None;
1083 }
1084 
1085 /// Codegen strategy for TransferOp1dConversion, depending on the
1086 /// operation.
1087 template <typename OpTy>
1088 struct Strategy1d;
1089 
1090 /// Codegen strategy for TransferReadOp.
1091 template <>
1092 struct Strategy1d<TransferReadOp> {
1093   static void generateForLoopBody(OpBuilder &b, Location loc,
1094                                   TransferReadOp xferOp, Value iv,
1095                                   ValueRange loopState) {
1096     SmallVector<Value, 8> indices;
1097     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1098     Value ivI32 =
1099         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1100     auto vec = loopState[0];
1101 
1102     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1103     // padding value).
1104     auto nextVec = generateInBoundsCheck(
1105         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1106         /*inBoundsCase=*/
1107         [&](OpBuilder &b, Location loc) {
1108           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
1109           return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
1110         },
1111         /*outOfBoundsCase=*/
1112         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1113     b.create<scf::YieldOp>(loc, nextVec);
1114   }
1115 
1116   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1117     // Inititalize vector with padding value.
1118     Location loc = xferOp.getLoc();
1119     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
1120   }
1121 };
1122 
1123 /// Codegen strategy for TransferWriteOp.
1124 template <>
1125 struct Strategy1d<TransferWriteOp> {
1126   static void generateForLoopBody(OpBuilder &b, Location loc,
1127                                   TransferWriteOp xferOp, Value iv,
1128                                   ValueRange /*loopState*/) {
1129     SmallVector<Value, 8> indices;
1130     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1131     Value ivI32 =
1132         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1133 
1134     // Nothing to do in case of out-of-bounds access.
1135     generateInBoundsCheck(
1136         b, xferOp, iv, dim,
1137         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1138           auto val =
1139               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
1140           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
1141         });
1142     b.create<scf::YieldOp>(loc);
1143   }
1144 
1145   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1146     return Value();
1147   }
1148 };
1149 
1150 /// Return true if the last dimension of the MemRefType has unit stride.
1151 static bool isLastMemrefDimUnitStride(MemRefType type) {
1152   int64_t offset;
1153   SmallVector<int64_t, 4> strides;
1154   auto successStrides = getStridesAndOffset(type, strides, offset);
1155   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1156 }
1157 
1158 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1159 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1160 /// vector load/stores due to non-unit strides or broadcasts:
1161 ///
1162 /// * Transfer dimension is not the last memref dimension
1163 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1164 /// * Memref has a layout map with non-unit stride on the last dimension
1165 ///
1166 /// This pattern generates IR as follows:
1167 ///
1168 /// 1. Generate a for loop iterating over each vector element.
1169 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1170 ///    depending on OpTy.
1171 ///
1172 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1173 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1174 ///       to ConvertVectorToLLVM.
1175 ///
1176 /// E.g.:
1177 /// ```
1178 /// vector.transfer_write %vec, %A[%a, %b]
1179 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1180 ///    : vector<9xf32>, memref<?x?xf32>
1181 /// ```
1182 /// Is rewritten to approximately the following pseudo-IR:
1183 /// ```
1184 /// for i = 0 to 9 {
1185 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1186 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1187 /// }
1188 /// ```
1189 template <typename OpTy>
1190 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1191   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1192 
1193   LogicalResult matchAndRewrite(OpTy xferOp,
1194                                 PatternRewriter &rewriter) const override {
1195     auto map = xferOp.permutation_map();
1196     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1197 
1198     if (!memRefType)
1199       return failure();
1200     if (xferOp.getVectorType().getRank() != 1)
1201       return failure();
1202     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1203       return failure(); // Handled by ConvertVectorToLLVM
1204 
1205     // Loop bounds, step, state...
1206     Location loc = xferOp.getLoc();
1207     auto vecType = xferOp.getVectorType();
1208     auto lb = rewriter.create<ConstantIndexOp>(loc, 0);
1209     auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0));
1210     auto step = rewriter.create<ConstantIndexOp>(loc, 1);
1211     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1212 
1213     // Generate for loop.
1214     rewriter.replaceOpWithNewOp<scf::ForOp>(
1215         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1216         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1217           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1218         });
1219 
1220     return success();
1221   }
1222 };
1223 
1224 } // namespace lowering_1_d
1225 } // namespace
1226 
1227 namespace mlir {
1228 
1229 void populateVectorToSCFConversionPatterns(
1230     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1231   if (options.unroll) {
1232     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1233                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1234         patterns.getContext(), options);
1235   } else {
1236     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1237                  lowering_n_d::PrepareTransferWriteConversion,
1238                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1239                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1240         patterns.getContext(), options);
1241   }
1242 
1243   if (options.targetRank == 1) {
1244     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1245                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1246         patterns.getContext(), options);
1247   }
1248 }
1249 
1250 } // namespace mlir
1251 
1252 namespace {
1253 
1254 struct ConvertVectorToSCFPass
1255     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1256   ConvertVectorToSCFPass() = default;
1257   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1258     this->fullUnroll = options.unroll;
1259     this->targetRank = options.targetRank;
1260     this->lowerPermutationMaps = options.lowerPermutationMaps;
1261     this->lowerTensors = options.lowerTensors;
1262   }
1263 
1264   void runOnFunction() override {
1265     VectorTransferToSCFOptions options;
1266     options.unroll = fullUnroll;
1267     options.targetRank = targetRank;
1268     options.lowerPermutationMaps = lowerPermutationMaps;
1269     options.lowerTensors = lowerTensors;
1270 
1271     // Lower permutation maps first.
1272     if (lowerPermutationMaps) {
1273       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1274       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1275           lowerTransferPatterns);
1276       (void)applyPatternsAndFoldGreedily(getFunction(),
1277                                          std::move(lowerTransferPatterns));
1278     }
1279 
1280     RewritePatternSet patterns(getFunction().getContext());
1281     populateVectorToSCFConversionPatterns(patterns, options);
1282     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
1283   }
1284 };
1285 
1286 } // namespace
1287 
1288 std::unique_ptr<Pass>
1289 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1290   return std::make_unique<ConvertVectorToSCFPass>(options);
1291 }
1292