1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/Vector/VectorTransforms.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/Passes.h"
29 
30 using namespace mlir;
31 using vector::TransferReadOp;
32 using vector::TransferWriteOp;
33 
34 namespace {
35 
36 /// Attribute name used for labeling transfer ops during progressive lowering.
37 static const char kPassLabel[] = "__vector_to_scf_lowering__";
38 
39 /// Patterns that inherit from this struct have access to
40 /// VectorTransferToSCFOptions.
41 template <typename OpTy>
42 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
43   explicit VectorToSCFPattern(MLIRContext *context,
44                               VectorTransferToSCFOptions opt)
45       : OpRewritePattern<OpTy>(context), options(opt) {}
46 
47   VectorTransferToSCFOptions options;
48 };
49 
50 /// Given a vector transfer op, calculate which dimension of the `source`
51 /// memref should be unpacked in the next application of TransferOpConversion.
52 /// A return value of None indicates a broadcast.
53 template <typename OpTy>
54 static Optional<int64_t> unpackedDim(OpTy xferOp) {
55   auto map = xferOp.permutation_map();
56   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
57     return expr.getPosition();
58   }
59   assert(xferOp.isBroadcastDim(0) &&
60          "Expected AffineDimExpr or AffineConstantExpr");
61   return None;
62 }
63 
64 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
65 /// map is identical to the current permutation map, but the first result is
66 /// omitted.
67 template <typename OpTy>
68 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
69   auto map = xferOp.permutation_map();
70   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
71                         b.getContext());
72 }
73 
74 /// Calculate the indices for the new vector transfer op.
75 ///
76 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
77 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
78 ///                                 ^^^^^^
79 ///              `iv` is the iteration variable of the (new) surrounding loop.
80 template <typename OpTy>
81 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
82                            SmallVector<Value, 8> &indices) {
83   typename OpTy::Adaptor adaptor(xferOp);
84   // Corresponding memref dim of the vector dim that is unpacked.
85   auto dim = unpackedDim(xferOp);
86   auto prevIndices = adaptor.indices();
87   indices.append(prevIndices.begin(), prevIndices.end());
88 
89   Location loc = xferOp.getLoc();
90   bool isBroadcast = !dim.hasValue();
91   if (!isBroadcast) {
92     AffineExpr d0, d1;
93     bindDims(xferOp.getContext(), d0, d1);
94     Value offset = adaptor.indices()[dim.getValue()];
95     indices[dim.getValue()] =
96         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
97   }
98 }
99 
100 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
101                             Value value) {
102   if (hasRetVal) {
103     assert(value && "Expected non-empty value");
104     b.create<scf::YieldOp>(loc, value);
105   } else {
106     b.create<scf::YieldOp>(loc);
107   }
108 }
109 
110 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
111 /// is set to true. No such check is generated under following circumstances:
112 /// * xferOp does not have a mask.
113 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
114 ///   computed and attached to the new transfer op in the pattern.)
115 /// * The to-be-unpacked dim of xferOp is a broadcast.
116 template <typename OpTy>
117 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
118   if (!xferOp.mask())
119     return Value();
120   if (xferOp.getMaskType().getRank() != 1)
121     return Value();
122   if (xferOp.isBroadcastDim(0))
123     return Value();
124 
125   Location loc = xferOp.getLoc();
126   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), iv);
127 }
128 
129 /// Helper function TransferOpConversion and TransferOp1dConversion.
130 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
131 /// specified dimension `dim` with the loop iteration variable `iv`.
132 /// E.g., when unpacking dimension 0 from:
133 /// ```
134 /// %vec = vector.transfer_read %A[%a, %b] %cst
135 ///     : vector<5x4xf32>, memref<?x?xf32>
136 /// ```
137 /// An if check similar to this will be generated inside the loop:
138 /// ```
139 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
140 /// if (%a + iv < %d) {
141 ///   (in-bounds case)
142 /// } else {
143 ///   (out-of-bounds case)
144 /// }
145 /// ```
146 ///
147 /// If the transfer is 1D and has a mask, this function generates a more complex
148 /// check also accounts for potentially masked out elements.
149 ///
150 /// This function variant returns the value returned by `inBoundsCase` or
151 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
152 /// `resultTypes`.
153 template <typename OpTy>
154 static Value generateInBoundsCheck(
155     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
156     TypeRange resultTypes,
157     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
158     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
159   bool hasRetVal = !resultTypes.empty();
160   Value cond; // Condition to be built...
161 
162   // Condition check 1: Access in-bounds?
163   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
164   Location loc = xferOp.getLoc();
165   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
166   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
167     Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim);
168     AffineExpr d0, d1;
169     bindDims(xferOp.getContext(), d0, d1);
170     Value base = xferOp.indices()[dim.getValue()];
171     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
172     cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
173                                     memrefIdx);
174   }
175 
176   // Condition check 2: Masked in?
177   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
178     if (cond)
179       cond = lb.create<arith::AndIOp>(cond, maskCond);
180     else
181       cond = maskCond;
182   }
183 
184   // If the condition is non-empty, generate an SCF::IfOp.
185   if (cond) {
186     auto check = lb.create<scf::IfOp>(
187         resultTypes, cond,
188         /*thenBuilder=*/
189         [&](OpBuilder &b, Location loc) {
190           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
191         },
192         /*elseBuilder=*/
193         [&](OpBuilder &b, Location loc) {
194           if (outOfBoundsCase) {
195             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
196           } else {
197             b.create<scf::YieldOp>(loc);
198           }
199         });
200 
201     return hasRetVal ? check.getResult(0) : Value();
202   }
203 
204   // Condition is empty, no need for an SCF::IfOp.
205   return inBoundsCase(b, loc);
206 }
207 
208 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
209 /// a return value. Consequently, this function does not have a return value.
210 template <typename OpTy>
211 static void generateInBoundsCheck(
212     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
213     function_ref<void(OpBuilder &, Location)> inBoundsCase,
214     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
215   generateInBoundsCheck(
216       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
217       /*inBoundsCase=*/
218       [&](OpBuilder &b, Location loc) {
219         inBoundsCase(b, loc);
220         return Value();
221       },
222       /*outOfBoundsCase=*/
223       [&](OpBuilder &b, Location loc) {
224         if (outOfBoundsCase)
225           outOfBoundsCase(b, loc);
226         return Value();
227       });
228 }
229 
230 /// Given an ArrayAttr, return a copy where the first element is dropped.
231 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
232   if (!attr)
233     return attr;
234   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
235 }
236 
237 /// Add the pass label to a vector transfer op if its rank is not the target
238 /// rank.
239 template <typename OpTy>
240 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
241                                 unsigned targetRank) {
242   if (newXferOp.getVectorType().getRank() > targetRank)
243     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
244 }
245 
246 /// Return true if this transfer op operates on a source tensor.
247 template <typename OpTy>
248 static bool isTensorOp(OpTy xferOp) {
249   if (xferOp.getShapedType().template isa<RankedTensorType>()) {
250     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
251       // TransferWriteOps on tensors have a result.
252       assert(xferOp->getNumResults() > 0);
253     }
254     return true;
255   }
256   return false;
257 }
258 
259 namespace lowering_n_d {
260 
261 /// Helper data structure for data and mask buffers.
262 struct BufferAllocs {
263   Value dataBuffer;
264   Value maskBuffer;
265 };
266 
267 /// Allocate temporary buffers for data (vector) and mask (if present).
268 /// TODO: Parallelism and threadlocal considerations.
269 template <typename OpTy>
270 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
271   Location loc = xferOp.getLoc();
272   OpBuilder::InsertionGuard guard(b);
273   Operation *scope =
274       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
275   assert(scope && "Expected op to be inside automatic allocation scope");
276   b.setInsertionPointToStart(&scope->getRegion(0).front());
277 
278   BufferAllocs result;
279   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
280   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
281 
282   if (xferOp.mask()) {
283     auto maskType = MemRefType::get({}, xferOp.mask().getType());
284     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
285     b.setInsertionPoint(xferOp);
286     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
287     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
288   }
289 
290   return result;
291 }
292 
293 /// Given a MemRefType with VectorType element type, unpack one dimension from
294 /// the VectorType into the MemRefType.
295 ///
296 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
297 static MemRefType unpackOneDim(MemRefType type) {
298   auto vectorType = type.getElementType().dyn_cast<VectorType>();
299   auto memrefShape = type.getShape();
300   SmallVector<int64_t, 8> newMemrefShape;
301   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
302   newMemrefShape.push_back(vectorType.getDimSize(0));
303   return MemRefType::get(newMemrefShape,
304                          VectorType::get(vectorType.getShape().drop_front(),
305                                          vectorType.getElementType()));
306 }
307 
308 /// Given a transfer op, find the memref from which the mask is loaded. This
309 /// is similar to Strategy<TransferWriteOp>::getBuffer.
310 template <typename OpTy>
311 static Value getMaskBuffer(OpTy xferOp) {
312   assert(xferOp.mask() && "Expected that transfer op has mask");
313   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
314   assert(loadOp && "Expected transfer op mask produced by LoadOp");
315   return loadOp.getMemRef();
316 }
317 
318 /// Codegen strategy, depending on the operation.
319 template <typename OpTy>
320 struct Strategy;
321 
322 /// Code strategy for vector TransferReadOp.
323 template <>
324 struct Strategy<TransferReadOp> {
325   /// Find the StoreOp that is used for writing the current TransferReadOp's
326   /// result to the temporary buffer allocation.
327   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
328     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
329     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
330     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
331     return storeOp;
332   }
333 
334   /// Find the temporary buffer allocation. All labeled TransferReadOps are
335   /// used like this, where %buf is either the buffer allocation or a type cast
336   /// of the buffer allocation:
337   /// ```
338   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
339   /// memref.store %vec, %buf[...] ...
340   /// ```
341   static Value getBuffer(TransferReadOp xferOp) {
342     return getStoreOp(xferOp).getMemRef();
343   }
344 
345   /// Retrieve the indices of the current StoreOp that stores into the buffer.
346   static void getBufferIndices(TransferReadOp xferOp,
347                                SmallVector<Value, 8> &indices) {
348     auto storeOp = getStoreOp(xferOp);
349     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
350     indices.append(prevIndices.begin(), prevIndices.end());
351   }
352 
353   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
354   /// accesses on the to-be-unpacked dimension.
355   ///
356   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
357   ///    variable `iv`.
358   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
359   ///
360   /// E.g.:
361   /// ```
362   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
363   ///     : memref<?x?x?xf32>, vector<4x3xf32>
364   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
365   /// ```
366   /// Is rewritten to:
367   /// ```
368   /// %casted = vector.type_cast %buf
369   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
370   /// for %j = 0 to 4 {
371   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
372   ///       : memref<?x?x?xf32>, vector<3xf32>
373   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
374   /// }
375   /// ```
376   ///
377   /// Note: The loop and type cast are generated in TransferOpConversion.
378   ///       The original TransferReadOp and store op are deleted in `cleanup`.
379   /// Note: The `mask` operand is set in TransferOpConversion.
380   static TransferReadOp rewriteOp(OpBuilder &b,
381                                   VectorTransferToSCFOptions options,
382                                   TransferReadOp xferOp, Value buffer, Value iv,
383                                   ValueRange /*loopState*/) {
384     SmallVector<Value, 8> storeIndices;
385     getBufferIndices(xferOp, storeIndices);
386     storeIndices.push_back(iv);
387 
388     SmallVector<Value, 8> xferIndices;
389     getXferIndices(b, xferOp, iv, xferIndices);
390 
391     Location loc = xferOp.getLoc();
392     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
393     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
394     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
395     auto newXferOp = b.create<vector::TransferReadOp>(
396         loc, vecType, xferOp.source(), xferIndices,
397         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
398         Value(), inBoundsAttr);
399 
400     maybeApplyPassLabel(b, newXferOp, options.targetRank);
401 
402     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
403     return newXferOp;
404   }
405 
406   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
407   /// padding value to the temporary buffer.
408   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
409                                     Value buffer, Value iv,
410                                     ValueRange /*loopState*/) {
411     SmallVector<Value, 8> storeIndices;
412     getBufferIndices(xferOp, storeIndices);
413     storeIndices.push_back(iv);
414 
415     Location loc = xferOp.getLoc();
416     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
417     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
418     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
419     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
420 
421     return Value();
422   }
423 
424   /// Cleanup after rewriting the op.
425   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
426                       scf::ForOp /*forOp*/) {
427     rewriter.eraseOp(getStoreOp(xferOp));
428     rewriter.eraseOp(xferOp);
429   }
430 
431   /// Return the initial loop state for the generated scf.for loop.
432   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
433 };
434 
435 /// Codegen strategy for vector TransferWriteOp.
436 template <>
437 struct Strategy<TransferWriteOp> {
438   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
439   /// used like this, where %buf is either the buffer allocation or a type cast
440   /// of the buffer allocation:
441   /// ```
442   /// %vec = memref.load %buf[...] ...
443   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
444   /// ```
445   static Value getBuffer(TransferWriteOp xferOp) {
446     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
447     assert(loadOp && "Expected transfer op vector produced by LoadOp");
448     return loadOp.getMemRef();
449   }
450 
451   /// Retrieve the indices of the current LoadOp that loads from the buffer.
452   static void getBufferIndices(TransferWriteOp xferOp,
453                                SmallVector<Value, 8> &indices) {
454     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
455     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
456     indices.append(prevIndices.begin(), prevIndices.end());
457   }
458 
459   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
460   /// accesses on the to-be-unpacked dimension.
461   ///
462   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
463   ///    using the loop iteration variable `iv`.
464   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
465   ///    to memory.
466   ///
467   /// Note: For more details, see comments on Strategy<TransferReadOp>.
468   static TransferWriteOp rewriteOp(OpBuilder &b,
469                                    VectorTransferToSCFOptions options,
470                                    TransferWriteOp xferOp, Value buffer,
471                                    Value iv, ValueRange loopState) {
472     SmallVector<Value, 8> loadIndices;
473     getBufferIndices(xferOp, loadIndices);
474     loadIndices.push_back(iv);
475 
476     SmallVector<Value, 8> xferIndices;
477     getXferIndices(b, xferOp, iv, xferIndices);
478 
479     Location loc = xferOp.getLoc();
480     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
481     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
482     auto source = loopState.empty() ? xferOp.source() : loopState[0];
483     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
484     auto newXferOp = b.create<vector::TransferWriteOp>(
485         loc, type, vec, source, xferIndices,
486         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
487         inBoundsAttr);
488 
489     maybeApplyPassLabel(b, newXferOp, options.targetRank);
490 
491     return newXferOp;
492   }
493 
494   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
495   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
496                                     Value buffer, Value iv,
497                                     ValueRange loopState) {
498     return isTensorOp(xferOp) ? loopState[0] : Value();
499   }
500 
501   /// Cleanup after rewriting the op.
502   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
503                       scf::ForOp forOp) {
504     if (isTensorOp(xferOp)) {
505       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
506       rewriter.replaceOp(xferOp, forOp->getResult(0));
507     } else {
508       rewriter.eraseOp(xferOp);
509     }
510   }
511 
512   /// Return the initial loop state for the generated scf.for loop.
513   static Value initialLoopState(TransferWriteOp xferOp) {
514     return isTensorOp(xferOp) ? xferOp.source() : Value();
515   }
516 };
517 
518 template <typename OpTy>
519 LogicalResult checkPrepareXferOp(OpTy xferOp,
520                                  VectorTransferToSCFOptions options) {
521   if (xferOp->hasAttr(kPassLabel))
522     return failure();
523   if (xferOp.getVectorType().getRank() <= options.targetRank)
524     return failure();
525   if (isTensorOp(xferOp) && !options.lowerTensors)
526     return failure();
527   // Transfer ops that modify the element type are not supported atm.
528   if (xferOp.getVectorType().getElementType() !=
529       xferOp.getShapedType().getElementType())
530     return failure();
531   return success();
532 }
533 
534 /// Prepare a TransferReadOp for progressive lowering.
535 ///
536 /// 1. Allocate a temporary buffer.
537 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
538 /// 3. Store the result of the TransferReadOp into the temporary buffer.
539 /// 4. Load the result from the temporary buffer and replace all uses of the
540 ///    original TransferReadOp with this load.
541 ///
542 /// E.g.:
543 /// ```
544 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
545 ///     : vector<5x4xf32>, memref<?x?x?xf32>
546 /// ```
547 /// is rewritten to:
548 /// ```
549 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
550 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
551 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
552 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
553 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
554 /// ```
555 ///
556 /// Note: A second temporary buffer may be allocated for the `mask` operand.
557 struct PrepareTransferReadConversion
558     : public VectorToSCFPattern<TransferReadOp> {
559   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
560 
561   LogicalResult matchAndRewrite(TransferReadOp xferOp,
562                                 PatternRewriter &rewriter) const override {
563     if (checkPrepareXferOp(xferOp, options).failed())
564       return failure();
565 
566     auto buffers = allocBuffers(rewriter, xferOp);
567     auto *newXfer = rewriter.clone(*xferOp.getOperation());
568     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
569     if (xferOp.mask()) {
570       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
571           buffers.maskBuffer);
572     }
573 
574     Location loc = xferOp.getLoc();
575     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
576                                      buffers.dataBuffer);
577     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
578 
579     return success();
580   }
581 };
582 
583 /// Prepare a TransferWriteOp for progressive lowering.
584 ///
585 /// 1. Allocate a temporary buffer.
586 /// 2. Store the vector into the buffer.
587 /// 3. Load the vector from the buffer again.
588 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
589 ///    marking it eligible for progressive lowering via TransferOpConversion.
590 ///
591 /// E.g.:
592 /// ```
593 /// vector.transfer_write %vec, %A[%a, %b, %c]
594 ///     : vector<5x4xf32>, memref<?x?x?xf32>
595 /// ```
596 /// is rewritten to:
597 /// ```
598 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
599 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
600 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
601 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
602 ///     : vector<5x4xf32>, memref<?x?x?xf32>
603 /// ```
604 ///
605 /// Note: A second temporary buffer may be allocated for the `mask` operand.
606 struct PrepareTransferWriteConversion
607     : public VectorToSCFPattern<TransferWriteOp> {
608   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
609 
610   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
611                                 PatternRewriter &rewriter) const override {
612     if (checkPrepareXferOp(xferOp, options).failed())
613       return failure();
614 
615     Location loc = xferOp.getLoc();
616     auto buffers = allocBuffers(rewriter, xferOp);
617     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
618     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
619     rewriter.updateRootInPlace(xferOp, [&]() {
620       xferOp.vectorMutable().assign(loadedVec);
621       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
622     });
623 
624     if (xferOp.mask()) {
625       rewriter.updateRootInPlace(
626           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
627     }
628 
629     return success();
630   }
631 };
632 
633 /// Progressive lowering of vector transfer ops: Unpack one dimension.
634 ///
635 /// 1. Unpack one dimension from the current buffer type and cast the buffer
636 ///    to that new type. E.g.:
637 ///    ```
638 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
639 ///    vector.transfer_write %vec ...
640 ///    ```
641 ///    The following cast is generated:
642 ///    ```
643 ///    %casted = vector.type_cast %0
644 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
645 ///    ```
646 /// 2. Generate a for loop and rewrite the transfer op according to the
647 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
648 ///    out-of-bounds, generate an if-check and handle both cases separately.
649 /// 3. Clean up according to the corresponding Strategy<OpTy>.
650 ///
651 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
652 /// source (as opposed to a memref source), then each iteration of the generated
653 /// scf.for loop yields the new tensor value. E.g.:
654 /// ```
655 /// %result = scf.for i = 0 to 5 {
656 ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
657 ///   %1 = vector.transfer_write %0, %source[...]
658 ///       : vector<4x3xf32>, tensor<5x4x3xf32>
659 ///   scf.yield %1 : tensor<5x4x3xf32>
660 /// }
661 /// ```
662 template <typename OpTy>
663 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
664   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
665 
666   void initialize() {
667     // This pattern recursively unpacks one dimension at a time. The recursion
668     // bounded as the rank is strictly decreasing.
669     this->setHasBoundedRewriteRecursion();
670   }
671 
672   LogicalResult matchAndRewrite(OpTy xferOp,
673                                 PatternRewriter &rewriter) const override {
674     if (!xferOp->hasAttr(kPassLabel))
675       return failure();
676 
677     // Find and cast data buffer. How the buffer can be found depends on OpTy.
678     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
679     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
680     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
681     auto castedDataType = unpackOneDim(dataBufferType);
682     auto castedDataBuffer =
683         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
684 
685     // If the xferOp has a mask: Find and cast mask buffer.
686     Value castedMaskBuffer;
687     if (xferOp.mask()) {
688       auto maskBuffer = getMaskBuffer(xferOp);
689       auto maskBufferType =
690           maskBuffer.getType().template dyn_cast<MemRefType>();
691       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
692         // Do not unpack a dimension of the mask, if:
693         // * To-be-unpacked transfer op dimension is a broadcast.
694         // * Mask is 1D, i.e., the mask cannot be further unpacked.
695         //   (That means that all remaining dimensions of the transfer op must
696         //   be broadcasted.)
697         castedMaskBuffer = maskBuffer;
698       } else {
699         auto castedMaskType = unpackOneDim(maskBufferType);
700         castedMaskBuffer =
701             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
702       }
703     }
704 
705     // Loop bounds and step.
706     auto lb = locB.create<arith::ConstantIndexOp>(0);
707     auto ub = locB.create<arith::ConstantIndexOp>(
708         castedDataType.getDimSize(castedDataType.getRank() - 1));
709     auto step = locB.create<arith::ConstantIndexOp>(1);
710     // TransferWriteOps that operate on tensors return the modified tensor and
711     // require a loop state.
712     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
713 
714     // Generate for loop.
715     auto result = locB.create<scf::ForOp>(
716         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
717         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
718           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
719 
720           auto result = generateInBoundsCheck(
721               b, xferOp, iv, unpackedDim(xferOp),
722               stateType ? TypeRange(stateType) : TypeRange(),
723               /*inBoundsCase=*/
724               [&](OpBuilder &b, Location loc) {
725                 // Create new transfer op.
726                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
727                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
728 
729                 // If old transfer op has a mask: Set mask on new transfer op.
730                 // Special case: If the mask of the old transfer op is 1D and
731                 // the
732                 //               unpacked dim is not a broadcast, no mask is
733                 //               needed on the new transfer op.
734                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
735                                       xferOp.getMaskType().getRank() > 1)) {
736                   OpBuilder::InsertionGuard guard(b);
737                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
738 
739                   SmallVector<Value, 8> loadIndices;
740                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
741                   // In case of broadcast: Use same indices to load from memref
742                   // as before.
743                   if (!xferOp.isBroadcastDim(0))
744                     loadIndices.push_back(iv);
745 
746                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
747                                                        loadIndices);
748                   rewriter.updateRootInPlace(
749                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
750                 }
751 
752                 return loopState.empty() ? Value() : newXfer->getResult(0);
753               },
754               /*outOfBoundsCase=*/
755               [&](OpBuilder &b, Location /*loc*/) {
756                 return Strategy<OpTy>::handleOutOfBoundsDim(
757                     b, xferOp, castedDataBuffer, iv, loopState);
758               });
759 
760           maybeYieldValue(b, loc, !loopState.empty(), result);
761         });
762 
763     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
764     return success();
765   }
766 };
767 
768 } // namespace lowering_n_d
769 
770 namespace lowering_n_d_unrolled {
771 
772 /// If the original transfer op has a mask, compute the mask of the new transfer
773 /// op (for the current iteration `i`) and assign it.
774 template <typename OpTy>
775 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
776                             int64_t i) {
777   if (!xferOp.mask())
778     return;
779 
780   if (xferOp.isBroadcastDim(0)) {
781     // To-be-unpacked dimension is a broadcast, which does not have a
782     // corresponding mask dimension. Mask attribute remains unchanged.
783     newXferOp.maskMutable().assign(xferOp.mask());
784     return;
785   }
786 
787   if (xferOp.getMaskType().getRank() > 1) {
788     // Unpack one dimension of the mask.
789     OpBuilder::InsertionGuard guard(b);
790     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
791 
792     llvm::SmallVector<int64_t, 1> indices({i});
793     Location loc = xferOp.getLoc();
794     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
795     newXferOp.maskMutable().assign(newMask);
796   }
797 
798   // If we end up here: The mask of the old transfer op is 1D and the unpacked
799   // dim is not a broadcast, so no mask is needed on the new transfer op.
800   // `generateInBoundsCheck` will have evaluated the mask already.
801 }
802 
803 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
804 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
805 /// memref buffer is allocated and the SCF loop is fully unrolled.
806 ///
807 /// ```
808 /// E.g.:
809 /// ```
810 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
811 ///     : memref<?x?x?xf32>, vector<5x4xf32>
812 /// ```
813 /// is rewritten to IR such as (simplified):
814 /// ```
815 /// %v_init = splat %padding : vector<5x4xf32>
816 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
817 ///     : memref<?x?x?xf32>, vector<4xf32>
818 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
819 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
820 ///     : memref<?x?x?xf32>, vector<4xf32>
821 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
822 /// ...
823 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
824 ///     : memref<?x?x?xf32>, vector<4xf32>
825 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
826 /// ```
827 ///
828 /// Note: As an optimization, if the result of the original TransferReadOp
829 /// was directly inserted into another vector, no new %v_init vector is created.
830 /// Instead, the new TransferReadOp results are inserted into that vector.
831 struct UnrollTransferReadConversion
832     : public VectorToSCFPattern<TransferReadOp> {
833   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
834 
835   void initialize() {
836     // This pattern recursively unpacks one dimension at a time. The recursion
837     // bounded as the rank is strictly decreasing.
838     setHasBoundedRewriteRecursion();
839   }
840 
841   /// Return the vector into which the newly created TransferReadOp results
842   /// are inserted.
843   Value getResultVector(TransferReadOp xferOp,
844                         PatternRewriter &rewriter) const {
845     if (auto insertOp = getInsertOp(xferOp))
846       return insertOp.dest();
847     Location loc = xferOp.getLoc();
848     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
849                                     xferOp.padding());
850   }
851 
852   /// If the result of the TransferReadOp has exactly one user, which is a
853   /// vector::InsertOp, return that operation.
854   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
855     if (xferOp->hasOneUse()) {
856       Operation *xferOpUser = *xferOp->getUsers().begin();
857       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
858         return insertOp;
859     }
860 
861     return vector::InsertOp();
862   }
863 
864   /// If the result of the TransferReadOp has exactly one user, which is a
865   /// vector::InsertOp, return that operation's indices.
866   void getInsertionIndices(TransferReadOp xferOp,
867                            SmallVector<int64_t, 8> &indices) const {
868     if (auto insertOp = getInsertOp(xferOp)) {
869       llvm::for_each(insertOp.position(), [&](Attribute attr) {
870         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
871       });
872     }
873   }
874 
875   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
876   /// accesses, and broadcasts and transposes in permutation maps.
877   LogicalResult matchAndRewrite(TransferReadOp xferOp,
878                                 PatternRewriter &rewriter) const override {
879     if (xferOp.getVectorType().getRank() <= options.targetRank)
880       return failure();
881     if (isTensorOp(xferOp) && !options.lowerTensors)
882       return failure();
883     // Transfer ops that modify the element type are not supported atm.
884     if (xferOp.getVectorType().getElementType() !=
885         xferOp.getShapedType().getElementType())
886       return failure();
887 
888     auto insertOp = getInsertOp(xferOp);
889     auto vec = getResultVector(xferOp, rewriter);
890     auto vecType = vec.getType().dyn_cast<VectorType>();
891     auto xferVecType = xferOp.getVectorType();
892     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
893                                           xferVecType.getElementType());
894     int64_t dimSize = xferVecType.getShape()[0];
895 
896     // Generate fully unrolled loop of transfer ops.
897     Location loc = xferOp.getLoc();
898     for (int64_t i = 0; i < dimSize; ++i) {
899       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
900 
901       vec = generateInBoundsCheck(
902           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
903           /*inBoundsCase=*/
904           [&](OpBuilder &b, Location loc) {
905             // Indices for the new transfer op.
906             SmallVector<Value, 8> xferIndices;
907             getXferIndices(b, xferOp, iv, xferIndices);
908 
909             // Indices for the new vector.insert op.
910             SmallVector<int64_t, 8> insertionIndices;
911             getInsertionIndices(xferOp, insertionIndices);
912             insertionIndices.push_back(i);
913 
914             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
915             auto newXferOp = b.create<vector::TransferReadOp>(
916                 loc, newXferVecType, xferOp.source(), xferIndices,
917                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
918                 xferOp.padding(), Value(), inBoundsAttr);
919             maybeAssignMask(b, xferOp, newXferOp, i);
920             return b.create<vector::InsertOp>(loc, newXferOp, vec,
921                                               insertionIndices);
922           },
923           /*outOfBoundsCase=*/
924           [&](OpBuilder &b, Location loc) {
925             // Loop through original (unmodified) vector.
926             return vec;
927           });
928     }
929 
930     if (insertOp) {
931       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
932       rewriter.replaceOp(insertOp, vec);
933       rewriter.eraseOp(xferOp);
934     } else {
935       rewriter.replaceOp(xferOp, vec);
936     }
937 
938     return success();
939   }
940 };
941 
942 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
943 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
944 /// memref buffer is allocated and the SCF loop is fully unrolled.
945 ///
946 /// ```
947 /// E.g.:
948 /// ```
949 /// vector.transfer_write %vec, %A[%a, %b, %c]
950 ///     : vector<5x4xf32>, memref<?x?x?xf32>
951 /// ```
952 /// is rewritten to IR such as (simplified):
953 /// ```
954 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
955 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
956 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
957 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
958 /// ...
959 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
960 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
961 /// ```
962 ///
963 /// Note: As an optimization, if the vector of the original TransferWriteOp
964 /// was directly extracted from another vector via an ExtractOp `a`, extract
965 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
966 /// doing so, `a` may become dead, and the number of ExtractOps generated during
967 /// recursive application of this pattern will be minimal.
968 struct UnrollTransferWriteConversion
969     : public VectorToSCFPattern<TransferWriteOp> {
970   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
971 
972   void initialize() {
973     // This pattern recursively unpacks one dimension at a time. The recursion
974     // bounded as the rank is strictly decreasing.
975     setHasBoundedRewriteRecursion();
976   }
977 
978   /// Return the vector from which newly generated ExtracOps will extract.
979   Value getDataVector(TransferWriteOp xferOp) const {
980     if (auto extractOp = getExtractOp(xferOp))
981       return extractOp.vector();
982     return xferOp.vector();
983   }
984 
985   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
986   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
987     if (auto *op = xferOp.vector().getDefiningOp())
988       return dyn_cast<vector::ExtractOp>(op);
989     return vector::ExtractOp();
990   }
991 
992   /// If the input of the given TransferWriteOp is an ExtractOp, return its
993   /// indices.
994   void getExtractionIndices(TransferWriteOp xferOp,
995                             SmallVector<int64_t, 8> &indices) const {
996     if (auto extractOp = getExtractOp(xferOp)) {
997       llvm::for_each(extractOp.position(), [&](Attribute attr) {
998         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
999       });
1000     }
1001   }
1002 
1003   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1004   /// accesses, and broadcasts and transposes in permutation maps.
1005   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1006                                 PatternRewriter &rewriter) const override {
1007     if (xferOp.getVectorType().getRank() <= options.targetRank)
1008       return failure();
1009     if (isTensorOp(xferOp) && !options.lowerTensors)
1010       return failure();
1011     // Transfer ops that modify the element type are not supported atm.
1012     if (xferOp.getVectorType().getElementType() !=
1013         xferOp.getShapedType().getElementType())
1014       return failure();
1015 
1016     auto vec = getDataVector(xferOp);
1017     auto xferVecType = xferOp.getVectorType();
1018     int64_t dimSize = xferVecType.getShape()[0];
1019     auto source = xferOp.source(); // memref or tensor to be written to.
1020     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1021 
1022     // Generate fully unrolled loop of transfer ops.
1023     Location loc = xferOp.getLoc();
1024     for (int64_t i = 0; i < dimSize; ++i) {
1025       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1026 
1027       auto updatedSource = generateInBoundsCheck(
1028           rewriter, xferOp, iv, unpackedDim(xferOp),
1029           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1030           /*inBoundsCase=*/
1031           [&](OpBuilder &b, Location loc) {
1032             // Indices for the new transfer op.
1033             SmallVector<Value, 8> xferIndices;
1034             getXferIndices(b, xferOp, iv, xferIndices);
1035 
1036             // Indices for the new vector.extract op.
1037             SmallVector<int64_t, 8> extractionIndices;
1038             getExtractionIndices(xferOp, extractionIndices);
1039             extractionIndices.push_back(i);
1040 
1041             auto extracted =
1042                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1043             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
1044             auto newXferOp = b.create<vector::TransferWriteOp>(
1045                 loc, sourceType, extracted, source, xferIndices,
1046                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1047                 inBoundsAttr);
1048 
1049             maybeAssignMask(b, xferOp, newXferOp, i);
1050 
1051             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1052           },
1053           /*outOfBoundsCase=*/
1054           [&](OpBuilder &b, Location loc) {
1055             return isTensorOp(xferOp) ? source : Value();
1056           });
1057 
1058       if (isTensorOp(xferOp))
1059         source = updatedSource;
1060     }
1061 
1062     if (isTensorOp(xferOp))
1063       rewriter.replaceOp(xferOp, source);
1064     else
1065       rewriter.eraseOp(xferOp);
1066 
1067     return success();
1068   }
1069 };
1070 
1071 } // namespace lowering_n_d_unrolled
1072 
1073 namespace lowering_1_d {
1074 
1075 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1076 /// part of TransferOp1dConversion. Return the memref dimension on which
1077 /// the transfer is operating. A return value of None indicates a broadcast.
1078 template <typename OpTy>
1079 static Optional<int64_t>
1080 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1081                    SmallVector<Value, 8> &memrefIndices) {
1082   auto indices = xferOp.indices();
1083   auto map = xferOp.permutation_map();
1084 
1085   memrefIndices.append(indices.begin(), indices.end());
1086   assert(map.getNumResults() == 1 &&
1087          "Expected 1 permutation map result for 1D transfer");
1088   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1089     Location loc = xferOp.getLoc();
1090     auto dim = expr.getPosition();
1091     AffineExpr d0, d1;
1092     bindDims(xferOp.getContext(), d0, d1);
1093     Value offset = memrefIndices[dim];
1094     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1095     return dim;
1096   }
1097 
1098   assert(xferOp.isBroadcastDim(0) &&
1099          "Expected AffineDimExpr or AffineConstantExpr");
1100   return None;
1101 }
1102 
1103 /// Codegen strategy for TransferOp1dConversion, depending on the
1104 /// operation.
1105 template <typename OpTy>
1106 struct Strategy1d;
1107 
1108 /// Codegen strategy for TransferReadOp.
1109 template <>
1110 struct Strategy1d<TransferReadOp> {
1111   static void generateForLoopBody(OpBuilder &b, Location loc,
1112                                   TransferReadOp xferOp, Value iv,
1113                                   ValueRange loopState) {
1114     SmallVector<Value, 8> indices;
1115     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1116     auto vec = loopState[0];
1117 
1118     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1119     // padding value).
1120     auto nextVec = generateInBoundsCheck(
1121         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1122         /*inBoundsCase=*/
1123         [&](OpBuilder &b, Location loc) {
1124           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
1125           return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1126         },
1127         /*outOfBoundsCase=*/
1128         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1129     b.create<scf::YieldOp>(loc, nextVec);
1130   }
1131 
1132   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1133     // Inititalize vector with padding value.
1134     Location loc = xferOp.getLoc();
1135     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
1136   }
1137 };
1138 
1139 /// Codegen strategy for TransferWriteOp.
1140 template <>
1141 struct Strategy1d<TransferWriteOp> {
1142   static void generateForLoopBody(OpBuilder &b, Location loc,
1143                                   TransferWriteOp xferOp, Value iv,
1144                                   ValueRange /*loopState*/) {
1145     SmallVector<Value, 8> indices;
1146     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1147 
1148     // Nothing to do in case of out-of-bounds access.
1149     generateInBoundsCheck(
1150         b, xferOp, iv, dim,
1151         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1152           auto val =
1153               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), iv);
1154           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
1155         });
1156     b.create<scf::YieldOp>(loc);
1157   }
1158 
1159   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1160     return Value();
1161   }
1162 };
1163 
1164 /// Return true if the last dimension of the MemRefType has unit stride.
1165 static bool isLastMemrefDimUnitStride(MemRefType type) {
1166   int64_t offset;
1167   SmallVector<int64_t, 4> strides;
1168   auto successStrides = getStridesAndOffset(type, strides, offset);
1169   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1170 }
1171 
1172 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1173 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1174 /// vector load/stores due to non-unit strides or broadcasts:
1175 ///
1176 /// * Transfer dimension is not the last memref dimension
1177 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1178 /// * Memref has a layout map with non-unit stride on the last dimension
1179 ///
1180 /// This pattern generates IR as follows:
1181 ///
1182 /// 1. Generate a for loop iterating over each vector element.
1183 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1184 ///    depending on OpTy.
1185 ///
1186 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1187 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1188 ///       to ConvertVectorToLLVM.
1189 ///
1190 /// E.g.:
1191 /// ```
1192 /// vector.transfer_write %vec, %A[%a, %b]
1193 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1194 ///    : vector<9xf32>, memref<?x?xf32>
1195 /// ```
1196 /// Is rewritten to approximately the following pseudo-IR:
1197 /// ```
1198 /// for i = 0 to 9 {
1199 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1200 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1201 /// }
1202 /// ```
1203 template <typename OpTy>
1204 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1205   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1206 
1207   LogicalResult matchAndRewrite(OpTy xferOp,
1208                                 PatternRewriter &rewriter) const override {
1209     auto map = xferOp.permutation_map();
1210     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1211 
1212     if (!memRefType)
1213       return failure();
1214     if (xferOp.getVectorType().getRank() != 1)
1215       return failure();
1216     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1217       return failure(); // Handled by ConvertVectorToLLVM
1218 
1219     // Loop bounds, step, state...
1220     Location loc = xferOp.getLoc();
1221     auto vecType = xferOp.getVectorType();
1222     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1223     auto ub =
1224         rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1225     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1226     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1227 
1228     // Generate for loop.
1229     rewriter.replaceOpWithNewOp<scf::ForOp>(
1230         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1231         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1232           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1233         });
1234 
1235     return success();
1236   }
1237 };
1238 
1239 } // namespace lowering_1_d
1240 } // namespace
1241 
1242 namespace mlir {
1243 
1244 void populateVectorToSCFConversionPatterns(
1245     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1246   if (options.unroll) {
1247     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1248                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1249         patterns.getContext(), options);
1250   } else {
1251     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1252                  lowering_n_d::PrepareTransferWriteConversion,
1253                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1254                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1255         patterns.getContext(), options);
1256   }
1257 
1258   if (options.targetRank == 1) {
1259     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1260                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1261         patterns.getContext(), options);
1262   }
1263 }
1264 
1265 } // namespace mlir
1266 
1267 namespace {
1268 
1269 struct ConvertVectorToSCFPass
1270     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1271   ConvertVectorToSCFPass() = default;
1272   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1273     this->fullUnroll = options.unroll;
1274     this->targetRank = options.targetRank;
1275     this->lowerPermutationMaps = options.lowerPermutationMaps;
1276     this->lowerTensors = options.lowerTensors;
1277   }
1278 
1279   void runOnFunction() override {
1280     VectorTransferToSCFOptions options;
1281     options.unroll = fullUnroll;
1282     options.targetRank = targetRank;
1283     options.lowerPermutationMaps = lowerPermutationMaps;
1284     options.lowerTensors = lowerTensors;
1285 
1286     // Lower permutation maps first.
1287     if (lowerPermutationMaps) {
1288       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1289       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1290           lowerTransferPatterns);
1291       (void)applyPatternsAndFoldGreedily(getFunction(),
1292                                          std::move(lowerTransferPatterns));
1293     }
1294 
1295     RewritePatternSet patterns(getFunction().getContext());
1296     populateVectorToSCFConversionPatterns(patterns, options);
1297     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
1298   }
1299 };
1300 
1301 } // namespace
1302 
1303 std::unique_ptr<Pass>
1304 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1305   return std::make_unique<ConvertVectorToSCFPass>(options);
1306 }
1307