1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/Dialect/Vector/VectorUtils.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 using namespace mlir;
30 using vector::TransferReadOp;
31 using vector::TransferWriteOp;
32 
33 namespace {
34 
35 /// Attribute name used for labeling transfer ops during progressive lowering.
36 static const char kPassLabel[] = "__vector_to_scf_lowering__";
37 
38 /// Patterns that inherit from this struct have access to
39 /// VectorTransferToSCFOptions.
40 template <typename OpTy>
41 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
42   explicit VectorToSCFPattern(MLIRContext *context,
43                               VectorTransferToSCFOptions opt)
44       : OpRewritePattern<OpTy>(context), options(opt) {}
45 
46   VectorTransferToSCFOptions options;
47 };
48 
49 /// Given a vector transfer op, calculate which dimension of the `source`
50 /// memref should be unpacked in the next application of TransferOpConversion.
51 /// A return value of None indicates a broadcast.
52 template <typename OpTy>
53 static Optional<int64_t> unpackedDim(OpTy xferOp) {
54   auto map = xferOp.permutation_map();
55   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
56     return expr.getPosition();
57   }
58   assert(xferOp.isBroadcastDim(0) &&
59          "Expected AffineDimExpr or AffineConstantExpr");
60   return None;
61 }
62 
63 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
64 /// map is identical to the current permutation map, but the first result is
65 /// omitted.
66 template <typename OpTy>
67 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
68   auto map = xferOp.permutation_map();
69   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
70                         b.getContext());
71 }
72 
73 /// Calculate the indices for the new vector transfer op.
74 ///
75 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
76 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
77 ///                                 ^^^^^^
78 ///              `iv` is the iteration variable of the (new) surrounding loop.
79 template <typename OpTy>
80 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
81                            SmallVector<Value, 8> &indices) {
82   typename OpTy::Adaptor adaptor(xferOp);
83   // Corresponding memref dim of the vector dim that is unpacked.
84   auto dim = unpackedDim(xferOp);
85   auto prevIndices = adaptor.indices();
86   indices.append(prevIndices.begin(), prevIndices.end());
87 
88   Location loc = xferOp.getLoc();
89   bool isBroadcast = !dim.hasValue();
90   if (!isBroadcast) {
91     AffineExpr d0, d1;
92     bindDims(xferOp.getContext(), d0, d1);
93     Value offset = adaptor.indices()[dim.getValue()];
94     indices[dim.getValue()] =
95         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
96   }
97 }
98 
99 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
100                             Value value) {
101   if (hasRetVal) {
102     assert(value && "Expected non-empty value");
103     b.create<scf::YieldOp>(loc, value);
104   } else {
105     b.create<scf::YieldOp>(loc);
106   }
107 }
108 
109 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
110 /// is set to true. No such check is generated under following circumstances:
111 /// * xferOp does not have a mask.
112 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
113 ///   computed and attached to the new transfer op in the pattern.)
114 /// * The to-be-unpacked dim of xferOp is a broadcast.
115 template <typename OpTy>
116 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
117   if (!xferOp.mask())
118     return Value();
119   if (xferOp.getMaskType().getRank() != 1)
120     return Value();
121   if (xferOp.isBroadcastDim(0))
122     return Value();
123 
124   Location loc = xferOp.getLoc();
125   Value ivI32 =
126       b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
127   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
128 }
129 
130 /// Helper function TransferOpConversion and TransferOp1dConversion.
131 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
132 /// specified dimension `dim` with the loop iteration variable `iv`.
133 /// E.g., when unpacking dimension 0 from:
134 /// ```
135 /// %vec = vector.transfer_read %A[%a, %b] %cst
136 ///     : vector<5x4xf32>, memref<?x?xf32>
137 /// ```
138 /// An if check similar to this will be generated inside the loop:
139 /// ```
140 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
141 /// if (%a + iv < %d) {
142 ///   (in-bounds case)
143 /// } else {
144 ///   (out-of-bounds case)
145 /// }
146 /// ```
147 ///
148 /// If the transfer is 1D and has a mask, this function generates a more complex
149 /// check also accounts for potentially masked out elements.
150 ///
151 /// This function variant returns the value returned by `inBoundsCase` or
152 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
153 /// `resultTypes`.
154 template <typename OpTy>
155 static Value generateInBoundsCheck(
156     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
157     TypeRange resultTypes,
158     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
159     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
160   bool hasRetVal = !resultTypes.empty();
161   Value cond; // Condition to be built...
162 
163   // Condition check 1: Access in-bounds?
164   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
165   Location loc = xferOp.getLoc();
166   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
167   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
168     Value memrefDim = lb.create<memref::DimOp>(xferOp.source(), *dim);
169     AffineExpr d0, d1;
170     bindDims(xferOp.getContext(), d0, d1);
171     Value base = xferOp.indices()[dim.getValue()];
172     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
173     cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx);
174   }
175 
176   // Condition check 2: Masked in?
177   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
178     if (cond)
179       cond = lb.create<AndOp>(cond, maskCond);
180     else
181       cond = maskCond;
182   }
183 
184   // If the condition is non-empty, generate an SCF::IfOp.
185   if (cond) {
186     auto check = lb.create<scf::IfOp>(
187         resultTypes, cond,
188         /*thenBuilder=*/
189         [&](OpBuilder &b, Location loc) {
190           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
191         },
192         /*elseBuilder=*/
193         [&](OpBuilder &b, Location loc) {
194           if (outOfBoundsCase) {
195             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
196           } else {
197             b.create<scf::YieldOp>(loc);
198           }
199         });
200 
201     return hasRetVal ? check.getResult(0) : Value();
202   }
203 
204   // Condition is empty, no need for an SCF::IfOp.
205   return inBoundsCase(b, loc);
206 }
207 
208 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
209 /// a return value. Consequently, this function does not have a return value.
210 template <typename OpTy>
211 static void generateInBoundsCheck(
212     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
213     function_ref<void(OpBuilder &, Location)> inBoundsCase,
214     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
215   generateInBoundsCheck(
216       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
217       /*inBoundsCase=*/
218       [&](OpBuilder &b, Location loc) {
219         inBoundsCase(b, loc);
220         return Value();
221       },
222       /*outOfBoundsCase=*/
223       [&](OpBuilder &b, Location loc) {
224         if (outOfBoundsCase)
225           outOfBoundsCase(b, loc);
226         return Value();
227       });
228 }
229 
230 /// Given an ArrayAttr, return a copy where the first element is dropped.
231 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
232   if (!attr)
233     return attr;
234   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
235 }
236 
237 /// Add the pass label to a vector transfer op if its rank is not the target
238 /// rank.
239 template <typename OpTy>
240 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
241                                 unsigned targetRank) {
242   if (newXferOp.getVectorType().getRank() > targetRank)
243     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
244 }
245 
246 /// Return true if this transfer op operates on a source tensor.
247 template <typename OpTy>
248 static bool isTensorOp(OpTy xferOp) {
249   if (xferOp.getShapedType().template isa<RankedTensorType>()) {
250     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
251       // TransferWriteOps on tensors have a result.
252       assert(xferOp->getNumResults() > 0);
253     }
254     return true;
255   }
256   return false;
257 }
258 
259 namespace lowering_n_d {
260 
261 /// Helper data structure for data and mask buffers.
262 struct BufferAllocs {
263   Value dataBuffer;
264   Value maskBuffer;
265 };
266 
267 /// Allocate temporary buffers for data (vector) and mask (if present).
268 /// TODO: Parallelism and threadlocal considerations.
269 template <typename OpTy>
270 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
271   Location loc = xferOp.getLoc();
272   OpBuilder::InsertionGuard guard(b);
273   Operation *scope =
274       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
275   assert(scope && "Expected op to be inside automatic allocation scope");
276   b.setInsertionPointToStart(&scope->getRegion(0).front());
277 
278   BufferAllocs result;
279   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
280   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
281 
282   if (xferOp.mask()) {
283     auto maskType = MemRefType::get({}, xferOp.mask().getType());
284     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
285     b.setInsertionPoint(xferOp);
286     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
287     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
288   }
289 
290   return result;
291 }
292 
293 /// Given a MemRefType with VectorType element type, unpack one dimension from
294 /// the VectorType into the MemRefType.
295 ///
296 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
297 static MemRefType unpackOneDim(MemRefType type) {
298   auto vectorType = type.getElementType().dyn_cast<VectorType>();
299   auto memrefShape = type.getShape();
300   SmallVector<int64_t, 8> newMemrefShape;
301   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
302   newMemrefShape.push_back(vectorType.getDimSize(0));
303   return MemRefType::get(newMemrefShape,
304                          VectorType::get(vectorType.getShape().drop_front(),
305                                          vectorType.getElementType()));
306 }
307 
308 /// Given a transfer op, find the memref from which the mask is loaded. This
309 /// is similar to Strategy<TransferWriteOp>::getBuffer.
310 template <typename OpTy>
311 static Value getMaskBuffer(OpTy xferOp) {
312   assert(xferOp.mask() && "Expected that transfer op has mask");
313   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
314   assert(loadOp && "Expected transfer op mask produced by LoadOp");
315   return loadOp.getMemRef();
316 }
317 
318 /// Codegen strategy, depending on the operation.
319 template <typename OpTy>
320 struct Strategy;
321 
322 /// Code strategy for vector TransferReadOp.
323 template <>
324 struct Strategy<TransferReadOp> {
325   /// Find the StoreOp that is used for writing the current TransferReadOp's
326   /// result to the temporary buffer allocation.
327   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
328     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
329     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
330     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
331     return storeOp;
332   }
333 
334   /// Find the temporary buffer allocation. All labeled TransferReadOps are
335   /// used like this, where %buf is either the buffer allocation or a type cast
336   /// of the buffer allocation:
337   /// ```
338   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
339   /// memref.store %vec, %buf[...] ...
340   /// ```
341   static Value getBuffer(TransferReadOp xferOp) {
342     return getStoreOp(xferOp).getMemRef();
343   }
344 
345   /// Retrieve the indices of the current StoreOp that stores into the buffer.
346   static void getBufferIndices(TransferReadOp xferOp,
347                                SmallVector<Value, 8> &indices) {
348     auto storeOp = getStoreOp(xferOp);
349     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
350     indices.append(prevIndices.begin(), prevIndices.end());
351   }
352 
353   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
354   /// accesses on the to-be-unpacked dimension.
355   ///
356   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
357   ///    variable `iv`.
358   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
359   ///
360   /// E.g.:
361   /// ```
362   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
363   ///     : memref<?x?x?xf32>, vector<4x3xf32>
364   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
365   /// ```
366   /// Is rewritten to:
367   /// ```
368   /// %casted = vector.type_cast %buf
369   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
370   /// for %j = 0 to 4 {
371   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
372   ///       : memref<?x?x?xf32>, vector<3xf32>
373   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
374   /// }
375   /// ```
376   ///
377   /// Note: The loop and type cast are generated in TransferOpConversion.
378   ///       The original TransferReadOp and store op are deleted in `cleanup`.
379   /// Note: The `mask` operand is set in TransferOpConversion.
380   static TransferReadOp rewriteOp(OpBuilder &b,
381                                   VectorTransferToSCFOptions options,
382                                   TransferReadOp xferOp, Value buffer, Value iv,
383                                   ValueRange /*loopState*/) {
384     SmallVector<Value, 8> storeIndices;
385     getBufferIndices(xferOp, storeIndices);
386     storeIndices.push_back(iv);
387 
388     SmallVector<Value, 8> xferIndices;
389     getXferIndices(b, xferOp, iv, xferIndices);
390 
391     Location loc = xferOp.getLoc();
392     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
393     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
394     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
395     auto newXferOp = b.create<vector::TransferReadOp>(
396         loc, vecType, xferOp.source(), xferIndices,
397         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
398         Value(), inBoundsAttr);
399 
400     maybeApplyPassLabel(b, newXferOp, options.targetRank);
401 
402     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
403     return newXferOp;
404   }
405 
406   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
407   /// padding value to the temporary buffer.
408   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
409                                     Value buffer, Value iv,
410                                     ValueRange /*loopState*/) {
411     SmallVector<Value, 8> storeIndices;
412     getBufferIndices(xferOp, storeIndices);
413     storeIndices.push_back(iv);
414 
415     Location loc = xferOp.getLoc();
416     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
417     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
418     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
419     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
420 
421     return Value();
422   }
423 
424   /// Cleanup after rewriting the op.
425   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
426                       scf::ForOp /*forOp*/) {
427     rewriter.eraseOp(getStoreOp(xferOp));
428     rewriter.eraseOp(xferOp);
429   }
430 
431   /// Return the initial loop state for the generated scf.for loop.
432   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
433 };
434 
435 /// Codegen strategy for vector TransferWriteOp.
436 template <>
437 struct Strategy<TransferWriteOp> {
438   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
439   /// used like this, where %buf is either the buffer allocation or a type cast
440   /// of the buffer allocation:
441   /// ```
442   /// %vec = memref.load %buf[...] ...
443   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
444   /// ```
445   static Value getBuffer(TransferWriteOp xferOp) {
446     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
447     assert(loadOp && "Expected transfer op vector produced by LoadOp");
448     return loadOp.getMemRef();
449   }
450 
451   /// Retrieve the indices of the current LoadOp that loads from the buffer.
452   static void getBufferIndices(TransferWriteOp xferOp,
453                                SmallVector<Value, 8> &indices) {
454     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
455     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
456     indices.append(prevIndices.begin(), prevIndices.end());
457   }
458 
459   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
460   /// accesses on the to-be-unpacked dimension.
461   ///
462   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
463   ///    using the loop iteration variable `iv`.
464   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
465   ///    to memory.
466   ///
467   /// Note: For more details, see comments on Strategy<TransferReadOp>.
468   static TransferWriteOp rewriteOp(OpBuilder &b,
469                                    VectorTransferToSCFOptions options,
470                                    TransferWriteOp xferOp, Value buffer,
471                                    Value iv, ValueRange loopState) {
472     SmallVector<Value, 8> loadIndices;
473     getBufferIndices(xferOp, loadIndices);
474     loadIndices.push_back(iv);
475 
476     SmallVector<Value, 8> xferIndices;
477     getXferIndices(b, xferOp, iv, xferIndices);
478 
479     Location loc = xferOp.getLoc();
480     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
481     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
482     auto source = loopState.empty() ? xferOp.source() : loopState[0];
483     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
484     auto newXferOp = b.create<vector::TransferWriteOp>(
485         loc, type, vec, source, xferIndices,
486         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
487         inBoundsAttr);
488 
489     maybeApplyPassLabel(b, newXferOp, options.targetRank);
490 
491     return newXferOp;
492   }
493 
494   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
495   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
496                                     Value buffer, Value iv,
497                                     ValueRange loopState) {
498     return isTensorOp(xferOp) ? loopState[0] : Value();
499   }
500 
501   /// Cleanup after rewriting the op.
502   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
503                       scf::ForOp forOp) {
504     if (isTensorOp(xferOp)) {
505       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
506       rewriter.replaceOp(xferOp, forOp->getResult(0));
507     } else {
508       rewriter.eraseOp(xferOp);
509     }
510   }
511 
512   /// Return the initial loop state for the generated scf.for loop.
513   static Value initialLoopState(TransferWriteOp xferOp) {
514     return isTensorOp(xferOp) ? xferOp.source() : Value();
515   }
516 };
517 
518 template <typename OpTy>
519 LogicalResult checkPrepareXferOp(OpTy xferOp,
520                                  VectorTransferToSCFOptions options) {
521   if (xferOp->hasAttr(kPassLabel))
522     return failure();
523   if (xferOp.getVectorType().getRank() <= options.targetRank)
524     return failure();
525   if (isTensorOp(xferOp) && !options.lowerTensors)
526     return failure();
527   // Transfer ops that modify the element type are not supported atm.
528   if (xferOp.getVectorType().getElementType() !=
529       xferOp.getShapedType().getElementType())
530     return failure();
531   return success();
532 }
533 
534 /// Prepare a TransferReadOp for progressive lowering.
535 ///
536 /// 1. Allocate a temporary buffer.
537 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
538 /// 3. Store the result of the TransferReadOp into the temporary buffer.
539 /// 4. Load the result from the temporary buffer and replace all uses of the
540 ///    original TransferReadOp with this load.
541 ///
542 /// E.g.:
543 /// ```
544 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
545 ///     : vector<5x4xf32>, memref<?x?x?xf32>
546 /// ```
547 /// is rewritten to:
548 /// ```
549 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
550 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
551 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
552 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
553 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
554 /// ```
555 ///
556 /// Note: A second temporary buffer may be allocated for the `mask` operand.
557 struct PrepareTransferReadConversion
558     : public VectorToSCFPattern<TransferReadOp> {
559   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
560 
561   LogicalResult matchAndRewrite(TransferReadOp xferOp,
562                                 PatternRewriter &rewriter) const override {
563     if (checkPrepareXferOp(xferOp, options).failed())
564       return failure();
565 
566     auto buffers = allocBuffers(rewriter, xferOp);
567     auto *newXfer = rewriter.clone(*xferOp.getOperation());
568     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
569     if (xferOp.mask()) {
570       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
571           buffers.maskBuffer);
572     }
573 
574     Location loc = xferOp.getLoc();
575     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
576                                      buffers.dataBuffer);
577     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
578 
579     return success();
580   }
581 };
582 
583 /// Prepare a TransferWriteOp for progressive lowering.
584 ///
585 /// 1. Allocate a temporary buffer.
586 /// 2. Store the vector into the buffer.
587 /// 3. Load the vector from the buffer again.
588 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
589 ///    marking it eligible for progressive lowering via TransferOpConversion.
590 ///
591 /// E.g.:
592 /// ```
593 /// vector.transfer_write %vec, %A[%a, %b, %c]
594 ///     : vector<5x4xf32>, memref<?x?x?xf32>
595 /// ```
596 /// is rewritten to:
597 /// ```
598 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
599 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
600 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
601 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
602 ///     : vector<5x4xf32>, memref<?x?x?xf32>
603 /// ```
604 ///
605 /// Note: A second temporary buffer may be allocated for the `mask` operand.
606 struct PrepareTransferWriteConversion
607     : public VectorToSCFPattern<TransferWriteOp> {
608   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
609 
610   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
611                                 PatternRewriter &rewriter) const override {
612     if (checkPrepareXferOp(xferOp, options).failed())
613       return failure();
614 
615     Location loc = xferOp.getLoc();
616     auto buffers = allocBuffers(rewriter, xferOp);
617     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
618     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
619     rewriter.updateRootInPlace(xferOp, [&]() {
620       xferOp.vectorMutable().assign(loadedVec);
621       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
622     });
623 
624     if (xferOp.mask()) {
625       rewriter.updateRootInPlace(
626           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
627     }
628 
629     return success();
630   }
631 };
632 
633 /// Progressive lowering of vector transfer ops: Unpack one dimension.
634 ///
635 /// 1. Unpack one dimension from the current buffer type and cast the buffer
636 ///    to that new type. E.g.:
637 ///    ```
638 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
639 ///    vector.transfer_write %vec ...
640 ///    ```
641 ///    The following cast is generated:
642 ///    ```
643 ///    %casted = vector.type_cast %0
644 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
645 ///    ```
646 /// 2. Generate a for loop and rewrite the transfer op according to the
647 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
648 ///    out-of-bounds, generate an if-check and handle both cases separately.
649 /// 3. Clean up according to the corresponding Strategy<OpTy>.
650 ///
651 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
652 /// source (as opposed to a memref source), then each iteration of the generated
653 /// scf.for loop yields the new tensor value. E.g.:
654 /// ```
655 /// %result = scf.for i = 0 to 5 {
656 ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
657 ///   %1 = vector.transfer_write %0, %source[...]
658 ///       : vector<4x3xf32>, tensor<5x4x3xf32>
659 ///   scf.yield %1 : tensor<5x4x3xf32>
660 /// }
661 /// ```
662 template <typename OpTy>
663 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
664   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
665 
666   void initialize() {
667     // This pattern recursively unpacks one dimension at a time. The recursion
668     // bounded as the rank is strictly decreasing.
669     this->setHasBoundedRewriteRecursion();
670   }
671 
672   LogicalResult matchAndRewrite(OpTy xferOp,
673                                 PatternRewriter &rewriter) const override {
674     if (!xferOp->hasAttr(kPassLabel))
675       return failure();
676 
677     // Find and cast data buffer. How the buffer can be found depends on OpTy.
678     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
679     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
680     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
681     auto castedDataType = unpackOneDim(dataBufferType);
682     auto castedDataBuffer =
683         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
684 
685     // If the xferOp has a mask: Find and cast mask buffer.
686     Value castedMaskBuffer;
687     if (xferOp.mask()) {
688       auto maskBuffer = getMaskBuffer(xferOp);
689       auto maskBufferType =
690           maskBuffer.getType().template dyn_cast<MemRefType>();
691       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
692         // Do not unpack a dimension of the mask, if:
693         // * To-be-unpacked transfer op dimension is a broadcast.
694         // * Mask is 1D, i.e., the mask cannot be further unpacked.
695         //   (That means that all remaining dimensions of the transfer op must
696         //   be broadcasted.)
697         castedMaskBuffer = maskBuffer;
698       } else {
699         auto castedMaskType = unpackOneDim(maskBufferType);
700         castedMaskBuffer =
701             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
702       }
703     }
704 
705     // Loop bounds and step.
706     auto lb = locB.create<ConstantIndexOp>(0);
707     auto ub = locB.create<ConstantIndexOp>(
708         castedDataType.getDimSize(castedDataType.getRank() - 1));
709     auto step = locB.create<ConstantIndexOp>(1);
710     // TransferWriteOps that operate on tensors return the modified tensor and
711     // require a loop state.
712     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
713 
714     // Generate for loop.
715     auto result = locB.create<scf::ForOp>(
716         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
717         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
718           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
719 
720           auto result = generateInBoundsCheck(
721               b, xferOp, iv, unpackedDim(xferOp),
722               stateType ? TypeRange(stateType) : TypeRange(),
723               /*inBoundsCase=*/
724               [&](OpBuilder &b, Location loc) {
725                 // Create new transfer op.
726                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
727                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
728 
729                 // If old transfer op has a mask: Set mask on new transfer op.
730                 // Special case: If the mask of the old transfer op is 1D and
731                 // the
732                 //               unpacked dim is not a broadcast, no mask is
733                 //               needed on the new transfer op.
734                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
735                                       xferOp.getMaskType().getRank() > 1)) {
736                   OpBuilder::InsertionGuard guard(b);
737                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
738 
739                   SmallVector<Value, 8> loadIndices;
740                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
741                   // In case of broadcast: Use same indices to load from memref
742                   // as before.
743                   if (!xferOp.isBroadcastDim(0))
744                     loadIndices.push_back(iv);
745 
746                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
747                                                        loadIndices);
748                   rewriter.updateRootInPlace(
749                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
750                 }
751 
752                 return loopState.empty() ? Value() : newXfer->getResult(0);
753               },
754               /*outOfBoundsCase=*/
755               [&](OpBuilder &b, Location /*loc*/) {
756                 return Strategy<OpTy>::handleOutOfBoundsDim(
757                     b, xferOp, castedDataBuffer, iv, loopState);
758               });
759 
760           maybeYieldValue(b, loc, !loopState.empty(), result);
761         });
762 
763     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
764     return success();
765   }
766 };
767 
768 } // namespace lowering_n_d
769 
770 namespace lowering_n_d_unrolled {
771 
772 /// If the original transfer op has a mask, compute the mask of the new transfer
773 /// op (for the current iteration `i`) and assign it.
774 template <typename OpTy>
775 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
776                             int64_t i) {
777   if (!xferOp.mask())
778     return;
779 
780   if (xferOp.isBroadcastDim(0)) {
781     // To-be-unpacked dimension is a broadcast, which does not have a
782     // corresponding mask dimension. Mask attribute remains unchanged.
783     newXferOp.maskMutable().assign(xferOp.mask());
784     return;
785   }
786 
787   if (xferOp.getMaskType().getRank() > 1) {
788     // Unpack one dimension of the mask.
789     OpBuilder::InsertionGuard guard(b);
790     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
791 
792     llvm::SmallVector<int64_t, 1> indices({i});
793     Location loc = xferOp.getLoc();
794     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
795     newXferOp.maskMutable().assign(newMask);
796   }
797 
798   // If we end up here: The mask of the old transfer op is 1D and the unpacked
799   // dim is not a broadcast, so no mask is needed on the new transfer op.
800   // `generateInBoundsCheck` will have evaluated the mask already.
801 }
802 
803 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
804 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
805 /// memref buffer is allocated and the SCF loop is fully unrolled.
806 ///
807 /// ```
808 /// E.g.:
809 /// ```
810 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
811 ///     : memref<?x?x?xf32>, vector<5x4xf32>
812 /// ```
813 /// is rewritten to IR such as (simplified):
814 /// ```
815 /// %v_init = splat %padding : vector<5x4xf32>
816 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
817 ///     : memref<?x?x?xf32>, vector<4xf32>
818 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
819 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
820 ///     : memref<?x?x?xf32>, vector<4xf32>
821 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
822 /// ...
823 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
824 ///     : memref<?x?x?xf32>, vector<4xf32>
825 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
826 /// ```
827 ///
828 /// Note: As an optimization, if the result of the original TransferReadOp
829 /// was directly inserted into another vector, no new %v_init vector is created.
830 /// Instead, the new TransferReadOp results are inserted into that vector.
831 struct UnrollTransferReadConversion
832     : public VectorToSCFPattern<TransferReadOp> {
833   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
834 
835   void initialize() {
836     // This pattern recursively unpacks one dimension at a time. The recursion
837     // bounded as the rank is strictly decreasing.
838     setHasBoundedRewriteRecursion();
839   }
840 
841   /// Return the vector into which the newly created TransferReadOp results
842   /// are inserted.
843   Value getResultVector(TransferReadOp xferOp,
844                         PatternRewriter &rewriter) const {
845     if (auto insertOp = getInsertOp(xferOp))
846       return insertOp.dest();
847     Location loc = xferOp.getLoc();
848     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
849                                     xferOp.padding());
850   }
851 
852   /// If the result of the TransferReadOp has exactly one user, which is a
853   /// vector::InsertOp, return that operation.
854   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
855     if (xferOp->hasOneUse()) {
856       Operation *xferOpUser = *xferOp->getUsers().begin();
857       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
858         return insertOp;
859     }
860 
861     return vector::InsertOp();
862   }
863 
864   /// If the result of the TransferReadOp has exactly one user, which is a
865   /// vector::InsertOp, return that operation's indices.
866   void getInsertionIndices(TransferReadOp xferOp,
867                            SmallVector<int64_t, 8> &indices) const {
868     if (auto insertOp = getInsertOp(xferOp)) {
869       llvm::for_each(insertOp.position(), [&](Attribute attr) {
870         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
871       });
872     }
873   }
874 
875   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
876   /// accesses, and broadcasts and transposes in permutation maps.
877   LogicalResult matchAndRewrite(TransferReadOp xferOp,
878                                 PatternRewriter &rewriter) const override {
879     if (xferOp.getVectorType().getRank() <= options.targetRank)
880       return failure();
881     if (isTensorOp(xferOp) && !options.lowerTensors)
882       return failure();
883     // Transfer ops that modify the element type are not supported atm.
884     if (xferOp.getVectorType().getElementType() !=
885         xferOp.getShapedType().getElementType())
886       return failure();
887 
888     auto insertOp = getInsertOp(xferOp);
889     auto vec = getResultVector(xferOp, rewriter);
890     auto vecType = vec.getType().dyn_cast<VectorType>();
891     auto xferVecType = xferOp.getVectorType();
892     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
893                                           xferVecType.getElementType());
894     int64_t dimSize = xferVecType.getShape()[0];
895 
896     // Generate fully unrolled loop of transfer ops.
897     Location loc = xferOp.getLoc();
898     for (int64_t i = 0; i < dimSize; ++i) {
899       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
900 
901       vec = generateInBoundsCheck(
902           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
903           /*inBoundsCase=*/
904           [&](OpBuilder &b, Location loc) {
905             // Indices for the new transfer op.
906             SmallVector<Value, 8> xferIndices;
907             getXferIndices(b, xferOp, iv, xferIndices);
908 
909             // Indices for the new vector.insert op.
910             SmallVector<int64_t, 8> insertionIndices;
911             getInsertionIndices(xferOp, insertionIndices);
912             insertionIndices.push_back(i);
913 
914             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
915             auto newXferOp = b.create<vector::TransferReadOp>(
916                 loc, newXferVecType, xferOp.source(), xferIndices,
917                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
918                 xferOp.padding(), Value(), inBoundsAttr);
919             maybeAssignMask(b, xferOp, newXferOp, i);
920             return b.create<vector::InsertOp>(loc, newXferOp, vec,
921                                               insertionIndices);
922           },
923           /*outOfBoundsCase=*/
924           [&](OpBuilder &b, Location loc) {
925             // Loop through original (unmodified) vector.
926             return vec;
927           });
928     }
929 
930     if (insertOp) {
931       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
932       rewriter.replaceOp(insertOp, vec);
933       rewriter.eraseOp(xferOp);
934     } else {
935       rewriter.replaceOp(xferOp, vec);
936     }
937 
938     return success();
939   }
940 };
941 
942 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
943 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
944 /// memref buffer is allocated and the SCF loop is fully unrolled.
945 ///
946 /// ```
947 /// E.g.:
948 /// ```
949 /// vector.transfer_write %vec, %A[%a, %b, %c]
950 ///     : vector<5x4xf32>, memref<?x?x?xf32>
951 /// ```
952 /// is rewritten to IR such as (simplified):
953 /// ```
954 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
955 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
956 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
957 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
958 /// ...
959 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
960 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
961 /// ```
962 ///
963 /// Note: As an optimization, if the vector of the original TransferWriteOp
964 /// was directly extracted from another vector via an ExtractOp `a`, extract
965 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
966 /// doing so, `a` may become dead, and the number of ExtractOps generated during
967 /// recursive application of this pattern will be minimal.
968 struct UnrollTransferWriteConversion
969     : public VectorToSCFPattern<TransferWriteOp> {
970   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
971 
972   void initialize() {
973     // This pattern recursively unpacks one dimension at a time. The recursion
974     // bounded as the rank is strictly decreasing.
975     setHasBoundedRewriteRecursion();
976   }
977 
978   /// Return the vector from which newly generated ExtracOps will extract.
979   Value getDataVector(TransferWriteOp xferOp) const {
980     if (auto extractOp = getExtractOp(xferOp))
981       return extractOp.vector();
982     return xferOp.vector();
983   }
984 
985   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
986   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
987     if (auto *op = xferOp.vector().getDefiningOp())
988       return dyn_cast<vector::ExtractOp>(op);
989     return vector::ExtractOp();
990   }
991 
992   /// If the input of the given TransferWriteOp is an ExtractOp, return its
993   /// indices.
994   void getExtractionIndices(TransferWriteOp xferOp,
995                             SmallVector<int64_t, 8> &indices) const {
996     if (auto extractOp = getExtractOp(xferOp)) {
997       llvm::for_each(extractOp.position(), [&](Attribute attr) {
998         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
999       });
1000     }
1001   }
1002 
1003   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1004   /// accesses, and broadcasts and transposes in permutation maps.
1005   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1006                                 PatternRewriter &rewriter) const override {
1007     if (xferOp.getVectorType().getRank() <= options.targetRank)
1008       return failure();
1009     if (isTensorOp(xferOp) && !options.lowerTensors)
1010       return failure();
1011     // Transfer ops that modify the element type are not supported atm.
1012     if (xferOp.getVectorType().getElementType() !=
1013         xferOp.getShapedType().getElementType())
1014       return failure();
1015 
1016     auto vec = getDataVector(xferOp);
1017     auto xferVecType = xferOp.getVectorType();
1018     int64_t dimSize = xferVecType.getShape()[0];
1019     auto source = xferOp.source(); // memref or tensor to be written to.
1020     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1021 
1022     // Generate fully unrolled loop of transfer ops.
1023     Location loc = xferOp.getLoc();
1024     for (int64_t i = 0; i < dimSize; ++i) {
1025       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
1026 
1027       auto updatedSource = generateInBoundsCheck(
1028           rewriter, xferOp, iv, unpackedDim(xferOp),
1029           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1030           /*inBoundsCase=*/
1031           [&](OpBuilder &b, Location loc) {
1032             // Indices for the new transfer op.
1033             SmallVector<Value, 8> xferIndices;
1034             getXferIndices(b, xferOp, iv, xferIndices);
1035 
1036             // Indices for the new vector.extract op.
1037             SmallVector<int64_t, 8> extractionIndices;
1038             getExtractionIndices(xferOp, extractionIndices);
1039             extractionIndices.push_back(i);
1040 
1041             auto extracted =
1042                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1043             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
1044             auto newXferOp = b.create<vector::TransferWriteOp>(
1045                 loc, sourceType, extracted, source, xferIndices,
1046                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1047                 inBoundsAttr);
1048 
1049             maybeAssignMask(b, xferOp, newXferOp, i);
1050 
1051             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1052           },
1053           /*outOfBoundsCase=*/
1054           [&](OpBuilder &b, Location loc) {
1055             return isTensorOp(xferOp) ? source : Value();
1056           });
1057 
1058       if (isTensorOp(xferOp))
1059         source = updatedSource;
1060     }
1061 
1062     if (isTensorOp(xferOp))
1063       rewriter.replaceOp(xferOp, source);
1064     else
1065       rewriter.eraseOp(xferOp);
1066 
1067     return success();
1068   }
1069 };
1070 
1071 } // namespace lowering_n_d_unrolled
1072 
1073 namespace lowering_1_d {
1074 
1075 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1076 /// part of TransferOp1dConversion. Return the memref dimension on which
1077 /// the transfer is operating. A return value of None indicates a broadcast.
1078 template <typename OpTy>
1079 static Optional<int64_t>
1080 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1081                    SmallVector<Value, 8> &memrefIndices) {
1082   auto indices = xferOp.indices();
1083   auto map = xferOp.permutation_map();
1084 
1085   memrefIndices.append(indices.begin(), indices.end());
1086   assert(map.getNumResults() == 1 &&
1087          "Expected 1 permutation map result for 1D transfer");
1088   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1089     Location loc = xferOp.getLoc();
1090     auto dim = expr.getPosition();
1091     AffineExpr d0, d1;
1092     bindDims(xferOp.getContext(), d0, d1);
1093     Value offset = memrefIndices[dim];
1094     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1095     return dim;
1096   }
1097 
1098   assert(xferOp.isBroadcastDim(0) &&
1099          "Expected AffineDimExpr or AffineConstantExpr");
1100   return None;
1101 }
1102 
1103 /// Codegen strategy for TransferOp1dConversion, depending on the
1104 /// operation.
1105 template <typename OpTy>
1106 struct Strategy1d;
1107 
1108 /// Codegen strategy for TransferReadOp.
1109 template <>
1110 struct Strategy1d<TransferReadOp> {
1111   static void generateForLoopBody(OpBuilder &b, Location loc,
1112                                   TransferReadOp xferOp, Value iv,
1113                                   ValueRange loopState) {
1114     SmallVector<Value, 8> indices;
1115     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1116     Value ivI32 =
1117         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1118     auto vec = loopState[0];
1119 
1120     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1121     // padding value).
1122     auto nextVec = generateInBoundsCheck(
1123         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1124         /*inBoundsCase=*/
1125         [&](OpBuilder &b, Location loc) {
1126           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
1127           return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
1128         },
1129         /*outOfBoundsCase=*/
1130         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1131     b.create<scf::YieldOp>(loc, nextVec);
1132   }
1133 
1134   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1135     // Inititalize vector with padding value.
1136     Location loc = xferOp.getLoc();
1137     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
1138   }
1139 };
1140 
1141 /// Codegen strategy for TransferWriteOp.
1142 template <>
1143 struct Strategy1d<TransferWriteOp> {
1144   static void generateForLoopBody(OpBuilder &b, Location loc,
1145                                   TransferWriteOp xferOp, Value iv,
1146                                   ValueRange /*loopState*/) {
1147     SmallVector<Value, 8> indices;
1148     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1149     Value ivI32 =
1150         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1151 
1152     // Nothing to do in case of out-of-bounds access.
1153     generateInBoundsCheck(
1154         b, xferOp, iv, dim,
1155         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1156           auto val =
1157               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
1158           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
1159         });
1160     b.create<scf::YieldOp>(loc);
1161   }
1162 
1163   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1164     return Value();
1165   }
1166 };
1167 
1168 /// Return true if the last dimension of the MemRefType has unit stride.
1169 static bool isLastMemrefDimUnitStride(MemRefType type) {
1170   int64_t offset;
1171   SmallVector<int64_t, 4> strides;
1172   auto successStrides = getStridesAndOffset(type, strides, offset);
1173   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1174 }
1175 
1176 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1177 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1178 /// vector load/stores due to non-unit strides or broadcasts:
1179 ///
1180 /// * Transfer dimension is not the last memref dimension
1181 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1182 /// * Memref has a layout map with non-unit stride on the last dimension
1183 ///
1184 /// This pattern generates IR as follows:
1185 ///
1186 /// 1. Generate a for loop iterating over each vector element.
1187 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1188 ///    depending on OpTy.
1189 ///
1190 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1191 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1192 ///       to ConvertVectorToLLVM.
1193 ///
1194 /// E.g.:
1195 /// ```
1196 /// vector.transfer_write %vec, %A[%a, %b]
1197 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1198 ///    : vector<9xf32>, memref<?x?xf32>
1199 /// ```
1200 /// Is rewritten to approximately the following pseudo-IR:
1201 /// ```
1202 /// for i = 0 to 9 {
1203 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1204 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1205 /// }
1206 /// ```
1207 template <typename OpTy>
1208 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1209   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1210 
1211   LogicalResult matchAndRewrite(OpTy xferOp,
1212                                 PatternRewriter &rewriter) const override {
1213     auto map = xferOp.permutation_map();
1214     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1215 
1216     if (!memRefType)
1217       return failure();
1218     if (xferOp.getVectorType().getRank() != 1)
1219       return failure();
1220     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1221       return failure(); // Handled by ConvertVectorToLLVM
1222 
1223     // Loop bounds, step, state...
1224     Location loc = xferOp.getLoc();
1225     auto vecType = xferOp.getVectorType();
1226     auto lb = rewriter.create<ConstantIndexOp>(loc, 0);
1227     auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0));
1228     auto step = rewriter.create<ConstantIndexOp>(loc, 1);
1229     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1230 
1231     // Generate for loop.
1232     rewriter.replaceOpWithNewOp<scf::ForOp>(
1233         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1234         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1235           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1236         });
1237 
1238     return success();
1239   }
1240 };
1241 
1242 } // namespace lowering_1_d
1243 } // namespace
1244 
1245 namespace mlir {
1246 
1247 void populateVectorToSCFConversionPatterns(
1248     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1249   if (options.unroll) {
1250     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1251                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1252         patterns.getContext(), options);
1253   } else {
1254     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1255                  lowering_n_d::PrepareTransferWriteConversion,
1256                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1257                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1258         patterns.getContext(), options);
1259   }
1260 
1261   if (options.targetRank == 1) {
1262     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1263                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1264         patterns.getContext(), options);
1265   }
1266 }
1267 
1268 } // namespace mlir
1269 
1270 namespace {
1271 
1272 struct ConvertVectorToSCFPass
1273     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1274   ConvertVectorToSCFPass() = default;
1275   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1276     this->fullUnroll = options.unroll;
1277     this->targetRank = options.targetRank;
1278     this->lowerPermutationMaps = options.lowerPermutationMaps;
1279     this->lowerTensors = options.lowerTensors;
1280   }
1281 
1282   void runOnFunction() override {
1283     VectorTransferToSCFOptions options;
1284     options.unroll = fullUnroll;
1285     options.targetRank = targetRank;
1286     options.lowerPermutationMaps = lowerPermutationMaps;
1287     options.lowerTensors = lowerTensors;
1288 
1289     // Lower permutation maps first.
1290     if (lowerPermutationMaps) {
1291       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1292       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1293           lowerTransferPatterns);
1294       (void)applyPatternsAndFoldGreedily(getFunction(),
1295                                          std::move(lowerTransferPatterns));
1296     }
1297 
1298     RewritePatternSet patterns(getFunction().getContext());
1299     populateVectorToSCFConversionPatterns(patterns, options);
1300     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
1301   }
1302 };
1303 
1304 } // namespace
1305 
1306 std::unique_ptr<Pass>
1307 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1308   return std::make_unique<ConvertVectorToSCFPass>(options);
1309 }
1310