1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 using namespace mlir;
30 using vector::TransferReadOp;
31 using vector::TransferWriteOp;
32 
33 namespace {
34 
35 /// Attribute name used for labeling transfer ops during progressive lowering.
36 static const char kPassLabel[] = "__vector_to_scf_lowering__";
37 
38 /// Patterns that inherit from this struct have access to
39 /// VectorTransferToSCFOptions.
40 template <typename OpTy>
41 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
42   explicit VectorToSCFPattern(MLIRContext *context,
43                               VectorTransferToSCFOptions opt)
44       : OpRewritePattern<OpTy>(context), options(opt) {}
45 
46   VectorTransferToSCFOptions options;
47 };
48 
49 /// Given a vector transfer op, calculate which dimension of the `source`
50 /// memref should be unpacked in the next application of TransferOpConversion.
51 /// A return value of None indicates a broadcast.
52 template <typename OpTy>
53 static Optional<int64_t> unpackedDim(OpTy xferOp) {
54   // TODO: support 0-d corner case.
55   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
56   auto map = xferOp.permutation_map();
57   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
58     return expr.getPosition();
59   }
60   assert(xferOp.isBroadcastDim(0) &&
61          "Expected AffineDimExpr or AffineConstantExpr");
62   return None;
63 }
64 
65 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
66 /// map is identical to the current permutation map, but the first result is
67 /// omitted.
68 template <typename OpTy>
69 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
70   // TODO: support 0-d corner case.
71   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
72   auto map = xferOp.permutation_map();
73   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
74                         b.getContext());
75 }
76 
77 /// Calculate the indices for the new vector transfer op.
78 ///
79 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
80 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
81 ///                                 ^^^^^^
82 ///              `iv` is the iteration variable of the (new) surrounding loop.
83 template <typename OpTy>
84 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
85                            SmallVector<Value, 8> &indices) {
86   typename OpTy::Adaptor adaptor(xferOp);
87   // Corresponding memref dim of the vector dim that is unpacked.
88   auto dim = unpackedDim(xferOp);
89   auto prevIndices = adaptor.indices();
90   indices.append(prevIndices.begin(), prevIndices.end());
91 
92   Location loc = xferOp.getLoc();
93   bool isBroadcast = !dim.hasValue();
94   if (!isBroadcast) {
95     AffineExpr d0, d1;
96     bindDims(xferOp.getContext(), d0, d1);
97     Value offset = adaptor.indices()[dim.getValue()];
98     indices[dim.getValue()] =
99         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
100   }
101 }
102 
103 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
104                             Value value) {
105   if (hasRetVal) {
106     assert(value && "Expected non-empty value");
107     b.create<scf::YieldOp>(loc, value);
108   } else {
109     b.create<scf::YieldOp>(loc);
110   }
111 }
112 
113 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
114 /// is set to true. No such check is generated under following circumstances:
115 /// * xferOp does not have a mask.
116 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
117 ///   computed and attached to the new transfer op in the pattern.)
118 /// * The to-be-unpacked dim of xferOp is a broadcast.
119 template <typename OpTy>
120 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
121   if (!xferOp.mask())
122     return Value();
123   if (xferOp.getMaskType().getRank() != 1)
124     return Value();
125   if (xferOp.isBroadcastDim(0))
126     return Value();
127 
128   Location loc = xferOp.getLoc();
129   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), iv);
130 }
131 
132 /// Helper function TransferOpConversion and TransferOp1dConversion.
133 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
134 /// specified dimension `dim` with the loop iteration variable `iv`.
135 /// E.g., when unpacking dimension 0 from:
136 /// ```
137 /// %vec = vector.transfer_read %A[%a, %b] %cst
138 ///     : vector<5x4xf32>, memref<?x?xf32>
139 /// ```
140 /// An if check similar to this will be generated inside the loop:
141 /// ```
142 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
143 /// if (%a + iv < %d) {
144 ///   (in-bounds case)
145 /// } else {
146 ///   (out-of-bounds case)
147 /// }
148 /// ```
149 ///
150 /// If the transfer is 1D and has a mask, this function generates a more complex
151 /// check also accounts for potentially masked out elements.
152 ///
153 /// This function variant returns the value returned by `inBoundsCase` or
154 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
155 /// `resultTypes`.
156 template <typename OpTy>
157 static Value generateInBoundsCheck(
158     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
159     TypeRange resultTypes,
160     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
161     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
162   bool hasRetVal = !resultTypes.empty();
163   Value cond; // Condition to be built...
164 
165   // Condition check 1: Access in-bounds?
166   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
167   Location loc = xferOp.getLoc();
168   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
169   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
170     Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim);
171     AffineExpr d0, d1;
172     bindDims(xferOp.getContext(), d0, d1);
173     Value base = xferOp.indices()[dim.getValue()];
174     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
175     cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
176                                     memrefIdx);
177   }
178 
179   // Condition check 2: Masked in?
180   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
181     if (cond)
182       cond = lb.create<arith::AndIOp>(cond, maskCond);
183     else
184       cond = maskCond;
185   }
186 
187   // If the condition is non-empty, generate an SCF::IfOp.
188   if (cond) {
189     auto check = lb.create<scf::IfOp>(
190         resultTypes, cond,
191         /*thenBuilder=*/
192         [&](OpBuilder &b, Location loc) {
193           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
194         },
195         /*elseBuilder=*/
196         [&](OpBuilder &b, Location loc) {
197           if (outOfBoundsCase) {
198             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
199           } else {
200             b.create<scf::YieldOp>(loc);
201           }
202         });
203 
204     return hasRetVal ? check.getResult(0) : Value();
205   }
206 
207   // Condition is empty, no need for an SCF::IfOp.
208   return inBoundsCase(b, loc);
209 }
210 
211 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
212 /// a return value. Consequently, this function does not have a return value.
213 template <typename OpTy>
214 static void generateInBoundsCheck(
215     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
216     function_ref<void(OpBuilder &, Location)> inBoundsCase,
217     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
218   generateInBoundsCheck(
219       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
220       /*inBoundsCase=*/
221       [&](OpBuilder &b, Location loc) {
222         inBoundsCase(b, loc);
223         return Value();
224       },
225       /*outOfBoundsCase=*/
226       [&](OpBuilder &b, Location loc) {
227         if (outOfBoundsCase)
228           outOfBoundsCase(b, loc);
229         return Value();
230       });
231 }
232 
233 /// Given an ArrayAttr, return a copy where the first element is dropped.
234 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
235   if (!attr)
236     return attr;
237   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
238 }
239 
240 /// Add the pass label to a vector transfer op if its rank is not the target
241 /// rank.
242 template <typename OpTy>
243 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
244                                 unsigned targetRank) {
245   if (newXferOp.getVectorType().getRank() > targetRank)
246     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
247 }
248 
249 /// Return true if this transfer op operates on a source tensor.
250 template <typename OpTy>
251 static bool isTensorOp(OpTy xferOp) {
252   if (xferOp.getShapedType().template isa<RankedTensorType>()) {
253     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
254       // TransferWriteOps on tensors have a result.
255       assert(xferOp->getNumResults() > 0);
256     }
257     return true;
258   }
259   return false;
260 }
261 
262 namespace lowering_n_d {
263 
264 /// Helper data structure for data and mask buffers.
265 struct BufferAllocs {
266   Value dataBuffer;
267   Value maskBuffer;
268 };
269 
270 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
271 static Operation *getAutomaticAllocationScope(Operation *op) {
272   Operation *scope =
273       op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
274   assert(scope && "Expected op to be inside automatic allocation scope");
275   return scope;
276 }
277 
278 /// Allocate temporary buffers for data (vector) and mask (if present).
279 template <typename OpTy>
280 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
281   Location loc = xferOp.getLoc();
282   OpBuilder::InsertionGuard guard(b);
283   Operation *scope = getAutomaticAllocationScope(xferOp);
284   assert(scope->getNumRegions() == 1 &&
285          "AutomaticAllocationScope with >1 regions");
286   b.setInsertionPointToStart(&scope->getRegion(0).front());
287 
288   BufferAllocs result;
289   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
290   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
291 
292   if (xferOp.mask()) {
293     auto maskType = MemRefType::get({}, xferOp.mask().getType());
294     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
295     b.setInsertionPoint(xferOp);
296     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
297     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
298   }
299 
300   return result;
301 }
302 
303 /// Given a MemRefType with VectorType element type, unpack one dimension from
304 /// the VectorType into the MemRefType.
305 ///
306 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
307 static MemRefType unpackOneDim(MemRefType type) {
308   auto vectorType = type.getElementType().dyn_cast<VectorType>();
309   auto memrefShape = type.getShape();
310   SmallVector<int64_t, 8> newMemrefShape;
311   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
312   newMemrefShape.push_back(vectorType.getDimSize(0));
313   return MemRefType::get(newMemrefShape,
314                          VectorType::get(vectorType.getShape().drop_front(),
315                                          vectorType.getElementType()));
316 }
317 
318 /// Given a transfer op, find the memref from which the mask is loaded. This
319 /// is similar to Strategy<TransferWriteOp>::getBuffer.
320 template <typename OpTy>
321 static Value getMaskBuffer(OpTy xferOp) {
322   assert(xferOp.mask() && "Expected that transfer op has mask");
323   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
324   assert(loadOp && "Expected transfer op mask produced by LoadOp");
325   return loadOp.getMemRef();
326 }
327 
328 /// Codegen strategy, depending on the operation.
329 template <typename OpTy>
330 struct Strategy;
331 
332 /// Code strategy for vector TransferReadOp.
333 template <>
334 struct Strategy<TransferReadOp> {
335   /// Find the StoreOp that is used for writing the current TransferReadOp's
336   /// result to the temporary buffer allocation.
337   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
338     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
339     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
340     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
341     return storeOp;
342   }
343 
344   /// Find the temporary buffer allocation. All labeled TransferReadOps are
345   /// used like this, where %buf is either the buffer allocation or a type cast
346   /// of the buffer allocation:
347   /// ```
348   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
349   /// memref.store %vec, %buf[...] ...
350   /// ```
351   static Value getBuffer(TransferReadOp xferOp) {
352     return getStoreOp(xferOp).getMemRef();
353   }
354 
355   /// Retrieve the indices of the current StoreOp that stores into the buffer.
356   static void getBufferIndices(TransferReadOp xferOp,
357                                SmallVector<Value, 8> &indices) {
358     auto storeOp = getStoreOp(xferOp);
359     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
360     indices.append(prevIndices.begin(), prevIndices.end());
361   }
362 
363   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
364   /// accesses on the to-be-unpacked dimension.
365   ///
366   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
367   ///    variable `iv`.
368   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
369   ///
370   /// E.g.:
371   /// ```
372   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
373   ///     : memref<?x?x?xf32>, vector<4x3xf32>
374   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
375   /// ```
376   /// Is rewritten to:
377   /// ```
378   /// %casted = vector.type_cast %buf
379   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
380   /// for %j = 0 to 4 {
381   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
382   ///       : memref<?x?x?xf32>, vector<3xf32>
383   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
384   /// }
385   /// ```
386   ///
387   /// Note: The loop and type cast are generated in TransferOpConversion.
388   ///       The original TransferReadOp and store op are deleted in `cleanup`.
389   /// Note: The `mask` operand is set in TransferOpConversion.
390   static TransferReadOp rewriteOp(OpBuilder &b,
391                                   VectorTransferToSCFOptions options,
392                                   TransferReadOp xferOp, Value buffer, Value iv,
393                                   ValueRange /*loopState*/) {
394     SmallVector<Value, 8> storeIndices;
395     getBufferIndices(xferOp, storeIndices);
396     storeIndices.push_back(iv);
397 
398     SmallVector<Value, 8> xferIndices;
399     getXferIndices(b, xferOp, iv, xferIndices);
400 
401     Location loc = xferOp.getLoc();
402     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
403     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
404     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
405     auto newXferOp = b.create<vector::TransferReadOp>(
406         loc, vecType, xferOp.source(), xferIndices,
407         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
408         Value(), inBoundsAttr);
409 
410     maybeApplyPassLabel(b, newXferOp, options.targetRank);
411 
412     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
413     return newXferOp;
414   }
415 
416   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
417   /// padding value to the temporary buffer.
418   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
419                                     Value buffer, Value iv,
420                                     ValueRange /*loopState*/) {
421     SmallVector<Value, 8> storeIndices;
422     getBufferIndices(xferOp, storeIndices);
423     storeIndices.push_back(iv);
424 
425     Location loc = xferOp.getLoc();
426     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
427     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
428     auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.padding());
429     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
430 
431     return Value();
432   }
433 
434   /// Cleanup after rewriting the op.
435   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
436                       scf::ForOp /*forOp*/) {
437     rewriter.eraseOp(getStoreOp(xferOp));
438     rewriter.eraseOp(xferOp);
439   }
440 
441   /// Return the initial loop state for the generated scf.for loop.
442   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
443 };
444 
445 /// Codegen strategy for vector TransferWriteOp.
446 template <>
447 struct Strategy<TransferWriteOp> {
448   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
449   /// used like this, where %buf is either the buffer allocation or a type cast
450   /// of the buffer allocation:
451   /// ```
452   /// %vec = memref.load %buf[...] ...
453   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
454   /// ```
455   static Value getBuffer(TransferWriteOp xferOp) {
456     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
457     assert(loadOp && "Expected transfer op vector produced by LoadOp");
458     return loadOp.getMemRef();
459   }
460 
461   /// Retrieve the indices of the current LoadOp that loads from the buffer.
462   static void getBufferIndices(TransferWriteOp xferOp,
463                                SmallVector<Value, 8> &indices) {
464     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
465     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
466     indices.append(prevIndices.begin(), prevIndices.end());
467   }
468 
469   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
470   /// accesses on the to-be-unpacked dimension.
471   ///
472   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
473   ///    using the loop iteration variable `iv`.
474   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
475   ///    to memory.
476   ///
477   /// Note: For more details, see comments on Strategy<TransferReadOp>.
478   static TransferWriteOp rewriteOp(OpBuilder &b,
479                                    VectorTransferToSCFOptions options,
480                                    TransferWriteOp xferOp, Value buffer,
481                                    Value iv, ValueRange loopState) {
482     SmallVector<Value, 8> loadIndices;
483     getBufferIndices(xferOp, loadIndices);
484     loadIndices.push_back(iv);
485 
486     SmallVector<Value, 8> xferIndices;
487     getXferIndices(b, xferOp, iv, xferIndices);
488 
489     Location loc = xferOp.getLoc();
490     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
491     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
492     auto source = loopState.empty() ? xferOp.source() : loopState[0];
493     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
494     auto newXferOp = b.create<vector::TransferWriteOp>(
495         loc, type, vec, source, xferIndices,
496         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
497         inBoundsAttr);
498 
499     maybeApplyPassLabel(b, newXferOp, options.targetRank);
500 
501     return newXferOp;
502   }
503 
504   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
505   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
506                                     Value buffer, Value iv,
507                                     ValueRange loopState) {
508     return isTensorOp(xferOp) ? loopState[0] : Value();
509   }
510 
511   /// Cleanup after rewriting the op.
512   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
513                       scf::ForOp forOp) {
514     if (isTensorOp(xferOp)) {
515       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
516       rewriter.replaceOp(xferOp, forOp->getResult(0));
517     } else {
518       rewriter.eraseOp(xferOp);
519     }
520   }
521 
522   /// Return the initial loop state for the generated scf.for loop.
523   static Value initialLoopState(TransferWriteOp xferOp) {
524     return isTensorOp(xferOp) ? xferOp.source() : Value();
525   }
526 };
527 
528 template <typename OpTy>
529 LogicalResult checkPrepareXferOp(OpTy xferOp,
530                                  VectorTransferToSCFOptions options) {
531   if (xferOp->hasAttr(kPassLabel))
532     return failure();
533   if (xferOp.getVectorType().getRank() <= options.targetRank)
534     return failure();
535   if (isTensorOp(xferOp) && !options.lowerTensors)
536     return failure();
537   // Transfer ops that modify the element type are not supported atm.
538   if (xferOp.getVectorType().getElementType() !=
539       xferOp.getShapedType().getElementType())
540     return failure();
541   return success();
542 }
543 
544 /// Prepare a TransferReadOp for progressive lowering.
545 ///
546 /// 1. Allocate a temporary buffer.
547 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
548 /// 3. Store the result of the TransferReadOp into the temporary buffer.
549 /// 4. Load the result from the temporary buffer and replace all uses of the
550 ///    original TransferReadOp with this load.
551 ///
552 /// E.g.:
553 /// ```
554 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
555 ///     : vector<5x4xf32>, memref<?x?x?xf32>
556 /// ```
557 /// is rewritten to:
558 /// ```
559 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
560 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
561 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
562 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
563 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
564 /// ```
565 ///
566 /// Note: A second temporary buffer may be allocated for the `mask` operand.
567 struct PrepareTransferReadConversion
568     : public VectorToSCFPattern<TransferReadOp> {
569   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
570 
571   LogicalResult matchAndRewrite(TransferReadOp xferOp,
572                                 PatternRewriter &rewriter) const override {
573     if (checkPrepareXferOp(xferOp, options).failed())
574       return failure();
575 
576     auto buffers = allocBuffers(rewriter, xferOp);
577     auto *newXfer = rewriter.clone(*xferOp.getOperation());
578     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
579     if (xferOp.mask()) {
580       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
581           buffers.maskBuffer);
582     }
583 
584     Location loc = xferOp.getLoc();
585     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
586                                      buffers.dataBuffer);
587     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
588 
589     return success();
590   }
591 };
592 
593 /// Prepare a TransferWriteOp for progressive lowering.
594 ///
595 /// 1. Allocate a temporary buffer.
596 /// 2. Store the vector into the buffer.
597 /// 3. Load the vector from the buffer again.
598 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
599 ///    marking it eligible for progressive lowering via TransferOpConversion.
600 ///
601 /// E.g.:
602 /// ```
603 /// vector.transfer_write %vec, %A[%a, %b, %c]
604 ///     : vector<5x4xf32>, memref<?x?x?xf32>
605 /// ```
606 /// is rewritten to:
607 /// ```
608 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
609 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
610 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
611 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
612 ///     : vector<5x4xf32>, memref<?x?x?xf32>
613 /// ```
614 ///
615 /// Note: A second temporary buffer may be allocated for the `mask` operand.
616 struct PrepareTransferWriteConversion
617     : public VectorToSCFPattern<TransferWriteOp> {
618   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
619 
620   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
621                                 PatternRewriter &rewriter) const override {
622     if (checkPrepareXferOp(xferOp, options).failed())
623       return failure();
624 
625     Location loc = xferOp.getLoc();
626     auto buffers = allocBuffers(rewriter, xferOp);
627     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
628     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
629     rewriter.updateRootInPlace(xferOp, [&]() {
630       xferOp.vectorMutable().assign(loadedVec);
631       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
632     });
633 
634     if (xferOp.mask()) {
635       rewriter.updateRootInPlace(
636           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
637     }
638 
639     return success();
640   }
641 };
642 
643 /// Progressive lowering of vector transfer ops: Unpack one dimension.
644 ///
645 /// 1. Unpack one dimension from the current buffer type and cast the buffer
646 ///    to that new type. E.g.:
647 ///    ```
648 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
649 ///    vector.transfer_write %vec ...
650 ///    ```
651 ///    The following cast is generated:
652 ///    ```
653 ///    %casted = vector.type_cast %0
654 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
655 ///    ```
656 /// 2. Generate a for loop and rewrite the transfer op according to the
657 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
658 ///    out-of-bounds, generate an if-check and handle both cases separately.
659 /// 3. Clean up according to the corresponding Strategy<OpTy>.
660 ///
661 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
662 /// source (as opposed to a memref source), then each iteration of the generated
663 /// scf.for loop yields the new tensor value. E.g.:
664 /// ```
665 /// %result = scf.for i = 0 to 5 {
666 ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
667 ///   %1 = vector.transfer_write %0, %source[...]
668 ///       : vector<4x3xf32>, tensor<5x4x3xf32>
669 ///   scf.yield %1 : tensor<5x4x3xf32>
670 /// }
671 /// ```
672 template <typename OpTy>
673 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
674   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
675 
676   void initialize() {
677     // This pattern recursively unpacks one dimension at a time. The recursion
678     // bounded as the rank is strictly decreasing.
679     this->setHasBoundedRewriteRecursion();
680   }
681 
682   LogicalResult matchAndRewrite(OpTy xferOp,
683                                 PatternRewriter &rewriter) const override {
684     if (!xferOp->hasAttr(kPassLabel))
685       return failure();
686 
687     // Find and cast data buffer. How the buffer can be found depends on OpTy.
688     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
689     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
690     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
691     auto castedDataType = unpackOneDim(dataBufferType);
692     auto castedDataBuffer =
693         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
694 
695     // If the xferOp has a mask: Find and cast mask buffer.
696     Value castedMaskBuffer;
697     if (xferOp.mask()) {
698       auto maskBuffer = getMaskBuffer(xferOp);
699       auto maskBufferType =
700           maskBuffer.getType().template dyn_cast<MemRefType>();
701       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
702         // Do not unpack a dimension of the mask, if:
703         // * To-be-unpacked transfer op dimension is a broadcast.
704         // * Mask is 1D, i.e., the mask cannot be further unpacked.
705         //   (That means that all remaining dimensions of the transfer op must
706         //   be broadcasted.)
707         castedMaskBuffer = maskBuffer;
708       } else {
709         auto castedMaskType = unpackOneDim(maskBufferType);
710         castedMaskBuffer =
711             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
712       }
713     }
714 
715     // Loop bounds and step.
716     auto lb = locB.create<arith::ConstantIndexOp>(0);
717     auto ub = locB.create<arith::ConstantIndexOp>(
718         castedDataType.getDimSize(castedDataType.getRank() - 1));
719     auto step = locB.create<arith::ConstantIndexOp>(1);
720     // TransferWriteOps that operate on tensors return the modified tensor and
721     // require a loop state.
722     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
723 
724     // Generate for loop.
725     auto result = locB.create<scf::ForOp>(
726         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
727         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
728           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
729 
730           auto result = generateInBoundsCheck(
731               b, xferOp, iv, unpackedDim(xferOp),
732               stateType ? TypeRange(stateType) : TypeRange(),
733               /*inBoundsCase=*/
734               [&](OpBuilder &b, Location loc) {
735                 // Create new transfer op.
736                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
737                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
738 
739                 // If old transfer op has a mask: Set mask on new transfer op.
740                 // Special case: If the mask of the old transfer op is 1D and
741                 // the
742                 //               unpacked dim is not a broadcast, no mask is
743                 //               needed on the new transfer op.
744                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
745                                       xferOp.getMaskType().getRank() > 1)) {
746                   OpBuilder::InsertionGuard guard(b);
747                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
748 
749                   SmallVector<Value, 8> loadIndices;
750                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
751                   // In case of broadcast: Use same indices to load from memref
752                   // as before.
753                   if (!xferOp.isBroadcastDim(0))
754                     loadIndices.push_back(iv);
755 
756                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
757                                                        loadIndices);
758                   rewriter.updateRootInPlace(
759                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
760                 }
761 
762                 return loopState.empty() ? Value() : newXfer->getResult(0);
763               },
764               /*outOfBoundsCase=*/
765               [&](OpBuilder &b, Location /*loc*/) {
766                 return Strategy<OpTy>::handleOutOfBoundsDim(
767                     b, xferOp, castedDataBuffer, iv, loopState);
768               });
769 
770           maybeYieldValue(b, loc, !loopState.empty(), result);
771         });
772 
773     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
774     return success();
775   }
776 };
777 
778 } // namespace lowering_n_d
779 
780 namespace lowering_n_d_unrolled {
781 
782 /// If the original transfer op has a mask, compute the mask of the new transfer
783 /// op (for the current iteration `i`) and assign it.
784 template <typename OpTy>
785 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
786                             int64_t i) {
787   if (!xferOp.mask())
788     return;
789 
790   if (xferOp.isBroadcastDim(0)) {
791     // To-be-unpacked dimension is a broadcast, which does not have a
792     // corresponding mask dimension. Mask attribute remains unchanged.
793     newXferOp.maskMutable().assign(xferOp.mask());
794     return;
795   }
796 
797   if (xferOp.getMaskType().getRank() > 1) {
798     // Unpack one dimension of the mask.
799     OpBuilder::InsertionGuard guard(b);
800     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
801 
802     llvm::SmallVector<int64_t, 1> indices({i});
803     Location loc = xferOp.getLoc();
804     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
805     newXferOp.maskMutable().assign(newMask);
806   }
807 
808   // If we end up here: The mask of the old transfer op is 1D and the unpacked
809   // dim is not a broadcast, so no mask is needed on the new transfer op.
810   // `generateInBoundsCheck` will have evaluated the mask already.
811 }
812 
813 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
814 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
815 /// memref buffer is allocated and the SCF loop is fully unrolled.
816 ///
817 /// ```
818 /// E.g.:
819 /// ```
820 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
821 ///     : memref<?x?x?xf32>, vector<5x4xf32>
822 /// ```
823 /// is rewritten to IR such as (simplified):
824 /// ```
825 /// %v_init = splat %padding : vector<5x4xf32>
826 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
827 ///     : memref<?x?x?xf32>, vector<4xf32>
828 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
829 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
830 ///     : memref<?x?x?xf32>, vector<4xf32>
831 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
832 /// ...
833 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
834 ///     : memref<?x?x?xf32>, vector<4xf32>
835 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
836 /// ```
837 ///
838 /// Note: As an optimization, if the result of the original TransferReadOp
839 /// was directly inserted into another vector, no new %v_init vector is created.
840 /// Instead, the new TransferReadOp results are inserted into that vector.
841 struct UnrollTransferReadConversion
842     : public VectorToSCFPattern<TransferReadOp> {
843   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
844 
845   void initialize() {
846     // This pattern recursively unpacks one dimension at a time. The recursion
847     // bounded as the rank is strictly decreasing.
848     setHasBoundedRewriteRecursion();
849   }
850 
851   /// Return the vector into which the newly created TransferReadOp results
852   /// are inserted.
853   Value getResultVector(TransferReadOp xferOp,
854                         PatternRewriter &rewriter) const {
855     if (auto insertOp = getInsertOp(xferOp))
856       return insertOp.dest();
857     Location loc = xferOp.getLoc();
858     return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
859                                             xferOp.padding());
860   }
861 
862   /// If the result of the TransferReadOp has exactly one user, which is a
863   /// vector::InsertOp, return that operation.
864   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
865     if (xferOp->hasOneUse()) {
866       Operation *xferOpUser = *xferOp->getUsers().begin();
867       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
868         return insertOp;
869     }
870 
871     return vector::InsertOp();
872   }
873 
874   /// If the result of the TransferReadOp has exactly one user, which is a
875   /// vector::InsertOp, return that operation's indices.
876   void getInsertionIndices(TransferReadOp xferOp,
877                            SmallVector<int64_t, 8> &indices) const {
878     if (auto insertOp = getInsertOp(xferOp)) {
879       llvm::for_each(insertOp.position(), [&](Attribute attr) {
880         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
881       });
882     }
883   }
884 
885   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
886   /// accesses, and broadcasts and transposes in permutation maps.
887   LogicalResult matchAndRewrite(TransferReadOp xferOp,
888                                 PatternRewriter &rewriter) const override {
889     if (xferOp.getVectorType().getRank() <= options.targetRank)
890       return failure();
891     if (isTensorOp(xferOp) && !options.lowerTensors)
892       return failure();
893     // Transfer ops that modify the element type are not supported atm.
894     if (xferOp.getVectorType().getElementType() !=
895         xferOp.getShapedType().getElementType())
896       return failure();
897 
898     auto insertOp = getInsertOp(xferOp);
899     auto vec = getResultVector(xferOp, rewriter);
900     auto vecType = vec.getType().dyn_cast<VectorType>();
901     auto xferVecType = xferOp.getVectorType();
902     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
903                                           xferVecType.getElementType());
904     int64_t dimSize = xferVecType.getShape()[0];
905 
906     // Generate fully unrolled loop of transfer ops.
907     Location loc = xferOp.getLoc();
908     for (int64_t i = 0; i < dimSize; ++i) {
909       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
910 
911       vec = generateInBoundsCheck(
912           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
913           /*inBoundsCase=*/
914           [&](OpBuilder &b, Location loc) {
915             // Indices for the new transfer op.
916             SmallVector<Value, 8> xferIndices;
917             getXferIndices(b, xferOp, iv, xferIndices);
918 
919             // Indices for the new vector.insert op.
920             SmallVector<int64_t, 8> insertionIndices;
921             getInsertionIndices(xferOp, insertionIndices);
922             insertionIndices.push_back(i);
923 
924             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
925             auto newXferOp = b.create<vector::TransferReadOp>(
926                 loc, newXferVecType, xferOp.source(), xferIndices,
927                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
928                 xferOp.padding(), Value(), inBoundsAttr);
929             maybeAssignMask(b, xferOp, newXferOp, i);
930             return b.create<vector::InsertOp>(loc, newXferOp, vec,
931                                               insertionIndices);
932           },
933           /*outOfBoundsCase=*/
934           [&](OpBuilder &b, Location loc) {
935             // Loop through original (unmodified) vector.
936             return vec;
937           });
938     }
939 
940     if (insertOp) {
941       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
942       rewriter.replaceOp(insertOp, vec);
943       rewriter.eraseOp(xferOp);
944     } else {
945       rewriter.replaceOp(xferOp, vec);
946     }
947 
948     return success();
949   }
950 };
951 
952 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
953 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
954 /// memref buffer is allocated and the SCF loop is fully unrolled.
955 ///
956 /// ```
957 /// E.g.:
958 /// ```
959 /// vector.transfer_write %vec, %A[%a, %b, %c]
960 ///     : vector<5x4xf32>, memref<?x?x?xf32>
961 /// ```
962 /// is rewritten to IR such as (simplified):
963 /// ```
964 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
965 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
966 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
967 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
968 /// ...
969 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
970 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
971 /// ```
972 ///
973 /// Note: As an optimization, if the vector of the original TransferWriteOp
974 /// was directly extracted from another vector via an ExtractOp `a`, extract
975 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
976 /// doing so, `a` may become dead, and the number of ExtractOps generated during
977 /// recursive application of this pattern will be minimal.
978 struct UnrollTransferWriteConversion
979     : public VectorToSCFPattern<TransferWriteOp> {
980   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
981 
982   void initialize() {
983     // This pattern recursively unpacks one dimension at a time. The recursion
984     // bounded as the rank is strictly decreasing.
985     setHasBoundedRewriteRecursion();
986   }
987 
988   /// Return the vector from which newly generated ExtracOps will extract.
989   Value getDataVector(TransferWriteOp xferOp) const {
990     if (auto extractOp = getExtractOp(xferOp))
991       return extractOp.vector();
992     return xferOp.vector();
993   }
994 
995   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
996   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
997     if (auto *op = xferOp.vector().getDefiningOp())
998       return dyn_cast<vector::ExtractOp>(op);
999     return vector::ExtractOp();
1000   }
1001 
1002   /// If the input of the given TransferWriteOp is an ExtractOp, return its
1003   /// indices.
1004   void getExtractionIndices(TransferWriteOp xferOp,
1005                             SmallVector<int64_t, 8> &indices) const {
1006     if (auto extractOp = getExtractOp(xferOp)) {
1007       llvm::for_each(extractOp.position(), [&](Attribute attr) {
1008         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
1009       });
1010     }
1011   }
1012 
1013   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1014   /// accesses, and broadcasts and transposes in permutation maps.
1015   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1016                                 PatternRewriter &rewriter) const override {
1017     if (xferOp.getVectorType().getRank() <= options.targetRank)
1018       return failure();
1019     if (isTensorOp(xferOp) && !options.lowerTensors)
1020       return failure();
1021     // Transfer ops that modify the element type are not supported atm.
1022     if (xferOp.getVectorType().getElementType() !=
1023         xferOp.getShapedType().getElementType())
1024       return failure();
1025 
1026     auto vec = getDataVector(xferOp);
1027     auto xferVecType = xferOp.getVectorType();
1028     int64_t dimSize = xferVecType.getShape()[0];
1029     auto source = xferOp.source(); // memref or tensor to be written to.
1030     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1031 
1032     // Generate fully unrolled loop of transfer ops.
1033     Location loc = xferOp.getLoc();
1034     for (int64_t i = 0; i < dimSize; ++i) {
1035       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1036 
1037       auto updatedSource = generateInBoundsCheck(
1038           rewriter, xferOp, iv, unpackedDim(xferOp),
1039           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1040           /*inBoundsCase=*/
1041           [&](OpBuilder &b, Location loc) {
1042             // Indices for the new transfer op.
1043             SmallVector<Value, 8> xferIndices;
1044             getXferIndices(b, xferOp, iv, xferIndices);
1045 
1046             // Indices for the new vector.extract op.
1047             SmallVector<int64_t, 8> extractionIndices;
1048             getExtractionIndices(xferOp, extractionIndices);
1049             extractionIndices.push_back(i);
1050 
1051             auto extracted =
1052                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1053             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
1054             auto newXferOp = b.create<vector::TransferWriteOp>(
1055                 loc, sourceType, extracted, source, xferIndices,
1056                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1057                 inBoundsAttr);
1058 
1059             maybeAssignMask(b, xferOp, newXferOp, i);
1060 
1061             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1062           },
1063           /*outOfBoundsCase=*/
1064           [&](OpBuilder &b, Location loc) {
1065             return isTensorOp(xferOp) ? source : Value();
1066           });
1067 
1068       if (isTensorOp(xferOp))
1069         source = updatedSource;
1070     }
1071 
1072     if (isTensorOp(xferOp))
1073       rewriter.replaceOp(xferOp, source);
1074     else
1075       rewriter.eraseOp(xferOp);
1076 
1077     return success();
1078   }
1079 };
1080 
1081 } // namespace lowering_n_d_unrolled
1082 
1083 namespace lowering_1_d {
1084 
1085 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1086 /// part of TransferOp1dConversion. Return the memref dimension on which
1087 /// the transfer is operating. A return value of None indicates a broadcast.
1088 template <typename OpTy>
1089 static Optional<int64_t>
1090 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1091                    SmallVector<Value, 8> &memrefIndices) {
1092   auto indices = xferOp.indices();
1093   auto map = xferOp.permutation_map();
1094   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1095 
1096   memrefIndices.append(indices.begin(), indices.end());
1097   assert(map.getNumResults() == 1 &&
1098          "Expected 1 permutation map result for 1D transfer");
1099   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1100     Location loc = xferOp.getLoc();
1101     auto dim = expr.getPosition();
1102     AffineExpr d0, d1;
1103     bindDims(xferOp.getContext(), d0, d1);
1104     Value offset = memrefIndices[dim];
1105     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1106     return dim;
1107   }
1108 
1109   assert(xferOp.isBroadcastDim(0) &&
1110          "Expected AffineDimExpr or AffineConstantExpr");
1111   return None;
1112 }
1113 
1114 /// Codegen strategy for TransferOp1dConversion, depending on the
1115 /// operation.
1116 template <typename OpTy>
1117 struct Strategy1d;
1118 
1119 /// Codegen strategy for TransferReadOp.
1120 template <>
1121 struct Strategy1d<TransferReadOp> {
1122   static void generateForLoopBody(OpBuilder &b, Location loc,
1123                                   TransferReadOp xferOp, Value iv,
1124                                   ValueRange loopState) {
1125     SmallVector<Value, 8> indices;
1126     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1127     auto vec = loopState[0];
1128 
1129     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1130     // padding value).
1131     auto nextVec = generateInBoundsCheck(
1132         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1133         /*inBoundsCase=*/
1134         [&](OpBuilder &b, Location loc) {
1135           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
1136           return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1137         },
1138         /*outOfBoundsCase=*/
1139         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1140     b.create<scf::YieldOp>(loc, nextVec);
1141   }
1142 
1143   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1144     // Inititalize vector with padding value.
1145     Location loc = xferOp.getLoc();
1146     return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1147                                      xferOp.padding());
1148   }
1149 };
1150 
1151 /// Codegen strategy for TransferWriteOp.
1152 template <>
1153 struct Strategy1d<TransferWriteOp> {
1154   static void generateForLoopBody(OpBuilder &b, Location loc,
1155                                   TransferWriteOp xferOp, Value iv,
1156                                   ValueRange /*loopState*/) {
1157     SmallVector<Value, 8> indices;
1158     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1159 
1160     // Nothing to do in case of out-of-bounds access.
1161     generateInBoundsCheck(
1162         b, xferOp, iv, dim,
1163         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1164           auto val =
1165               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), iv);
1166           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
1167         });
1168     b.create<scf::YieldOp>(loc);
1169   }
1170 
1171   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1172     return Value();
1173   }
1174 };
1175 
1176 /// Return true if the last dimension of the MemRefType has unit stride.
1177 static bool isLastMemrefDimUnitStride(MemRefType type) {
1178   int64_t offset;
1179   SmallVector<int64_t, 4> strides;
1180   auto successStrides = getStridesAndOffset(type, strides, offset);
1181   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1182 }
1183 
1184 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1185 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1186 /// vector load/stores due to non-unit strides or broadcasts:
1187 ///
1188 /// * Transfer dimension is not the last memref dimension
1189 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1190 /// * Memref has a layout map with non-unit stride on the last dimension
1191 ///
1192 /// This pattern generates IR as follows:
1193 ///
1194 /// 1. Generate a for loop iterating over each vector element.
1195 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1196 ///    depending on OpTy.
1197 ///
1198 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1199 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1200 ///       to ConvertVectorToLLVM.
1201 ///
1202 /// E.g.:
1203 /// ```
1204 /// vector.transfer_write %vec, %A[%a, %b]
1205 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1206 ///    : vector<9xf32>, memref<?x?xf32>
1207 /// ```
1208 /// Is rewritten to approximately the following pseudo-IR:
1209 /// ```
1210 /// for i = 0 to 9 {
1211 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1212 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1213 /// }
1214 /// ```
1215 template <typename OpTy>
1216 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1217   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1218 
1219   LogicalResult matchAndRewrite(OpTy xferOp,
1220                                 PatternRewriter &rewriter) const override {
1221     // TODO: support 0-d corner case.
1222     if (xferOp.getTransferRank() == 0)
1223       return failure();
1224     auto map = xferOp.permutation_map();
1225     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1226 
1227     if (!memRefType)
1228       return failure();
1229     if (xferOp.getVectorType().getRank() != 1)
1230       return failure();
1231     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1232       return failure(); // Handled by ConvertVectorToLLVM
1233 
1234     // Loop bounds, step, state...
1235     Location loc = xferOp.getLoc();
1236     auto vecType = xferOp.getVectorType();
1237     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1238     auto ub =
1239         rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1240     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1241     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1242 
1243     // Generate for loop.
1244     rewriter.replaceOpWithNewOp<scf::ForOp>(
1245         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1246         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1247           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1248         });
1249 
1250     return success();
1251   }
1252 };
1253 
1254 } // namespace lowering_1_d
1255 } // namespace
1256 
1257 namespace mlir {
1258 
1259 void populateVectorToSCFConversionPatterns(
1260     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1261   if (options.unroll) {
1262     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1263                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1264         patterns.getContext(), options);
1265   } else {
1266     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1267                  lowering_n_d::PrepareTransferWriteConversion,
1268                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1269                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1270         patterns.getContext(), options);
1271   }
1272 
1273   if (options.targetRank == 1) {
1274     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1275                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1276         patterns.getContext(), options);
1277   }
1278 }
1279 
1280 } // namespace mlir
1281 
1282 namespace {
1283 
1284 struct ConvertVectorToSCFPass
1285     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1286   ConvertVectorToSCFPass() = default;
1287   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1288     this->fullUnroll = options.unroll;
1289     this->targetRank = options.targetRank;
1290     this->lowerPermutationMaps = options.lowerPermutationMaps;
1291     this->lowerTensors = options.lowerTensors;
1292   }
1293 
1294   void runOnOperation() override {
1295     VectorTransferToSCFOptions options;
1296     options.unroll = fullUnroll;
1297     options.targetRank = targetRank;
1298     options.lowerPermutationMaps = lowerPermutationMaps;
1299     options.lowerTensors = lowerTensors;
1300 
1301     // Lower permutation maps first.
1302     if (lowerPermutationMaps) {
1303       RewritePatternSet lowerTransferPatterns(getOperation().getContext());
1304       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1305           lowerTransferPatterns);
1306       (void)applyPatternsAndFoldGreedily(getOperation(),
1307                                          std::move(lowerTransferPatterns));
1308     }
1309 
1310     RewritePatternSet patterns(getOperation().getContext());
1311     populateVectorToSCFConversionPatterns(patterns, options);
1312     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1313   }
1314 };
1315 
1316 } // namespace
1317 
1318 std::unique_ptr<Pass>
1319 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1320   return std::make_unique<ConvertVectorToSCFPass>(options);
1321 }
1322