1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/Utils.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/Dialect/Vector/VectorUtils.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 using namespace mlir;
30 using vector::TransferReadOp;
31 using vector::TransferWriteOp;
32 
33 namespace {
34 
35 /// Attribute name used for labeling transfer ops during progressive lowering.
36 static const char kPassLabel[] = "__vector_to_scf_lowering__";
37 
38 /// Patterns that inherit from this struct have access to
39 /// VectorTransferToSCFOptions.
40 template <typename OpTy>
41 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
42   explicit VectorToSCFPattern(MLIRContext *context,
43                               VectorTransferToSCFOptions opt)
44       : OpRewritePattern<OpTy>(context), options(opt) {}
45 
46   VectorTransferToSCFOptions options;
47 };
48 
49 /// Given a vector transfer op, calculate which dimension of the `source`
50 /// memref should be unpacked in the next application of TransferOpConversion.
51 /// A return value of None indicates a broadcast.
52 template <typename OpTy>
53 static Optional<int64_t> unpackedDim(OpTy xferOp) {
54   auto map = xferOp.permutation_map();
55   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
56     return expr.getPosition();
57   }
58   assert(xferOp.isBroadcastDim(0) &&
59          "Expected AffineDimExpr or AffineConstantExpr");
60   return None;
61 }
62 
63 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
64 /// map is identical to the current permutation map, but the first result is
65 /// omitted.
66 template <typename OpTy>
67 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
68   auto map = xferOp.permutation_map();
69   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
70                         b.getContext());
71 }
72 
73 /// Calculate the indices for the new vector transfer op.
74 ///
75 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
76 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
77 ///                                 ^^^^^^
78 ///              `iv` is the iteration variable of the (new) surrounding loop.
79 template <typename OpTy>
80 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
81                            SmallVector<Value, 8> &indices) {
82   typename OpTy::Adaptor adaptor(xferOp);
83   // Corresponding memref dim of the vector dim that is unpacked.
84   auto dim = unpackedDim(xferOp);
85   auto prevIndices = adaptor.indices();
86   indices.append(prevIndices.begin(), prevIndices.end());
87 
88   Location loc = xferOp.getLoc();
89   bool isBroadcast = !dim.hasValue();
90   if (!isBroadcast) {
91     AffineExpr d0, d1;
92     bindDims(xferOp.getContext(), d0, d1);
93     Value offset = adaptor.indices()[dim.getValue()];
94     indices[dim.getValue()] =
95         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
96   }
97 }
98 
99 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
100                             Value value) {
101   if (hasRetVal) {
102     b.create<scf::YieldOp>(loc, value);
103   } else {
104     b.create<scf::YieldOp>(loc);
105   }
106 }
107 
108 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
109 /// is set to true. No such check is generated under following circumstances:
110 /// * xferOp does not have a mask.
111 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
112 ///   computed and attached to the new transfer op in the pattern.)
113 /// * The to-be-unpacked dim of xferOp is a broadcast.
114 template <typename OpTy>
115 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
116   if (!xferOp.mask())
117     return Value();
118   if (xferOp.getMaskType().getRank() != 1)
119     return Value();
120   if (xferOp.isBroadcastDim(0))
121     return Value();
122 
123   Location loc = xferOp.getLoc();
124   Value ivI32 =
125       b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
126   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
127 }
128 
129 /// Helper function TransferOpConversion and TransferOp1dConversion.
130 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
131 /// specified dimension `dim` with the loop iteration variable `iv`.
132 /// E.g., when unpacking dimension 0 from:
133 /// ```
134 /// %vec = vector.transfer_read %A[%a, %b] %cst
135 ///     : vector<5x4xf32>, memref<?x?xf32>
136 /// ```
137 /// An if check similar to this will be generated inside the loop:
138 /// ```
139 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
140 /// if (%a + iv < %d) {
141 ///   (in-bounds case)
142 /// } else {
143 ///   (out-of-bounds case)
144 /// }
145 /// ```
146 ///
147 /// If the transfer is 1D and has a mask, this function generates a more complex
148 /// check also accounts for potentially masked out elements.
149 ///
150 /// This function variant returns the value returned by `inBoundsCase` or
151 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
152 /// `resultTypes`.
153 template <typename OpTy>
154 static Value generateInBoundsCheck(
155     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
156     TypeRange resultTypes,
157     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
158     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
159   bool hasRetVal = !resultTypes.empty();
160   Value cond; // Condition to be built...
161 
162   // Condition check 1: Access in-bounds?
163   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
164   Location loc = xferOp.getLoc();
165   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
166   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
167     Value memrefDim = lb.create<memref::DimOp>(xferOp.source(), *dim);
168     AffineExpr d0, d1;
169     bindDims(xferOp.getContext(), d0, d1);
170     Value base = xferOp.indices()[dim.getValue()];
171     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
172     cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx);
173   }
174 
175   // Condition check 2: Masked in?
176   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
177     if (cond)
178       cond = lb.create<AndOp>(cond, maskCond);
179     else
180       cond = maskCond;
181   }
182 
183   // If the condition is non-empty, generate an SCF::IfOp.
184   if (cond) {
185     auto check = lb.create<scf::IfOp>(
186         resultTypes, cond,
187         /*thenBuilder=*/
188         [&](OpBuilder &b, Location loc) {
189           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
190         },
191         /*elseBuilder=*/
192         [&](OpBuilder &b, Location loc) {
193           if (outOfBoundsCase) {
194             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
195           } else {
196             b.create<scf::YieldOp>(loc);
197           }
198         });
199 
200     return hasRetVal ? check.getResult(0) : Value();
201   }
202 
203   // Condition is empty, no need for an SCF::IfOp.
204   return inBoundsCase(b, loc);
205 }
206 
207 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
208 /// a return value. Consequently, this function does not have a return value.
209 template <typename OpTy>
210 static void generateInBoundsCheck(
211     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
212     function_ref<void(OpBuilder &, Location)> inBoundsCase,
213     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
214   generateInBoundsCheck(
215       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
216       /*inBoundsCase=*/
217       [&](OpBuilder &b, Location loc) {
218         inBoundsCase(b, loc);
219         return Value();
220       },
221       /*outOfBoundsCase=*/
222       [&](OpBuilder &b, Location loc) {
223         if (outOfBoundsCase)
224           outOfBoundsCase(b, loc);
225         return Value();
226       });
227 }
228 
229 /// Given an ArrayAttr, return a copy where the first element is dropped.
230 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
231   if (!attr)
232     return attr;
233   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
234 }
235 
236 /// Add the pass label to a vector transfer op if its rank is not the target
237 /// rank.
238 template <typename OpTy>
239 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
240                                 unsigned targetRank) {
241   if (newXferOp.getVectorType().getRank() > targetRank)
242     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
243 }
244 
245 namespace lowering_n_d {
246 
247 /// Helper data structure for data and mask buffers.
248 struct BufferAllocs {
249   Value dataBuffer;
250   Value maskBuffer;
251 };
252 
253 /// Allocate temporary buffers for data (vector) and mask (if present).
254 /// TODO: Parallelism and threadlocal considerations.
255 template <typename OpTy>
256 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
257   Location loc = xferOp.getLoc();
258   OpBuilder::InsertionGuard guard(b);
259   Operation *scope =
260       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
261   assert(scope && "Expected op to be inside automatic allocation scope");
262   b.setInsertionPointToStart(&scope->getRegion(0).front());
263 
264   BufferAllocs result;
265   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
266   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
267 
268   if (xferOp.mask()) {
269     auto maskType = MemRefType::get({}, xferOp.mask().getType());
270     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
271     b.setInsertionPoint(xferOp);
272     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
273     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
274   }
275 
276   return result;
277 }
278 
279 /// Given a MemRefType with VectorType element type, unpack one dimension from
280 /// the VectorType into the MemRefType.
281 ///
282 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
283 static MemRefType unpackOneDim(MemRefType type) {
284   auto vectorType = type.getElementType().dyn_cast<VectorType>();
285   auto memrefShape = type.getShape();
286   SmallVector<int64_t, 8> newMemrefShape;
287   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
288   newMemrefShape.push_back(vectorType.getDimSize(0));
289   return MemRefType::get(newMemrefShape,
290                          VectorType::get(vectorType.getShape().drop_front(),
291                                          vectorType.getElementType()));
292 }
293 
294 /// Given a transfer op, find the memref from which the mask is loaded. This
295 /// is similar to Strategy<TransferWriteOp>::getBuffer.
296 template <typename OpTy>
297 static Value getMaskBuffer(OpTy xferOp) {
298   assert(xferOp.mask() && "Expected that transfer op has mask");
299   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
300   assert(loadOp && "Expected transfer op mask produced by LoadOp");
301   return loadOp.getMemRef();
302 }
303 
304 /// Codegen strategy, depending on the operation.
305 template <typename OpTy>
306 struct Strategy;
307 
308 /// Code strategy for vector TransferReadOp.
309 template <>
310 struct Strategy<TransferReadOp> {
311   /// Find the StoreOp that is used for writing the current TransferReadOp's
312   /// result to the temporary buffer allocation.
313   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
314     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
315     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
316     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
317     return storeOp;
318   }
319 
320   /// Find the temporary buffer allocation. All labeled TransferReadOps are
321   /// used like this, where %buf is either the buffer allocation or a type cast
322   /// of the buffer allocation:
323   /// ```
324   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
325   /// memref.store %vec, %buf[...] ...
326   /// ```
327   static Value getBuffer(TransferReadOp xferOp) {
328     return getStoreOp(xferOp).getMemRef();
329   }
330 
331   /// Retrieve the indices of the current StoreOp that stores into the buffer.
332   static void getBufferIndices(TransferReadOp xferOp,
333                                SmallVector<Value, 8> &indices) {
334     auto storeOp = getStoreOp(xferOp);
335     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
336     indices.append(prevIndices.begin(), prevIndices.end());
337   }
338 
339   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
340   /// accesses on the to-be-unpacked dimension.
341   ///
342   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
343   ///    variable `iv`.
344   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
345   ///
346   /// E.g.:
347   /// ```
348   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
349   ///     : memref<?x?x?xf32>, vector<4x3xf32>
350   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
351   /// ```
352   /// Is rewritten to:
353   /// ```
354   /// %casted = vector.type_cast %buf
355   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
356   /// for %j = 0 to 4 {
357   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
358   ///       : memref<?x?x?xf32>, vector<3xf32>
359   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
360   /// }
361   /// ```
362   ///
363   /// Note: The loop and type cast are generated in TransferOpConversion.
364   ///       The original TransferReadOp and store op are deleted in `cleanup`.
365   /// Note: The `mask` operand is set in TransferOpConversion.
366   static TransferReadOp rewriteOp(OpBuilder &b,
367                                   VectorTransferToSCFOptions options,
368                                   TransferReadOp xferOp, Value buffer,
369                                   Value iv) {
370     SmallVector<Value, 8> storeIndices;
371     getBufferIndices(xferOp, storeIndices);
372     storeIndices.push_back(iv);
373 
374     SmallVector<Value, 8> xferIndices;
375     getXferIndices(b, xferOp, iv, xferIndices);
376 
377     Location loc = xferOp.getLoc();
378     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
379     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
380     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
381     auto newXferOp = b.create<vector::TransferReadOp>(
382         loc, vecType, xferOp.source(), xferIndices,
383         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
384         Value(), inBoundsAttr);
385 
386     maybeApplyPassLabel(b, newXferOp, options.targetRank);
387 
388     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
389     return newXferOp;
390   }
391 
392   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
393   /// padding value to the temporary buffer.
394   static void handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
395                                    Value buffer, Value iv) {
396     SmallVector<Value, 8> storeIndices;
397     getBufferIndices(xferOp, storeIndices);
398     storeIndices.push_back(iv);
399 
400     Location loc = xferOp.getLoc();
401     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
402     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
403     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
404     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
405   }
406 
407   /// Cleanup after rewriting the op.
408   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) {
409     rewriter.eraseOp(getStoreOp(xferOp));
410     rewriter.eraseOp(xferOp);
411   }
412 };
413 
414 /// Codegen strategy for vector TransferWriteOp.
415 template <>
416 struct Strategy<TransferWriteOp> {
417   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
418   /// used like this, where %buf is either the buffer allocation or a type cast
419   /// of the buffer allocation:
420   /// ```
421   /// %vec = memref.load %buf[...] ...
422   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
423   /// ```
424   static Value getBuffer(TransferWriteOp xferOp) {
425     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
426     assert(loadOp && "Expected transfer op vector produced by LoadOp");
427     return loadOp.getMemRef();
428   }
429 
430   /// Retrieve the indices of the current LoadOp that loads from the buffer.
431   static void getBufferIndices(TransferWriteOp xferOp,
432                                SmallVector<Value, 8> &indices) {
433     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
434     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
435     indices.append(prevIndices.begin(), prevIndices.end());
436   }
437 
438   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
439   /// accesses on the to-be-unpacked dimension.
440   ///
441   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
442   ///    using the loop iteration variable `iv`.
443   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
444   ///    to memory.
445   ///
446   /// Note: For more details, see comments on Strategy<TransferReadOp>.
447   static TransferWriteOp rewriteOp(OpBuilder &b,
448                                    VectorTransferToSCFOptions options,
449                                    TransferWriteOp xferOp, Value buffer,
450                                    Value iv) {
451     SmallVector<Value, 8> loadIndices;
452     getBufferIndices(xferOp, loadIndices);
453     loadIndices.push_back(iv);
454 
455     SmallVector<Value, 8> xferIndices;
456     getXferIndices(b, xferOp, iv, xferIndices);
457 
458     Location loc = xferOp.getLoc();
459     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
460     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
461     auto newXferOp = b.create<vector::TransferWriteOp>(
462         loc, Type(), vec, xferOp.source(), xferIndices,
463         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
464         inBoundsAttr);
465 
466     maybeApplyPassLabel(b, newXferOp, options.targetRank);
467 
468     return newXferOp;
469   }
470 
471   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
472   static void handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
473                                    Value buffer, Value iv) {}
474 
475   /// Cleanup after rewriting the op.
476   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
477     rewriter.eraseOp(xferOp);
478   }
479 };
480 
481 template <typename OpTy>
482 LogicalResult checkPrepareXferOp(OpTy xferOp,
483                                  VectorTransferToSCFOptions options) {
484   if (xferOp->hasAttr(kPassLabel))
485     return failure();
486   if (xferOp.getVectorType().getRank() <= options.targetRank)
487     return failure();
488   if (xferOp.getShapedType().template isa<RankedTensorType>())
489     return failure();
490   // Transfer ops that modify the element type are not supported atm.
491   if (xferOp.getVectorType().getElementType() !=
492       xferOp.getShapedType().getElementType())
493     return failure();
494   return success();
495 }
496 
497 /// Prepare a TransferReadOp for progressive lowering.
498 ///
499 /// 1. Allocate a temporary buffer.
500 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
501 /// 3. Store the result of the TransferReadOp into the temporary buffer.
502 /// 4. Load the result from the temporary buffer and replace all uses of the
503 ///    original TransferReadOp with this load.
504 ///
505 /// E.g.:
506 /// ```
507 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
508 ///     : vector<5x4xf32>, memref<?x?x?xf32>
509 /// ```
510 /// is rewritten to:
511 /// ```
512 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
513 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
514 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
515 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
516 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
517 /// ```
518 ///
519 /// Note: A second temporary buffer may be allocated for the `mask` operand.
520 struct PrepareTransferReadConversion
521     : public VectorToSCFPattern<TransferReadOp> {
522   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
523 
524   LogicalResult matchAndRewrite(TransferReadOp xferOp,
525                                 PatternRewriter &rewriter) const override {
526     if (checkPrepareXferOp(xferOp, options).failed())
527       return failure();
528 
529     auto buffers = allocBuffers(rewriter, xferOp);
530     auto *newXfer = rewriter.clone(*xferOp.getOperation());
531     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
532     if (xferOp.mask()) {
533       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
534           buffers.maskBuffer);
535     }
536 
537     Location loc = xferOp.getLoc();
538     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
539                                      buffers.dataBuffer);
540     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
541 
542     return success();
543   }
544 };
545 
546 /// Prepare a TransferWriteOp for progressive lowering.
547 ///
548 /// 1. Allocate a temporary buffer.
549 /// 2. Store the vector into the buffer.
550 /// 3. Load the vector from the buffer again.
551 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
552 ///    marking it eligible for progressive lowering via TransferOpConversion.
553 ///
554 /// E.g.:
555 /// ```
556 /// vector.transfer_write %vec, %A[%a, %b, %c]
557 ///     : vector<5x4xf32>, memref<?x?x?xf32>
558 /// ```
559 /// is rewritten to:
560 /// ```
561 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
562 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
563 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
564 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
565 ///     : vector<5x4xf32>, memref<?x?x?xf32>
566 /// ```
567 ///
568 /// Note: A second temporary buffer may be allocated for the `mask` operand.
569 struct PrepareTransferWriteConversion
570     : public VectorToSCFPattern<TransferWriteOp> {
571   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
572 
573   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
574                                 PatternRewriter &rewriter) const override {
575     if (checkPrepareXferOp(xferOp, options).failed())
576       return failure();
577 
578     Location loc = xferOp.getLoc();
579     auto buffers = allocBuffers(rewriter, xferOp);
580     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
581     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
582     rewriter.updateRootInPlace(xferOp, [&]() {
583       xferOp.vectorMutable().assign(loadedVec);
584       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
585     });
586 
587     if (xferOp.mask()) {
588       rewriter.updateRootInPlace(
589           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
590     }
591 
592     return success();
593   }
594 };
595 
596 /// Progressive lowering of vector transfer ops: Unpack one dimension.
597 ///
598 /// 1. Unpack one dimension from the current buffer type and cast the buffer
599 ///    to that new type. E.g.:
600 ///    ```
601 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
602 ///    vector.transfer_write %vec ...
603 ///    ```
604 ///    The following cast is generated:
605 ///    ```
606 ///    %casted = vector.type_cast %0
607 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
608 ///    ```
609 /// 2. Generate a for loop and rewrite the transfer op according to the
610 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
611 ///    out-of-bounds, generate an if-check and handle both cases separately.
612 /// 3. Clean up according to the corresponding Strategy<OpTy>.
613 template <typename OpTy>
614 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
615   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
616 
617   LogicalResult matchAndRewrite(OpTy xferOp,
618                                 PatternRewriter &rewriter) const override {
619     if (!xferOp->hasAttr(kPassLabel))
620       return failure();
621 
622     // Find and cast data buffer. How the buffer can be found depends on OpTy.
623     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
624     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
625     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
626     auto castedDataType = unpackOneDim(dataBufferType);
627     auto castedDataBuffer =
628         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
629 
630     // If the xferOp has a mask: Find and cast mask buffer.
631     Value castedMaskBuffer;
632     if (xferOp.mask()) {
633       auto maskBuffer = getMaskBuffer(xferOp);
634       auto maskBufferType =
635           maskBuffer.getType().template dyn_cast<MemRefType>();
636       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
637         // Do not unpack a dimension of the mask, if:
638         // * To-be-unpacked transfer op dimension is a broadcast.
639         // * Mask is 1D, i.e., the mask cannot be further unpacked.
640         //   (That means that all remaining dimensions of the transfer op must
641         //   be broadcasted.)
642         castedMaskBuffer = maskBuffer;
643       } else {
644         auto castedMaskType = unpackOneDim(maskBufferType);
645         castedMaskBuffer =
646             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
647       }
648     }
649 
650     // Loop bounds and step.
651     auto lb = locB.create<ConstantIndexOp>(0);
652     auto ub = locB.create<ConstantIndexOp>(
653         castedDataType.getDimSize(castedDataType.getRank() - 1));
654     auto step = locB.create<ConstantIndexOp>(1);
655 
656     // Generate for loop.
657     locB.create<scf::ForOp>(
658         lb, ub, step, ValueRange(),
659         [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) {
660           generateInBoundsCheck(
661               b, xferOp, iv, unpackedDim(xferOp),
662               /*inBoundsCase=*/
663               [&](OpBuilder &b, Location loc) {
664                 // Create new transfer op.
665                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
666                     b, this->options, xferOp, castedDataBuffer, iv);
667 
668                 // If old transfer op has a mask: Set mask on new transfer op.
669                 // Special case: If the mask of the old transfer op is 1D and
670                 // the
671                 //               unpacked dim is not a broadcast, no mask is
672                 //               needed on the new transfer op.
673                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
674                                       xferOp.getMaskType().getRank() > 1)) {
675                   OpBuilder::InsertionGuard guard(b);
676                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
677 
678                   SmallVector<Value, 8> loadIndices;
679                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
680                   // In case of broadcast: Use same indices to load from memref
681                   // as before.
682                   if (!xferOp.isBroadcastDim(0))
683                     loadIndices.push_back(iv);
684 
685                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
686                                                        loadIndices);
687                   rewriter.updateRootInPlace(
688                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
689                 }
690               },
691               /*outOfBoundsCase=*/
692               [&](OpBuilder &b, Location /*loc*/) {
693                 Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp,
694                                                      castedDataBuffer, iv);
695               });
696           b.create<scf::YieldOp>(loc);
697         });
698 
699     Strategy<OpTy>::cleanup(rewriter, xferOp);
700     return success();
701   }
702 };
703 
704 } // namespace lowering_n_d
705 
706 namespace lowering_n_d_unrolled {
707 
708 /// If the original transfer op has a mask, compute the mask of the new transfer
709 /// op (for the current iteration `i`) and assign it.
710 template <typename OpTy>
711 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
712                             int64_t i) {
713   if (!xferOp.mask())
714     return;
715 
716   if (xferOp.isBroadcastDim(0)) {
717     // To-be-unpacked dimension is a broadcast, which does not have a
718     // corresponding mask dimension. Mask attribute remains unchanged.
719     newXferOp.maskMutable().assign(xferOp.mask());
720     return;
721   }
722 
723   if (xferOp.getMaskType().getRank() > 1) {
724     // Unpack one dimension of the mask.
725     OpBuilder::InsertionGuard guard(b);
726     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
727 
728     llvm::SmallVector<int64_t, 1> indices({i});
729     Location loc = xferOp.getLoc();
730     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
731     newXferOp.maskMutable().assign(newMask);
732   }
733 
734   // If we end up here: The mask of the old transfer op is 1D and the unpacked
735   // dim is not a broadcast, so no mask is needed on the new transfer op.
736   // `generateInBoundsCheck` will have evaluated the mask already.
737 }
738 
739 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
740 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
741 /// memref buffer is allocated and the SCF loop is fully unrolled.
742 ///
743 /// ```
744 /// E.g.:
745 /// ```
746 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
747 ///     : memref<?x?x?xf32>, vector<5x4xf32>
748 /// ```
749 /// is rewritten to IR such as (simplified):
750 /// ```
751 /// %v_init = splat %padding : vector<5x4xf32>
752 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
753 ///     : memref<?x?x?xf32>, vector<4xf32>
754 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
755 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
756 ///     : memref<?x?x?xf32>, vector<4xf32>
757 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
758 /// ...
759 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
760 ///     : memref<?x?x?xf32>, vector<4xf32>
761 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
762 /// ```
763 ///
764 /// Note: As an optimization, if the result of the original TransferReadOp
765 /// was directly inserted into another vector, no new %v_init vector is created.
766 /// Instead, the new TransferReadOp results are inserted into that vector.
767 struct UnrollTransferReadConversion
768     : public VectorToSCFPattern<TransferReadOp> {
769   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
770 
771   /// Return the vector into which the newly created TransferReadOp results
772   /// are inserted.
773   Value getResultVector(TransferReadOp xferOp,
774                         PatternRewriter &rewriter) const {
775     if (auto insertOp = getInsertOp(xferOp))
776       return insertOp.dest();
777     Location loc = xferOp.getLoc();
778     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
779                                     xferOp.padding());
780   }
781 
782   /// If the result of the TransferReadOp has exactly one user, which is a
783   /// vector::InsertOp, return that operation.
784   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
785     if (xferOp->hasOneUse()) {
786       Operation *xferOpUser = *xferOp->getUsers().begin();
787       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
788         return insertOp;
789     }
790 
791     return vector::InsertOp();
792   }
793 
794   /// If the result of the TransferReadOp has exactly one user, which is a
795   /// vector::InsertOp, return that operation's indices.
796   void getInsertionIndices(TransferReadOp xferOp,
797                            SmallVector<int64_t, 8> &indices) const {
798     if (auto insertOp = getInsertOp(xferOp)) {
799       llvm::for_each(insertOp.position(), [&](Attribute attr) {
800         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
801       });
802     }
803   }
804 
805   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
806   /// accesses, and broadcasts and transposes in permutation maps.
807   LogicalResult matchAndRewrite(TransferReadOp xferOp,
808                                 PatternRewriter &rewriter) const override {
809     if (xferOp.getVectorType().getRank() <= options.targetRank)
810       return failure();
811     if (xferOp.getShapedType().template isa<RankedTensorType>())
812       return failure();
813     // Transfer ops that modify the element type are not supported atm.
814     if (xferOp.getVectorType().getElementType() !=
815         xferOp.getShapedType().getElementType())
816       return failure();
817 
818     auto insertOp = getInsertOp(xferOp);
819     auto vec = getResultVector(xferOp, rewriter);
820     auto vecType = vec.getType().dyn_cast<VectorType>();
821     auto xferVecType = xferOp.getVectorType();
822     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
823                                           xferVecType.getElementType());
824     int64_t dimSize = xferVecType.getShape()[0];
825 
826     // Generate fully unrolled loop of transfer ops.
827     Location loc = xferOp.getLoc();
828     for (int64_t i = 0; i < dimSize; ++i) {
829       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
830 
831       vec = generateInBoundsCheck(
832           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
833           /*inBoundsCase=*/
834           [&](OpBuilder &b, Location loc) {
835             // Indices for the new transfer op.
836             SmallVector<Value, 8> xferIndices;
837             getXferIndices(b, xferOp, iv, xferIndices);
838 
839             // Indices for the new vector.insert op.
840             SmallVector<int64_t, 8> insertionIndices;
841             getInsertionIndices(xferOp, insertionIndices);
842             insertionIndices.push_back(i);
843 
844             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
845             auto newXferOp = b.create<vector::TransferReadOp>(
846                 loc, newXferVecType, xferOp.source(), xferIndices,
847                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
848                 xferOp.padding(), Value(), inBoundsAttr);
849             maybeAssignMask(b, xferOp, newXferOp, i);
850             return b.create<vector::InsertOp>(loc, newXferOp, vec,
851                                               insertionIndices);
852           },
853           /*outOfBoundsCase=*/
854           [&](OpBuilder &b, Location loc) {
855             // Loop through original (unmodified) vector.
856             return vec;
857           });
858     }
859 
860     if (insertOp) {
861       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
862       rewriter.replaceOp(insertOp, vec);
863       rewriter.eraseOp(xferOp);
864     } else {
865       rewriter.replaceOp(xferOp, vec);
866     }
867 
868     return success();
869   }
870 };
871 
872 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
873 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
874 /// memref buffer is allocated and the SCF loop is fully unrolled.
875 ///
876 /// ```
877 /// E.g.:
878 /// ```
879 /// vector.transfer_write %vec, %A[%a, %b, %c]
880 ///     : vector<5x4xf32>, memref<?x?x?xf32>
881 /// ```
882 /// is rewritten to IR such as (simplified):
883 /// ```
884 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
885 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
886 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
887 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
888 /// ...
889 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
890 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
891 /// ```
892 ///
893 /// Note: As an optimization, if the vector of the original TransferWriteOp
894 /// was directly extracted from another vector via an ExtractOp `a`, extract
895 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
896 /// doing so, `a` may become dead, and the number of ExtractOps generated during
897 /// recursive application of this pattern will be minimal.
898 struct UnrollTransferWriteConversion
899     : public VectorToSCFPattern<TransferWriteOp> {
900   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
901 
902   /// Return the vector from which newly generated ExtracOps will extract.
903   Value getDataVector(TransferWriteOp xferOp) const {
904     if (auto extractOp = getExtractOp(xferOp))
905       return extractOp.vector();
906     return xferOp.vector();
907   }
908 
909   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
910   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
911     if (auto *op = xferOp.vector().getDefiningOp())
912       return dyn_cast<vector::ExtractOp>(op);
913     return vector::ExtractOp();
914   }
915 
916   /// If the input of the given TransferWriteOp is an ExtractOp, return its
917   /// indices.
918   void getExtractionIndices(TransferWriteOp xferOp,
919                             SmallVector<int64_t, 8> &indices) const {
920     if (auto extractOp = getExtractOp(xferOp)) {
921       llvm::for_each(extractOp.position(), [&](Attribute attr) {
922         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
923       });
924     }
925   }
926 
927   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
928   /// accesses, and broadcasts and transposes in permutation maps.
929   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
930                                 PatternRewriter &rewriter) const override {
931     if (xferOp.getVectorType().getRank() <= options.targetRank)
932       return failure();
933     if (xferOp.getShapedType().template isa<RankedTensorType>())
934       return failure();
935     // Transfer ops that modify the element type are not supported atm.
936     if (xferOp.getVectorType().getElementType() !=
937         xferOp.getShapedType().getElementType())
938       return failure();
939 
940     auto vec = getDataVector(xferOp);
941     auto xferVecType = xferOp.getVectorType();
942     int64_t dimSize = xferVecType.getShape()[0];
943 
944     // Generate fully unrolled loop of transfer ops.
945     Location loc = xferOp.getLoc();
946     for (int64_t i = 0; i < dimSize; ++i) {
947       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
948 
949       generateInBoundsCheck(
950           rewriter, xferOp, iv, unpackedDim(xferOp),
951           /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
952             // Indices for the new transfer op.
953             SmallVector<Value, 8> xferIndices;
954             getXferIndices(b, xferOp, iv, xferIndices);
955 
956             // Indices for the new vector.extract op.
957             SmallVector<int64_t, 8> extractionIndices;
958             getExtractionIndices(xferOp, extractionIndices);
959             extractionIndices.push_back(i);
960 
961             auto extracted =
962                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
963             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
964 
965             auto newXferOp = b.create<vector::TransferWriteOp>(
966                 loc, Type(), extracted, xferOp.source(), xferIndices,
967                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
968                 inBoundsAttr);
969 
970             maybeAssignMask(b, xferOp, newXferOp, i);
971           });
972     }
973 
974     rewriter.eraseOp(xferOp);
975     return success();
976   }
977 };
978 
979 } // namespace lowering_n_d_unrolled
980 
981 namespace lowering_1_d {
982 
983 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
984 /// part of TransferOp1dConversion. Return the memref dimension on which
985 /// the transfer is operating. A return value of None indicates a broadcast.
986 template <typename OpTy>
987 static Optional<int64_t>
988 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
989                    SmallVector<Value, 8> &memrefIndices) {
990   auto indices = xferOp.indices();
991   auto map = xferOp.permutation_map();
992 
993   memrefIndices.append(indices.begin(), indices.end());
994   assert(map.getNumResults() == 1 &&
995          "Expected 1 permutation map result for 1D transfer");
996   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
997     Location loc = xferOp.getLoc();
998     auto dim = expr.getPosition();
999     AffineExpr d0, d1;
1000     bindDims(xferOp.getContext(), d0, d1);
1001     Value offset = memrefIndices[dim];
1002     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1003     return dim;
1004   }
1005 
1006   assert(xferOp.isBroadcastDim(0) &&
1007          "Expected AffineDimExpr or AffineConstantExpr");
1008   return None;
1009 }
1010 
1011 /// Codegen strategy for TransferOp1dConversion, depending on the
1012 /// operation.
1013 template <typename OpTy>
1014 struct Strategy1d;
1015 
1016 /// Codegen strategy for TransferReadOp.
1017 template <>
1018 struct Strategy1d<TransferReadOp> {
1019   static void generateForLoopBody(OpBuilder &b, Location loc,
1020                                   TransferReadOp xferOp, Value iv,
1021                                   ValueRange loopState) {
1022     SmallVector<Value, 8> indices;
1023     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1024     Value ivI32 =
1025         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1026     auto vec = loopState[0];
1027 
1028     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1029     // padding value).
1030     auto nextVec = generateInBoundsCheck(
1031         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1032         /*inBoundsCase=*/
1033         [&](OpBuilder &b, Location loc) {
1034           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
1035           return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
1036         },
1037         /*outOfBoundsCase=*/
1038         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1039     b.create<scf::YieldOp>(loc, nextVec);
1040   }
1041 
1042   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1043     // Inititalize vector with padding value.
1044     Location loc = xferOp.getLoc();
1045     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
1046   }
1047 };
1048 
1049 /// Codegen strategy for TransferWriteOp.
1050 template <>
1051 struct Strategy1d<TransferWriteOp> {
1052   static void generateForLoopBody(OpBuilder &b, Location loc,
1053                                   TransferWriteOp xferOp, Value iv,
1054                                   ValueRange /*loopState*/) {
1055     SmallVector<Value, 8> indices;
1056     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1057     Value ivI32 =
1058         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1059 
1060     // Nothing to do in case of out-of-bounds access.
1061     generateInBoundsCheck(
1062         b, xferOp, iv, dim,
1063         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1064           auto val =
1065               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
1066           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
1067         });
1068     b.create<scf::YieldOp>(loc);
1069   }
1070 
1071   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1072     return Value();
1073   }
1074 };
1075 
1076 /// Return true if the last dimension of the MemRefType has unit stride.
1077 static bool isLastMemrefDimUnitStride(MemRefType type) {
1078   int64_t offset;
1079   SmallVector<int64_t, 4> strides;
1080   auto successStrides = getStridesAndOffset(type, strides, offset);
1081   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1082 }
1083 
1084 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1085 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1086 /// vector load/stores due to non-unit strides or broadcasts:
1087 ///
1088 /// * Transfer dimension is not the last memref dimension
1089 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1090 /// * Memref has a layout map with non-unit stride on the last dimension
1091 ///
1092 /// This pattern generates IR as follows:
1093 ///
1094 /// 1. Generate a for loop iterating over each vector element.
1095 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1096 ///    depending on OpTy.
1097 ///
1098 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1099 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1100 ///       to ConvertVectorToLLVM.
1101 ///
1102 /// E.g.:
1103 /// ```
1104 /// vector.transfer_write %vec, %A[%a, %b]
1105 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1106 ///    : vector<9xf32>, memref<?x?xf32>
1107 /// ```
1108 /// Is rewritten to approximately the following pseudo-IR:
1109 /// ```
1110 /// for i = 0 to 9 {
1111 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1112 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1113 /// }
1114 /// ```
1115 template <typename OpTy>
1116 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1117   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1118 
1119   LogicalResult matchAndRewrite(OpTy xferOp,
1120                                 PatternRewriter &rewriter) const override {
1121     auto map = xferOp.permutation_map();
1122     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1123 
1124     if (!memRefType)
1125       return failure();
1126     if (xferOp.getVectorType().getRank() != 1)
1127       return failure();
1128     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1129       return failure(); // Handled by ConvertVectorToLLVM
1130 
1131     // Loop bounds, step, state...
1132     Location loc = xferOp.getLoc();
1133     auto vecType = xferOp.getVectorType();
1134     auto lb = rewriter.create<ConstantIndexOp>(loc, 0);
1135     auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0));
1136     auto step = rewriter.create<ConstantIndexOp>(loc, 1);
1137     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1138 
1139     // Generate for loop.
1140     rewriter.replaceOpWithNewOp<scf::ForOp>(
1141         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1142         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1143           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1144         });
1145 
1146     return success();
1147   }
1148 };
1149 
1150 } // namespace lowering_1_d
1151 } // namespace
1152 
1153 namespace mlir {
1154 
1155 void populateVectorToSCFConversionPatterns(
1156     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1157   if (options.unroll) {
1158     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1159                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1160         patterns.getContext(), options);
1161   } else {
1162     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1163                  lowering_n_d::PrepareTransferWriteConversion,
1164                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1165                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1166         patterns.getContext(), options);
1167   }
1168 
1169   if (options.targetRank == 1) {
1170     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1171                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1172         patterns.getContext(), options);
1173   }
1174 }
1175 
1176 } // namespace mlir
1177 
1178 namespace {
1179 
1180 struct ConvertVectorToSCFPass
1181     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1182   ConvertVectorToSCFPass() = default;
1183   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1184     this->fullUnroll = options.unroll;
1185     this->targetRank = options.targetRank;
1186     this->lowerPermutationMaps = options.lowerPermutationMaps;
1187   }
1188 
1189   void runOnFunction() override {
1190     VectorTransferToSCFOptions options;
1191     options.unroll = fullUnroll;
1192     options.targetRank = targetRank;
1193     options.lowerPermutationMaps = lowerPermutationMaps;
1194 
1195     // Lower permutation maps first.
1196     if (lowerPermutationMaps) {
1197       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1198       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1199           lowerTransferPatterns);
1200       (void)applyPatternsAndFoldGreedily(getFunction(),
1201                                          std::move(lowerTransferPatterns));
1202     }
1203 
1204     RewritePatternSet patterns(getFunction().getContext());
1205     populateVectorToSCFConversionPatterns(patterns, options);
1206     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
1207   }
1208 };
1209 
1210 } // namespace
1211 
1212 std::unique_ptr<Pass>
1213 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1214   return std::make_unique<ConvertVectorToSCFPass>(options);
1215 }
1216